Skip to main content

hyper_keyring/
vault.rs

1use serde::{Deserialize, Serialize};
2
3use crate::controller::WalletController;
4use crate::KeyringError;
5
6use aes_gcm::{
7    aead::{Aead, KeyInit},
8    Aes256Gcm, Nonce,
9};
10use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
11use hmac::Hmac;
12use pbkdf2::pbkdf2;
13use rand::RngCore;
14use sha2::Sha256;
15
16/// OWASP 2023 recommended minimum for PBKDF2-SHA256.
17const PBKDF2_ITERATIONS: u32 = 600_000;
18const SALT_LEN: usize = 32;
19const NONCE_LEN: usize = 12;
20const KEY_LEN: usize = 32; // AES-256
21
22/// Current vault format version.
23const VAULT_VERSION: u32 = 1;
24
25// ---------------------------------------------------------------------------
26// Error
27// ---------------------------------------------------------------------------
28
29#[derive(Debug, thiserror::Error)]
30pub enum VaultError {
31    #[error("Vault is locked")]
32    Locked,
33    #[error("Wrong password")]
34    WrongPassword,
35    #[error("Vault not initialized")]
36    NotInitialized,
37    #[error("Encryption error: {0}")]
38    EncryptionError(String),
39    #[error("Keyring error: {0}")]
40    KeyringError(#[from] KeyringError),
41}
42
43// ---------------------------------------------------------------------------
44// Persisted structures
45// ---------------------------------------------------------------------------
46
47/// Encrypted vault metadata (stored alongside the encrypted data).
48#[derive(Serialize, Deserialize, Clone, Debug)]
49pub struct VaultMetadata {
50    pub version: u32,
51    pub pbkdf2_iterations: u32,
52    pub salt: String,  // Base64-encoded 32-byte salt
53    pub nonce: String, // Base64-encoded 12-byte AES-GCM nonce
54}
55
56/// Full vault structure (what gets persisted as JSON).
57#[derive(Serialize, Deserialize, Clone, Debug)]
58pub struct EncryptedVault {
59    pub metadata: VaultMetadata,
60    pub ciphertext: String, // Base64-encoded encrypted keyring states
61}
62
63// ---------------------------------------------------------------------------
64// In-memory state
65// ---------------------------------------------------------------------------
66
67/// Decrypted vault state (in-memory only).
68pub struct VaultState {
69    pub controller: WalletController,
70}
71
72pub struct Vault {
73    state: Option<VaultState>,         // None when locked
74    encrypted: Option<EncryptedVault>, // The persisted form
75}
76
77// ---------------------------------------------------------------------------
78// Crypto helpers
79// ---------------------------------------------------------------------------
80
81fn derive_key(password: &str, salt: &[u8], iterations: u32) -> [u8; KEY_LEN] {
82    let mut key = [0u8; KEY_LEN];
83    pbkdf2::<Hmac<Sha256>>(password.as_bytes(), salt, iterations, &mut key)
84        .expect("HMAC can be initialized with any key length");
85    key
86}
87
88fn encrypt(plaintext: &[u8], password: &str) -> Result<EncryptedVault, VaultError> {
89    let mut rng = rand::thread_rng();
90
91    let mut salt = [0u8; SALT_LEN];
92    rng.fill_bytes(&mut salt);
93
94    let mut nonce_bytes = [0u8; NONCE_LEN];
95    rng.fill_bytes(&mut nonce_bytes);
96
97    let key = derive_key(password, &salt, PBKDF2_ITERATIONS);
98    let cipher =
99        Aes256Gcm::new_from_slice(&key).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
100
101    let nonce = Nonce::from(nonce_bytes);
102    let ciphertext = cipher
103        .encrypt(&nonce, plaintext)
104        .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
105
106    Ok(EncryptedVault {
107        metadata: VaultMetadata {
108            version: VAULT_VERSION,
109            pbkdf2_iterations: PBKDF2_ITERATIONS,
110            salt: BASE64.encode(salt),
111            nonce: BASE64.encode(nonce_bytes),
112        },
113        ciphertext: BASE64.encode(ciphertext),
114    })
115}
116
117fn decrypt(vault: &EncryptedVault, password: &str) -> Result<Vec<u8>, VaultError> {
118    let salt = BASE64
119        .decode(&vault.metadata.salt)
120        .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
121    let nonce_vec = BASE64
122        .decode(&vault.metadata.nonce)
123        .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
124    let nonce_bytes: [u8; NONCE_LEN] = nonce_vec
125        .try_into()
126        .map_err(|_| VaultError::EncryptionError("invalid nonce length".to_string()))?;
127    let ciphertext = BASE64
128        .decode(&vault.ciphertext)
129        .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
130
131    let key = derive_key(password, &salt, vault.metadata.pbkdf2_iterations);
132    let cipher =
133        Aes256Gcm::new_from_slice(&key).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
134
135    let nonce = Nonce::from(nonce_bytes);
136    cipher
137        .decrypt(&nonce, ciphertext.as_ref())
138        .map_err(|_| VaultError::WrongPassword)
139}
140
141// ---------------------------------------------------------------------------
142// File persistence helpers
143// ---------------------------------------------------------------------------
144
145/// Return the default vault file path: `<data_dir>/hyper-agent/vault.json`.
146pub fn default_vault_path() -> Option<std::path::PathBuf> {
147    dirs::data_dir().map(|d| d.join("hyper-agent").join("vault.json"))
148}
149
150/// Load an `EncryptedVault` from the default file location.
151pub fn load_vault_from_file() -> Result<Option<EncryptedVault>, VaultError> {
152    let path = match default_vault_path() {
153        Some(p) => p,
154        None => return Ok(None),
155    };
156    if !path.exists() {
157        return Ok(None);
158    }
159    let data =
160        std::fs::read_to_string(&path).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
161    let vault: EncryptedVault =
162        serde_json::from_str(&data).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
163    Ok(Some(vault))
164}
165
166/// Persist an `EncryptedVault` to the default file location.
167pub fn save_vault_to_file(vault: &EncryptedVault) -> Result<(), VaultError> {
168    let path = match default_vault_path() {
169        Some(p) => p,
170        None => {
171            return Err(VaultError::EncryptionError(
172                "Cannot determine data directory".to_string(),
173            ))
174        }
175    };
176    if let Some(parent) = path.parent() {
177        std::fs::create_dir_all(parent).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
178    }
179    let json = serde_json::to_string_pretty(vault)
180        .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
181    std::fs::write(&path, json).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
182    Ok(())
183}
184
185/// Store the encrypted vault JSON as a single keychain entry (fallback).
186pub fn save_vault_to_keychain(vault: &EncryptedVault) -> Result<(), VaultError> {
187    let json =
188        serde_json::to_string(vault).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
189    let entry = keyring::Entry::new("hyper-agent", "encrypted-vault")
190        .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
191    entry
192        .set_password(&json)
193        .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
194    Ok(())
195}
196
197/// Load the encrypted vault from the OS keychain (fallback).
198pub fn load_vault_from_keychain() -> Result<Option<EncryptedVault>, VaultError> {
199    let entry = match keyring::Entry::new("hyper-agent", "encrypted-vault") {
200        Ok(e) => e,
201        Err(_) => return Ok(None),
202    };
203    match entry.get_password() {
204        Ok(json) => {
205            let vault: EncryptedVault = serde_json::from_str(&json)
206                .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
207            Ok(Some(vault))
208        }
209        Err(_) => Ok(None),
210    }
211}
212
213// ---------------------------------------------------------------------------
214// Vault implementation
215// ---------------------------------------------------------------------------
216
217impl Vault {
218    /// Create a new empty vault.
219    pub fn new() -> Self {
220        Self {
221            state: None,
222            encrypted: None,
223        }
224    }
225
226    /// Create vault from existing encrypted data.
227    pub fn from_encrypted(encrypted: EncryptedVault) -> Self {
228        Self {
229            state: None,
230            encrypted: Some(encrypted),
231        }
232    }
233
234    /// Unlock vault with password (decrypt keyring states).
235    pub fn unlock(&mut self, password: &str) -> Result<(), VaultError> {
236        let enc = self.encrypted.as_ref().ok_or(VaultError::NotInitialized)?;
237        let plaintext = decrypt(enc, password)?;
238        let controller = WalletController::deserialize(&plaintext)?;
239        self.state = Some(VaultState { controller });
240        Ok(())
241    }
242
243    /// Lock vault (clear decrypted state from memory).
244    pub fn lock(&mut self) {
245        self.state = None;
246    }
247
248    /// Check if vault is unlocked.
249    pub fn is_unlocked(&self) -> bool {
250        self.state.is_some()
251    }
252
253    /// Get the wallet controller (must be unlocked).
254    pub fn controller(&self) -> Result<&WalletController, VaultError> {
255        self.state
256            .as_ref()
257            .map(|s| &s.controller)
258            .ok_or(VaultError::Locked)
259    }
260
261    /// Get mutable wallet controller (must be unlocked).
262    pub fn controller_mut(&mut self) -> Result<&mut WalletController, VaultError> {
263        self.state
264            .as_mut()
265            .map(|s| &mut s.controller)
266            .ok_or(VaultError::Locked)
267    }
268
269    /// Save current state: encrypt with password and return EncryptedVault.
270    pub fn save(&self, password: &str) -> Result<EncryptedVault, VaultError> {
271        let ctrl = self.controller()?;
272        let plaintext = ctrl.serialize()?;
273        encrypt(&plaintext, password)
274    }
275
276    /// Initialize a new vault with password (for first-time setup).
277    /// Creates an empty WalletController inside.
278    pub fn initialize(&mut self, password: &str) -> Result<EncryptedVault, VaultError> {
279        let controller = WalletController::new();
280        let plaintext = controller.serialize()?;
281        let enc = encrypt(&plaintext, password)?;
282        self.encrypted = Some(enc.clone());
283        self.state = Some(VaultState { controller });
284        Ok(enc)
285    }
286}
287
288impl Default for Vault {
289    fn default() -> Self {
290        Self::new()
291    }
292}
293
294// ---------------------------------------------------------------------------
295// Tests
296// ---------------------------------------------------------------------------
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    const TEST_PASSWORD: &str = "correct-horse-battery-staple";
303    const WRONG_PASSWORD: &str = "wrong-password";
304    const TEST_MNEMONIC: &str =
305        "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
306
307    #[test]
308    fn test_initialize_unlock_get_controller() {
309        let mut vault = Vault::new();
310        assert!(!vault.is_unlocked());
311
312        let enc = vault.initialize(TEST_PASSWORD).unwrap();
313        assert!(vault.is_unlocked());
314
315        // Should be able to get controller
316        let ctrl = vault.controller().unwrap();
317        assert_eq!(ctrl.get_accounts().len(), 0);
318
319        // Re-create vault from encrypted and unlock
320        let mut vault2 = Vault::from_encrypted(enc);
321        assert!(!vault2.is_unlocked());
322        vault2.unlock(TEST_PASSWORD).unwrap();
323        assert!(vault2.is_unlocked());
324    }
325
326    #[test]
327    fn test_lock_vault_controller_inaccessible() {
328        let mut vault = Vault::new();
329        vault.initialize(TEST_PASSWORD).unwrap();
330        assert!(vault.is_unlocked());
331
332        vault.lock();
333        assert!(!vault.is_unlocked());
334
335        let result = vault.controller();
336        assert!(matches!(result, Err(VaultError::Locked)));
337    }
338
339    #[test]
340    fn test_encrypt_decrypt_roundtrip() {
341        let mut vault = Vault::new();
342        vault.initialize(TEST_PASSWORD).unwrap();
343
344        // Add an HD wallet
345        vault
346            .controller_mut()
347            .unwrap()
348            .create_hd_wallet(Some(TEST_MNEMONIC))
349            .unwrap();
350
351        let accounts_before = vault.controller().unwrap().get_accounts();
352        assert_eq!(accounts_before.len(), 1);
353
354        // Save (encrypt)
355        let enc = vault.save(TEST_PASSWORD).unwrap();
356
357        // Load into a new vault and unlock
358        let mut vault2 = Vault::from_encrypted(enc);
359        vault2.unlock(TEST_PASSWORD).unwrap();
360
361        let accounts_after = vault2.controller().unwrap().get_accounts();
362        assert_eq!(accounts_before.len(), accounts_after.len());
363        assert_eq!(accounts_before[0].address, accounts_after[0].address);
364    }
365
366    #[test]
367    fn test_wrong_password_fails() {
368        let mut vault = Vault::new();
369        let enc = vault.initialize(TEST_PASSWORD).unwrap();
370
371        let mut vault2 = Vault::from_encrypted(enc);
372        let result = vault2.unlock(WRONG_PASSWORD);
373        assert!(matches!(result, Err(VaultError::WrongPassword)));
374        assert!(!vault2.is_unlocked());
375    }
376
377    #[test]
378    fn test_hd_wallet_survives_save_reload() {
379        let mut vault = Vault::new();
380        vault.initialize(TEST_PASSWORD).unwrap();
381
382        // Create HD wallet and derive additional accounts
383        let ctrl = vault.controller_mut().unwrap();
384        ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
385        ctrl.derive_next_agent().unwrap();
386        ctrl.derive_next_agent().unwrap();
387
388        let accounts_before: Vec<String> = vault
389            .controller()
390            .unwrap()
391            .get_accounts()
392            .iter()
393            .map(|a| a.address.clone())
394            .collect();
395        assert_eq!(accounts_before.len(), 3);
396
397        // Save, reload, verify
398        let enc = vault.save(TEST_PASSWORD).unwrap();
399        let mut vault2 = Vault::from_encrypted(enc);
400        vault2.unlock(TEST_PASSWORD).unwrap();
401
402        let accounts_after: Vec<String> = vault2
403            .controller()
404            .unwrap()
405            .get_accounts()
406            .iter()
407            .map(|a| a.address.clone())
408            .collect();
409        assert_eq!(accounts_before, accounts_after);
410
411        // Verify private keys match
412        for addr in &accounts_before {
413            let key1 = vault.controller().unwrap().export_account(addr).unwrap();
414            let key2 = vault2.controller().unwrap().export_account(addr).unwrap();
415            assert_eq!(key1, key2);
416        }
417    }
418
419    #[test]
420    fn test_unlock_not_initialized_fails() {
421        let mut vault = Vault::new();
422        let result = vault.unlock(TEST_PASSWORD);
423        assert!(matches!(result, Err(VaultError::NotInitialized)));
424    }
425
426    #[test]
427    fn test_save_while_locked_fails() {
428        let mut vault = Vault::new();
429        vault.initialize(TEST_PASSWORD).unwrap();
430        vault.lock();
431        let result = vault.save(TEST_PASSWORD);
432        assert!(matches!(result, Err(VaultError::Locked)));
433    }
434
435    #[test]
436    fn test_encrypt_decrypt_raw_helpers() {
437        let plaintext = b"hello vault";
438        let enc = encrypt(plaintext, TEST_PASSWORD).unwrap();
439
440        // Correct password
441        let decrypted = decrypt(&enc, TEST_PASSWORD).unwrap();
442        assert_eq!(decrypted, plaintext);
443
444        // Wrong password
445        let result = decrypt(&enc, WRONG_PASSWORD);
446        assert!(matches!(result, Err(VaultError::WrongPassword)));
447    }
448
449    #[test]
450    fn test_vault_metadata_version() {
451        let mut vault = Vault::new();
452        let enc = vault.initialize(TEST_PASSWORD).unwrap();
453        assert_eq!(enc.metadata.version, VAULT_VERSION);
454        assert_eq!(enc.metadata.pbkdf2_iterations, PBKDF2_ITERATIONS);
455    }
456
457    #[test]
458    fn test_encrypted_vault_serializable_as_json() {
459        let mut vault = Vault::new();
460        let enc = vault.initialize(TEST_PASSWORD).unwrap();
461
462        // Should roundtrip through JSON
463        let json = serde_json::to_string(&enc).unwrap();
464        let enc2: EncryptedVault = serde_json::from_str(&json).unwrap();
465
466        // Unlock with the deserialized copy
467        let mut vault2 = Vault::from_encrypted(enc2);
468        vault2.unlock(TEST_PASSWORD).unwrap();
469        assert!(vault2.is_unlocked());
470    }
471}