Skip to main content

kobe_evm/
deriver.rs

1//! Ethereum address derivation from a unified wallet.
2
3#[cfg(feature = "alloc")]
4use alloc::{
5    format,
6    string::{String, ToString},
7    vec::Vec,
8};
9
10use bip32::{DerivationPath, XPrv};
11use k256::ecdsa::SigningKey;
12use kobe::Wallet;
13use zeroize::Zeroizing;
14
15use crate::Error;
16use crate::address::{public_key_to_address, to_checksum_address};
17use crate::derivation_style::DerivationStyle;
18
19/// Ethereum address deriver from a unified wallet seed.
20///
21/// This deriver takes a seed from [`kobe::Wallet`] and derives
22/// Ethereum addresses following BIP32/44 standards.
23#[derive(Debug)]
24pub struct Deriver<'a> {
25    /// Reference to the wallet for seed access.
26    wallet: &'a Wallet,
27}
28
29/// A derived Ethereum address with associated keys.
30#[derive(Debug, Clone)]
31#[non_exhaustive]
32pub struct DerivedAddress {
33    /// Derivation path used (e.g., `m/44'/60'/0'/0/0`).
34    pub path: String,
35    /// Private key in hex format without 0x prefix (zeroized on drop).
36    pub private_key_hex: Zeroizing<String>,
37    /// Public key in uncompressed hex format.
38    pub public_key_hex: String,
39    /// Checksummed Ethereum address (EIP-55).
40    pub address: String,
41}
42
43impl<'a> Deriver<'a> {
44    /// Create a new Ethereum deriver from a wallet.
45    #[must_use]
46    pub const fn new(wallet: &'a Wallet) -> Self {
47        Self { wallet }
48    }
49
50    /// Derive an address using the Standard derivation style.
51    ///
52    /// Uses path: `m/44'/60'/0'/0/{index}` (MetaMask/Trezor compatible)
53    ///
54    /// # Arguments
55    ///
56    /// * `index` - The address index
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if derivation fails.
61    #[inline]
62    pub fn derive(&self, index: u32) -> Result<DerivedAddress, Error> {
63        self.derive_with(DerivationStyle::Standard, index)
64    }
65
66    /// Derive an address using a specific derivation style.
67    ///
68    /// This method supports different hardware/software wallet path formats:
69    /// - **Standard** (MetaMask/Trezor): `m/44'/60'/0'/0/{index}`
70    /// - **Ledger Live**: `m/44'/60'/{index}'/0/0`
71    /// - **Ledger Legacy**: `m/44'/60'/0'/{index}`
72    ///
73    /// # Arguments
74    ///
75    /// * `style` - The derivation style to use
76    /// * `index` - The address/account index
77    ///
78    /// # Errors
79    ///
80    /// Returns an error if derivation fails.
81    #[inline]
82    pub fn derive_with(&self, style: DerivationStyle, index: u32) -> Result<DerivedAddress, Error> {
83        self.derive_path(&style.path(index))
84    }
85
86    /// Derive multiple addresses using the Standard derivation style.
87    ///
88    /// # Arguments
89    ///
90    /// * `start` - Starting address index
91    /// * `count` - Number of addresses to derive
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if any derivation fails.
96    #[inline]
97    pub fn derive_many(&self, start: u32, count: u32) -> Result<Vec<DerivedAddress>, Error> {
98        self.derive_many_with(DerivationStyle::Standard, start, count)
99    }
100
101    /// Derive multiple addresses using a specific derivation style.
102    ///
103    /// # Arguments
104    ///
105    /// * `style` - The derivation style to use
106    /// * `start` - Starting index
107    /// * `count` - Number of addresses to derive
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if any derivation fails.
112    pub fn derive_many_with(
113        &self,
114        style: DerivationStyle,
115        start: u32,
116        count: u32,
117    ) -> Result<Vec<DerivedAddress>, Error> {
118        let end = start.checked_add(count).ok_or_else(|| {
119            Error::Derivation("index overflow: start + count exceeds u32::MAX".into())
120        })?;
121        (start..end)
122            .map(|index| self.derive_with(style, index))
123            .collect()
124    }
125
126    /// Derive an address at a custom derivation path.
127    ///
128    /// This is the lowest-level derivation method, allowing full control
129    /// over the derivation path.
130    ///
131    /// # Arguments
132    ///
133    /// * `path` - BIP-32 derivation path (e.g., `m/44'/60'/0'/0/0`)
134    ///
135    /// # Errors
136    ///
137    /// Returns an error if derivation fails.
138    pub fn derive_path(&self, path: &str) -> Result<DerivedAddress, Error> {
139        let private_key = self.derive_key(path)?;
140
141        let public_key = private_key.verifying_key();
142        let public_key_bytes = public_key.to_encoded_point(false);
143        let address = public_key_to_address(public_key_bytes.as_bytes())?;
144
145        Ok(DerivedAddress {
146            path: path.to_string(),
147            private_key_hex: Zeroizing::new(hex::encode(private_key.to_bytes())),
148            public_key_hex: hex::encode(public_key_bytes.as_bytes()),
149            address: to_checksum_address(&address),
150        })
151    }
152
153    /// Derive a private key at the given path using bip32 crate.
154    fn derive_key(&self, path: &str) -> Result<SigningKey, Error> {
155        // Parse derivation path
156        let derivation_path: DerivationPath = path
157            .parse()
158            .map_err(|e| Error::Derivation(format!("invalid derivation path: {e}")))?;
159
160        // Derive from seed directly using path
161        let derived = XPrv::derive_from_path(self.wallet.seed(), &derivation_path)
162            .map_err(|e| Error::Derivation(format!("key derivation failed: {e}")))?;
163
164        // Get signing key (XPrv wraps k256::ecdsa::SigningKey)
165        Ok(derived.private_key().clone())
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
174
175    fn test_wallet() -> Wallet {
176        Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap()
177    }
178
179    #[test]
180    fn test_derive_address() {
181        let wallet = test_wallet();
182        let deriver = Deriver::new(&wallet);
183        let addr = deriver.derive(0).unwrap();
184
185        assert!(addr.address.starts_with("0x"));
186        assert_eq!(addr.address.len(), 42);
187        assert_eq!(addr.path, "m/44'/60'/0'/0/0");
188    }
189
190    #[test]
191    fn test_derive_multiple() {
192        let wallet = test_wallet();
193        let deriver = Deriver::new(&wallet);
194        let addrs = deriver.derive_many(0, 5).unwrap();
195
196        assert_eq!(addrs.len(), 5);
197
198        // All addresses should be unique
199        let mut seen = Vec::new();
200        for addr in &addrs {
201            assert!(!seen.contains(&addr.address));
202            seen.push(addr.address.clone());
203        }
204        assert_eq!(seen.len(), 5);
205    }
206
207    #[test]
208    fn test_deterministic_derivation() {
209        let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
210        let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
211
212        let deriver1 = Deriver::new(&wallet1);
213        let deriver2 = Deriver::new(&wallet2);
214
215        let addr1 = deriver1.derive(0).unwrap();
216        let addr2 = deriver2.derive(0).unwrap();
217
218        assert_eq!(addr1.address, addr2.address);
219    }
220
221    #[test]
222    fn test_passphrase_changes_addresses() {
223        let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
224        let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, Some("password")).unwrap();
225
226        let deriver1 = Deriver::new(&wallet1);
227        let deriver2 = Deriver::new(&wallet2);
228
229        let addr1 = deriver1.derive(0).unwrap();
230        let addr2 = deriver2.derive(0).unwrap();
231
232        // Same mnemonic with different passphrase should produce different addresses
233        assert_ne!(addr1.address, addr2.address);
234    }
235
236    #[test]
237    fn test_derive_with_standard() {
238        let wallet = test_wallet();
239        let deriver = Deriver::new(&wallet);
240
241        let addr = deriver.derive_with(DerivationStyle::Standard, 0).unwrap();
242        assert_eq!(addr.path, "m/44'/60'/0'/0/0");
243
244        let addr = deriver.derive_with(DerivationStyle::Standard, 5).unwrap();
245        assert_eq!(addr.path, "m/44'/60'/0'/0/5");
246    }
247
248    #[test]
249    fn test_derive_with_ledger_live() {
250        let wallet = test_wallet();
251        let deriver = Deriver::new(&wallet);
252
253        let addr = deriver.derive_with(DerivationStyle::LedgerLive, 0).unwrap();
254        assert_eq!(addr.path, "m/44'/60'/0'/0/0");
255
256        let addr = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
257        assert_eq!(addr.path, "m/44'/60'/1'/0/0");
258    }
259
260    #[test]
261    fn test_derive_with_ledger_legacy() {
262        let wallet = test_wallet();
263        let deriver = Deriver::new(&wallet);
264
265        let addr = deriver
266            .derive_with(DerivationStyle::LedgerLegacy, 0)
267            .unwrap();
268        assert_eq!(addr.path, "m/44'/60'/0'/0");
269
270        let addr = deriver
271            .derive_with(DerivationStyle::LedgerLegacy, 3)
272            .unwrap();
273        assert_eq!(addr.path, "m/44'/60'/0'/3");
274    }
275
276    #[test]
277    fn test_different_styles_produce_different_addresses() {
278        let wallet = test_wallet();
279        let deriver = Deriver::new(&wallet);
280
281        let standard = deriver.derive_with(DerivationStyle::Standard, 1).unwrap();
282        let ledger_live = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
283        let ledger_legacy = deriver
284            .derive_with(DerivationStyle::LedgerLegacy, 1)
285            .unwrap();
286
287        // Different paths should produce different addresses
288        assert_ne!(standard.address, ledger_live.address);
289        assert_ne!(standard.address, ledger_legacy.address);
290        assert_ne!(ledger_live.address, ledger_legacy.address);
291    }
292
293    #[test]
294    fn test_derive_many_with() {
295        let wallet = test_wallet();
296        let deriver = Deriver::new(&wallet);
297
298        let addrs = deriver
299            .derive_many_with(DerivationStyle::LedgerLive, 0, 3)
300            .unwrap();
301
302        assert_eq!(addrs.len(), 3);
303        assert_eq!(addrs[0].path, "m/44'/60'/0'/0/0");
304        assert_eq!(addrs[1].path, "m/44'/60'/1'/0/0");
305        assert_eq!(addrs[2].path, "m/44'/60'/2'/0/0");
306    }
307
308    #[test]
309    fn test_derive_path() {
310        let wallet = test_wallet();
311        let deriver = Deriver::new(&wallet);
312
313        let addr = deriver.derive_path("m/44'/60'/0'/0/0").unwrap();
314        assert_eq!(addr.path, "m/44'/60'/0'/0/0");
315        assert!(addr.address.starts_with("0x"));
316    }
317}