1use serde::{Deserialize, Serialize};
2
3use crate::controller::WalletController;
4use crate::KeyringError;
5
6use aes_gcm::{
7 aead::{Aead, KeyInit},
8 Aes256Gcm, Nonce,
9};
10use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
11use hmac::Hmac;
12use pbkdf2::pbkdf2;
13use rand::RngCore;
14use sha2::Sha256;
15
16const PBKDF2_ITERATIONS: u32 = 600_000;
18const SALT_LEN: usize = 32;
19const NONCE_LEN: usize = 12;
20const KEY_LEN: usize = 32; const VAULT_VERSION: u32 = 1;
24
25#[derive(Debug, thiserror::Error)]
30pub enum VaultError {
31 #[error("Vault is locked")]
32 Locked,
33 #[error("Wrong password")]
34 WrongPassword,
35 #[error("Vault not initialized")]
36 NotInitialized,
37 #[error("Encryption error: {0}")]
38 EncryptionError(String),
39 #[error("Keyring error: {0}")]
40 KeyringError(#[from] KeyringError),
41}
42
43#[derive(Serialize, Deserialize, Clone, Debug)]
49pub struct VaultMetadata {
50 pub version: u32,
51 pub pbkdf2_iterations: u32,
52 pub salt: String, pub nonce: String, }
55
56#[derive(Serialize, Deserialize, Clone, Debug)]
58pub struct EncryptedVault {
59 pub metadata: VaultMetadata,
60 pub ciphertext: String, }
62
63pub struct VaultState {
69 pub controller: WalletController,
70}
71
72pub struct Vault {
73 state: Option<VaultState>, encrypted: Option<EncryptedVault>, }
76
77fn derive_key(password: &str, salt: &[u8], iterations: u32) -> [u8; KEY_LEN] {
82 let mut key = [0u8; KEY_LEN];
83 pbkdf2::<Hmac<Sha256>>(password.as_bytes(), salt, iterations, &mut key)
84 .expect("HMAC can be initialized with any key length");
85 key
86}
87
88fn encrypt(plaintext: &[u8], password: &str) -> Result<EncryptedVault, VaultError> {
89 let mut rng = rand::thread_rng();
90
91 let mut salt = [0u8; SALT_LEN];
92 rng.fill_bytes(&mut salt);
93
94 let mut nonce_bytes = [0u8; NONCE_LEN];
95 rng.fill_bytes(&mut nonce_bytes);
96
97 let key = derive_key(password, &salt, PBKDF2_ITERATIONS);
98 let cipher =
99 Aes256Gcm::new_from_slice(&key).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
100
101 let nonce = Nonce::from(nonce_bytes);
102 let ciphertext = cipher
103 .encrypt(&nonce, plaintext)
104 .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
105
106 Ok(EncryptedVault {
107 metadata: VaultMetadata {
108 version: VAULT_VERSION,
109 pbkdf2_iterations: PBKDF2_ITERATIONS,
110 salt: BASE64.encode(salt),
111 nonce: BASE64.encode(nonce_bytes),
112 },
113 ciphertext: BASE64.encode(ciphertext),
114 })
115}
116
117fn decrypt(vault: &EncryptedVault, password: &str) -> Result<Vec<u8>, VaultError> {
118 let salt = BASE64
119 .decode(&vault.metadata.salt)
120 .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
121 let nonce_vec = BASE64
122 .decode(&vault.metadata.nonce)
123 .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
124 let nonce_bytes: [u8; NONCE_LEN] = nonce_vec
125 .try_into()
126 .map_err(|_| VaultError::EncryptionError("invalid nonce length".to_string()))?;
127 let ciphertext = BASE64
128 .decode(&vault.ciphertext)
129 .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
130
131 let key = derive_key(password, &salt, vault.metadata.pbkdf2_iterations);
132 let cipher =
133 Aes256Gcm::new_from_slice(&key).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
134
135 let nonce = Nonce::from(nonce_bytes);
136 cipher
137 .decrypt(&nonce, ciphertext.as_ref())
138 .map_err(|_| VaultError::WrongPassword)
139}
140
141pub fn default_vault_path() -> Option<std::path::PathBuf> {
147 dirs::data_dir().map(|d| d.join("hyper-agent").join("vault.json"))
148}
149
150pub fn load_vault_from_file() -> Result<Option<EncryptedVault>, VaultError> {
152 let path = match default_vault_path() {
153 Some(p) => p,
154 None => return Ok(None),
155 };
156 if !path.exists() {
157 return Ok(None);
158 }
159 let data =
160 std::fs::read_to_string(&path).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
161 let vault: EncryptedVault =
162 serde_json::from_str(&data).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
163 Ok(Some(vault))
164}
165
166pub fn save_vault_to_file(vault: &EncryptedVault) -> Result<(), VaultError> {
168 let path = match default_vault_path() {
169 Some(p) => p,
170 None => {
171 return Err(VaultError::EncryptionError(
172 "Cannot determine data directory".to_string(),
173 ))
174 }
175 };
176 if let Some(parent) = path.parent() {
177 std::fs::create_dir_all(parent).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
178 }
179 let json = serde_json::to_string_pretty(vault)
180 .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
181 std::fs::write(&path, json).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
182 Ok(())
183}
184
185pub fn save_vault_to_keychain(vault: &EncryptedVault) -> Result<(), VaultError> {
187 let json =
188 serde_json::to_string(vault).map_err(|e| VaultError::EncryptionError(e.to_string()))?;
189 let entry = keyring::Entry::new("hyper-agent", "encrypted-vault")
190 .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
191 entry
192 .set_password(&json)
193 .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
194 Ok(())
195}
196
197pub fn load_vault_from_keychain() -> Result<Option<EncryptedVault>, VaultError> {
199 let entry = match keyring::Entry::new("hyper-agent", "encrypted-vault") {
200 Ok(e) => e,
201 Err(_) => return Ok(None),
202 };
203 match entry.get_password() {
204 Ok(json) => {
205 let vault: EncryptedVault = serde_json::from_str(&json)
206 .map_err(|e| VaultError::EncryptionError(e.to_string()))?;
207 Ok(Some(vault))
208 }
209 Err(_) => Ok(None),
210 }
211}
212
213impl Vault {
218 pub fn new() -> Self {
220 Self {
221 state: None,
222 encrypted: None,
223 }
224 }
225
226 pub fn from_encrypted(encrypted: EncryptedVault) -> Self {
228 Self {
229 state: None,
230 encrypted: Some(encrypted),
231 }
232 }
233
234 pub fn unlock(&mut self, password: &str) -> Result<(), VaultError> {
236 let enc = self.encrypted.as_ref().ok_or(VaultError::NotInitialized)?;
237 let plaintext = decrypt(enc, password)?;
238 let controller = WalletController::deserialize(&plaintext)?;
239 self.state = Some(VaultState { controller });
240 Ok(())
241 }
242
243 pub fn lock(&mut self) {
245 self.state = None;
246 }
247
248 pub fn is_unlocked(&self) -> bool {
250 self.state.is_some()
251 }
252
253 pub fn controller(&self) -> Result<&WalletController, VaultError> {
255 self.state
256 .as_ref()
257 .map(|s| &s.controller)
258 .ok_or(VaultError::Locked)
259 }
260
261 pub fn controller_mut(&mut self) -> Result<&mut WalletController, VaultError> {
263 self.state
264 .as_mut()
265 .map(|s| &mut s.controller)
266 .ok_or(VaultError::Locked)
267 }
268
269 pub fn save(&self, password: &str) -> Result<EncryptedVault, VaultError> {
271 let ctrl = self.controller()?;
272 let plaintext = ctrl.serialize()?;
273 encrypt(&plaintext, password)
274 }
275
276 pub fn initialize(&mut self, password: &str) -> Result<EncryptedVault, VaultError> {
279 let controller = WalletController::new();
280 let plaintext = controller.serialize()?;
281 let enc = encrypt(&plaintext, password)?;
282 self.encrypted = Some(enc.clone());
283 self.state = Some(VaultState { controller });
284 Ok(enc)
285 }
286}
287
288impl Default for Vault {
289 fn default() -> Self {
290 Self::new()
291 }
292}
293
294#[cfg(test)]
299mod tests {
300 use super::*;
301
302 const TEST_PASSWORD: &str = "correct-horse-battery-staple";
303 const WRONG_PASSWORD: &str = "wrong-password";
304 const TEST_MNEMONIC: &str =
305 "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
306
307 #[test]
308 fn test_initialize_unlock_get_controller() {
309 let mut vault = Vault::new();
310 assert!(!vault.is_unlocked());
311
312 let enc = vault.initialize(TEST_PASSWORD).unwrap();
313 assert!(vault.is_unlocked());
314
315 let ctrl = vault.controller().unwrap();
317 assert_eq!(ctrl.get_accounts().len(), 0);
318
319 let mut vault2 = Vault::from_encrypted(enc);
321 assert!(!vault2.is_unlocked());
322 vault2.unlock(TEST_PASSWORD).unwrap();
323 assert!(vault2.is_unlocked());
324 }
325
326 #[test]
327 fn test_lock_vault_controller_inaccessible() {
328 let mut vault = Vault::new();
329 vault.initialize(TEST_PASSWORD).unwrap();
330 assert!(vault.is_unlocked());
331
332 vault.lock();
333 assert!(!vault.is_unlocked());
334
335 let result = vault.controller();
336 assert!(matches!(result, Err(VaultError::Locked)));
337 }
338
339 #[test]
340 fn test_encrypt_decrypt_roundtrip() {
341 let mut vault = Vault::new();
342 vault.initialize(TEST_PASSWORD).unwrap();
343
344 vault
346 .controller_mut()
347 .unwrap()
348 .create_hd_wallet(Some(TEST_MNEMONIC))
349 .unwrap();
350
351 let accounts_before = vault.controller().unwrap().get_accounts();
352 assert_eq!(accounts_before.len(), 1);
353
354 let enc = vault.save(TEST_PASSWORD).unwrap();
356
357 let mut vault2 = Vault::from_encrypted(enc);
359 vault2.unlock(TEST_PASSWORD).unwrap();
360
361 let accounts_after = vault2.controller().unwrap().get_accounts();
362 assert_eq!(accounts_before.len(), accounts_after.len());
363 assert_eq!(accounts_before[0].address, accounts_after[0].address);
364 }
365
366 #[test]
367 fn test_wrong_password_fails() {
368 let mut vault = Vault::new();
369 let enc = vault.initialize(TEST_PASSWORD).unwrap();
370
371 let mut vault2 = Vault::from_encrypted(enc);
372 let result = vault2.unlock(WRONG_PASSWORD);
373 assert!(matches!(result, Err(VaultError::WrongPassword)));
374 assert!(!vault2.is_unlocked());
375 }
376
377 #[test]
378 fn test_hd_wallet_survives_save_reload() {
379 let mut vault = Vault::new();
380 vault.initialize(TEST_PASSWORD).unwrap();
381
382 let ctrl = vault.controller_mut().unwrap();
384 ctrl.create_hd_wallet(Some(TEST_MNEMONIC)).unwrap();
385 ctrl.derive_next_agent().unwrap();
386 ctrl.derive_next_agent().unwrap();
387
388 let accounts_before: Vec<String> = vault
389 .controller()
390 .unwrap()
391 .get_accounts()
392 .iter()
393 .map(|a| a.address.clone())
394 .collect();
395 assert_eq!(accounts_before.len(), 3);
396
397 let enc = vault.save(TEST_PASSWORD).unwrap();
399 let mut vault2 = Vault::from_encrypted(enc);
400 vault2.unlock(TEST_PASSWORD).unwrap();
401
402 let accounts_after: Vec<String> = vault2
403 .controller()
404 .unwrap()
405 .get_accounts()
406 .iter()
407 .map(|a| a.address.clone())
408 .collect();
409 assert_eq!(accounts_before, accounts_after);
410
411 for addr in &accounts_before {
413 let key1 = vault.controller().unwrap().export_account(addr).unwrap();
414 let key2 = vault2.controller().unwrap().export_account(addr).unwrap();
415 assert_eq!(key1, key2);
416 }
417 }
418
419 #[test]
420 fn test_unlock_not_initialized_fails() {
421 let mut vault = Vault::new();
422 let result = vault.unlock(TEST_PASSWORD);
423 assert!(matches!(result, Err(VaultError::NotInitialized)));
424 }
425
426 #[test]
427 fn test_save_while_locked_fails() {
428 let mut vault = Vault::new();
429 vault.initialize(TEST_PASSWORD).unwrap();
430 vault.lock();
431 let result = vault.save(TEST_PASSWORD);
432 assert!(matches!(result, Err(VaultError::Locked)));
433 }
434
435 #[test]
436 fn test_encrypt_decrypt_raw_helpers() {
437 let plaintext = b"hello vault";
438 let enc = encrypt(plaintext, TEST_PASSWORD).unwrap();
439
440 let decrypted = decrypt(&enc, TEST_PASSWORD).unwrap();
442 assert_eq!(decrypted, plaintext);
443
444 let result = decrypt(&enc, WRONG_PASSWORD);
446 assert!(matches!(result, Err(VaultError::WrongPassword)));
447 }
448
449 #[test]
450 fn test_vault_metadata_version() {
451 let mut vault = Vault::new();
452 let enc = vault.initialize(TEST_PASSWORD).unwrap();
453 assert_eq!(enc.metadata.version, VAULT_VERSION);
454 assert_eq!(enc.metadata.pbkdf2_iterations, PBKDF2_ITERATIONS);
455 }
456
457 #[test]
458 fn test_encrypted_vault_serializable_as_json() {
459 let mut vault = Vault::new();
460 let enc = vault.initialize(TEST_PASSWORD).unwrap();
461
462 let json = serde_json::to_string(&enc).unwrap();
464 let enc2: EncryptedVault = serde_json::from_str(&json).unwrap();
465
466 let mut vault2 = Vault::from_encrypted(enc2);
468 vault2.unlock(TEST_PASSWORD).unwrap();
469 assert!(vault2.is_unlocked());
470 }
471}