Skip to main content

kobe_evm/
deriver.rs

1//! Ethereum address derivation from a unified wallet.
2
3#[cfg(feature = "alloc")]
4use alloc::{
5    format,
6    string::{String, ToString},
7    vec::Vec,
8};
9
10use bip32::{DerivationPath, XPrv};
11use k256::ecdsa::SigningKey;
12use kobe::Wallet;
13use zeroize::Zeroizing;
14
15use crate::Error;
16use crate::address::{public_key_to_address, to_checksum_address};
17use crate::derivation_style::DerivationStyle;
18
19/// Ethereum address deriver from a unified wallet seed.
20///
21/// This deriver takes a seed from [`kobe::Wallet`] and derives
22/// Ethereum addresses following BIP32/44 standards.
23#[derive(Debug)]
24pub struct Deriver<'a> {
25    /// Reference to the wallet for seed access.
26    wallet: &'a Wallet,
27}
28
29/// A derived Ethereum address with associated keys.
30#[derive(Debug, Clone)]
31#[non_exhaustive]
32pub struct DerivedAddress {
33    /// Derivation path used (e.g., `m/44'/60'/0'/0/0`).
34    pub path: String,
35    /// Private key in hex format without 0x prefix (zeroized on drop).
36    pub private_key_hex: Zeroizing<String>,
37    /// Public key in uncompressed hex format.
38    pub public_key_hex: String,
39    /// Checksummed Ethereum address (EIP-55).
40    pub address: String,
41}
42
43impl<'a> Deriver<'a> {
44    /// Create a new Ethereum deriver from a wallet.
45    #[must_use]
46    pub const fn new(wallet: &'a Wallet) -> Self {
47        Self { wallet }
48    }
49
50    /// Derive an address using the Standard derivation style.
51    ///
52    /// Uses path: `m/44'/60'/0'/0/{index}` (MetaMask/Trezor compatible)
53    ///
54    /// # Arguments
55    ///
56    /// * `index` - The address index
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if derivation fails.
61    #[inline]
62    pub fn derive(&self, index: u32) -> Result<DerivedAddress, Error> {
63        self.derive_with(DerivationStyle::Standard, index)
64    }
65
66    /// Derive an address using a specific derivation style.
67    ///
68    /// This method supports different hardware/software wallet path formats:
69    /// - **Standard** (MetaMask/Trezor): `m/44'/60'/0'/0/{index}`
70    /// - **Ledger Live**: `m/44'/60'/{index}'/0/0`
71    /// - **Ledger Legacy**: `m/44'/60'/0'/{index}`
72    ///
73    /// # Arguments
74    ///
75    /// * `style` - The derivation style to use
76    /// * `index` - The address/account index
77    ///
78    /// # Errors
79    ///
80    /// Returns an error if derivation fails.
81    #[inline]
82    pub fn derive_with(&self, style: DerivationStyle, index: u32) -> Result<DerivedAddress, Error> {
83        self.derive_path(&style.path(index))
84    }
85
86    /// Derive multiple addresses using the Standard derivation style.
87    ///
88    /// # Arguments
89    ///
90    /// * `start` - Starting address index
91    /// * `count` - Number of addresses to derive
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if any derivation fails.
96    #[inline]
97    pub fn derive_many(&self, start: u32, count: u32) -> Result<Vec<DerivedAddress>, Error> {
98        self.derive_many_with(DerivationStyle::Standard, start, count)
99    }
100
101    /// Derive multiple addresses using a specific derivation style.
102    ///
103    /// # Arguments
104    ///
105    /// * `style` - The derivation style to use
106    /// * `start` - Starting index
107    /// * `count` - Number of addresses to derive
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if any derivation fails.
112    pub fn derive_many_with(
113        &self,
114        style: DerivationStyle,
115        start: u32,
116        count: u32,
117    ) -> Result<Vec<DerivedAddress>, Error> {
118        let end = start.checked_add(count).ok_or_else(|| {
119            Error::Derivation("index overflow: start + count exceeds u32::MAX".into())
120        })?;
121        (start..end)
122            .map(|index| self.derive_with(style, index))
123            .collect()
124    }
125
126    /// Derive an address at a custom derivation path.
127    ///
128    /// This is the lowest-level derivation method, allowing full control
129    /// over the derivation path.
130    ///
131    /// # Arguments
132    ///
133    /// * `path` - BIP-32 derivation path (e.g., `m/44'/60'/0'/0/0`)
134    ///
135    /// # Errors
136    ///
137    /// Returns an error if derivation fails.
138    pub fn derive_path(&self, path: &str) -> Result<DerivedAddress, Error> {
139        let private_key = self.derive_key(path)?;
140
141        let public_key = private_key.verifying_key();
142        let public_key_bytes = public_key.to_encoded_point(false);
143        let address = public_key_to_address(public_key_bytes.as_bytes())?;
144
145        Ok(DerivedAddress {
146            path: path.to_string(),
147            private_key_hex: Zeroizing::new(hex::encode(private_key.to_bytes())),
148            public_key_hex: hex::encode(public_key_bytes.as_bytes()),
149            address: to_checksum_address(&address),
150        })
151    }
152
153    /// Derive a private key at the given path using bip32 crate.
154    fn derive_key(&self, path: &str) -> Result<SigningKey, Error> {
155        // Parse derivation path
156        let derivation_path: DerivationPath = path
157            .parse()
158            .map_err(|e| Error::Derivation(format!("invalid derivation path: {e}")))?;
159
160        // Derive from seed directly using path
161        let derived = XPrv::derive_from_path(self.wallet.seed(), &derivation_path)
162            .map_err(|e| Error::Derivation(format!("key derivation failed: {e}")))?;
163
164        // Get signing key (XPrv wraps k256::ecdsa::SigningKey)
165        Ok(derived.private_key().clone())
166    }
167}
168
169#[cfg(test)]
170#[allow(clippy::unwrap_used, clippy::shadow_unrelated)]
171mod tests {
172    use super::*;
173
174    const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
175
176    fn test_wallet() -> Wallet {
177        Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap()
178    }
179
180    #[test]
181    fn test_derive_address() {
182        let wallet = test_wallet();
183        let deriver = Deriver::new(&wallet);
184        let addr = deriver.derive(0).unwrap();
185
186        assert!(addr.address.starts_with("0x"));
187        assert_eq!(addr.address.len(), 42);
188        assert_eq!(addr.path, "m/44'/60'/0'/0/0");
189    }
190
191    #[test]
192    fn test_derive_multiple() {
193        let wallet = test_wallet();
194        let deriver = Deriver::new(&wallet);
195        let addrs = deriver.derive_many(0, 5).unwrap();
196
197        assert_eq!(addrs.len(), 5);
198
199        // All addresses should be unique
200        let mut seen = Vec::new();
201        for addr in &addrs {
202            assert!(!seen.contains(&addr.address));
203            seen.push(addr.address.clone());
204        }
205        assert_eq!(seen.len(), 5);
206    }
207
208    #[test]
209    fn test_deterministic_derivation() {
210        let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
211        let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
212
213        let deriver1 = Deriver::new(&wallet1);
214        let deriver2 = Deriver::new(&wallet2);
215
216        let addr1 = deriver1.derive(0).unwrap();
217        let addr2 = deriver2.derive(0).unwrap();
218
219        assert_eq!(addr1.address, addr2.address);
220    }
221
222    #[test]
223    fn test_passphrase_changes_addresses() {
224        let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
225        let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, Some("password")).unwrap();
226
227        let deriver1 = Deriver::new(&wallet1);
228        let deriver2 = Deriver::new(&wallet2);
229
230        let addr1 = deriver1.derive(0).unwrap();
231        let addr2 = deriver2.derive(0).unwrap();
232
233        // Same mnemonic with different passphrase should produce different addresses
234        assert_ne!(addr1.address, addr2.address);
235    }
236
237    #[test]
238    fn test_derive_with_standard() {
239        let wallet = test_wallet();
240        let deriver = Deriver::new(&wallet);
241
242        let addr = deriver.derive_with(DerivationStyle::Standard, 0).unwrap();
243        assert_eq!(addr.path, "m/44'/60'/0'/0/0");
244
245        let addr = deriver.derive_with(DerivationStyle::Standard, 5).unwrap();
246        assert_eq!(addr.path, "m/44'/60'/0'/0/5");
247    }
248
249    #[test]
250    fn test_derive_with_ledger_live() {
251        let wallet = test_wallet();
252        let deriver = Deriver::new(&wallet);
253
254        let addr = deriver.derive_with(DerivationStyle::LedgerLive, 0).unwrap();
255        assert_eq!(addr.path, "m/44'/60'/0'/0/0");
256
257        let addr = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
258        assert_eq!(addr.path, "m/44'/60'/1'/0/0");
259    }
260
261    #[test]
262    fn test_derive_with_ledger_legacy() {
263        let wallet = test_wallet();
264        let deriver = Deriver::new(&wallet);
265
266        let addr = deriver
267            .derive_with(DerivationStyle::LedgerLegacy, 0)
268            .unwrap();
269        assert_eq!(addr.path, "m/44'/60'/0'/0");
270
271        let addr = deriver
272            .derive_with(DerivationStyle::LedgerLegacy, 3)
273            .unwrap();
274        assert_eq!(addr.path, "m/44'/60'/0'/3");
275    }
276
277    #[test]
278    fn test_different_styles_produce_different_addresses() {
279        let wallet = test_wallet();
280        let deriver = Deriver::new(&wallet);
281
282        let standard = deriver.derive_with(DerivationStyle::Standard, 1).unwrap();
283        let ledger_live = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
284        let ledger_legacy = deriver
285            .derive_with(DerivationStyle::LedgerLegacy, 1)
286            .unwrap();
287
288        // Different paths should produce different addresses
289        assert_ne!(standard.address, ledger_live.address);
290        assert_ne!(standard.address, ledger_legacy.address);
291        assert_ne!(ledger_live.address, ledger_legacy.address);
292    }
293
294    #[test]
295    fn test_derive_many_with() {
296        let wallet = test_wallet();
297        let deriver = Deriver::new(&wallet);
298
299        let addrs = deriver
300            .derive_many_with(DerivationStyle::LedgerLive, 0, 3)
301            .unwrap();
302
303        assert_eq!(addrs.len(), 3);
304        assert_eq!(addrs[0].path, "m/44'/60'/0'/0/0");
305        assert_eq!(addrs[1].path, "m/44'/60'/1'/0/0");
306        assert_eq!(addrs[2].path, "m/44'/60'/2'/0/0");
307    }
308
309    #[test]
310    fn test_derive_path() {
311        let wallet = test_wallet();
312        let deriver = Deriver::new(&wallet);
313
314        let addr = deriver.derive_path("m/44'/60'/0'/0/0").unwrap();
315        assert_eq!(addr.path, "m/44'/60'/0'/0/0");
316        assert!(addr.address.starts_with("0x"));
317    }
318}