1use alloc::format;
4use alloc::string::String;
5use alloc::vec::Vec;
6use core::fmt;
7use core::str::FromStr;
8
9use alloy_primitives::{Address, keccak256};
10use kobe_primitives::{
11 DerivationStyle as _,
15 Derive,
16 DeriveError,
17 DerivedAccount,
18 DerivedPublicKey,
19 ParseDerivationStyleError,
20 Wallet,
21 derive_range,
22};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
35#[non_exhaustive]
36pub enum DerivationStyle {
37 #[default]
39 Standard,
40 LedgerLive,
42 LedgerLegacy,
44}
45
46const ALL_STYLES: &[DerivationStyle] = &[
49 DerivationStyle::Standard,
50 DerivationStyle::LedgerLive,
51 DerivationStyle::LedgerLegacy,
52];
53
54const ACCEPTED_TOKENS: &[&str] = &[
57 "standard",
58 "metamask",
59 "trezor",
60 "bip44",
61 "ledger-live",
62 "ledgerlive",
63 "live",
64 "ledger-legacy",
65 "ledgerlegacy",
66 "legacy",
67 "mew",
68];
69
70impl kobe_primitives::DerivationStyle for DerivationStyle {
71 fn path(self, index: u32) -> String {
72 match self {
73 Self::Standard => format!("m/44'/60'/0'/0/{index}"),
74 Self::LedgerLive => format!("m/44'/60'/{index}'/0/0"),
75 Self::LedgerLegacy => format!("m/44'/60'/0'/{index}"),
76 }
77 }
78
79 fn name(self) -> &'static str {
80 match self {
81 Self::Standard => "Standard (MetaMask/Trezor)",
82 Self::LedgerLive => "Ledger Live",
83 Self::LedgerLegacy => "Ledger Legacy",
84 }
85 }
86
87 fn all() -> &'static [Self] {
88 ALL_STYLES
89 }
90}
91
92impl fmt::Display for DerivationStyle {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 f.write_str(<Self as kobe_primitives::DerivationStyle>::name(*self))
95 }
96}
97
98impl FromStr for DerivationStyle {
99 type Err = ParseDerivationStyleError;
100
101 fn from_str(s: &str) -> Result<Self, Self::Err> {
102 match s.to_lowercase().as_str() {
103 "standard" | "metamask" | "trezor" | "bip44" => Ok(Self::Standard),
104 "ledger-live" | "ledgerlive" | "live" => Ok(Self::LedgerLive),
105 "ledger-legacy" | "ledgerlegacy" | "legacy" | "mew" => Ok(Self::LedgerLegacy),
106 _ => Err(ParseDerivationStyleError::new(
107 "ethereum",
108 s,
109 ACCEPTED_TOKENS,
110 )),
111 }
112 }
113}
114
115#[derive(Debug)]
117pub struct Deriver<'a> {
118 wallet: &'a Wallet,
120}
121
122impl<'a> Deriver<'a> {
123 #[must_use]
125 pub const fn new(wallet: &'a Wallet) -> Self {
126 Self { wallet }
127 }
128
129 pub fn derive_with(
135 &self,
136 style: DerivationStyle,
137 index: u32,
138 ) -> Result<DerivedAccount, DeriveError> {
139 self.derive_at(&style.path(index))
140 }
141
142 pub fn derive_many_with(
148 &self,
149 style: DerivationStyle,
150 start: u32,
151 count: u32,
152 ) -> Result<Vec<DerivedAccount>, DeriveError> {
153 derive_range(start, count, |i| self.derive_with(style, i))
154 }
155
156 pub fn derive_at(&self, path: &str) -> Result<DerivedAccount, DeriveError> {
162 let key = self.wallet.derive_secp256k1(path)?;
163 let uncompressed = key.uncompressed_pubkey();
164
165 let addr_hash = keccak256(&uncompressed[1..]);
166 let (_, addr_bytes) = addr_hash.split_at(12);
167 let address = Address::from_slice(addr_bytes);
168
169 Ok(DerivedAccount::new(
170 String::from(path),
171 key.private_key_bytes(),
172 DerivedPublicKey::Secp256k1Uncompressed(uncompressed),
173 address.to_checksum(None),
174 ))
175 }
176}
177
178impl Derive for Deriver<'_> {
179 type Account = DerivedAccount;
180 type Error = DeriveError;
181
182 fn derive(&self, index: u32) -> Result<DerivedAccount, DeriveError> {
183 self.derive_with(DerivationStyle::Standard, index)
184 }
185
186 fn derive_path(&self, path: &str) -> Result<DerivedAccount, DeriveError> {
187 self.derive_at(path)
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use kobe_primitives::DeriveExt;
194
195 use super::*;
196
197 const MNEMONIC_ABANDON: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
199
200 const MNEMONIC_HARDHAT: &str = "test test test test test test test test test test test junk";
207
208 fn wallet() -> Wallet {
209 Wallet::from_mnemonic(MNEMONIC_ABANDON, None).unwrap()
210 }
211
212 #[test]
217 fn kat_evm_hardhat_default_index0() {
218 let w = Wallet::from_mnemonic(MNEMONIC_HARDHAT, None).unwrap();
219 let a = Deriver::new(&w).derive(0).unwrap();
220 assert_eq!(a.path(), "m/44'/60'/0'/0/0");
221 assert_eq!(a.address(), "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266");
222 assert_eq!(
223 a.private_key_hex().as_str(),
224 "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80"
225 );
226 }
227
228 #[test]
229 fn kat_evm_hardhat_default_index1() {
230 let w = Wallet::from_mnemonic(MNEMONIC_HARDHAT, None).unwrap();
231 let a = Deriver::new(&w).derive(1).unwrap();
232 assert_eq!(a.path(), "m/44'/60'/0'/0/1");
233 assert_eq!(a.address(), "0x70997970C51812dc3A010C7d01b50e0d17dc79C8");
234 assert_eq!(
235 a.private_key_hex().as_str(),
236 "59c6995e998f97a5a0044966f0945389dc9e86dae88c7a8412f4603b6b78690d"
237 );
238 }
239
240 #[test]
244 fn kat_evm_abandon_index0() {
245 let a = Deriver::new(&wallet()).derive(0).unwrap();
246 assert_eq!(a.path(), "m/44'/60'/0'/0/0");
247 assert_eq!(a.address(), "0x9858EfFD232B4033E47d90003D41EC34EcaEda94");
248 assert_eq!(
249 a.private_key_hex().as_str(),
250 "1ab42cc412b618bdea3a599e3c9bae199ebf030895b039e9db1e30dafb12b727"
251 );
252 }
253
254 #[test]
255 fn kat_evm_abandon_index1() {
256 let a = Deriver::new(&wallet()).derive(1).unwrap();
257 assert_eq!(a.path(), "m/44'/60'/0'/0/1");
258 assert_eq!(a.address(), "0x6Fac4D18c912343BF86fa7049364Dd4E424Ab9C0");
259 assert_eq!(
260 a.private_key_hex().as_str(),
261 "9a983cb3d832fbde5ab49d692b7a8bf5b5d232479c99333d0fc8e1d21f1b55b6"
262 );
263 }
264
265 #[test]
268 fn derivation_styles_produce_distinct_addresses() {
269 let w = wallet();
270 let d = Deriver::new(&w);
271 let standard = d.derive_with(DerivationStyle::Standard, 1).unwrap();
272 let live = d.derive_with(DerivationStyle::LedgerLive, 1).unwrap();
273 let legacy = d.derive_with(DerivationStyle::LedgerLegacy, 1).unwrap();
274 assert_eq!(standard.path(), "m/44'/60'/0'/0/1");
275 assert_eq!(live.path(), "m/44'/60'/1'/0/0");
276 assert_eq!(legacy.path(), "m/44'/60'/0'/1");
277 assert_ne!(standard.address(), live.address());
278 assert_ne!(standard.address(), legacy.address());
279 assert_ne!(live.address(), legacy.address());
280 }
281
282 #[test]
284 fn derivation_style_path_shapes() {
285 assert_eq!(DerivationStyle::Standard.path(0), "m/44'/60'/0'/0/0");
286 assert_eq!(DerivationStyle::LedgerLive.path(1), "m/44'/60'/1'/0/0");
287 assert_eq!(DerivationStyle::LedgerLegacy.path(2), "m/44'/60'/0'/2");
288 }
289
290 #[test]
292 fn derivation_style_from_str_accepts_aliases() {
293 assert_eq!(
294 "standard".parse::<DerivationStyle>().unwrap(),
295 DerivationStyle::Standard
296 );
297 assert_eq!(
298 "metamask".parse::<DerivationStyle>().unwrap(),
299 DerivationStyle::Standard
300 );
301 assert_eq!(
302 "ledger-live".parse::<DerivationStyle>().unwrap(),
303 DerivationStyle::LedgerLive
304 );
305 assert_eq!(
306 "legacy".parse::<DerivationStyle>().unwrap(),
307 DerivationStyle::LedgerLegacy
308 );
309 assert!("definitely-not-a-style".parse::<DerivationStyle>().is_err());
310 }
311
312 #[test]
314 fn derive_many_matches_individual() {
315 let w = wallet();
316 let d = Deriver::new(&w);
317 let batch = d.derive_many(0, 5).unwrap();
318 let single: Vec<_> = (0..5).map(|i| d.derive(i).unwrap()).collect();
319 for (b, s) in batch.iter().zip(single.iter()) {
320 assert_eq!(b.address(), s.address());
321 assert_eq!(b.path(), s.path());
322 }
323 }
324
325 #[test]
326 fn passphrase_changes_derivation() {
327 let w = Wallet::from_mnemonic(MNEMONIC_ABANDON, Some("TREZOR")).unwrap();
328 assert_ne!(
329 Deriver::new(&wallet()).derive(0).unwrap().address(),
330 Deriver::new(&w).derive(0).unwrap().address(),
331 );
332 }
333}