1#[cfg(feature = "alloc")]
4use alloc::{
5 format,
6 string::{String, ToString},
7 vec::Vec,
8};
9
10use bip32::{DerivationPath, XPrv};
11use k256::ecdsa::SigningKey;
12use kobe_core::Wallet;
13use zeroize::Zeroizing;
14
15use crate::Error;
16use crate::utils::{public_key_to_address, to_checksum_address};
17
18#[derive(Debug)]
23pub struct Deriver<'a> {
24 wallet: &'a Wallet,
26}
27
28#[derive(Debug)]
30pub struct DerivedAddress {
31 pub path: String,
33 pub private_key_hex: Zeroizing<String>,
35 pub public_key_hex: String,
37 pub address: String,
39}
40
41impl<'a> Deriver<'a> {
42 #[must_use]
44 pub const fn new(wallet: &'a Wallet) -> Self {
45 Self { wallet }
46 }
47
48 pub fn derive(
62 &self,
63 account: u32,
64 change: bool,
65 address_index: u32,
66 ) -> Result<DerivedAddress, Error> {
67 let change_val = if change { 1 } else { 0 };
68 let path = format!("m/44'/60'/{account}'/{change_val}/{address_index}");
69 self.derive_at_path(&path)
70 }
71
72 pub fn derive_at_path(&self, path: &str) -> Result<DerivedAddress, Error> {
78 let private_key = self.derive_key(path)?;
79
80 let public_key = private_key.verifying_key();
81 let public_key_bytes = public_key.to_encoded_point(false);
82 let address = public_key_to_address(public_key_bytes.as_bytes());
83
84 Ok(DerivedAddress {
85 path: path.to_string(),
86 private_key_hex: Zeroizing::new(hex::encode(private_key.to_bytes())),
87 public_key_hex: hex::encode(public_key_bytes.as_bytes()),
88 address: to_checksum_address(&address),
89 })
90 }
91
92 pub fn derive_many(
98 &self,
99 account: u32,
100 change: bool,
101 start_index: u32,
102 count: u32,
103 ) -> Result<Vec<DerivedAddress>, Error> {
104 (start_index..start_index + count)
105 .map(|index| self.derive(account, change, index))
106 .collect()
107 }
108
109 fn derive_key(&self, path: &str) -> Result<SigningKey, Error> {
111 let derivation_path: DerivationPath = path
113 .parse()
114 .map_err(|e| Error::Derivation(format!("invalid derivation path: {e}")))?;
115
116 let derived = XPrv::derive_from_path(self.wallet.seed(), &derivation_path)
118 .map_err(|e| Error::Derivation(format!("key derivation failed: {e}")))?;
119
120 Ok(derived.private_key().clone())
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
130
131 fn test_wallet() -> Wallet {
132 Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap()
133 }
134
135 #[test]
136 fn test_derive_address() {
137 let wallet = test_wallet();
138 let deriver = Deriver::new(&wallet);
139 let addr = deriver.derive(0, false, 0).unwrap();
140
141 assert!(addr.address.starts_with("0x"));
142 assert_eq!(addr.address.len(), 42);
143 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
144 }
145
146 #[test]
147 fn test_derive_multiple() {
148 let wallet = test_wallet();
149 let deriver = Deriver::new(&wallet);
150 let addrs = deriver.derive_many(0, false, 0, 5).unwrap();
151
152 assert_eq!(addrs.len(), 5);
153
154 let mut seen = alloc::vec::Vec::new();
156 for addr in &addrs {
157 assert!(!seen.contains(&addr.address));
158 seen.push(addr.address.clone());
159 }
160 assert_eq!(seen.len(), 5);
161 }
162
163 #[test]
164 fn test_deterministic_derivation() {
165 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
166 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
167
168 let deriver1 = Deriver::new(&wallet1);
169 let deriver2 = Deriver::new(&wallet2);
170
171 let addr1 = deriver1.derive(0, false, 0).unwrap();
172 let addr2 = deriver2.derive(0, false, 0).unwrap();
173
174 assert_eq!(addr1.address, addr2.address);
175 }
176
177 #[test]
178 fn test_passphrase_changes_addresses() {
179 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
180 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, Some("password")).unwrap();
181
182 let deriver1 = Deriver::new(&wallet1);
183 let deriver2 = Deriver::new(&wallet2);
184
185 let addr1 = deriver1.derive(0, false, 0).unwrap();
186 let addr2 = deriver2.derive(0, false, 0).unwrap();
187
188 assert_ne!(addr1.address, addr2.address);
190 }
191}