1use std::collections::HashMap;
2
3use alloy::primitives::B256;
4use alloy::signers::local::coins_bip39::{English, Mnemonic};
5use alloy::signers::local::PrivateKeySigner;
6use alloy::signers::SignerSync;
7use serde::{Deserialize, Serialize};
8
9use crate::{Keyring, KeyringAccount, KeyringError};
10
11const DEFAULT_HD_PATH: &str = "m/44'/60'/0'/0";
12
13#[derive(Serialize, Deserialize)]
15struct HdKeyringState {
16 mnemonic: String,
17 number_of_accounts: usize,
18 hd_path: String,
19}
20
21pub struct HdKeyring {
22 mnemonic: String,
23 hd_path: String,
24 number_of_accounts: usize,
25 accounts: Vec<(String, PrivateKeySigner)>,
27 address_index: HashMap<String, usize>,
29}
30
31impl HdKeyring {
32 pub fn new(word_count: usize) -> Result<Self, KeyringError> {
36 let mut rng = rand::thread_rng();
37 let phrase = match word_count {
38 12 => Mnemonic::<English>::new_with_count(&mut rng, 12)
39 .map_err(|e| KeyringError::InvalidKey(format!("failed to generate mnemonic: {e}")))?
40 .to_phrase(),
41 24 => Mnemonic::<English>::new_with_count(&mut rng, 24)
42 .map_err(|e| KeyringError::InvalidKey(format!("failed to generate mnemonic: {e}")))?
43 .to_phrase(),
44 _ => {
45 return Err(KeyringError::InvalidKey(
46 "word_count must be 12 or 24".to_string(),
47 ))
48 }
49 };
50 Self::from_mnemonic(&phrase, None)
51 }
52
53 pub fn from_mnemonic(mnemonic: &str, hd_path: Option<&str>) -> Result<Self, KeyringError> {
55 motosan_wallet_core::mnemonic_to_signer(mnemonic, 0)
57 .map_err(|e| KeyringError::InvalidKey(format!("invalid mnemonic: {e}")))?;
58
59 let path = hd_path.unwrap_or(DEFAULT_HD_PATH).to_string();
60
61 Ok(Self {
62 mnemonic: mnemonic.to_string(),
63 hd_path: path,
64 number_of_accounts: 0,
65 accounts: Vec::new(),
66 address_index: HashMap::new(),
67 })
68 }
69
70 pub fn mnemonic(&self) -> &str {
72 &self.mnemonic
73 }
74
75 pub fn hd_path(&self) -> &str {
77 &self.hd_path
78 }
79
80 pub fn derive_accounts(&mut self, n: usize) -> Result<Vec<KeyringAccount>, KeyringError> {
82 let start = self.number_of_accounts;
83 let mut new_accounts = Vec::with_capacity(n);
84
85 for i in start..(start + n) {
86 let signer = motosan_wallet_core::mnemonic_to_signer(&self.mnemonic, i as u32)
87 .map_err(|e| {
88 KeyringError::InvalidKey(format!(
89 "failed to derive key at {}/{i}: {e}",
90 self.hd_path
91 ))
92 })?;
93
94 let address = format!("0x{}", hex::encode(signer.address().as_slice()));
95
96 let idx = self.accounts.len();
97 self.accounts.push((address.clone(), signer));
98 self.address_index.insert(address.clone(), idx);
99
100 new_accounts.push(KeyringAccount {
101 address,
102 label: None,
103 });
104 }
105
106 self.number_of_accounts = start + n;
107 Ok(new_accounts)
108 }
109}
110
111impl Keyring for HdKeyring {
112 fn keyring_type(&self) -> &str {
113 "hd"
114 }
115
116 fn serialize(&self) -> Result<Vec<u8>, KeyringError> {
117 let state = HdKeyringState {
118 mnemonic: self.mnemonic.clone(),
119 number_of_accounts: self.number_of_accounts,
120 hd_path: self.hd_path.clone(),
121 };
122 serde_json::to_vec(&state).map_err(|e| KeyringError::SerializationError(e.to_string()))
123 }
124
125 fn deserialize(data: &[u8]) -> Result<Self, KeyringError> {
126 let state: HdKeyringState = serde_json::from_slice(data)
127 .map_err(|e| KeyringError::SerializationError(e.to_string()))?;
128 let mut keyring = Self::from_mnemonic(&state.mnemonic, Some(&state.hd_path))?;
129 if state.number_of_accounts > 0 {
130 keyring.derive_accounts(state.number_of_accounts)?;
131 }
132 Ok(keyring)
133 }
134
135 fn add_accounts(
138 &mut self,
139 private_keys: &[String],
140 ) -> Result<Vec<KeyringAccount>, KeyringError> {
141 self.derive_accounts(private_keys.len())
142 }
143
144 fn get_accounts(&self) -> Vec<KeyringAccount> {
145 self.accounts
146 .iter()
147 .map(|(addr, _)| KeyringAccount {
148 address: addr.clone(),
149 label: None,
150 })
151 .collect()
152 }
153
154 fn export_account(&self, address: &str) -> Result<String, KeyringError> {
155 let addr = address.to_lowercase();
156 let idx = self
157 .address_index
158 .get(&addr)
159 .ok_or_else(|| KeyringError::AccountNotFound(addr.clone()))?;
160 let (_, ref signer) = self.accounts[*idx];
161 Ok(format!("0x{}", hex::encode(signer.to_bytes())))
162 }
163
164 fn remove_account(&mut self, address: &str) -> Result<(), KeyringError> {
165 let addr = address.to_lowercase();
166 let idx = self
167 .address_index
168 .remove(&addr)
169 .ok_or_else(|| KeyringError::AccountNotFound(addr.clone()))?;
170 self.accounts.remove(idx);
171 self.address_index.clear();
173 for (i, (a, _)) in self.accounts.iter().enumerate() {
174 self.address_index.insert(a.clone(), i);
175 }
176 Ok(())
177 }
178
179 fn sign_hash(&self, address: &str, hash: &[u8; 32]) -> Result<[u8; 65], KeyringError> {
180 let addr = address.to_lowercase();
181 let idx = self
182 .address_index
183 .get(&addr)
184 .ok_or_else(|| KeyringError::AccountNotFound(addr.clone()))?;
185 let (_, ref signer) = self.accounts[*idx];
186
187 let b256 = B256::from(*hash);
188 let alloy_sig = signer
189 .sign_hash_sync(&b256)
190 .map_err(|e| KeyringError::SigningError(e.to_string()))?;
191
192 let r_bytes: [u8; 32] = alloy_sig.r().to_be_bytes();
193 let s_bytes: [u8; 32] = alloy_sig.s().to_be_bytes();
194 let v: u8 = alloy_sig.v() as u8; let mut sig = [0u8; 65];
197 sig[..32].copy_from_slice(&r_bytes);
198 sig[32..64].copy_from_slice(&s_bytes);
199 sig[64] = v;
200 Ok(sig)
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 const TEST_MNEMONIC: &str =
210 "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
211
212 const EXPECTED_ADDR_0: &str = "0x9858effd232b4033e47d90003d41ec34ecaeda94";
214
215 #[test]
216 fn test_generate_mnemonic_12_words() {
217 let kr = HdKeyring::new(12).unwrap();
218 let words: Vec<&str> = kr.mnemonic().split_whitespace().collect();
219 assert_eq!(words.len(), 12);
220 }
221
222 #[test]
223 fn test_generate_mnemonic_24_words() {
224 let kr = HdKeyring::new(24).unwrap();
225 let words: Vec<&str> = kr.mnemonic().split_whitespace().collect();
226 assert_eq!(words.len(), 24);
227 }
228
229 #[test]
230 fn test_invalid_word_count() {
231 let result = HdKeyring::new(15);
232 assert!(result.is_err());
233 }
234
235 #[test]
236 fn test_from_mnemonic_known_vector() {
237 let mut kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
238 let accounts = kr.derive_accounts(1).unwrap();
239 assert_eq!(accounts.len(), 1);
240 assert_eq!(accounts[0].address, EXPECTED_ADDR_0);
241 }
242
243 #[test]
244 fn test_deterministic_derivation() {
245 let mut kr1 = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
246 let mut kr2 = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
247 let accounts1 = kr1.derive_accounts(3).unwrap();
248 let accounts2 = kr2.derive_accounts(3).unwrap();
249 for i in 0..3 {
250 assert_eq!(accounts1[i].address, accounts2[i].address);
251 }
252 }
253
254 #[test]
255 fn test_derive_multiple_accounts() {
256 let mut kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
257 let accounts = kr.derive_accounts(5).unwrap();
258 assert_eq!(accounts.len(), 5);
259 let mut addrs: Vec<&str> = accounts.iter().map(|a| a.address.as_str()).collect();
261 addrs.sort();
262 addrs.dedup();
263 assert_eq!(addrs.len(), 5);
264 }
265
266 #[test]
267 fn test_incremental_derivation() {
268 let mut kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
269 let first = kr.derive_accounts(2).unwrap();
270 let second = kr.derive_accounts(2).unwrap();
271 assert_eq!(kr.get_accounts().len(), 4);
272 assert_ne!(first[0].address, second[0].address);
274 }
275
276 #[test]
277 fn test_add_accounts_derives_by_count() {
278 let mut kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
279 let accounts = kr
281 .add_accounts(&["ignored".to_string(), "also_ignored".to_string()])
282 .unwrap();
283 assert_eq!(accounts.len(), 2);
284 assert_eq!(accounts[0].address, EXPECTED_ADDR_0);
285 }
286
287 #[test]
288 fn test_serialize_deserialize_roundtrip() {
289 let mut kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
290 kr.derive_accounts(3).unwrap();
291 let original_accounts = kr.get_accounts();
292
293 let data = kr.serialize().unwrap();
294 let kr2 = HdKeyring::deserialize(&data).unwrap();
295 let restored_accounts = kr2.get_accounts();
296
297 assert_eq!(original_accounts.len(), restored_accounts.len());
298 for (a, b) in original_accounts.iter().zip(restored_accounts.iter()) {
299 assert_eq!(a.address, b.address);
300 }
301
302 for acc in &original_accounts {
304 let key1 = kr.export_account(&acc.address).unwrap();
305 let key2 = kr2.export_account(&acc.address).unwrap();
306 assert_eq!(key1, key2);
307 }
308 }
309
310 #[test]
311 fn test_export_account() {
312 let mut kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
313 let accounts = kr.derive_accounts(1).unwrap();
314 let exported = kr.export_account(&accounts[0].address).unwrap();
315 assert!(exported.starts_with("0x"));
316 assert_eq!(exported.len(), 66); let stripped = exported.strip_prefix("0x").unwrap();
320 let bytes = hex::decode(stripped).unwrap();
321 use k256::ecdsa::{SigningKey, VerifyingKey};
322 use sha3::{Digest, Keccak256};
323 let sk = SigningKey::from_bytes(bytes.as_slice().into()).unwrap();
324 let vk = VerifyingKey::from(&sk);
325 let point = vk.to_encoded_point(false);
326 let pubkey_bytes = &point.as_bytes()[1..];
327 let hash = Keccak256::digest(pubkey_bytes);
328 let addr_bytes = &hash[12..];
329 let addr = format!("0x{}", hex::encode(addr_bytes));
330 assert_eq!(addr, accounts[0].address);
331 }
332
333 #[test]
334 fn test_export_account_not_found() {
335 let kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
336 let result = kr.export_account("0xdeadbeef");
337 assert!(matches!(result, Err(KeyringError::AccountNotFound(_))));
338 }
339
340 #[test]
341 fn test_remove_account() {
342 let mut kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
343 kr.derive_accounts(3).unwrap();
344 let accounts = kr.get_accounts();
345 assert_eq!(accounts.len(), 3);
346
347 kr.remove_account(&accounts[1].address).unwrap();
348 let remaining = kr.get_accounts();
349 assert_eq!(remaining.len(), 2);
350 assert_eq!(remaining[0].address, accounts[0].address);
351 assert_eq!(remaining[1].address, accounts[2].address);
352 }
353
354 #[test]
355 fn test_remove_account_not_found() {
356 let mut kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
357 let result = kr.remove_account("0xnonexistent");
358 assert!(matches!(result, Err(KeyringError::AccountNotFound(_))));
359 }
360
361 #[test]
362 fn test_sign_hash() {
363 let mut kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
364 let accounts = kr.derive_accounts(1).unwrap();
365 let addr = &accounts[0].address;
366
367 let hash = [0xab_u8; 32];
368 let sig = kr.sign_hash(addr, &hash).unwrap();
369 assert_eq!(sig.len(), 65);
370 assert!(sig[64] == 0 || sig[64] == 1);
371
372 use k256::ecdsa::{RecoveryId, Signature, VerifyingKey};
374 use sha3::{Digest, Keccak256};
375 let signature = Signature::from_slice(&sig[..64]).unwrap();
376 let recovery_id = RecoveryId::from_byte(sig[64]).unwrap();
377 let recovered = VerifyingKey::recover_from_prehash(&hash, &signature, recovery_id).unwrap();
378 let point = recovered.to_encoded_point(false);
379 let pubkey_bytes = &point.as_bytes()[1..];
380 let keccak = Keccak256::digest(pubkey_bytes);
381 let addr_bytes = &keccak[12..];
382 let recovered_addr = format!("0x{}", hex::encode(addr_bytes));
383 assert_eq!(recovered_addr, *addr);
384 }
385
386 #[test]
387 fn test_sign_hash_account_not_found() {
388 let kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
389 let hash = [0u8; 32];
390 let result = kr.sign_hash("0xnonexistent", &hash);
391 assert!(matches!(result, Err(KeyringError::AccountNotFound(_))));
392 }
393
394 #[test]
395 fn test_keyring_type() {
396 let kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
397 assert_eq!(kr.keyring_type(), "hd");
398 }
399
400 #[test]
401 fn test_hd_path_default() {
402 let kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, None).unwrap();
403 assert_eq!(kr.hd_path(), "m/44'/60'/0'/0");
404 }
405
406 #[test]
407 fn test_custom_hd_path() {
408 let kr = HdKeyring::from_mnemonic(TEST_MNEMONIC, Some("m/44'/60'/1'/0")).unwrap();
409 assert_eq!(kr.hd_path(), "m/44'/60'/1'/0");
410 }
411
412 #[test]
413 fn test_invalid_mnemonic() {
414 let result = HdKeyring::from_mnemonic("not a valid mnemonic phrase", None);
415 assert!(result.is_err());
416 }
417}