1#[cfg(feature = "alloc")]
4use alloc::{
5 format,
6 string::{String, ToString},
7 vec::Vec,
8};
9
10use bip32::{DerivationPath, XPrv};
11use k256::ecdsa::SigningKey;
12use kobe::Wallet;
13use zeroize::Zeroizing;
14
15use crate::Error;
16use crate::address::{public_key_to_address, to_checksum_address};
17use crate::derivation_style::DerivationStyle;
18
19#[derive(Debug)]
24pub struct Deriver<'a> {
25 wallet: &'a Wallet,
27}
28
29#[derive(Debug, Clone)]
31pub struct DerivedAddress {
32 pub path: String,
34 pub private_key_hex: Zeroizing<String>,
36 pub public_key_hex: String,
38 pub address: String,
40}
41
42impl<'a> Deriver<'a> {
43 #[must_use]
45 pub const fn new(wallet: &'a Wallet) -> Self {
46 Self { wallet }
47 }
48
49 #[inline]
61 pub fn derive(&self, index: u32) -> Result<DerivedAddress, Error> {
62 self.derive_with(DerivationStyle::Standard, index)
63 }
64
65 #[inline]
81 pub fn derive_with(&self, style: DerivationStyle, index: u32) -> Result<DerivedAddress, Error> {
82 self.derive_path(&style.path(index))
83 }
84
85 #[inline]
96 pub fn derive_many(&self, start: u32, count: u32) -> Result<Vec<DerivedAddress>, Error> {
97 self.derive_many_with(DerivationStyle::Standard, start, count)
98 }
99
100 pub fn derive_many_with(
112 &self,
113 style: DerivationStyle,
114 start: u32,
115 count: u32,
116 ) -> Result<Vec<DerivedAddress>, Error> {
117 (start..start + count)
118 .map(|index| self.derive_with(style, index))
119 .collect()
120 }
121
122 pub fn derive_path(&self, path: &str) -> Result<DerivedAddress, Error> {
135 let private_key = self.derive_key(path)?;
136
137 let public_key = private_key.verifying_key();
138 let public_key_bytes = public_key.to_encoded_point(false);
139 let address = public_key_to_address(public_key_bytes.as_bytes());
140
141 Ok(DerivedAddress {
142 path: path.to_string(),
143 private_key_hex: Zeroizing::new(hex::encode(private_key.to_bytes())),
144 public_key_hex: hex::encode(public_key_bytes.as_bytes()),
145 address: to_checksum_address(&address),
146 })
147 }
148
149 fn derive_key(&self, path: &str) -> Result<SigningKey, Error> {
151 let derivation_path: DerivationPath = path
153 .parse()
154 .map_err(|e| Error::Derivation(format!("invalid derivation path: {e}")))?;
155
156 let derived = XPrv::derive_from_path(self.wallet.seed(), &derivation_path)
158 .map_err(|e| Error::Derivation(format!("key derivation failed: {e}")))?;
159
160 Ok(derived.private_key().clone())
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
170
171 fn test_wallet() -> Wallet {
172 Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap()
173 }
174
175 #[test]
176 fn test_derive_address() {
177 let wallet = test_wallet();
178 let deriver = Deriver::new(&wallet);
179 let addr = deriver.derive(0).unwrap();
180
181 assert!(addr.address.starts_with("0x"));
182 assert_eq!(addr.address.len(), 42);
183 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
184 }
185
186 #[test]
187 fn test_derive_multiple() {
188 let wallet = test_wallet();
189 let deriver = Deriver::new(&wallet);
190 let addrs = deriver.derive_many(0, 5).unwrap();
191
192 assert_eq!(addrs.len(), 5);
193
194 let mut seen = Vec::new();
196 for addr in &addrs {
197 assert!(!seen.contains(&addr.address));
198 seen.push(addr.address.clone());
199 }
200 assert_eq!(seen.len(), 5);
201 }
202
203 #[test]
204 fn test_deterministic_derivation() {
205 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
206 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
207
208 let deriver1 = Deriver::new(&wallet1);
209 let deriver2 = Deriver::new(&wallet2);
210
211 let addr1 = deriver1.derive(0).unwrap();
212 let addr2 = deriver2.derive(0).unwrap();
213
214 assert_eq!(addr1.address, addr2.address);
215 }
216
217 #[test]
218 fn test_passphrase_changes_addresses() {
219 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
220 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, Some("password")).unwrap();
221
222 let deriver1 = Deriver::new(&wallet1);
223 let deriver2 = Deriver::new(&wallet2);
224
225 let addr1 = deriver1.derive(0).unwrap();
226 let addr2 = deriver2.derive(0).unwrap();
227
228 assert_ne!(addr1.address, addr2.address);
230 }
231
232 #[test]
233 fn test_derive_with_standard() {
234 let wallet = test_wallet();
235 let deriver = Deriver::new(&wallet);
236
237 let addr = deriver.derive_with(DerivationStyle::Standard, 0).unwrap();
238 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
239
240 let addr = deriver.derive_with(DerivationStyle::Standard, 5).unwrap();
241 assert_eq!(addr.path, "m/44'/60'/0'/0/5");
242 }
243
244 #[test]
245 fn test_derive_with_ledger_live() {
246 let wallet = test_wallet();
247 let deriver = Deriver::new(&wallet);
248
249 let addr = deriver.derive_with(DerivationStyle::LedgerLive, 0).unwrap();
250 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
251
252 let addr = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
253 assert_eq!(addr.path, "m/44'/60'/1'/0/0");
254 }
255
256 #[test]
257 fn test_derive_with_ledger_legacy() {
258 let wallet = test_wallet();
259 let deriver = Deriver::new(&wallet);
260
261 let addr = deriver
262 .derive_with(DerivationStyle::LedgerLegacy, 0)
263 .unwrap();
264 assert_eq!(addr.path, "m/44'/60'/0'/0");
265
266 let addr = deriver
267 .derive_with(DerivationStyle::LedgerLegacy, 3)
268 .unwrap();
269 assert_eq!(addr.path, "m/44'/60'/0'/3");
270 }
271
272 #[test]
273 fn test_different_styles_produce_different_addresses() {
274 let wallet = test_wallet();
275 let deriver = Deriver::new(&wallet);
276
277 let standard = deriver.derive_with(DerivationStyle::Standard, 1).unwrap();
278 let ledger_live = deriver.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
279 let ledger_legacy = deriver
280 .derive_with(DerivationStyle::LedgerLegacy, 1)
281 .unwrap();
282
283 assert_ne!(standard.address, ledger_live.address);
285 assert_ne!(standard.address, ledger_legacy.address);
286 assert_ne!(ledger_live.address, ledger_legacy.address);
287 }
288
289 #[test]
290 fn test_derive_many_with() {
291 let wallet = test_wallet();
292 let deriver = Deriver::new(&wallet);
293
294 let addrs = deriver
295 .derive_many_with(DerivationStyle::LedgerLive, 0, 3)
296 .unwrap();
297
298 assert_eq!(addrs.len(), 3);
299 assert_eq!(addrs[0].path, "m/44'/60'/0'/0/0");
300 assert_eq!(addrs[1].path, "m/44'/60'/1'/0/0");
301 assert_eq!(addrs[2].path, "m/44'/60'/2'/0/0");
302 }
303
304 #[test]
305 fn test_derive_path() {
306 let wallet = test_wallet();
307 let deriver = Deriver::new(&wallet);
308
309 let addr = deriver.derive_path("m/44'/60'/0'/0/0").unwrap();
310 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
311 assert!(addr.address.starts_with("0x"));
312 }
313}