Skip to main content

kobe_svm/
deriver.rs

1//! Solana address derivation from HD wallet.
2
3use alloc::string::{String, ToString};
4use alloc::vec::Vec;
5
6use ed25519_dalek::VerifyingKey;
7use kobe_core::slip10::DerivedKey;
8use kobe_core::{Derive, DerivedAccount, Wallet};
9use zeroize::Zeroizing;
10
11use crate::Error;
12use crate::derivation_style::DerivationStyle;
13
14/// A derived Solana address with associated keys.
15#[derive(Debug, Clone)]
16#[non_exhaustive]
17pub struct DerivedAddress {
18    /// Derivation path used (e.g., `m/44'/501'/0'/0'`).
19    pub path: String,
20    /// Private key in hex format (zeroized on drop).
21    pub private_key_hex: Zeroizing<String>,
22    /// Full keypair in base58 format (64 bytes: secret 32B + public 32B, zeroized on drop).
23    ///
24    /// This is the standard format used by Phantom, Backpack, Solflare wallets.
25    pub keypair_base58: Zeroizing<String>,
26    /// Public key in hex format.
27    pub public_key_hex: String,
28    /// Solana address (Base58 encoded public key).
29    pub address: String,
30}
31
32/// Solana address deriver from a unified wallet seed.
33///
34/// This deriver takes a seed from [`kobe::Wallet`] and derives
35/// Solana addresses following BIP44/SLIP-0010 standards.
36#[derive(Debug)]
37pub struct Deriver<'a> {
38    /// Reference to the wallet for seed access.
39    wallet: &'a Wallet,
40}
41
42impl<'a> Deriver<'a> {
43    /// Create a new Solana deriver from a wallet.
44    #[inline]
45    #[must_use]
46    pub const fn new(wallet: &'a Wallet) -> Self {
47        Self { wallet }
48    }
49
50    /// Derive a Solana address using the Standard derivation style.
51    ///
52    /// Uses path: `m/44'/501'/index'/0'` (Phantom, Backpack, etc.)
53    ///
54    /// # Arguments
55    ///
56    /// * `index` - The address index
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if derivation fails.
61    #[inline]
62    pub fn derive(&self, index: u32) -> Result<DerivedAddress, Error> {
63        self.derive_with(DerivationStyle::Standard, index)
64    }
65
66    /// Derive a Solana address with a specific derivation style.
67    ///
68    /// This method supports different wallet path formats:
69    /// - **Standard** (Phantom/Backpack): `m/44'/501'/index'/0'`
70    /// - **Trust**: `m/44'/501'/index'`
71    /// - **Ledger Live**: `m/44'/501'/index'/0'/0'`
72    /// - **Legacy**: `m/501'/{index}'/0'/0'`
73    ///
74    /// # Arguments
75    ///
76    /// * `style` - The derivation style to use
77    /// * `index` - The address/account index
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if derivation fails.
82    #[allow(deprecated)]
83    pub fn derive_with(&self, style: DerivationStyle, index: u32) -> Result<DerivedAddress, Error> {
84        let path = style.path(index);
85        let derived = DerivedKey::derive_path(self.wallet.seed(), &path)?;
86        Ok(build_derived_address(&derived, path))
87    }
88
89    /// Derive multiple addresses using the Standard derivation style.
90    ///
91    /// # Arguments
92    ///
93    /// * `start` - Starting address index
94    /// * `count` - Number of addresses to derive
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if any derivation fails.
99    #[inline]
100    pub fn derive_many(&self, start: u32, count: u32) -> Result<Vec<DerivedAddress>, Error> {
101        self.derive_many_with(DerivationStyle::Standard, start, count)
102    }
103
104    /// Derive multiple addresses with a specific derivation style.
105    ///
106    /// # Arguments
107    ///
108    /// * `style` - The derivation style to use
109    /// * `start` - Starting index
110    /// * `count` - Number of addresses to derive
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if any derivation fails.
115    pub fn derive_many_with(
116        &self,
117        style: DerivationStyle,
118        start: u32,
119        count: u32,
120    ) -> Result<Vec<DerivedAddress>, Error> {
121        let end = start
122            .checked_add(count)
123            .ok_or(kobe_core::Error::IndexOverflow)?;
124        (start..end)
125            .map(|index| self.derive_with(style, index))
126            .collect()
127    }
128
129    /// Derive an address at a custom derivation path.
130    ///
131    /// This is the lowest-level derivation method, allowing full control
132    /// over the derivation path.
133    ///
134    /// **Note**: Ed25519 (Solana) only supports hardened derivation.
135    /// All path components will be treated as hardened.
136    ///
137    /// # Arguments
138    ///
139    /// * `path` - SLIP-0010 derivation path (e.g., `m/44'/501'/0'/0'`)
140    ///
141    /// # Errors
142    ///
143    /// Returns an error if derivation fails.
144    pub fn derive_path(&self, path: &str) -> Result<DerivedAddress, Error> {
145        let derived = DerivedKey::derive_path(self.wallet.seed(), path)?;
146        Ok(build_derived_address(&derived, path.to_string()))
147    }
148}
149
150impl Derive for Deriver<'_> {
151    type Error = Error;
152
153    fn derive(&self, index: u32) -> Result<DerivedAccount, Error> {
154        let da = self.derive_with(DerivationStyle::Standard, index)?;
155        Ok(DerivedAccount::new(
156            da.path,
157            da.private_key_hex,
158            da.public_key_hex,
159            da.address,
160        ))
161    }
162
163    fn derive_path(&self, path: &str) -> Result<DerivedAccount, Error> {
164        let da = Deriver::derive_path(self, path)?;
165        Ok(DerivedAccount::new(
166            da.path,
167            da.private_key_hex,
168            da.public_key_hex,
169            da.address,
170        ))
171    }
172}
173
174/// Build a [`DerivedAddress`] from a raw [`DerivedKey`] and path string.
175fn build_derived_address(derived: &DerivedKey, path: String) -> DerivedAddress {
176    let signing_key = derived.to_signing_key();
177    let verifying_key: VerifyingKey = signing_key.verifying_key();
178    let public_key_bytes = verifying_key.as_bytes();
179
180    let mut keypair_bytes = Zeroizing::new([0u8; 64]);
181    keypair_bytes[..32].copy_from_slice(derived.private_key.as_slice());
182    keypair_bytes[32..].copy_from_slice(public_key_bytes);
183    let keypair_b58 = bs58::encode(&*keypair_bytes).into_string();
184
185    DerivedAddress {
186        path,
187        private_key_hex: Zeroizing::new(hex::encode(derived.private_key.as_slice())),
188        keypair_base58: Zeroizing::new(keypair_b58),
189        public_key_hex: hex::encode(public_key_bytes),
190        address: bs58::encode(public_key_bytes).into_string(),
191    }
192}
193
194#[cfg(test)]
195#[allow(deprecated, clippy::unwrap_used)]
196mod tests {
197    use super::*;
198
199    fn test_wallet() -> Wallet {
200        Wallet::from_mnemonic(
201            "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about",
202            None,
203        )
204        .unwrap()
205    }
206
207    #[test]
208    fn test_derive_address() {
209        let wallet = test_wallet();
210        let deriver = Deriver::new(&wallet);
211        let addr = deriver.derive(0).unwrap();
212
213        // Solana addresses are 32-44 characters in Base58
214        assert!(addr.address.len() >= 32 && addr.address.len() <= 44);
215        assert_eq!(addr.path, "m/44'/501'/0'/0'");
216    }
217
218    #[test]
219    fn test_derive_many() {
220        let wallet = test_wallet();
221        let deriver = Deriver::new(&wallet);
222        let addresses = deriver.derive_many(0, 3).unwrap();
223
224        assert_eq!(addresses.len(), 3);
225        assert_eq!(addresses[0].path, "m/44'/501'/0'/0'");
226        assert_eq!(addresses[1].path, "m/44'/501'/1'/0'");
227        assert_eq!(addresses[2].path, "m/44'/501'/2'/0'");
228
229        // All addresses should be unique
230        assert_ne!(addresses[0].address, addresses[1].address);
231        assert_ne!(addresses[1].address, addresses[2].address);
232    }
233
234    #[test]
235    fn test_deterministic_derivation() {
236        let wallet = test_wallet();
237        let deriver = Deriver::new(&wallet);
238
239        let addr1 = deriver.derive(0).unwrap();
240        let addr2 = deriver.derive(0).unwrap();
241
242        assert_eq!(addr1.address, addr2.address);
243        assert_eq!(*addr1.private_key_hex, *addr2.private_key_hex);
244    }
245
246    #[test]
247    fn test_derive_with_trust() {
248        let wallet = test_wallet();
249        let deriver = Deriver::new(&wallet);
250        let addr = deriver.derive_with(DerivationStyle::Trust, 0).unwrap();
251
252        assert_eq!(addr.path, "m/44'/501'/0'");
253        assert!(addr.address.len() >= 32 && addr.address.len() <= 44);
254    }
255
256    #[test]
257    fn test_derive_with_ledger_live() {
258        let wallet = test_wallet();
259        let deriver = Deriver::new(&wallet);
260        let addr = deriver.derive_with(DerivationStyle::LedgerLive, 0).unwrap();
261
262        assert_eq!(addr.path, "m/44'/501'/0'/0'/0'");
263        assert!(addr.address.len() >= 32 && addr.address.len() <= 44);
264    }
265
266    #[test]
267    fn test_different_styles_produce_different_addresses() {
268        let wallet = test_wallet();
269        let deriver = Deriver::new(&wallet);
270
271        let standard = deriver.derive_with(DerivationStyle::Standard, 0).unwrap();
272        let trust = deriver.derive_with(DerivationStyle::Trust, 0).unwrap();
273        let ledger_live = deriver.derive_with(DerivationStyle::LedgerLive, 0).unwrap();
274        let legacy = deriver.derive_with(DerivationStyle::Legacy, 0).unwrap();
275
276        // All styles should produce different addresses
277        assert_ne!(standard.address, trust.address);
278        assert_ne!(standard.address, ledger_live.address);
279        assert_ne!(standard.address, legacy.address);
280        assert_ne!(trust.address, ledger_live.address);
281    }
282
283    #[test]
284    fn kat_solana_standard_index0() {
285        // Cross-verified with Python SLIP-10 + nacl.signing + base58
286        let wallet = test_wallet();
287        let addr = Deriver::new(&wallet).derive(0).unwrap();
288        assert_eq!(addr.address, "HAgk14JpMQLgt6rVgv7cBQFJWFto5Dqxi472uT3DKpqk");
289    }
290}