1use alloc::{
4 format,
5 string::{String, ToString},
6 vec::Vec,
7};
8use core::fmt;
9use core::str::FromStr;
10
11use alloy_primitives::{Address, keccak256};
12pub use kobe_core::DerivedAccount;
13use kobe_core::{Derive, Wallet};
14
15use crate::Error;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
22#[non_exhaustive]
23pub enum DerivationStyle {
24 #[default]
26 Standard,
27 LedgerLive,
29 LedgerLegacy,
31}
32
33impl DerivationStyle {
34 #[must_use]
36 pub fn path(self, index: u32) -> String {
37 match self {
38 Self::Standard => format!("m/44'/60'/0'/0/{index}"),
39 Self::LedgerLive => format!("m/44'/60'/{index}'/0/0"),
40 Self::LedgerLegacy => format!("m/44'/60'/0'/{index}"),
41 }
42 }
43
44 #[must_use]
46 pub const fn name(self) -> &'static str {
47 match self {
48 Self::Standard => "Standard (MetaMask/Trezor)",
49 Self::LedgerLive => "Ledger Live",
50 Self::LedgerLegacy => "Ledger Legacy",
51 }
52 }
53}
54
55impl fmt::Display for DerivationStyle {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 f.write_str(self.name())
58 }
59}
60
61impl FromStr for DerivationStyle {
62 type Err = Error;
63
64 fn from_str(s: &str) -> Result<Self, Self::Err> {
65 match s.to_lowercase().as_str() {
66 "standard" | "metamask" | "trezor" | "bip44" => Ok(Self::Standard),
67 "ledger-live" | "ledgerlive" | "live" => Ok(Self::LedgerLive),
68 "ledger-legacy" | "ledgerlegacy" | "legacy" | "mew" => Ok(Self::LedgerLegacy),
69 _ => Err(
70 kobe_core::Error::Bip32Derivation(format!("unknown derivation style: {s}")).into(),
71 ),
72 }
73 }
74}
75
76#[derive(Debug)]
78pub struct Deriver<'a> {
79 wallet: &'a Wallet,
81}
82
83impl<'a> Deriver<'a> {
84 #[must_use]
86 pub const fn new(wallet: &'a Wallet) -> Self {
87 Self { wallet }
88 }
89
90 pub fn derive_with(&self, style: DerivationStyle, index: u32) -> Result<DerivedAccount, Error> {
92 self.derive_at_path(&style.path(index))
93 }
94
95 pub fn derive_many_with(
97 &self,
98 style: DerivationStyle,
99 start: u32,
100 count: u32,
101 ) -> Result<Vec<DerivedAccount>, Error> {
102 let end = start
103 .checked_add(count)
104 .ok_or(kobe_core::Error::IndexOverflow)?;
105 (start..end).map(|i| self.derive_with(style, i)).collect()
106 }
107
108 fn derive_at_path(&self, path: &str) -> Result<DerivedAccount, Error> {
110 let key = kobe_core::bip32::DerivedSecp256k1Key::derive(self.wallet.seed(), path)?;
111 let uncompressed = key.uncompressed_pubkey();
112
113 let addr_hash = keccak256(&uncompressed[1..]);
114 let address = Address::from_slice(&addr_hash[12..]);
115
116 Ok(DerivedAccount::new(
117 path.to_string(),
118 key.private_key_hex(),
119 key.uncompressed_pubkey_hex(),
120 address.to_checksum(None),
121 ))
122 }
123}
124
125impl Derive for Deriver<'_> {
126 type Error = Error;
127
128 fn derive(&self, index: u32) -> Result<DerivedAccount, Error> {
129 self.derive_with(DerivationStyle::Standard, index)
130 }
131
132 fn derive_path(&self, path: &str) -> Result<DerivedAccount, Error> {
133 self.derive_at_path(path)
134 }
135}
136
137#[cfg(test)]
138#[allow(clippy::unwrap_used)]
139mod tests {
140 use kobe_core::DeriveExt;
141
142 use super::*;
143
144 const MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
145
146 fn wallet() -> Wallet {
147 Wallet::from_mnemonic(MNEMONIC, None).unwrap()
148 }
149
150 #[test]
151 fn derive_standard_address() {
152 let w = wallet();
153 let d = Deriver::new(&w);
154 let a = d.derive(0).unwrap();
155 assert!(a.address.starts_with("0x"));
156 assert_eq!(a.address.len(), 42);
157 assert_eq!(a.path, "m/44'/60'/0'/0/0");
158 }
159
160 #[test]
161 fn deterministic() {
162 let w = wallet();
163 let a = Deriver::new(&w).derive(0).unwrap();
164 let b = Deriver::new(&w).derive(0).unwrap();
165 assert_eq!(a.address, b.address);
166 }
167
168 #[test]
169 fn different_indices() {
170 let w = wallet();
171 let d = Deriver::new(&w);
172 assert_ne!(d.derive(0).unwrap().address, d.derive(1).unwrap().address);
173 }
174
175 #[test]
176 fn passphrase_changes_address() {
177 let w1 = Wallet::from_mnemonic(MNEMONIC, None).unwrap();
178 let w2 = Wallet::from_mnemonic(MNEMONIC, Some("pass")).unwrap();
179 assert_ne!(
180 Deriver::new(&w1).derive(0).unwrap().address,
181 Deriver::new(&w2).derive(0).unwrap().address,
182 );
183 }
184
185 #[test]
186 fn derivation_styles_produce_different_addresses() {
187 let w = wallet();
188 let d = Deriver::new(&w);
189 let standard = d.derive_with(DerivationStyle::Standard, 1).unwrap();
190 let live = d.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
191 let legacy = d.derive_with(DerivationStyle::LedgerLegacy, 1).unwrap();
192 assert_ne!(standard.address, live.address);
193 assert_ne!(standard.address, legacy.address);
194 assert_ne!(live.address, legacy.address);
195 }
196
197 #[test]
198 fn style_paths() {
199 assert_eq!(DerivationStyle::Standard.path(0), "m/44'/60'/0'/0/0");
200 assert_eq!(DerivationStyle::LedgerLive.path(1), "m/44'/60'/1'/0/0");
201 assert_eq!(DerivationStyle::LedgerLegacy.path(2), "m/44'/60'/0'/2");
202 }
203
204 #[test]
205 fn style_from_str() {
206 assert_eq!(
207 "standard".parse::<DerivationStyle>().unwrap(),
208 DerivationStyle::Standard
209 );
210 assert_eq!(
211 "metamask".parse::<DerivationStyle>().unwrap(),
212 DerivationStyle::Standard
213 );
214 assert_eq!(
215 "ledger-live".parse::<DerivationStyle>().unwrap(),
216 DerivationStyle::LedgerLive
217 );
218 assert_eq!(
219 "legacy".parse::<DerivationStyle>().unwrap(),
220 DerivationStyle::LedgerLegacy
221 );
222 assert!("bad".parse::<DerivationStyle>().is_err());
223 }
224
225 #[test]
226 fn derive_many_returns_correct_count() {
227 let w = wallet();
228 let d = Deriver::new(&w);
229 let accounts = d.derive_many(0, 5).unwrap();
230 assert_eq!(accounts.len(), 5);
231 for (i, a) in accounts.iter().enumerate() {
232 assert_eq!(a.path, format!("m/44'/60'/0'/0/{i}"));
233 }
234 }
235
236 #[test]
237 fn derive_path_custom() {
238 let w = wallet();
239 let d = Deriver::new(&w);
240 let a = d.derive_path("m/44'/60'/0'/0/42").unwrap();
241 assert_eq!(a.path, "m/44'/60'/0'/0/42");
242 assert!(a.address.starts_with("0x"));
243 }
244
245 #[test]
246 fn eip55_checksum_via_alloy() {
247 let addr: Address = "0x5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed"
248 .parse()
249 .unwrap();
250 assert_eq!(
251 addr.to_checksum(None),
252 "0x5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed"
253 );
254 }
255
256 #[test]
257 fn kat_evm_standard_index0() {
258 let w = wallet();
260 let a = Deriver::new(&w).derive(0).unwrap();
261 assert_eq!(a.address, "0x9858EfFD232B4033E47d90003D41EC34EcaEda94");
262 assert_eq!(
263 a.private_key.as_str(),
264 "1ab42cc412b618bdea3a599e3c9bae199ebf030895b039e9db1e30dafb12b727"
265 );
266 }
267
268 #[test]
269 fn kat_evm_standard_index1() {
270 let w = wallet();
271 let a = Deriver::new(&w).derive(1).unwrap();
272 assert_eq!(a.address, "0x6Fac4D18c912343BF86fa7049364Dd4E424Ab9C0");
273 }
274}