Skip to main content

kobe_evm/
deriver.rs

1//! Ethereum address derivation from a unified wallet.
2
3#[cfg(feature = "alloc")]
4use alloc::{
5    format,
6    string::{String, ToString},
7    vec::Vec,
8};
9
10use bip32::{DerivationPath, XPrv};
11use k256::ecdsa::SigningKey;
12use kobe::Wallet;
13use zeroize::Zeroizing;
14
15use crate::Error;
16use crate::address::{public_key_to_address, to_checksum_address};
17use crate::derivation_style::DerivationStyle;
18
19/// Ethereum address deriver from a unified wallet seed.
20///
21/// This deriver takes a seed from [`kobe::Wallet`] and derives
22/// Ethereum addresses following BIP32/44 standards.
23#[derive(Debug)]
24pub struct Deriver<'a> {
25    /// Reference to the wallet for seed access.
26    wallet: &'a Wallet,
27}
28
29/// A derived Ethereum address with associated keys.
30#[derive(Debug, Clone)]
31pub struct DerivedAddress {
32    /// Derivation path used (e.g., `m/44'/60'/0'/0/0`).
33    pub path: String,
34    /// Private key in hex format without 0x prefix (zeroized on drop).
35    pub private_key_hex: Zeroizing<String>,
36    /// Public key in uncompressed hex format.
37    pub public_key_hex: String,
38    /// Checksummed Ethereum address (EIP-55)..
39    pub address: String,
40}
41
42impl<'a> Deriver<'a> {
43    /// Create a new Ethereum deriver from a wallet.
44    #[must_use]
45    pub const fn new(wallet: &'a Wallet) -> Self {
46        Self { wallet }
47    }
48
49    /// Derive an address using the Standard derivation style.
50    ///
51    /// Uses path: `m/44'/60'/0'/0/{index}` (MetaMask/Trezor compatible)
52    ///
53    /// # Arguments
54    ///
55    /// * `index` - The address index
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if derivation fails.
60    #[inline]
61    pub fn derive(&self, index: u32) -> Result<DerivedAddress, Error> {
62        self.derive_with(DerivationStyle::Standard, index)
63    }
64
65    /// Derive an address using a specific derivation style.
66    ///
67    /// This method supports different hardware/software wallet path formats:
68    /// - **Standard** (MetaMask/Trezor): `m/44'/60'/0'/0/{index}`
69    /// - **Ledger Live**: `m/44'/60'/{index}'/0/0`
70    /// - **Ledger Legacy**: `m/44'/60'/0'/{index}`
71    ///
72    /// # Arguments
73    ///
74    /// * `style` - The derivation style to use
75    /// * `index` - The address/account index
76    ///
77    /// # Errors
78    ///
79    /// Returns an error if derivation fails.
80    #[inline]
81    pub fn derive_with(&self, style: DerivationStyle, index: u32) -> Result<DerivedAddress, Error> {
82        self.derive_path(&style.path(index))
83    }
84
85    /// Derive multiple addresses using the Standard derivation style.
86    ///
87    /// # Arguments
88    ///
89    /// * `start` - Starting address index
90    /// * `count` - Number of addresses to derive
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if any derivation fails.
95    #[inline]
96    pub fn derive_many(&self, start: u32, count: u32) -> Result<Vec<DerivedAddress>, Error> {
97        self.derive_many_with(DerivationStyle::Standard, start, count)
98    }
99
100    /// Derive multiple addresses using a specific derivation style.
101    ///
102    /// # Arguments
103    ///
104    /// * `style` - The derivation style to use
105    /// * `start` - Starting index
106    /// * `count` - Number of addresses to derive
107    ///
108    /// # Errors
109    ///
110    /// Returns an error if any derivation fails.
111    pub fn derive_many_with(
112        &self,
113        style: DerivationStyle,
114        start: u32,
115        count: u32,
116    ) -> Result<Vec<DerivedAddress>, Error> {
117        (start..start + count)
118            .map(|index| self.derive_with(style, index))
119            .collect()
120    }
121
122    /// Derive an address at a custom derivation path.
123    ///
124    /// This is the lowest-level derivation method, allowing full control
125    /// over the derivation path.
126    ///
127    /// # Arguments
128    ///
129    /// * `path` - BIP-32 derivation path (e.g., `m/44'/60'/0'/0/0`)
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if derivation fails.
134    pub fn derive_path(&self, path: &str) -> Result<DerivedAddress, Error> {
135        let private_key = self.derive_key(path)?;
136
137        let public_key = private_key.verifying_key();
138        let public_key_bytes = public_key.to_encoded_point(false);
139        let address = public_key_to_address(public_key_bytes.as_bytes());
140
141        Ok(DerivedAddress {
142            path: path.to_string(),
143            private_key_hex: Zeroizing::new(hex::encode(private_key.to_bytes())),
144            public_key_hex: hex::encode(public_key_bytes.as_bytes()),
145            address: to_checksum_address(&address),
146        })
147    }
148
149    /// Derive a private key at the given path using bip32 crate.
150    fn derive_key(&self, path: &str) -> Result<SigningKey, Error> {
151        // Parse derivation path
152        let derivation_path: DerivationPath = path
153            .parse()
154            .map_err(|e| Error::Derivation(format!("invalid derivation path: {e}")))?;
155
156        // Derive from seed directly using path
157        let derived = XPrv::derive_from_path(self.wallet.seed(), &derivation_path)
158            .map_err(|e| Error::Derivation(format!("key derivation failed: {e}")))?;
159
160        // Get signing key (XPrv wraps k256::ecdsa::SigningKey)
161        Ok(derived.private_key().clone())
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
170
171    fn test_wallet() -> Wallet {
172        Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap()
173    }
174
175    #[test]
176    fn test_derive_address() {
177        let wallet = test_wallet();
178        let deriver = Deriver::new(&wallet);
179        let addr = deriver.derive(0).unwrap();
180
181        assert!(addr.address.starts_with("0x"));
182        assert_eq!(addr.address.len(), 42);
183        assert_eq!(addr.path, "m/44'/60'/0'/0/0");
184    }
185
186    #[test]
187    fn test_derive_multiple() {
188        let wallet = test_wallet();
189        let deriver = Deriver::new(&wallet);
190        let addrs = deriver.derive_many(0, 5).unwrap();
191
192        assert_eq!(addrs.len(), 5);
193
194        // All addresses should be unique
195        let mut seen = Vec::new();
196        for addr in &addrs {
197            assert!(!seen.contains(&addr.address));
198            seen.push(addr.address.clone());
199        }
200        assert_eq!(seen.len(), 5);
201    }
202
203    #[test]
204    fn test_deterministic_derivation() {
205        let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
206        let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
207
208        let deriver1 = Deriver::new(&wallet1);
209        let deriver2 = Deriver::new(&wallet2);
210
211        let addr1 = deriver1.derive(0).unwrap();
212        let addr2 = deriver2.derive(0).unwrap();
213
214        assert_eq!(addr1.address, addr2.address);
215    }
216
217    #[test]
218    fn test_passphrase_changes_addresses() {
219        let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
220        let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, Some("password")).unwrap();
221
222        let deriver1 = Deriver::new(&wallet1);
223        let deriver2 = Deriver::new(&wallet2);
224
225        let addr1 = deriver1.derive(0).unwrap();
226        let addr2 = deriver2.derive(0).unwrap();
227
228        // Same mnemonic with different passphrase should produce different addresses
229        assert_ne!(addr1.address, addr2.address);
230    }
231
232    #[test]
233    fn test_derive_with_standard() {
234        let wallet = test_wallet();
235        let deriver = Deriver::new(&wallet);
236
237        let addr = deriver.derive_with(DerivationStyle::Standard, 0).unwrap();
238        assert_eq!(addr.path, "m/44'/60'/0'/0/0");
239
240        let addr = deriver.derive_with(DerivationStyle::Standard, 5).unwrap();
241        assert_eq!(addr.path, "m/44'/60'/0'/0/5");
242    }
243
244    #[test]
245    fn test_derive_with_ledger_live() {
246        let wallet = test_wallet();
247        let deriver = Deriver::new(&wallet);
248
249        let addr = deriver.derive_with(DerivationStyle::LedgerLive, 0).unwrap();
250        assert_eq!(addr.path, "m/44'/60'/0'/0/0");
251
252        let addr = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
253        assert_eq!(addr.path, "m/44'/60'/1'/0/0");
254    }
255
256    #[test]
257    fn test_derive_with_ledger_legacy() {
258        let wallet = test_wallet();
259        let deriver = Deriver::new(&wallet);
260
261        let addr = deriver
262            .derive_with(DerivationStyle::LedgerLegacy, 0)
263            .unwrap();
264        assert_eq!(addr.path, "m/44'/60'/0'/0");
265
266        let addr = deriver
267            .derive_with(DerivationStyle::LedgerLegacy, 3)
268            .unwrap();
269        assert_eq!(addr.path, "m/44'/60'/0'/3");
270    }
271
272    #[test]
273    fn test_different_styles_produce_different_addresses() {
274        let wallet = test_wallet();
275        let deriver = Deriver::new(&wallet);
276
277        let standard = deriver.derive_with(DerivationStyle::Standard, 1).unwrap();
278        let ledger_live = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
279        let ledger_legacy = deriver
280            .derive_with(DerivationStyle::LedgerLegacy, 1)
281            .unwrap();
282
283        // Different paths should produce different addresses
284        assert_ne!(standard.address, ledger_live.address);
285        assert_ne!(standard.address, ledger_legacy.address);
286        assert_ne!(ledger_live.address, ledger_legacy.address);
287    }
288
289    #[test]
290    fn test_derive_many_with() {
291        let wallet = test_wallet();
292        let deriver = Deriver::new(&wallet);
293
294        let addrs = deriver
295            .derive_many_with(DerivationStyle::LedgerLive, 0, 3)
296            .unwrap();
297
298        assert_eq!(addrs.len(), 3);
299        assert_eq!(addrs[0].path, "m/44'/60'/0'/0/0");
300        assert_eq!(addrs[1].path, "m/44'/60'/1'/0/0");
301        assert_eq!(addrs[2].path, "m/44'/60'/2'/0/0");
302    }
303
304    #[test]
305    fn test_derive_path() {
306        let wallet = test_wallet();
307        let deriver = Deriver::new(&wallet);
308
309        let addr = deriver.derive_path("m/44'/60'/0'/0/0").unwrap();
310        assert_eq!(addr.path, "m/44'/60'/0'/0/0");
311        assert!(addr.address.starts_with("0x"));
312    }
313}