1#[cfg(feature = "alloc")]
4use alloc::{
5 format,
6 string::{String, ToString},
7 vec::Vec,
8};
9
10use hmac::{Hmac, Mac};
11use k256::{Scalar, ecdsa::SigningKey, elliptic_curve::PrimeField};
12use kobe_core::Wallet;
13use sha2::Sha512;
14use zeroize::Zeroizing;
15
16use crate::Error;
17use crate::utils::{public_key_to_address, to_checksum_address};
18
19type HmacSha512 = Hmac<Sha512>;
20
21#[derive(Debug)]
26pub struct Deriver<'a> {
27 wallet: &'a Wallet,
29}
30
31#[derive(Debug)]
33pub struct DerivedAddress {
34 pub path: String,
36 pub private_key_hex: Zeroizing<String>,
38 pub public_key_hex: String,
40 pub address: String,
42}
43
44impl<'a> Deriver<'a> {
45 #[must_use]
47 pub const fn new(wallet: &'a Wallet) -> Self {
48 Self { wallet }
49 }
50
51 pub fn derive(
65 &self,
66 account: u32,
67 change: bool,
68 address_index: u32,
69 ) -> Result<DerivedAddress, Error> {
70 let change_val = if change { 1 } else { 0 };
71 let path = format!("m/44'/60'/{account}'/{change_val}/{address_index}");
72 self.derive_at_path(&path)
73 }
74
75 pub fn derive_at_path(&self, path: &str) -> Result<DerivedAddress, Error> {
81 let private_key = self.derive_key(path)?;
82
83 let public_key = private_key.verifying_key();
84 let public_key_bytes = public_key.to_encoded_point(false);
85 let address = public_key_to_address(public_key_bytes.as_bytes());
86
87 Ok(DerivedAddress {
88 path: path.to_string(),
89 private_key_hex: Zeroizing::new(hex::encode(private_key.to_bytes())),
90 public_key_hex: hex::encode(public_key_bytes.as_bytes()),
91 address: to_checksum_address(&address),
92 })
93 }
94
95 pub fn derive_many(
101 &self,
102 account: u32,
103 change: bool,
104 start_index: u32,
105 count: u32,
106 ) -> Result<Vec<DerivedAddress>, Error> {
107 (start_index..start_index + count)
108 .map(|index| self.derive(account, change, index))
109 .collect()
110 }
111
112 fn derive_key(&self, path: &str) -> Result<SigningKey, Error> {
114 let path_str = path.strip_prefix("m/").unwrap_or(path);
116 let indices: Result<Vec<u32>, _> = path_str
117 .split('/')
118 .filter(|s| !s.is_empty())
119 .map(|component| {
120 let (num_str, hardened) = if let Some(s) = component.strip_suffix('\'') {
121 (s, true)
122 } else if let Some(s) = component.strip_suffix('h') {
123 (s, true)
124 } else {
125 (component, false)
126 };
127
128 num_str
129 .parse::<u32>()
130 .map(|n| if hardened { n | 0x8000_0000 } else { n })
131 .map_err(|_| Error::Derivation(format!("invalid path component: {component}")))
132 })
133 .collect();
134
135 let indices = indices?;
136
137 let mut mac =
139 HmacSha512::new_from_slice(b"Bitcoin seed").expect("HMAC can take key of any size");
140 mac.update(self.wallet.seed());
141 let result = mac.finalize().into_bytes();
142
143 let mut key = result[..32].to_vec();
144 let mut chain_code = result[32..].to_vec();
145
146 for index in indices {
148 let mut mac =
149 HmacSha512::new_from_slice(&chain_code).expect("HMAC can take key of any size");
150
151 if index & 0x8000_0000 != 0 {
152 mac.update(&[0u8]);
154 mac.update(&key);
155 } else {
156 let signing_key =
158 SigningKey::from_slice(&key).map_err(|_| Error::InvalidPrivateKey)?;
159 let public_key = signing_key.verifying_key().to_encoded_point(true);
160 mac.update(public_key.as_bytes());
161 }
162 mac.update(&index.to_be_bytes());
163
164 let result = mac.finalize().into_bytes();
165 let il = &result[..32];
166
167 let il_arr: [u8; 32] = (*il)
169 .try_into()
170 .map_err(|_| Error::Derivation("invalid IL length".to_string()))?;
171 let key_arr: [u8; 32] = key
172 .as_slice()
173 .try_into()
174 .map_err(|_| Error::Derivation("invalid key length".to_string()))?;
175
176 let il_scalar = Scalar::from_repr(il_arr.into());
177 let key_scalar = Scalar::from_repr(key_arr.into());
178
179 if il_scalar.is_none().into() || key_scalar.is_none().into() {
180 return Err(Error::Derivation("invalid scalar".to_string()));
181 }
182
183 let new_key = il_scalar.unwrap() + key_scalar.unwrap();
184 key = new_key.to_bytes().to_vec();
185 chain_code = result[32..].to_vec();
186 }
187
188 SigningKey::from_slice(&key).map_err(|_| Error::InvalidPrivateKey)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
197
198 fn test_wallet() -> Wallet {
199 Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap()
200 }
201
202 #[test]
203 fn test_derive_address() {
204 let wallet = test_wallet();
205 let deriver = Deriver::new(&wallet);
206 let addr = deriver.derive(0, false, 0).unwrap();
207
208 assert!(addr.address.starts_with("0x"));
209 assert_eq!(addr.address.len(), 42);
210 assert_eq!(addr.path, "m/44'/60'/0'/0/0");
211 }
212
213 #[test]
214 fn test_derive_multiple() {
215 let wallet = test_wallet();
216 let deriver = Deriver::new(&wallet);
217 let addrs = deriver.derive_many(0, false, 0, 5).unwrap();
218
219 assert_eq!(addrs.len(), 5);
220
221 let mut seen = alloc::vec::Vec::new();
223 for addr in &addrs {
224 assert!(!seen.contains(&addr.address));
225 seen.push(addr.address.clone());
226 }
227 assert_eq!(seen.len(), 5);
228 }
229
230 #[test]
231 fn test_deterministic_derivation() {
232 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
233 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
234
235 let deriver1 = Deriver::new(&wallet1);
236 let deriver2 = Deriver::new(&wallet2);
237
238 let addr1 = deriver1.derive(0, false, 0).unwrap();
239 let addr2 = deriver2.derive(0, false, 0).unwrap();
240
241 assert_eq!(addr1.address, addr2.address);
242 }
243
244 #[test]
245 fn test_passphrase_changes_addresses() {
246 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
247 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, Some("password")).unwrap();
248
249 let deriver1 = Deriver::new(&wallet1);
250 let deriver2 = Deriver::new(&wallet2);
251
252 let addr1 = deriver1.derive(0, false, 0).unwrap();
253 let addr2 = deriver2.derive(0, false, 0).unwrap();
254
255 assert_ne!(addr1.address, addr2.address);
257 }
258}