1use alloc::string::{String, ToString};
4use alloc::vec::Vec;
5
6use ed25519_dalek::VerifyingKey;
7use kobe_core::slip10::DerivedKey;
8use kobe_core::{Derive, DerivedAccount, Wallet};
9use zeroize::Zeroizing;
10
11use crate::Error;
12use crate::derivation_style::DerivationStyle;
13
14#[derive(Debug, Clone)]
16#[non_exhaustive]
17pub struct DerivedAddress {
18 pub path: String,
20 pub private_key_hex: Zeroizing<String>,
22 pub keypair_base58: Zeroizing<String>,
26 pub public_key_hex: String,
28 pub address: String,
30}
31
32#[derive(Debug)]
37pub struct Deriver<'a> {
38 wallet: &'a Wallet,
40}
41
42impl<'a> Deriver<'a> {
43 #[inline]
45 #[must_use]
46 pub const fn new(wallet: &'a Wallet) -> Self {
47 Self { wallet }
48 }
49
50 #[inline]
62 pub fn derive(&self, index: u32) -> Result<DerivedAddress, Error> {
63 self.derive_with(DerivationStyle::Standard, index)
64 }
65
66 #[allow(deprecated)]
83 pub fn derive_with(&self, style: DerivationStyle, index: u32) -> Result<DerivedAddress, Error> {
84 let path = style.path(index);
85 let derived = DerivedKey::derive_path(self.wallet.seed(), &path)?;
86 Ok(build_derived_address(&derived, path))
87 }
88
89 #[inline]
100 pub fn derive_many(&self, start: u32, count: u32) -> Result<Vec<DerivedAddress>, Error> {
101 self.derive_many_with(DerivationStyle::Standard, start, count)
102 }
103
104 pub fn derive_many_with(
116 &self,
117 style: DerivationStyle,
118 start: u32,
119 count: u32,
120 ) -> Result<Vec<DerivedAddress>, Error> {
121 let end = start
122 .checked_add(count)
123 .ok_or(kobe_core::Error::IndexOverflow)?;
124 (start..end)
125 .map(|index| self.derive_with(style, index))
126 .collect()
127 }
128
129 pub fn derive_path(&self, path: &str) -> Result<DerivedAddress, Error> {
145 let derived = DerivedKey::derive_path(self.wallet.seed(), path)?;
146 Ok(build_derived_address(&derived, path.to_string()))
147 }
148}
149
150impl Derive for Deriver<'_> {
151 type Error = Error;
152
153 fn derive(&self, index: u32) -> Result<DerivedAccount, Error> {
154 let da = self.derive_with(DerivationStyle::Standard, index)?;
155 Ok(DerivedAccount::new(
156 da.path,
157 da.private_key_hex,
158 da.public_key_hex,
159 da.address,
160 ))
161 }
162
163 fn derive_path(&self, path: &str) -> Result<DerivedAccount, Error> {
164 let da = Deriver::derive_path(self, path)?;
165 Ok(DerivedAccount::new(
166 da.path,
167 da.private_key_hex,
168 da.public_key_hex,
169 da.address,
170 ))
171 }
172}
173
174fn build_derived_address(derived: &DerivedKey, path: String) -> DerivedAddress {
176 let signing_key = derived.to_signing_key();
177 let verifying_key: VerifyingKey = signing_key.verifying_key();
178 let public_key_bytes = verifying_key.as_bytes();
179
180 let mut keypair_bytes = Zeroizing::new([0u8; 64]);
181 keypair_bytes[..32].copy_from_slice(derived.private_key.as_slice());
182 keypair_bytes[32..].copy_from_slice(public_key_bytes);
183 let keypair_b58 = bs58::encode(&*keypair_bytes).into_string();
184
185 DerivedAddress {
186 path,
187 private_key_hex: Zeroizing::new(hex::encode(derived.private_key.as_slice())),
188 keypair_base58: Zeroizing::new(keypair_b58),
189 public_key_hex: hex::encode(public_key_bytes),
190 address: bs58::encode(public_key_bytes).into_string(),
191 }
192}
193
194#[cfg(test)]
195#[allow(deprecated, clippy::unwrap_used)]
196mod tests {
197 use super::*;
198
199 fn test_wallet() -> Wallet {
200 Wallet::from_mnemonic(
201 "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about",
202 None,
203 )
204 .unwrap()
205 }
206
207 #[test]
208 fn test_derive_address() {
209 let wallet = test_wallet();
210 let deriver = Deriver::new(&wallet);
211 let addr = deriver.derive(0).unwrap();
212
213 assert!(addr.address.len() >= 32 && addr.address.len() <= 44);
215 assert_eq!(addr.path, "m/44'/501'/0'/0'");
216 }
217
218 #[test]
219 fn test_derive_many() {
220 let wallet = test_wallet();
221 let deriver = Deriver::new(&wallet);
222 let addresses = deriver.derive_many(0, 3).unwrap();
223
224 assert_eq!(addresses.len(), 3);
225 assert_eq!(addresses[0].path, "m/44'/501'/0'/0'");
226 assert_eq!(addresses[1].path, "m/44'/501'/1'/0'");
227 assert_eq!(addresses[2].path, "m/44'/501'/2'/0'");
228
229 assert_ne!(addresses[0].address, addresses[1].address);
231 assert_ne!(addresses[1].address, addresses[2].address);
232 }
233
234 #[test]
235 fn test_deterministic_derivation() {
236 let wallet = test_wallet();
237 let deriver = Deriver::new(&wallet);
238
239 let addr1 = deriver.derive(0).unwrap();
240 let addr2 = deriver.derive(0).unwrap();
241
242 assert_eq!(addr1.address, addr2.address);
243 assert_eq!(*addr1.private_key_hex, *addr2.private_key_hex);
244 }
245
246 #[test]
247 fn test_derive_with_trust() {
248 let wallet = test_wallet();
249 let deriver = Deriver::new(&wallet);
250 let addr = deriver.derive_with(DerivationStyle::Trust, 0).unwrap();
251
252 assert_eq!(addr.path, "m/44'/501'/0'");
253 assert!(addr.address.len() >= 32 && addr.address.len() <= 44);
254 }
255
256 #[test]
257 fn test_derive_with_ledger_live() {
258 let wallet = test_wallet();
259 let deriver = Deriver::new(&wallet);
260 let addr = deriver.derive_with(DerivationStyle::LedgerLive, 0).unwrap();
261
262 assert_eq!(addr.path, "m/44'/501'/0'/0'/0'");
263 assert!(addr.address.len() >= 32 && addr.address.len() <= 44);
264 }
265
266 #[test]
267 fn test_different_styles_produce_different_addresses() {
268 let wallet = test_wallet();
269 let deriver = Deriver::new(&wallet);
270
271 let standard = deriver.derive_with(DerivationStyle::Standard, 0).unwrap();
272 let trust = deriver.derive_with(DerivationStyle::Trust, 0).unwrap();
273 let ledger_live = deriver.derive_with(DerivationStyle::LedgerLive, 0).unwrap();
274 let legacy = deriver.derive_with(DerivationStyle::Legacy, 0).unwrap();
275
276 assert_ne!(standard.address, trust.address);
278 assert_ne!(standard.address, ledger_live.address);
279 assert_ne!(standard.address, legacy.address);
280 assert_ne!(trust.address, ledger_live.address);
281 }
282
283 #[test]
284 fn kat_solana_standard_index0() {
285 let wallet = test_wallet();
287 let addr = Deriver::new(&wallet).derive(0).unwrap();
288 assert_eq!(addr.address, "HAgk14JpMQLgt6rVgv7cBQFJWFto5Dqxi472uT3DKpqk");
289 }
290}