1#[cfg(feature = "alloc")]
4use alloc::{
5 format,
6 string::{String, ToString},
7 vec::Vec,
8};
9
10use bip32::{DerivationPath, XPrv};
11use k256::ecdsa::SigningKey;
12use kobe::Wallet;
13use zeroize::Zeroizing;
14
15use crate::Error;
16use crate::address::{public_key_to_address, to_checksum_address};
17use crate::derivation_style::DerivationStyle;
18
19#[derive(Debug)]
36pub struct Deriver<'a> {
37 wallet: &'a Wallet,
39}
40
41#[derive(Debug, Clone)]
43pub struct DerivedAddress {
44 pub path: String,
46 pub private_key_hex: Zeroizing<String>,
48 pub public_key_hex: String,
50 pub address: String,
52}
53
54impl<'a> Deriver<'a> {
55 #[must_use]
57 pub const fn new(wallet: &'a Wallet) -> Self {
58 Self { wallet }
59 }
60
61 #[inline]
73 pub fn derive(&self, index: u32) -> Result<DerivedAddress, Error> {
74 self.derive_with(DerivationStyle::Standard, index)
75 }
76
77 #[inline]
107 pub fn derive_with(&self, style: DerivationStyle, index: u32) -> Result<DerivedAddress, Error> {
108 self.derive_path(&style.path(index))
109 }
110
111 #[inline]
122 pub fn derive_many(&self, start: u32, count: u32) -> Result<Vec<DerivedAddress>, Error> {
123 self.derive_many_with(DerivationStyle::Standard, start, count)
124 }
125
126 pub fn derive_many_with(
138 &self,
139 style: DerivationStyle,
140 start: u32,
141 count: u32,
142 ) -> Result<Vec<DerivedAddress>, Error> {
143 (start..start + count)
144 .map(|index| self.derive_with(style, index))
145 .collect()
146 }
147
148 pub fn derive_path(&self, path: &str) -> Result<DerivedAddress, Error> {
161 let private_key = self.derive_key(path)?;
162
163 let public_key = private_key.verifying_key();
164 let public_key_bytes = public_key.to_encoded_point(false);
165 let address = public_key_to_address(public_key_bytes.as_bytes());
166
167 Ok(DerivedAddress {
168 path: path.to_string(),
169 private_key_hex: Zeroizing::new(hex::encode(private_key.to_bytes())),
170 public_key_hex: hex::encode(public_key_bytes.as_bytes()),
171 address: to_checksum_address(&address),
172 })
173 }
174
175 fn derive_key(&self, path: &str) -> Result<SigningKey, Error> {
177 let derivation_path: DerivationPath = path
179 .parse()
180 .map_err(|e| Error::Derivation(format!("invalid derivation path: {e}")))?;
181
182 let derived = XPrv::derive_from_path(self.wallet.seed(), &derivation_path)
184 .map_err(|e| Error::Derivation(format!("key derivation failed: {e}")))?;
185
186 Ok(derived.private_key().clone())
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
196
197 fn test_wallet() -> Wallet {
198 Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap()
199 }
200
201 #[test]
202 fn test_derive_address() {
203 let wallet = test_wallet();
204 let deriver = Deriver::new(&wallet);
205 let addr = deriver.derive(0).unwrap();
206
207 assert!(addr.address.starts_with("0x"));
208 assert_eq!(addr.address.len(), 42);
209 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
210 }
211
212 #[test]
213 fn test_derive_multiple() {
214 let wallet = test_wallet();
215 let deriver = Deriver::new(&wallet);
216 let addrs = deriver.derive_many(0, 5).unwrap();
217
218 assert_eq!(addrs.len(), 5);
219
220 let mut seen = Vec::new();
222 for addr in &addrs {
223 assert!(!seen.contains(&addr.address));
224 seen.push(addr.address.clone());
225 }
226 assert_eq!(seen.len(), 5);
227 }
228
229 #[test]
230 fn test_deterministic_derivation() {
231 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
232 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
233
234 let deriver1 = Deriver::new(&wallet1);
235 let deriver2 = Deriver::new(&wallet2);
236
237 let addr1 = deriver1.derive(0).unwrap();
238 let addr2 = deriver2.derive(0).unwrap();
239
240 assert_eq!(addr1.address, addr2.address);
241 }
242
243 #[test]
244 fn test_passphrase_changes_addresses() {
245 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
246 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, Some("password")).unwrap();
247
248 let deriver1 = Deriver::new(&wallet1);
249 let deriver2 = Deriver::new(&wallet2);
250
251 let addr1 = deriver1.derive(0).unwrap();
252 let addr2 = deriver2.derive(0).unwrap();
253
254 assert_ne!(addr1.address, addr2.address);
256 }
257
258 #[test]
259 fn test_derive_with_standard() {
260 let wallet = test_wallet();
261 let deriver = Deriver::new(&wallet);
262
263 let addr = deriver.derive_with(DerivationStyle::Standard, 0).unwrap();
264 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
265
266 let addr = deriver.derive_with(DerivationStyle::Standard, 5).unwrap();
267 assert_eq!(addr.path, "m/44'/60'/0'/0/5");
268 }
269
270 #[test]
271 fn test_derive_with_ledger_live() {
272 let wallet = test_wallet();
273 let deriver = Deriver::new(&wallet);
274
275 let addr = deriver.derive_with(DerivationStyle::LedgerLive, 0).unwrap();
276 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
277
278 let addr = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
279 assert_eq!(addr.path, "m/44'/60'/1'/0/0");
280 }
281
282 #[test]
283 fn test_derive_with_ledger_legacy() {
284 let wallet = test_wallet();
285 let deriver = Deriver::new(&wallet);
286
287 let addr = deriver
288 .derive_with(DerivationStyle::LedgerLegacy, 0)
289 .unwrap();
290 assert_eq!(addr.path, "m/44'/60'/0'/0");
291
292 let addr = deriver
293 .derive_with(DerivationStyle::LedgerLegacy, 3)
294 .unwrap();
295 assert_eq!(addr.path, "m/44'/60'/0'/3");
296 }
297
298 #[test]
299 fn test_different_styles_produce_different_addresses() {
300 let wallet = test_wallet();
301 let deriver = Deriver::new(&wallet);
302
303 let standard = deriver.derive_with(DerivationStyle::Standard, 1).unwrap();
304 let ledger_live = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
305 let ledger_legacy = deriver
306 .derive_with(DerivationStyle::LedgerLegacy, 1)
307 .unwrap();
308
309 assert_ne!(standard.address, ledger_live.address);
311 assert_ne!(standard.address, ledger_legacy.address);
312 assert_ne!(ledger_live.address, ledger_legacy.address);
313 }
314
315 #[test]
316 fn test_derive_many_with() {
317 let wallet = test_wallet();
318 let deriver = Deriver::new(&wallet);
319
320 let addrs = deriver
321 .derive_many_with(DerivationStyle::LedgerLive, 0, 3)
322 .unwrap();
323
324 assert_eq!(addrs.len(), 3);
325 assert_eq!(addrs[0].path, "m/44'/60'/0'/0/0");
326 assert_eq!(addrs[1].path, "m/44'/60'/1'/0/0");
327 assert_eq!(addrs[2].path, "m/44'/60'/2'/0/0");
328 }
329
330 #[test]
331 fn test_derive_path() {
332 let wallet = test_wallet();
333 let deriver = Deriver::new(&wallet);
334
335 let addr = deriver.derive_path("m/44'/60'/0'/0/0").unwrap();
336 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
337 assert!(addr.address.starts_with("0x"));
338 }
339}