1#[cfg(feature = "alloc")]
4use alloc::{
5 string::{String, ToString},
6 vec::Vec,
7};
8use core::marker::PhantomData;
9
10use bitcoin::{PrivateKey, bip32::Xpriv, key::CompressedPublicKey, secp256k1::Secp256k1};
11use kobe_core::{Derive, DerivedAccount, Wallet};
12use zeroize::Zeroizing;
13
14use crate::address::create_address;
15use crate::{AddressType, DerivationPath, Error, Network};
16
17#[derive(Debug)]
22pub struct Deriver<'a> {
23 master_key: Xpriv,
25 secp: Secp256k1<bitcoin::secp256k1::All>,
27 network: Network,
29 _wallet: PhantomData<&'a Wallet>,
31}
32
33#[derive(Debug, Clone)]
35#[non_exhaustive]
36pub struct DerivedAddress {
37 pub path: DerivationPath,
39 pub private_key_hex: Zeroizing<String>,
41 pub private_key_wif: Zeroizing<String>,
43 pub public_key_hex: String,
45 pub address: String,
47 pub address_type: AddressType,
49}
50
51impl<'a> Deriver<'a> {
52 #[inline]
58 pub fn new(wallet: &'a Wallet, network: Network) -> Result<Self, Error> {
59 let master_key = Xpriv::new_master(network.to_bitcoin_network(), wallet.seed())?;
60
61 Ok(Self {
62 master_key,
63 secp: Secp256k1::new(),
64 network,
65 _wallet: PhantomData,
66 })
67 }
68
69 #[inline]
81 pub fn derive(&self, index: u32) -> Result<DerivedAddress, Error> {
82 self.derive_with(AddressType::P2wpkh, index)
83 }
84
85 #[inline]
102 pub fn derive_with(
103 &self,
104 address_type: AddressType,
105 index: u32,
106 ) -> Result<DerivedAddress, Error> {
107 let path = DerivationPath::bip_standard(address_type, self.network, 0, false, index)?;
108 self.derive_path(&path, address_type)
109 }
110
111 #[inline]
122 pub fn derive_many(&self, start: u32, count: u32) -> Result<Vec<DerivedAddress>, Error> {
123 self.derive_many_with(AddressType::P2wpkh, start, count)
124 }
125
126 pub fn derive_many_with(
138 &self,
139 address_type: AddressType,
140 start: u32,
141 count: u32,
142 ) -> Result<Vec<DerivedAddress>, Error> {
143 let end = start.checked_add(count).ok_or_else(|| {
144 Error::InvalidDerivationPath("index overflow: start + count exceeds u32::MAX".into())
145 })?;
146 (start..end)
147 .map(|index| self.derive_with(address_type, index))
148 .collect()
149 }
150
151 pub fn derive_path(
170 &self,
171 path: &DerivationPath,
172 address_type: AddressType,
173 ) -> Result<DerivedAddress, Error> {
174 let derived = self.master_key.derive_priv(&self.secp, path.inner())?;
175
176 let private_key = PrivateKey::new(derived.private_key, self.network.to_bitcoin_network());
177 let public_key = CompressedPublicKey::from_private_key(&self.secp, &private_key)
178 .map_err(|_| Error::InvalidPrivateKey)?;
179
180 let address = create_address(&public_key, self.network, address_type);
181
182 let private_key_bytes = Zeroizing::new(derived.private_key.secret_bytes());
183
184 Ok(DerivedAddress {
185 path: path.clone(),
186 private_key_hex: Zeroizing::new(hex::encode(private_key_bytes)),
187 private_key_wif: Zeroizing::new(private_key.to_wif()),
188 public_key_hex: public_key.to_string(),
189 address: address.to_string(),
190 address_type,
191 })
192 }
193
194 #[must_use]
196 pub const fn network(&self) -> Network {
197 self.network
198 }
199
200 fn derive_account_at_path(&self, path_str: &str) -> Result<DerivedAccount, Error> {
202 let path = DerivationPath::from_path_str(path_str)?;
203 let da = self.derive_path(&path, AddressType::P2wpkh)?;
204 Ok(DerivedAccount::new(
205 da.path.to_string(),
206 da.private_key_hex,
207 da.public_key_hex,
208 da.address,
209 ))
210 }
211}
212
213impl Derive for Deriver<'_> {
214 type Error = Error;
215
216 fn derive(&self, index: u32) -> Result<DerivedAccount, Error> {
217 let da = self.derive_with(AddressType::P2wpkh, index)?;
218 Ok(DerivedAccount::new(
219 da.path.to_string(),
220 da.private_key_hex,
221 da.public_key_hex,
222 da.address,
223 ))
224 }
225
226 fn derive_path(&self, path: &str) -> Result<DerivedAccount, Error> {
227 self.derive_account_at_path(path)
228 }
229}
230
231#[cfg(test)]
232#[allow(clippy::unwrap_used)]
233mod tests {
234 use super::*;
235
236 const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
237
238 fn test_wallet() -> Wallet {
239 Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap()
240 }
241
242 #[test]
243 fn kat_p2wpkh() {
244 let wallet = test_wallet();
245 let deriver = Deriver::new(&wallet, Network::Mainnet).unwrap();
246 let addr = deriver.derive_with(AddressType::P2wpkh, 0).unwrap();
247 assert_eq!(addr.address, "bc1qcr8te4kr609gcawutmrza0j4xv80jy8z306fyu");
248 assert_eq!(addr.path.to_string(), "m/84'/0'/0'/0/0");
249 }
250
251 #[test]
252 fn kat_p2pkh() {
253 let wallet = test_wallet();
254 let deriver = Deriver::new(&wallet, Network::Mainnet).unwrap();
255 let addr = deriver.derive_with(AddressType::P2pkh, 0).unwrap();
256 assert_eq!(addr.address, "1LqBGSKuX5yYUonjxT5qGfpUsXKYYWeabA");
257 assert_eq!(addr.path.to_string(), "m/44'/0'/0'/0/0");
258 }
259
260 #[test]
261 fn kat_p2sh_p2wpkh() {
262 let wallet = test_wallet();
263 let deriver = Deriver::new(&wallet, Network::Mainnet).unwrap();
264 let addr = deriver.derive_with(AddressType::P2shP2wpkh, 0).unwrap();
265 assert_eq!(addr.address, "37VucYSaXLCAsxYyAPfbSi9eh4iEcbShgf");
266 assert_eq!(addr.path.to_string(), "m/49'/0'/0'/0/0");
267 }
268
269 #[test]
270 fn p2tr_prefix() {
271 let wallet = test_wallet();
272 let deriver = Deriver::new(&wallet, Network::Mainnet).unwrap();
273 let addr = deriver.derive_with(AddressType::P2tr, 0).unwrap();
274 assert!(addr.address.starts_with("bc1p"));
275 assert_eq!(addr.path.to_string(), "m/86'/0'/0'/0/0");
276 }
277
278 #[test]
279 fn derive_default_is_p2wpkh() {
280 let wallet = test_wallet();
281 let deriver = Deriver::new(&wallet, Network::Mainnet).unwrap();
282 let def = deriver.derive(0).unwrap();
283 let explicit = deriver.derive_with(AddressType::P2wpkh, 0).unwrap();
284 assert_eq!(def.address, explicit.address);
285 }
286
287 #[test]
288 fn testnet_prefix() {
289 let wallet = test_wallet();
290 let deriver = Deriver::new(&wallet, Network::Testnet).unwrap();
291 let addr = deriver.derive(0).unwrap();
292 assert!(addr.address.starts_with("tb1q"));
293 assert_eq!(addr.path.to_string(), "m/84'/1'/0'/0/0");
294 }
295
296 #[test]
297 fn derive_many_unique() {
298 let wallet = test_wallet();
299 let deriver = Deriver::new(&wallet, Network::Mainnet).unwrap();
300 let addrs = deriver.derive_many(0, 5).unwrap();
301 assert_eq!(addrs.len(), 5);
302 let set: std::collections::HashSet<_> = addrs.iter().map(|a| &a.address).collect();
303 assert_eq!(set.len(), 5);
304 }
305
306 #[test]
307 fn passphrase_changes_addresses() {
308 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
309 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, Some("password")).unwrap();
310 let d1 = Deriver::new(&wallet1, Network::Mainnet).unwrap();
311 let d2 = Deriver::new(&wallet2, Network::Mainnet).unwrap();
312 assert_ne!(d1.derive(0).unwrap().address, d2.derive(0).unwrap().address);
313 }
314}