1#[cfg(feature = "alloc")]
4use alloc::{
5 format,
6 string::{String, ToString},
7 vec::Vec,
8};
9
10use bip32::{DerivationPath, XPrv};
11use k256::ecdsa::SigningKey;
12use kobe::Wallet;
13use zeroize::Zeroizing;
14
15use crate::Error;
16use crate::address::{public_key_to_address, to_checksum_address};
17use crate::derivation_style::DerivationStyle;
18
19#[derive(Debug)]
24pub struct Deriver<'a> {
25 wallet: &'a Wallet,
27}
28
29#[derive(Debug, Clone)]
31#[non_exhaustive]
32pub struct DerivedAddress {
33 pub path: String,
35 pub private_key_hex: Zeroizing<String>,
37 pub public_key_hex: String,
39 pub address: String,
41}
42
43impl<'a> Deriver<'a> {
44 #[must_use]
46 pub const fn new(wallet: &'a Wallet) -> Self {
47 Self { wallet }
48 }
49
50 #[inline]
62 pub fn derive(&self, index: u32) -> Result<DerivedAddress, Error> {
63 self.derive_with(DerivationStyle::Standard, index)
64 }
65
66 #[inline]
82 pub fn derive_with(&self, style: DerivationStyle, index: u32) -> Result<DerivedAddress, Error> {
83 self.derive_path(&style.path(index))
84 }
85
86 #[inline]
97 pub fn derive_many(&self, start: u32, count: u32) -> Result<Vec<DerivedAddress>, Error> {
98 self.derive_many_with(DerivationStyle::Standard, start, count)
99 }
100
101 pub fn derive_many_with(
113 &self,
114 style: DerivationStyle,
115 start: u32,
116 count: u32,
117 ) -> Result<Vec<DerivedAddress>, Error> {
118 let end = start.checked_add(count).ok_or_else(|| {
119 Error::Derivation("index overflow: start + count exceeds u32::MAX".into())
120 })?;
121 (start..end)
122 .map(|index| self.derive_with(style, index))
123 .collect()
124 }
125
126 pub fn derive_path(&self, path: &str) -> Result<DerivedAddress, Error> {
139 let private_key = self.derive_key(path)?;
140
141 let public_key = private_key.verifying_key();
142 let public_key_bytes = public_key.to_encoded_point(false);
143 let address = public_key_to_address(public_key_bytes.as_bytes())?;
144
145 Ok(DerivedAddress {
146 path: path.to_string(),
147 private_key_hex: Zeroizing::new(hex::encode(private_key.to_bytes())),
148 public_key_hex: hex::encode(public_key_bytes.as_bytes()),
149 address: to_checksum_address(&address),
150 })
151 }
152
153 fn derive_key(&self, path: &str) -> Result<SigningKey, Error> {
155 let derivation_path: DerivationPath = path
157 .parse()
158 .map_err(|e| Error::Derivation(format!("invalid derivation path: {e}")))?;
159
160 let derived = XPrv::derive_from_path(self.wallet.seed(), &derivation_path)
162 .map_err(|e| Error::Derivation(format!("key derivation failed: {e}")))?;
163
164 Ok(derived.private_key().clone())
166 }
167}
168
169#[cfg(test)]
170#[allow(clippy::unwrap_used, clippy::shadow_unrelated)]
171mod tests {
172 use super::*;
173
174 const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
175
176 fn test_wallet() -> Wallet {
177 Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap()
178 }
179
180 #[test]
181 fn test_derive_address() {
182 let wallet = test_wallet();
183 let deriver = Deriver::new(&wallet);
184 let addr = deriver.derive(0).unwrap();
185
186 assert!(addr.address.starts_with("0x"));
187 assert_eq!(addr.address.len(), 42);
188 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
189 }
190
191 #[test]
192 fn test_derive_multiple() {
193 let wallet = test_wallet();
194 let deriver = Deriver::new(&wallet);
195 let addrs = deriver.derive_many(0, 5).unwrap();
196
197 assert_eq!(addrs.len(), 5);
198
199 let mut seen = Vec::new();
201 for addr in &addrs {
202 assert!(!seen.contains(&addr.address));
203 seen.push(addr.address.clone());
204 }
205 assert_eq!(seen.len(), 5);
206 }
207
208 #[test]
209 fn test_deterministic_derivation() {
210 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
211 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
212
213 let deriver1 = Deriver::new(&wallet1);
214 let deriver2 = Deriver::new(&wallet2);
215
216 let addr1 = deriver1.derive(0).unwrap();
217 let addr2 = deriver2.derive(0).unwrap();
218
219 assert_eq!(addr1.address, addr2.address);
220 }
221
222 #[test]
223 fn test_passphrase_changes_addresses() {
224 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
225 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, Some("password")).unwrap();
226
227 let deriver1 = Deriver::new(&wallet1);
228 let deriver2 = Deriver::new(&wallet2);
229
230 let addr1 = deriver1.derive(0).unwrap();
231 let addr2 = deriver2.derive(0).unwrap();
232
233 assert_ne!(addr1.address, addr2.address);
235 }
236
237 #[test]
238 fn test_derive_with_standard() {
239 let wallet = test_wallet();
240 let deriver = Deriver::new(&wallet);
241
242 let addr = deriver.derive_with(DerivationStyle::Standard, 0).unwrap();
243 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
244
245 let addr = deriver.derive_with(DerivationStyle::Standard, 5).unwrap();
246 assert_eq!(addr.path, "m/44'/60'/0'/0/5");
247 }
248
249 #[test]
250 fn test_derive_with_ledger_live() {
251 let wallet = test_wallet();
252 let deriver = Deriver::new(&wallet);
253
254 let addr = deriver.derive_with(DerivationStyle::LedgerLive, 0).unwrap();
255 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
256
257 let addr = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
258 assert_eq!(addr.path, "m/44'/60'/1'/0/0");
259 }
260
261 #[test]
262 fn test_derive_with_ledger_legacy() {
263 let wallet = test_wallet();
264 let deriver = Deriver::new(&wallet);
265
266 let addr = deriver
267 .derive_with(DerivationStyle::LedgerLegacy, 0)
268 .unwrap();
269 assert_eq!(addr.path, "m/44'/60'/0'/0");
270
271 let addr = deriver
272 .derive_with(DerivationStyle::LedgerLegacy, 3)
273 .unwrap();
274 assert_eq!(addr.path, "m/44'/60'/0'/3");
275 }
276
277 #[test]
278 fn test_different_styles_produce_different_addresses() {
279 let wallet = test_wallet();
280 let deriver = Deriver::new(&wallet);
281
282 let standard = deriver.derive_with(DerivationStyle::Standard, 1).unwrap();
283 let ledger_live = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
284 let ledger_legacy = deriver
285 .derive_with(DerivationStyle::LedgerLegacy, 1)
286 .unwrap();
287
288 assert_ne!(standard.address, ledger_live.address);
290 assert_ne!(standard.address, ledger_legacy.address);
291 assert_ne!(ledger_live.address, ledger_legacy.address);
292 }
293
294 #[test]
295 fn test_derive_many_with() {
296 let wallet = test_wallet();
297 let deriver = Deriver::new(&wallet);
298
299 let addrs = deriver
300 .derive_many_with(DerivationStyle::LedgerLive, 0, 3)
301 .unwrap();
302
303 assert_eq!(addrs.len(), 3);
304 assert_eq!(addrs[0].path, "m/44'/60'/0'/0/0");
305 assert_eq!(addrs[1].path, "m/44'/60'/1'/0/0");
306 assert_eq!(addrs[2].path, "m/44'/60'/2'/0/0");
307 }
308
309 #[test]
310 fn test_derive_path() {
311 let wallet = test_wallet();
312 let deriver = Deriver::new(&wallet);
313
314 let addr = deriver.derive_path("m/44'/60'/0'/0/0").unwrap();
315 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
316 assert!(addr.address.starts_with("0x"));
317 }
318}