Skip to main content

hyper_keyring/
controller.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{hd::HdKeyring, simple::SimpleKeyring, Keyring, KeyringAccount, KeyringError};
4
5/// Serializable form for the entire wallet controller state.
6#[derive(Serialize, Deserialize)]
7struct WalletControllerState {
8    hd_keyrings: Vec<serde_json::Value>,
9    simple_keyrings: Vec<serde_json::Value>,
10}
11
12/// WalletController manages multiple keyrings (HD + Simple) with a unified interface.
13pub struct WalletController {
14    hd_keyrings: Vec<HdKeyring>,
15    simple_keyrings: Vec<SimpleKeyring>,
16}
17
18impl WalletController {
19    pub fn new() -> Self {
20        Self {
21            hd_keyrings: Vec::new(),
22            simple_keyrings: Vec::new(),
23        }
24    }
25
26    /// Create a new HD wallet (optionally from existing mnemonic).
27    /// Derives one initial account and returns it.
28    pub fn create_hd_wallet(
29        &mut self,
30        mnemonic: Option<&str>,
31    ) -> Result<Vec<KeyringAccount>, KeyringError> {
32        let mut hd = match mnemonic {
33            Some(m) => HdKeyring::from_mnemonic(m, None)?,
34            None => HdKeyring::new(12)?,
35        };
36        let accounts = hd.derive_accounts(1)?;
37        self.hd_keyrings.push(hd);
38        Ok(accounts)
39    }
40
41    /// Import a standalone private key into a simple keyring.
42    pub fn import_key(&mut self, private_key: &str) -> Result<KeyringAccount, KeyringError> {
43        // Check if the key already exists in any keyring
44        // We need to try adding it to detect duplicates across keyrings.
45        // First, create a temporary simple keyring to derive the address.
46        let mut temp = SimpleKeyring::new();
47        let accounts = temp.add_accounts(&[private_key.to_string()])?;
48        let address = &accounts[0].address;
49
50        // Check if this address already exists in any keyring
51        for hd in &self.hd_keyrings {
52            for acc in hd.get_accounts() {
53                if acc.address == *address {
54                    return Err(KeyringError::DuplicateAccount(address.clone()));
55                }
56            }
57        }
58        for sk in &self.simple_keyrings {
59            for acc in sk.get_accounts() {
60                if acc.address == *address {
61                    return Err(KeyringError::DuplicateAccount(address.clone()));
62                }
63            }
64        }
65
66        // Add to an existing simple keyring or create a new one
67        if self.simple_keyrings.is_empty() {
68            self.simple_keyrings.push(temp);
69        } else {
70            self.simple_keyrings
71                .last_mut()
72                .unwrap()
73                .add_accounts(&[private_key.to_string()])?;
74        }
75        Ok(accounts.into_iter().next().unwrap())
76    }
77
78    /// Derive next agent wallet from the first HD keyring.
79    pub fn derive_next_agent(&mut self) -> Result<KeyringAccount, KeyringError> {
80        let hd = self
81            .hd_keyrings
82            .first_mut()
83            .ok_or_else(|| KeyringError::AccountNotFound("no HD keyring exists".to_string()))?;
84        let accounts = hd.derive_accounts(1)?;
85        Ok(accounts.into_iter().next().unwrap())
86    }
87
88    /// Get all accounts across all keyrings.
89    pub fn get_accounts(&self) -> Vec<KeyringAccount> {
90        let mut accounts = Vec::new();
91        for hd in &self.hd_keyrings {
92            accounts.extend(hd.get_accounts());
93        }
94        for sk in &self.simple_keyrings {
95            accounts.extend(sk.get_accounts());
96        }
97        accounts
98    }
99
100    /// Export private key for an address (searches HD keyrings first, then simple).
101    pub fn export_account(&self, address: &str) -> Result<String, KeyringError> {
102        let addr = address.to_lowercase();
103        for hd in &self.hd_keyrings {
104            match hd.export_account(&addr) {
105                Ok(key) => return Ok(key),
106                Err(KeyringError::AccountNotFound(_)) => continue,
107                Err(e) => return Err(e),
108            }
109        }
110        for sk in &self.simple_keyrings {
111            match sk.export_account(&addr) {
112                Ok(key) => return Ok(key),
113                Err(KeyringError::AccountNotFound(_)) => continue,
114                Err(e) => return Err(e),
115            }
116        }
117        Err(KeyringError::AccountNotFound(addr))
118    }
119
120    /// Sign a hash with the key for a given address (routes to correct keyring).
121    pub fn sign_for_account(
122        &self,
123        address: &str,
124        hash: &[u8; 32],
125    ) -> Result<[u8; 65], KeyringError> {
126        let addr = address.to_lowercase();
127        for hd in &self.hd_keyrings {
128            match hd.sign_hash(&addr, hash) {
129                Ok(sig) => return Ok(sig),
130                Err(KeyringError::AccountNotFound(_)) => continue,
131                Err(e) => return Err(e),
132            }
133        }
134        for sk in &self.simple_keyrings {
135            match sk.sign_hash(&addr, hash) {
136                Ok(sig) => return Ok(sig),
137                Err(KeyringError::AccountNotFound(_)) => continue,
138                Err(e) => return Err(e),
139            }
140        }
141        Err(KeyringError::AccountNotFound(addr))
142    }
143
144    /// Remove an account from whichever keyring contains it.
145    pub fn remove_account(&mut self, address: &str) -> Result<(), KeyringError> {
146        let addr = address.to_lowercase();
147        for hd in &mut self.hd_keyrings {
148            match hd.remove_account(&addr) {
149                Ok(()) => return Ok(()),
150                Err(KeyringError::AccountNotFound(_)) => continue,
151                Err(e) => return Err(e),
152            }
153        }
154        for sk in &mut self.simple_keyrings {
155            match sk.remove_account(&addr) {
156                Ok(()) => return Ok(()),
157                Err(KeyringError::AccountNotFound(_)) => continue,
158                Err(e) => return Err(e),
159            }
160        }
161        Err(KeyringError::AccountNotFound(addr))
162    }
163
164    /// Serialize all keyrings for storage.
165    pub fn serialize(&self) -> Result<Vec<u8>, KeyringError> {
166        let hd_keyrings: Vec<serde_json::Value> = self
167            .hd_keyrings
168            .iter()
169            .map(|hd| {
170                let bytes = hd.serialize()?;
171                serde_json::from_slice(&bytes)
172                    .map_err(|e| KeyringError::SerializationError(e.to_string()))
173            })
174            .collect::<Result<_, _>>()?;
175
176        let simple_keyrings: Vec<serde_json::Value> = self
177            .simple_keyrings
178            .iter()
179            .map(|sk| {
180                let bytes = sk.serialize()?;
181                serde_json::from_slice(&bytes)
182                    .map_err(|e| KeyringError::SerializationError(e.to_string()))
183            })
184            .collect::<Result<_, _>>()?;
185
186        let state = WalletControllerState {
187            hd_keyrings,
188            simple_keyrings,
189        };
190        serde_json::to_vec(&state).map_err(|e| KeyringError::SerializationError(e.to_string()))
191    }
192
193    /// Deserialize and restore all keyrings.
194    pub fn deserialize(data: &[u8]) -> Result<Self, KeyringError> {
195        let state: WalletControllerState = serde_json::from_slice(data)
196            .map_err(|e| KeyringError::SerializationError(e.to_string()))?;
197
198        let hd_keyrings: Vec<HdKeyring> = state
199            .hd_keyrings
200            .into_iter()
201            .map(|v| {
202                let bytes = serde_json::to_vec(&v)
203                    .map_err(|e| KeyringError::SerializationError(e.to_string()))?;
204                HdKeyring::deserialize(&bytes)
205            })
206            .collect::<Result<_, _>>()?;
207
208        let simple_keyrings: Vec<SimpleKeyring> = state
209            .simple_keyrings
210            .into_iter()
211            .map(|v| {
212                let bytes = serde_json::to_vec(&v)
213                    .map_err(|e| KeyringError::SerializationError(e.to_string()))?;
214                SimpleKeyring::deserialize(&bytes)
215            })
216            .collect::<Result<_, _>>()?;
217
218        Ok(Self {
219            hd_keyrings,
220            simple_keyrings,
221        })
222    }
223}
224
225impl Default for WalletController {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231impl Keyring for WalletController {
232    fn keyring_type(&self) -> &str {
233        "controller"
234    }
235
236    fn serialize(&self) -> Result<Vec<u8>, KeyringError> {
237        WalletController::serialize(self)
238    }
239
240    fn deserialize(data: &[u8]) -> Result<Self, KeyringError>
241    where
242        Self: Sized,
243    {
244        WalletController::deserialize(data)
245    }
246
247    fn add_accounts(
248        &mut self,
249        private_keys: &[String],
250    ) -> Result<Vec<KeyringAccount>, KeyringError> {
251        let mut results = Vec::new();
252        for key in private_keys {
253            results.push(self.import_key(key)?);
254        }
255        Ok(results)
256    }
257
258    fn get_accounts(&self) -> Vec<KeyringAccount> {
259        WalletController::get_accounts(self)
260    }
261
262    fn export_account(&self, address: &str) -> Result<String, KeyringError> {
263        WalletController::export_account(self, address)
264    }
265
266    fn remove_account(&mut self, address: &str) -> Result<(), KeyringError> {
267        WalletController::remove_account(self, address)
268    }
269
270    fn sign_hash(&self, address: &str, hash: &[u8; 32]) -> Result<[u8; 65], KeyringError> {
271        self.sign_for_account(address, hash)
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    const TEST_MNEMONIC: &str =
280        "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
281    const TEST_PRIVATE_KEY: &str =
282        "0x4c0883a69102937d6231471b5dbb6204fe512961708279f22a82e1e0e3e1d0a2";
283    const TEST_PRIVATE_KEY_2: &str =
284        "0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
285
286    #[test]
287    fn test_create_hd_wallet_and_derive_accounts() {
288        let mut ctrl = WalletController::new();
289        let accounts = ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
290        assert_eq!(accounts.len(), 1);
291        assert!(accounts[0].address.starts_with("0x"));
292        assert_eq!(accounts[0].address.len(), 42);
293
294        // Derive more accounts
295        let next = ctrl.derive_next_agent().unwrap();
296        assert_ne!(next.address, accounts[0].address);
297
298        let all = ctrl.get_accounts();
299        assert_eq!(all.len(), 2);
300    }
301
302    #[test]
303    fn test_create_hd_wallet_random() {
304        let mut ctrl = WalletController::new();
305        let accounts = ctrl.create_hd_wallet(None).unwrap();
306        assert_eq!(accounts.len(), 1);
307        assert!(accounts[0].address.starts_with("0x"));
308    }
309
310    #[test]
311    fn test_import_standalone_key() {
312        let mut ctrl = WalletController::new();
313        let account = ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
314        assert!(account.address.starts_with("0x"));
315        assert_eq!(account.address.len(), 42);
316
317        let all = ctrl.get_accounts();
318        assert_eq!(all.len(), 1);
319        assert_eq!(all[0].address, account.address);
320    }
321
322    #[test]
323    fn test_get_all_accounts_mixed_keyrings() {
324        let mut ctrl = WalletController::new();
325        let hd_accounts = ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
326        let simple_account = ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
327
328        let all = ctrl.get_accounts();
329        assert_eq!(all.len(), 2);
330
331        let addresses: Vec<&str> = all.iter().map(|a| a.address.as_str()).collect();
332        assert!(addresses.contains(&hd_accounts[0].address.as_str()));
333        assert!(addresses.contains(&simple_account.address.as_str()));
334    }
335
336    #[test]
337    fn test_sign_with_hd_account() {
338        let mut ctrl = WalletController::new();
339        let accounts = ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
340        let addr = &accounts[0].address;
341
342        let hash = [0xab_u8; 32];
343        let sig = ctrl.sign_for_account(addr, &hash).unwrap();
344        assert_eq!(sig.len(), 65);
345        assert!(sig[64] == 0 || sig[64] == 1);
346
347        // Verify recovery
348        use k256::ecdsa::{RecoveryId, Signature, VerifyingKey};
349        use sha3::Digest;
350        let signature = Signature::from_slice(&sig[..64]).unwrap();
351        let recovery_id = RecoveryId::from_byte(sig[64]).unwrap();
352        let recovered = VerifyingKey::recover_from_prehash(&hash, &signature, recovery_id).unwrap();
353        let point = recovered.to_encoded_point(false);
354        let pubkey_bytes = &point.as_bytes()[1..];
355        let h = sha3::Keccak256::digest(pubkey_bytes);
356        let recovered_addr = format!("0x{}", hex::encode(&h[12..]));
357        assert_eq!(recovered_addr, *addr);
358    }
359
360    #[test]
361    fn test_sign_with_simple_account() {
362        let mut ctrl = WalletController::new();
363        let account = ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
364
365        let hash = [0xcd_u8; 32];
366        let sig = ctrl.sign_for_account(&account.address, &hash).unwrap();
367        assert_eq!(sig.len(), 65);
368        assert!(sig[64] == 0 || sig[64] == 1);
369    }
370
371    #[test]
372    fn test_export_account_from_hd() {
373        let mut ctrl = WalletController::new();
374        let accounts = ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
375        let exported = ctrl.export_account(&accounts[0].address).unwrap();
376        assert!(exported.starts_with("0x"));
377        assert_eq!(exported.len(), 66); // 0x + 64 hex
378    }
379
380    #[test]
381    fn test_export_account_from_simple() {
382        let mut ctrl = WalletController::new();
383        let account = ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
384        let exported = ctrl.export_account(&account.address).unwrap();
385        // Should match the original key (normalised)
386        let expected = TEST_PRIVATE_KEY.to_lowercase();
387        assert_eq!(exported, expected);
388    }
389
390    #[test]
391    fn test_remove_account() {
392        let mut ctrl = WalletController::new();
393        let accounts = ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
394        let simple = ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
395        assert_eq!(ctrl.get_accounts().len(), 2);
396
397        // Remove the HD account
398        ctrl.remove_account(&accounts[0].address).unwrap();
399        assert_eq!(ctrl.get_accounts().len(), 1);
400        assert_eq!(ctrl.get_accounts()[0].address, simple.address);
401
402        // Remove the simple account
403        ctrl.remove_account(&simple.address).unwrap();
404        assert_eq!(ctrl.get_accounts().len(), 0);
405    }
406
407    #[test]
408    fn test_serialize_deserialize_roundtrip() {
409        let mut ctrl = WalletController::new();
410        ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
411        ctrl.derive_next_agent().unwrap();
412        ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
413
414        let original_accounts = ctrl.get_accounts();
415        assert_eq!(original_accounts.len(), 3);
416
417        let data = ctrl.serialize().unwrap();
418        let ctrl2 = WalletController::deserialize(&data).unwrap();
419        let restored_accounts = ctrl2.get_accounts();
420
421        assert_eq!(original_accounts.len(), restored_accounts.len());
422
423        // All original accounts should be present in restored
424        for acc in &original_accounts {
425            let exported_orig = ctrl.export_account(&acc.address).unwrap();
426            let exported_restored = ctrl2.export_account(&acc.address).unwrap();
427            assert_eq!(exported_orig, exported_restored);
428        }
429    }
430
431    #[test]
432    fn test_sign_unknown_address_error() {
433        let ctrl = WalletController::new();
434        let hash = [0u8; 32];
435        let result = ctrl.sign_for_account("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef", &hash);
436        assert!(result.is_err());
437        match result.unwrap_err() {
438            KeyringError::AccountNotFound(_) => {}
439            other => panic!("Expected AccountNotFound, got: {:?}", other),
440        }
441    }
442
443    #[test]
444    fn test_derive_next_agent_no_hd_keyring() {
445        let mut ctrl = WalletController::new();
446        let result = ctrl.derive_next_agent();
447        assert!(result.is_err());
448    }
449
450    #[test]
451    fn test_import_duplicate_key_across_keyrings() {
452        let mut ctrl = WalletController::new();
453        ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
454        let result = ctrl.import_key(TEST_PRIVATE_KEY);
455        assert!(result.is_err());
456        match result.unwrap_err() {
457            KeyringError::DuplicateAccount(_) => {}
458            other => panic!("Expected DuplicateAccount, got: {:?}", other),
459        }
460    }
461
462    #[test]
463    fn test_import_multiple_keys() {
464        let mut ctrl = WalletController::new();
465        let acc1 = ctrl.import_key(TEST_PRIVATE_KEY).unwrap();
466        let acc2 = ctrl.import_key(TEST_PRIVATE_KEY_2).unwrap();
467        assert_ne!(acc1.address, acc2.address);
468        assert_eq!(ctrl.get_accounts().len(), 2);
469    }
470
471    #[test]
472    fn test_remove_unknown_account_error() {
473        let mut ctrl = WalletController::new();
474        let result = ctrl.remove_account("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef");
475        assert!(result.is_err());
476        match result.unwrap_err() {
477            KeyringError::AccountNotFound(_) => {}
478            other => panic!("Expected AccountNotFound, got: {:?}", other),
479        }
480    }
481
482    #[test]
483    fn test_export_unknown_account_error() {
484        let ctrl = WalletController::new();
485        let result = ctrl.export_account("0xdeadbeef");
486        assert!(result.is_err());
487        match result.unwrap_err() {
488            KeyringError::AccountNotFound(_) => {}
489            other => panic!("Expected AccountNotFound, got: {:?}", other),
490        }
491    }
492}