1#[cfg(feature = "alloc")]
4use alloc::{
5 format,
6 string::{String, ToString},
7 vec::Vec,
8};
9
10use bip32::{DerivationPath, XPrv};
11use k256::ecdsa::SigningKey;
12use kobe_core::Wallet;
13use zeroize::Zeroizing;
14
15use crate::Error;
16use crate::address::{public_key_to_address, to_checksum_address};
17use crate::derivation_style::DerivationStyle;
18
19#[derive(Debug)]
36pub struct Deriver<'a> {
37 wallet: &'a Wallet,
39}
40
41#[derive(Debug, Clone)]
43pub struct DerivedAddress {
44 pub path: String,
46 pub private_key_hex: Zeroizing<String>,
48 pub public_key_hex: String,
50 pub address: String,
52}
53
54impl<'a> Deriver<'a> {
55 #[must_use]
57 pub const fn new(wallet: &'a Wallet) -> Self {
58 Self { wallet }
59 }
60
61 #[inline]
75 pub fn derive(
76 &self,
77 account: u32,
78 change: bool,
79 address_index: u32,
80 ) -> Result<DerivedAddress, Error> {
81 let change_val = i32::from(change);
82 let path = format!("m/44'/60'/{account}'/{change_val}/{address_index}");
83 self.derive_at_path(&path)
84 }
85
86 pub fn derive_at_path(&self, path: &str) -> Result<DerivedAddress, Error> {
92 let private_key = self.derive_key(path)?;
93
94 let public_key = private_key.verifying_key();
95 let public_key_bytes = public_key.to_encoded_point(false);
96 let address = public_key_to_address(public_key_bytes.as_bytes());
97
98 Ok(DerivedAddress {
99 path: path.to_string(),
100 private_key_hex: Zeroizing::new(hex::encode(private_key.to_bytes())),
101 public_key_hex: hex::encode(public_key_bytes.as_bytes()),
102 address: to_checksum_address(&address),
103 })
104 }
105
106 pub fn derive_many(
119 &self,
120 account: u32,
121 change: bool,
122 start_index: u32,
123 count: u32,
124 ) -> Result<Vec<DerivedAddress>, Error> {
125 (start_index..start_index + count)
126 .map(|index| self.derive(account, change, index))
127 .collect()
128 }
129
130 #[inline]
160 pub fn derive_with_style(
161 &self,
162 style: DerivationStyle,
163 index: u32,
164 ) -> Result<DerivedAddress, Error> {
165 self.derive_at_path(&style.path(index))
166 }
167
168 pub fn derive_many_with_style(
180 &self,
181 style: DerivationStyle,
182 start_index: u32,
183 count: u32,
184 ) -> Result<Vec<DerivedAddress>, Error> {
185 (start_index..start_index + count)
186 .map(|index| self.derive_with_style(style, index))
187 .collect()
188 }
189
190 fn derive_key(&self, path: &str) -> Result<SigningKey, Error> {
192 let derivation_path: DerivationPath = path
194 .parse()
195 .map_err(|e| Error::Derivation(format!("invalid derivation path: {e}")))?;
196
197 let derived = XPrv::derive_from_path(self.wallet.seed(), &derivation_path)
199 .map_err(|e| Error::Derivation(format!("key derivation failed: {e}")))?;
200
201 Ok(derived.private_key().clone())
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
211
212 fn test_wallet() -> Wallet {
213 Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap()
214 }
215
216 #[test]
217 fn test_derive_address() {
218 let wallet = test_wallet();
219 let deriver = Deriver::new(&wallet);
220 let addr = deriver.derive(0, false, 0).unwrap();
221
222 assert!(addr.address.starts_with("0x"));
223 assert_eq!(addr.address.len(), 42);
224 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
225 }
226
227 #[test]
228 fn test_derive_multiple() {
229 let wallet = test_wallet();
230 let deriver = Deriver::new(&wallet);
231 let addrs = deriver.derive_many(0, false, 0, 5).unwrap();
232
233 assert_eq!(addrs.len(), 5);
234
235 let mut seen = alloc::vec::Vec::new();
237 for addr in &addrs {
238 assert!(!seen.contains(&addr.address));
239 seen.push(addr.address.clone());
240 }
241 assert_eq!(seen.len(), 5);
242 }
243
244 #[test]
245 fn test_deterministic_derivation() {
246 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
247 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
248
249 let deriver1 = Deriver::new(&wallet1);
250 let deriver2 = Deriver::new(&wallet2);
251
252 let addr1 = deriver1.derive(0, false, 0).unwrap();
253 let addr2 = deriver2.derive(0, false, 0).unwrap();
254
255 assert_eq!(addr1.address, addr2.address);
256 }
257
258 #[test]
259 fn test_passphrase_changes_addresses() {
260 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
261 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, Some("password")).unwrap();
262
263 let deriver1 = Deriver::new(&wallet1);
264 let deriver2 = Deriver::new(&wallet2);
265
266 let addr1 = deriver1.derive(0, false, 0).unwrap();
267 let addr2 = deriver2.derive(0, false, 0).unwrap();
268
269 assert_ne!(addr1.address, addr2.address);
271 }
272
273 #[test]
274 fn test_derive_with_style_standard() {
275 let wallet = test_wallet();
276 let deriver = Deriver::new(&wallet);
277
278 let addr = deriver
279 .derive_with_style(DerivationStyle::Standard, 0)
280 .unwrap();
281 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
282
283 let addr = deriver
284 .derive_with_style(DerivationStyle::Standard, 5)
285 .unwrap();
286 assert_eq!(addr.path, "m/44'/60'/0'/0/5");
287 }
288
289 #[test]
290 fn test_derive_with_style_ledger_live() {
291 let wallet = test_wallet();
292 let deriver = Deriver::new(&wallet);
293
294 let addr = deriver
295 .derive_with_style(DerivationStyle::LedgerLive, 0)
296 .unwrap();
297 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
298
299 let addr = deriver
300 .derive_with_style(DerivationStyle::LedgerLive, 1)
301 .unwrap();
302 assert_eq!(addr.path, "m/44'/60'/1'/0/0");
303 }
304
305 #[test]
306 fn test_derive_with_style_ledger_legacy() {
307 let wallet = test_wallet();
308 let deriver = Deriver::new(&wallet);
309
310 let addr = deriver
311 .derive_with_style(DerivationStyle::LedgerLegacy, 0)
312 .unwrap();
313 assert_eq!(addr.path, "m/44'/60'/0'/0");
314
315 let addr = deriver
316 .derive_with_style(DerivationStyle::LedgerLegacy, 3)
317 .unwrap();
318 assert_eq!(addr.path, "m/44'/60'/0'/3");
319 }
320
321 #[test]
322 fn test_different_styles_produce_different_addresses() {
323 let wallet = test_wallet();
324 let deriver = Deriver::new(&wallet);
325
326 let standard = deriver
327 .derive_with_style(DerivationStyle::Standard, 1)
328 .unwrap();
329 let ledger_live = deriver
330 .derive_with_style(DerivationStyle::LedgerLive, 1)
331 .unwrap();
332 let ledger_legacy = deriver
333 .derive_with_style(DerivationStyle::LedgerLegacy, 1)
334 .unwrap();
335
336 assert_ne!(standard.address, ledger_live.address);
338 assert_ne!(standard.address, ledger_legacy.address);
339 assert_ne!(ledger_live.address, ledger_legacy.address);
340 }
341
342 #[test]
343 fn test_derive_many_with_style() {
344 let wallet = test_wallet();
345 let deriver = Deriver::new(&wallet);
346
347 let addrs = deriver
348 .derive_many_with_style(DerivationStyle::LedgerLive, 0, 3)
349 .unwrap();
350
351 assert_eq!(addrs.len(), 3);
352 assert_eq!(addrs[0].path, "m/44'/60'/0'/0/0");
353 assert_eq!(addrs[1].path, "m/44'/60'/1'/0/0");
354 assert_eq!(addrs[2].path, "m/44'/60'/2'/0/0");
355 }
356}