1#[cfg(feature = "alloc")]
4use alloc::{
5 string::{String, ToString},
6 vec::Vec,
7};
8
9use bitcoin::{
10 Address, PrivateKey, PublicKey, bip32::Xpriv, key::CompressedPublicKey, secp256k1::Secp256k1,
11};
12use core::marker::PhantomData;
13use kobe_core::Wallet;
14use zeroize::Zeroizing;
15
16use crate::{AddressType, DerivationPath, Error, Network};
17
18#[derive(Debug)]
23pub struct Deriver<'a> {
24 master_key: Xpriv,
26 network: Network,
28 _wallet: PhantomData<&'a Wallet>,
30}
31
32#[derive(Debug)]
34pub struct DerivedAddress {
35 pub path: DerivationPath,
37 pub private_key_wif: Zeroizing<String>,
39 pub public_key_hex: String,
41 pub address: String,
43 pub address_type: AddressType,
45}
46
47impl<'a> Deriver<'a> {
48 pub fn new(wallet: &'a Wallet, network: Network) -> Result<Self, Error> {
54 let master_key = Xpriv::new_master(network.to_bitcoin_network(), wallet.seed())?;
55
56 Ok(Self {
57 master_key,
58 network,
59 _wallet: PhantomData,
60 })
61 }
62
63 pub fn derive(
76 &self,
77 address_type: AddressType,
78 account: u32,
79 change: bool,
80 address_index: u32,
81 ) -> Result<DerivedAddress, Error> {
82 let path = DerivationPath::bip_standard(
83 address_type,
84 self.network,
85 account,
86 change,
87 address_index,
88 );
89 self.derive_at_path(&path, address_type)
90 }
91
92 pub fn derive_at_path(
98 &self,
99 path: &DerivationPath,
100 address_type: AddressType,
101 ) -> Result<DerivedAddress, Error> {
102 let secp = bitcoin::secp256k1::Secp256k1::new();
103 let derived = self.master_key.derive_priv(&secp, path.inner())?;
104
105 let private_key = PrivateKey::new(derived.private_key, self.network.to_bitcoin_network());
106 let public_key = CompressedPublicKey::from_private_key(&secp, &private_key)
107 .expect("valid private key always produces valid public key");
108
109 let address = Self::create_address(&public_key, self.network, address_type);
110
111 Ok(DerivedAddress {
112 path: path.clone(),
113 private_key_wif: Zeroizing::new(private_key.to_wif()),
114 public_key_hex: public_key.to_string(),
115 address: address.to_string(),
116 address_type,
117 })
118 }
119
120 pub fn derive_many(
126 &self,
127 address_type: AddressType,
128 account: u32,
129 change: bool,
130 start_index: u32,
131 count: u32,
132 ) -> Result<Vec<DerivedAddress>, Error> {
133 (start_index..start_index + count)
134 .map(|index| self.derive(address_type, account, change, index))
135 .collect()
136 }
137
138 fn create_address(
140 public_key: &CompressedPublicKey,
141 network: Network,
142 address_type: AddressType,
143 ) -> Address {
144 let btc_network = network.to_bitcoin_network();
145
146 match address_type {
147 AddressType::P2pkh => Address::p2pkh(PublicKey::from(*public_key), btc_network),
148 AddressType::P2shP2wpkh => Address::p2shwpkh(public_key, btc_network),
149 AddressType::P2wpkh => Address::p2wpkh(public_key, btc_network),
150 AddressType::P2tr => {
151 let secp = Secp256k1::verification_only();
152 let internal_key = public_key.0.x_only_public_key().0;
153 Address::p2tr(&secp, internal_key, None, btc_network)
154 }
155 }
156 }
157
158 #[must_use]
160 pub const fn network(&self) -> Network {
161 self.network
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
170
171 fn test_wallet() -> Wallet {
172 Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap()
173 }
174
175 #[test]
176 fn test_derive_p2wpkh() {
177 let wallet = test_wallet();
178 let deriver = Deriver::new(&wallet, Network::Mainnet).unwrap();
179 let addr = deriver.derive(AddressType::P2wpkh, 0, false, 0).unwrap();
180
181 assert!(addr.address.starts_with("bc1q"));
182 assert_eq!(addr.path.to_string(), "m/84'/0'/0'/0/0");
183 }
184
185 #[test]
186 fn test_derive_p2pkh() {
187 let wallet = test_wallet();
188 let deriver = Deriver::new(&wallet, Network::Mainnet).unwrap();
189 let addr = deriver.derive(AddressType::P2pkh, 0, false, 0).unwrap();
190
191 assert!(addr.address.starts_with('1'));
192 assert_eq!(addr.path.to_string(), "m/44'/0'/0'/0/0");
193 }
194
195 #[test]
196 fn test_derive_p2sh() {
197 let wallet = test_wallet();
198 let deriver = Deriver::new(&wallet, Network::Mainnet).unwrap();
199 let addr = deriver
200 .derive(AddressType::P2shP2wpkh, 0, false, 0)
201 .unwrap();
202
203 assert!(addr.address.starts_with('3'));
204 assert_eq!(addr.path.to_string(), "m/49'/0'/0'/0/0");
205 }
206
207 #[test]
208 fn test_derive_p2tr() {
209 let wallet = test_wallet();
210 let deriver = Deriver::new(&wallet, Network::Mainnet).unwrap();
211 let addr = deriver.derive(AddressType::P2tr, 0, false, 0).unwrap();
212
213 assert!(addr.address.starts_with("bc1p"));
214 assert_eq!(addr.path.to_string(), "m/86'/0'/0'/0/0");
215 }
216
217 #[test]
218 fn test_derive_testnet() {
219 let wallet = test_wallet();
220 let deriver = Deriver::new(&wallet, Network::Testnet).unwrap();
221 let addr = deriver.derive(AddressType::P2wpkh, 0, false, 0).unwrap();
222
223 assert!(addr.address.starts_with("tb1q"));
224 assert_eq!(addr.path.to_string(), "m/84'/1'/0'/0/0");
225 }
226
227 #[test]
228 fn test_derive_multiple() {
229 let wallet = test_wallet();
230 let deriver = Deriver::new(&wallet, Network::Mainnet).unwrap();
231 let addrs = deriver
232 .derive_many(AddressType::P2wpkh, 0, false, 0, 5)
233 .unwrap();
234
235 assert_eq!(addrs.len(), 5);
236
237 let mut seen = alloc::vec::Vec::new();
239 for addr in &addrs {
240 assert!(!seen.contains(&addr.address));
241 seen.push(addr.address.clone());
242 }
243 assert_eq!(seen.len(), 5);
244 }
245
246 #[test]
247 fn test_passphrase_changes_addresses() {
248 let wallet1 = Wallet::from_mnemonic(TEST_MNEMONIC, None).unwrap();
249 let wallet2 = Wallet::from_mnemonic(TEST_MNEMONIC, Some("password")).unwrap();
250
251 let deriver1 = Deriver::new(&wallet1, Network::Mainnet).unwrap();
252 let deriver2 = Deriver::new(&wallet2, Network::Mainnet).unwrap();
253
254 let addr1 = deriver1.derive(AddressType::P2wpkh, 0, false, 0).unwrap();
255 let addr2 = deriver2.derive(AddressType::P2wpkh, 0, false, 0).unwrap();
256
257 assert_ne!(addr1.address, addr2.address);
259 }
260}