1#[cfg(feature = "alloc")]
4use alloc::{
5 format,
6 string::{String, ToString},
7 vec::Vec,
8};
9
10use bip32::{DerivationPath, XPrv};
11use k256::ecdsa::SigningKey;
12use kobe::Wallet;
13use zeroize::Zeroizing;
14
15use crate::Error;
16use crate::address::{public_key_to_address, to_checksum_address};
17use crate::derivation_style::DerivationStyle;
18
19#[derive(Debug)]
24pub struct Deriver<'a> {
25 wallet: &'a Wallet,
27}
28
29#[derive(Debug, Clone)]
31#[non_exhaustive]
32pub struct DerivedAddress {
33 pub path: String,
35 pub private_key_hex: Zeroizing<String>,
37 pub public_key_hex: String,
39 pub address: String,
41}
42
43impl<'a> Deriver<'a> {
44 #[must_use]
46 pub const fn new(wallet: &'a Wallet) -> Self {
47 Self { wallet }
48 }
49
50 #[inline]
62 pub fn derive(&self, index: u32) -> Result<DerivedAddress, Error> {
63 self.derive_with(DerivationStyle::Standard, index)
64 }
65
66 #[inline]
82 pub fn derive_with(&self, style: DerivationStyle, index: u32) -> Result<DerivedAddress, Error> {
83 self.derive_path(&style.path(index))
84 }
85
86 #[inline]
97 pub fn derive_many(&self, start: u32, count: u32) -> Result<Vec<DerivedAddress>, Error> {
98 self.derive_many_with(DerivationStyle::Standard, start, count)
99 }
100
101 pub fn derive_many_with(
113 &self,
114 style: DerivationStyle,
115 start: u32,
116 count: u32,
117 ) -> Result<Vec<DerivedAddress>, Error> {
118 let end = start.checked_add(count).ok_or_else(|| {
119 Error::Derivation("index overflow: start + count exceeds u32::MAX".into())
120 })?;
121 (start..end)
122 .map(|index| self.derive_with(style, index))
123 .collect()
124 }
125
126 pub fn derive_path(&self, path: &str) -> Result<DerivedAddress, Error> {
139 let private_key = self.derive_key(path)?;
140
141 let public_key = private_key.verifying_key();
142 let public_key_bytes = public_key.to_encoded_point(false);
143 let address = public_key_to_address(public_key_bytes.as_bytes())?;
144
145 Ok(DerivedAddress {
146 path: path.to_string(),
147 private_key_hex: Zeroizing::new(hex::encode(private_key.to_bytes())),
148 public_key_hex: hex::encode(public_key_bytes.as_bytes()),
149 address: to_checksum_address(&address),
150 })
151 }
152
153 fn derive_key(&self, path: &str) -> Result<SigningKey, Error> {
155 let derivation_path: DerivationPath = path
157 .parse()
158 .map_err(|e| Error::Derivation(format!("invalid derivation path: {e}")))?;
159
160 let derived = XPrv::derive_from_path(self.wallet.seed(), &derivation_path)
162 .map_err(|e| Error::Derivation(format!("key derivation failed: {e}")))?;
163
164 Ok(derived.private_key().clone())
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
174
175 fn test_wallet() -> Wallet {
176 Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap()
177 }
178
179 #[test]
180 fn test_derive_address() {
181 let wallet = test_wallet();
182 let deriver = Deriver::new(&wallet);
183 let addr = deriver.derive(0).unwrap();
184
185 assert!(addr.address.starts_with("0x"));
186 assert_eq!(addr.address.len(), 42);
187 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
188 }
189
190 #[test]
191 fn test_derive_multiple() {
192 let wallet = test_wallet();
193 let deriver = Deriver::new(&wallet);
194 let addrs = deriver.derive_many(0, 5).unwrap();
195
196 assert_eq!(addrs.len(), 5);
197
198 let mut seen = Vec::new();
200 for addr in &addrs {
201 assert!(!seen.contains(&addr.address));
202 seen.push(addr.address.clone());
203 }
204 assert_eq!(seen.len(), 5);
205 }
206
207 #[test]
208 fn test_deterministic_derivation() {
209 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
210 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
211
212 let deriver1 = Deriver::new(&wallet1);
213 let deriver2 = Deriver::new(&wallet2);
214
215 let addr1 = deriver1.derive(0).unwrap();
216 let addr2 = deriver2.derive(0).unwrap();
217
218 assert_eq!(addr1.address, addr2.address);
219 }
220
221 #[test]
222 fn test_passphrase_changes_addresses() {
223 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
224 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, Some("password")).unwrap();
225
226 let deriver1 = Deriver::new(&wallet1);
227 let deriver2 = Deriver::new(&wallet2);
228
229 let addr1 = deriver1.derive(0).unwrap();
230 let addr2 = deriver2.derive(0).unwrap();
231
232 assert_ne!(addr1.address, addr2.address);
234 }
235
236 #[test]
237 fn test_derive_with_standard() {
238 let wallet = test_wallet();
239 let deriver = Deriver::new(&wallet);
240
241 let addr = deriver.derive_with(DerivationStyle::Standard, 0).unwrap();
242 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
243
244 let addr = deriver.derive_with(DerivationStyle::Standard, 5).unwrap();
245 assert_eq!(addr.path, "m/44'/60'/0'/0/5");
246 }
247
248 #[test]
249 fn test_derive_with_ledger_live() {
250 let wallet = test_wallet();
251 let deriver = Deriver::new(&wallet);
252
253 let addr = deriver.derive_with(DerivationStyle::LedgerLive, 0).unwrap();
254 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
255
256 let addr = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
257 assert_eq!(addr.path, "m/44'/60'/1'/0/0");
258 }
259
260 #[test]
261 fn test_derive_with_ledger_legacy() {
262 let wallet = test_wallet();
263 let deriver = Deriver::new(&wallet);
264
265 let addr = deriver
266 .derive_with(DerivationStyle::LedgerLegacy, 0)
267 .unwrap();
268 assert_eq!(addr.path, "m/44'/60'/0'/0");
269
270 let addr = deriver
271 .derive_with(DerivationStyle::LedgerLegacy, 3)
272 .unwrap();
273 assert_eq!(addr.path, "m/44'/60'/0'/3");
274 }
275
276 #[test]
277 fn test_different_styles_produce_different_addresses() {
278 let wallet = test_wallet();
279 let deriver = Deriver::new(&wallet);
280
281 let standard = deriver.derive_with(DerivationStyle::Standard, 1).unwrap();
282 let ledger_live = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
283 let ledger_legacy = deriver
284 .derive_with(DerivationStyle::LedgerLegacy, 1)
285 .unwrap();
286
287 assert_ne!(standard.address, ledger_live.address);
289 assert_ne!(standard.address, ledger_legacy.address);
290 assert_ne!(ledger_live.address, ledger_legacy.address);
291 }
292
293 #[test]
294 fn test_derive_many_with() {
295 let wallet = test_wallet();
296 let deriver = Deriver::new(&wallet);
297
298 let addrs = deriver
299 .derive_many_with(DerivationStyle::LedgerLive, 0, 3)
300 .unwrap();
301
302 assert_eq!(addrs.len(), 3);
303 assert_eq!(addrs[0].path, "m/44'/60'/0'/0/0");
304 assert_eq!(addrs[1].path, "m/44'/60'/1'/0/0");
305 assert_eq!(addrs[2].path, "m/44'/60'/2'/0/0");
306 }
307
308 #[test]
309 fn test_derive_path() {
310 let wallet = test_wallet();
311 let deriver = Deriver::new(&wallet);
312
313 let addr = deriver.derive_path("m/44'/60'/0'/0/0").unwrap();
314 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
315 assert!(addr.address.starts_with("0x"));
316 }
317}