1use serde::{Deserialize, Serialize};
2
3use crate::{hd::HdKeyring, simple::SimpleKeyring, Keyring, KeyringAccount, KeyringError};
4
5#[derive(Serialize, Deserialize)]
7struct WalletControllerState {
8 hd_keyrings: Vec<serde_json::Value>,
9 simple_keyrings: Vec<serde_json::Value>,
10}
11
12pub struct WalletController {
14 hd_keyrings: Vec<HdKeyring>,
15 simple_keyrings: Vec<SimpleKeyring>,
16}
17
18impl WalletController {
19 pub fn new() -> Self {
20 Self {
21 hd_keyrings: Vec::new(),
22 simple_keyrings: Vec::new(),
23 }
24 }
25
26 pub fn create_hd_wallet(
29 &mut self,
30 mnemonic: Option<&str>,
31 ) -> Result<Vec<KeyringAccount>, KeyringError> {
32 let mut hd = match mnemonic {
33 Some(m) => HdKeyring::from_mnemonic(m, None)?,
34 None => HdKeyring::new(12)?,
35 };
36 let accounts = hd.derive_accounts(1)?;
37 self.hd_keyrings.push(hd);
38 Ok(accounts)
39 }
40
41 pub fn import_key(&mut self, private_key: &str) -> Result<KeyringAccount, KeyringError> {
43 let mut temp = SimpleKeyring::new();
47 let accounts = temp.add_accounts(&[private_key.to_string()])?;
48 let address = &accounts[0].address;
49
50 for hd in &self.hd_keyrings {
52 for acc in hd.get_accounts() {
53 if acc.address == *address {
54 return Err(KeyringError::DuplicateAccount(address.clone()));
55 }
56 }
57 }
58 for sk in &self.simple_keyrings {
59 for acc in sk.get_accounts() {
60 if acc.address == *address {
61 return Err(KeyringError::DuplicateAccount(address.clone()));
62 }
63 }
64 }
65
66 if self.simple_keyrings.is_empty() {
68 self.simple_keyrings.push(temp);
69 } else {
70 self.simple_keyrings
71 .last_mut()
72 .unwrap()
73 .add_accounts(&[private_key.to_string()])?;
74 }
75 Ok(accounts.into_iter().next().unwrap())
76 }
77
78 pub fn derive_next_agent(&mut self) -> Result<KeyringAccount, KeyringError> {
80 let hd = self
81 .hd_keyrings
82 .first_mut()
83 .ok_or_else(|| KeyringError::AccountNotFound("no HD keyring exists".to_string()))?;
84 let accounts = hd.derive_accounts(1)?;
85 Ok(accounts.into_iter().next().unwrap())
86 }
87
88 pub fn get_accounts(&self) -> Vec<KeyringAccount> {
90 let mut accounts = Vec::new();
91 for hd in &self.hd_keyrings {
92 accounts.extend(hd.get_accounts());
93 }
94 for sk in &self.simple_keyrings {
95 accounts.extend(sk.get_accounts());
96 }
97 accounts
98 }
99
100 pub fn export_account(&self, address: &str) -> Result<String, KeyringError> {
102 let addr = address.to_lowercase();
103 for hd in &self.hd_keyrings {
104 match hd.export_account(&addr) {
105 Ok(key) => return Ok(key),
106 Err(KeyringError::AccountNotFound(_)) => continue,
107 Err(e) => return Err(e),
108 }
109 }
110 for sk in &self.simple_keyrings {
111 match sk.export_account(&addr) {
112 Ok(key) => return Ok(key),
113 Err(KeyringError::AccountNotFound(_)) => continue,
114 Err(e) => return Err(e),
115 }
116 }
117 Err(KeyringError::AccountNotFound(addr))
118 }
119
120 pub fn sign_for_account(
122 &self,
123 address: &str,
124 hash: &[u8; 32],
125 ) -> Result<[u8; 65], KeyringError> {
126 let addr = address.to_lowercase();
127 for hd in &self.hd_keyrings {
128 match hd.sign_hash(&addr, hash) {
129 Ok(sig) => return Ok(sig),
130 Err(KeyringError::AccountNotFound(_)) => continue,
131 Err(e) => return Err(e),
132 }
133 }
134 for sk in &self.simple_keyrings {
135 match sk.sign_hash(&addr, hash) {
136 Ok(sig) => return Ok(sig),
137 Err(KeyringError::AccountNotFound(_)) => continue,
138 Err(e) => return Err(e),
139 }
140 }
141 Err(KeyringError::AccountNotFound(addr))
142 }
143
144 pub fn remove_account(&mut self, address: &str) -> Result<(), KeyringError> {
146 let addr = address.to_lowercase();
147 for hd in &mut self.hd_keyrings {
148 match hd.remove_account(&addr) {
149 Ok(()) => return Ok(()),
150 Err(KeyringError::AccountNotFound(_)) => continue,
151 Err(e) => return Err(e),
152 }
153 }
154 for sk in &mut self.simple_keyrings {
155 match sk.remove_account(&addr) {
156 Ok(()) => return Ok(()),
157 Err(KeyringError::AccountNotFound(_)) => continue,
158 Err(e) => return Err(e),
159 }
160 }
161 Err(KeyringError::AccountNotFound(addr))
162 }
163
164 pub fn serialize(&self) -> Result<Vec<u8>, KeyringError> {
166 let hd_keyrings: Vec<serde_json::Value> = self
167 .hd_keyrings
168 .iter()
169 .map(|hd| {
170 let bytes = hd.serialize()?;
171 serde_json::from_slice(&bytes)
172 .map_err(|e| KeyringError::SerializationError(e.to_string()))
173 })
174 .collect::<Result<_, _>>()?;
175
176 let simple_keyrings: Vec<serde_json::Value> = self
177 .simple_keyrings
178 .iter()
179 .map(|sk| {
180 let bytes = sk.serialize()?;
181 serde_json::from_slice(&bytes)
182 .map_err(|e| KeyringError::SerializationError(e.to_string()))
183 })
184 .collect::<Result<_, _>>()?;
185
186 let state = WalletControllerState {
187 hd_keyrings,
188 simple_keyrings,
189 };
190 serde_json::to_vec(&state).map_err(|e| KeyringError::SerializationError(e.to_string()))
191 }
192
193 pub fn deserialize(data: &[u8]) -> Result<Self, KeyringError> {
195 let state: WalletControllerState = serde_json::from_slice(data)
196 .map_err(|e| KeyringError::SerializationError(e.to_string()))?;
197
198 let hd_keyrings: Vec<HdKeyring> = state
199 .hd_keyrings
200 .into_iter()
201 .map(|v| {
202 let bytes = serde_json::to_vec(&v)
203 .map_err(|e| KeyringError::SerializationError(e.to_string()))?;
204 HdKeyring::deserialize(&bytes)
205 })
206 .collect::<Result<_, _>>()?;
207
208 let simple_keyrings: Vec<SimpleKeyring> = state
209 .simple_keyrings
210 .into_iter()
211 .map(|v| {
212 let bytes = serde_json::to_vec(&v)
213 .map_err(|e| KeyringError::SerializationError(e.to_string()))?;
214 SimpleKeyring::deserialize(&bytes)
215 })
216 .collect::<Result<_, _>>()?;
217
218 Ok(Self {
219 hd_keyrings,
220 simple_keyrings,
221 })
222 }
223}
224
225impl Default for WalletController {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231impl Keyring for WalletController {
232 fn keyring_type(&self) -> &str {
233 "controller"
234 }
235
236 fn serialize(&self) -> Result<Vec<u8>, KeyringError> {
237 WalletController::serialize(self)
238 }
239
240 fn deserialize(data: &[u8]) -> Result<Self, KeyringError>
241 where
242 Self: Sized,
243 {
244 WalletController::deserialize(data)
245 }
246
247 fn add_accounts(
248 &mut self,
249 private_keys: &[String],
250 ) -> Result<Vec<KeyringAccount>, KeyringError> {
251 let mut results = Vec::new();
252 for key in private_keys {
253 results.push(self.import_key(key)?);
254 }
255 Ok(results)
256 }
257
258 fn get_accounts(&self) -> Vec<KeyringAccount> {
259 WalletController::get_accounts(self)
260 }
261
262 fn export_account(&self, address: &str) -> Result<String, KeyringError> {
263 WalletController::export_account(self, address)
264 }
265
266 fn remove_account(&mut self, address: &str) -> Result<(), KeyringError> {
267 WalletController::remove_account(self, address)
268 }
269
270 fn sign_hash(&self, address: &str, hash: &[u8; 32]) -> Result<[u8; 65], KeyringError> {
271 self.sign_for_account(address, hash)
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 const TEST_MNEMONIC: &str =
280 "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
281 const TEST_PRIVATE_KEY: &str =
282 "0x4c0883a69102937d6231471b5dbb6204fe512961708279f22a82e1e0e3e1d0a2";
283 const TEST_PRIVATE_KEY_2: &str =
284 "0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
285
286 #[test]
287 fn test_create_hd_wallet_and_derive_accounts() {
288 let mut ctrl = WalletController::new();
289 let accounts = ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
290 assert_eq!(accounts.len(), 1);
291 assert!(accounts[0].address.starts_with("0x"));
292 assert_eq!(accounts[0].address.len(), 42);
293
294 let next = ctrl.derive_next_agent().unwrap();
296 assert_ne!(next.address, accounts[0].address);
297
298 let all = ctrl.get_accounts();
299 assert_eq!(all.len(), 2);
300 }
301
302 #[test]
303 fn test_create_hd_wallet_random() {
304 let mut ctrl = WalletController::new();
305 let accounts = ctrl.create_hd_wallet(None).unwrap();
306 assert_eq!(accounts.len(), 1);
307 assert!(accounts[0].address.starts_with("0x"));
308 }
309
310 #[test]
311 fn test_import_standalone_key() {
312 let mut ctrl = WalletController::new();
313 let account = ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
314 assert!(account.address.starts_with("0x"));
315 assert_eq!(account.address.len(), 42);
316
317 let all = ctrl.get_accounts();
318 assert_eq!(all.len(), 1);
319 assert_eq!(all[0].address, account.address);
320 }
321
322 #[test]
323 fn test_get_all_accounts_mixed_keyrings() {
324 let mut ctrl = WalletController::new();
325 let hd_accounts = ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
326 let simple_account = ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
327
328 let all = ctrl.get_accounts();
329 assert_eq!(all.len(), 2);
330
331 let addresses: Vec<&str> = all.iter().map(|a| a.address.as_str()).collect();
332 assert!(addresses.contains(&hd_accounts[0].address.as_str()));
333 assert!(addresses.contains(&simple_account.address.as_str()));
334 }
335
336 #[test]
337 fn test_sign_with_hd_account() {
338 let mut ctrl = WalletController::new();
339 let accounts = ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
340 let addr = &accounts[0].address;
341
342 let hash = [0xab_u8; 32];
343 let sig = ctrl.sign_for_account(addr, &hash).unwrap();
344 assert_eq!(sig.len(), 65);
345 assert!(sig[64] == 0 || sig[64] == 1);
346
347 use k256::ecdsa::{RecoveryId, Signature, VerifyingKey};
349 use sha3::Digest;
350 let signature = Signature::from_slice(&sig[..64]).unwrap();
351 let recovery_id = RecoveryId::from_byte(sig[64]).unwrap();
352 let recovered = VerifyingKey::recover_from_prehash(&hash, &signature, recovery_id).unwrap();
353 let point = recovered.to_encoded_point(false);
354 let pubkey_bytes = &point.as_bytes()[1..];
355 let h = sha3::Keccak256::digest(pubkey_bytes);
356 let recovered_addr = format!("0x{}", hex::encode(&h[12..]));
357 assert_eq!(recovered_addr, *addr);
358 }
359
360 #[test]
361 fn test_sign_with_simple_account() {
362 let mut ctrl = WalletController::new();
363 let account = ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
364
365 let hash = [0xcd_u8; 32];
366 let sig = ctrl.sign_for_account(&account.address, &hash).unwrap();
367 assert_eq!(sig.len(), 65);
368 assert!(sig[64] == 0 || sig[64] == 1);
369 }
370
371 #[test]
372 fn test_export_account_from_hd() {
373 let mut ctrl = WalletController::new();
374 let accounts = ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
375 let exported = ctrl.export_account(&accounts[0].address).unwrap();
376 assert!(exported.starts_with("0x"));
377 assert_eq!(exported.len(), 66); }
379
380 #[test]
381 fn test_export_account_from_simple() {
382 let mut ctrl = WalletController::new();
383 let account = ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
384 let exported = ctrl.export_account(&account.address).unwrap();
385 let expected = TEST_PRIVATE_KEY.to_lowercase();
387 assert_eq!(exported, expected);
388 }
389
390 #[test]
391 fn test_remove_account() {
392 let mut ctrl = WalletController::new();
393 let accounts = ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
394 let simple = ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
395 assert_eq!(ctrl.get_accounts().len(), 2);
396
397 ctrl.remove_account(&accounts[0].address).unwrap();
399 assert_eq!(ctrl.get_accounts().len(), 1);
400 assert_eq!(ctrl.get_accounts()[0].address, simple.address);
401
402 ctrl.remove_account(&simple.address).unwrap();
404 assert_eq!(ctrl.get_accounts().len(), 0);
405 }
406
407 #[test]
408 fn test_serialize_deserialize_roundtrip() {
409 let mut ctrl = WalletController::new();
410 ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
411 ctrl.derive_next_agent().unwrap();
412 ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
413
414 let original_accounts = ctrl.get_accounts();
415 assert_eq!(original_accounts.len(), 3);
416
417 let data = ctrl.serialize().unwrap();
418 let ctrl2 = WalletController::deserialize(&data).unwrap();
419 let restored_accounts = ctrl2.get_accounts();
420
421 assert_eq!(original_accounts.len(), restored_accounts.len());
422
423 for acc in &original_accounts {
425 let exported_orig = ctrl.export_account(&acc.address).unwrap();
426 let exported_restored = ctrl2.export_account(&acc.address).unwrap();
427 assert_eq!(exported_orig, exported_restored);
428 }
429 }
430
431 #[test]
432 fn test_sign_unknown_address_error() {
433 let ctrl = WalletController::new();
434 let hash = [0u8; 32];
435 let result = ctrl.sign_for_account("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef", &hash);
436 assert!(result.is_err());
437 match result.unwrap_err() {
438 KeyringError::AccountNotFound(_) => {}
439 other => panic!("Expected AccountNotFound, got: {:?}", other),
440 }
441 }
442
443 #[test]
444 fn test_derive_next_agent_no_hd_keyring() {
445 let mut ctrl = WalletController::new();
446 let result = ctrl.derive_next_agent();
447 assert!(result.is_err());
448 }
449
450 #[test]
451 fn test_import_duplicate_key_across_keyrings() {
452 let mut ctrl = WalletController::new();
453 ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
454 let result = ctrl.import_key(TEST_PRIVATE_KEY);
455 assert!(result.is_err());
456 match result.unwrap_err() {
457 KeyringError::DuplicateAccount(_) => {}
458 other => panic!("Expected DuplicateAccount, got: {:?}", other),
459 }
460 }
461
462 #[test]
463 fn test_import_multiple_keys() {
464 let mut ctrl = WalletController::new();
465 let acc1 = ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
466 let acc2 = ctrl.import_key(TEST_PRIVATE_KEY_2).unwrap();
467 assert_ne!(acc1.address, acc2.address);
468 assert_eq!(ctrl.get_accounts().len(), 2);
469 }
470
471 #[test]
472 fn test_remove_unknown_account_error() {
473 let mut ctrl = WalletController::new();
474 let result = ctrl.remove_account("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef");
475 assert!(result.is_err());
476 match result.unwrap_err() {
477 KeyringError::AccountNotFound(_) => {}
478 other => panic!("Expected AccountNotFound, got: {:?}", other),
479 }
480 }
481
482 #[test]
483 fn test_export_unknown_account_error() {
484 let ctrl = WalletController::new();
485 let result = ctrl.export_account("0xdeadbeef");
486 assert!(result.is_err());
487 match result.unwrap_err() {
488 KeyringError::AccountNotFound(_) => {}
489 other => panic!("Expected AccountNotFound, got: {:?}", other),
490 }
491 }
492}