use aes_gcm::{
aead::{Aead, KeyInit, OsRng},
Aes256Gcm, Nonce,
};
use anyhow::{anyhow, Context, Result};
use argon2::{
password_hash::{rand_core::RngCore, SaltString},
Argon2, PasswordHasher,
};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
const NONCE_SIZE: usize = 12;
#[derive(Debug, Serialize, Deserialize)]
pub struct EncryptedCredentials {
pub version: u32,
pub salt: String,
pub credentials: HashMap<String, EncryptedToken>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedToken {
pub nonce: String,
pub ciphertext: String,
}
impl Default for EncryptedCredentials {
fn default() -> Self {
Self {
version: 1,
salt: String::new(),
credentials: HashMap::new(),
}
}
}
pub fn derive_key() -> Result<[u8; 32]> {
let machine_id = machine_uid::get().map_err(|e| anyhow!("Failed to get machine ID: {}", e))?;
let username = whoami::username().unwrap_or_else(|_| "unknown".to_string());
let password = format!("{}:{}", machine_id, username);
let salt_string = SaltString::encode_b64(machine_id.as_bytes())
.map_err(|e| anyhow!("Failed to encode salt: {}", e))?;
let argon2 = Argon2::default();
let hash = argon2
.hash_password(password.as_bytes(), &salt_string)
.map_err(|e| anyhow!("Failed to hash password: {}", e))?;
let hash_bytes = hash.hash.ok_or_else(|| anyhow!("Hash output is missing"))?;
let mut key = [0u8; 32];
key.copy_from_slice(&hash_bytes.as_bytes()[..32]);
Ok(key)
}
pub fn encrypt(plaintext: &str, key: &[u8; 32]) -> Result<(String, String)> {
let cipher = Aes256Gcm::new(key.into());
let mut nonce_bytes = [0u8; NONCE_SIZE];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext.as_bytes())
.map_err(|e| anyhow!("Encryption failed: {}", e))?;
let nonce_b64 = BASE64.encode(nonce_bytes);
let ciphertext_b64 = BASE64.encode(ciphertext);
Ok((nonce_b64, ciphertext_b64))
}
pub fn decrypt(ciphertext_b64: &str, nonce_b64: &str, key: &[u8; 32]) -> Result<String> {
let cipher = Aes256Gcm::new(key.into());
let nonce_bytes = BASE64
.decode(nonce_b64)
.context("Failed to decode nonce from base64")?;
let ciphertext = BASE64
.decode(ciphertext_b64)
.context("Failed to decode ciphertext from base64")?;
if nonce_bytes.len() != NONCE_SIZE {
return Err(anyhow!(
"Invalid nonce size: expected {}, got {}",
NONCE_SIZE,
nonce_bytes.len()
));
}
let nonce = Nonce::from_slice(&nonce_bytes);
let plaintext = cipher
.decrypt(nonce, ciphertext.as_ref())
.map_err(|e| anyhow!("Decryption failed: {}", e))?;
String::from_utf8(plaintext).context("Decrypted data is not valid UTF-8")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_derive_key_deterministic() {
let key1 = derive_key().expect("Failed to derive key");
let key2 = derive_key().expect("Failed to derive key");
assert_eq!(key1, key2, "Key derivation should be deterministic");
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let key = derive_key().expect("Failed to derive key");
let plaintext = "my-secret-token-12345";
let (nonce, ciphertext) = encrypt(plaintext, &key).expect("Encryption failed");
assert_ne!(ciphertext, plaintext);
assert!(!ciphertext.contains("secret"));
let decrypted = decrypt(&ciphertext, &nonce, &key).expect("Decryption failed");
assert_eq!(decrypted, plaintext, "Decrypted text should match original");
}
#[test]
fn test_encrypt_produces_different_ciphertext() {
let key = derive_key().expect("Failed to derive key");
let plaintext = "same-plaintext";
let (nonce1, ciphertext1) = encrypt(plaintext, &key).expect("Encryption failed");
let (nonce2, ciphertext2) = encrypt(plaintext, &key).expect("Encryption failed");
assert_ne!(nonce1, nonce2, "Nonces should be randomly generated");
assert_ne!(
ciphertext1, ciphertext2,
"Ciphertexts should differ with different nonces"
);
assert_eq!(decrypt(&ciphertext1, &nonce1, &key).unwrap(), plaintext);
assert_eq!(decrypt(&ciphertext2, &nonce2, &key).unwrap(), plaintext);
}
#[test]
fn test_decrypt_with_wrong_key_fails() {
let key1 = derive_key().expect("Failed to derive key");
let mut key2 = key1;
key2[0] ^= 0xFF;
let plaintext = "secret-data";
let (nonce, ciphertext) = encrypt(plaintext, &key1).expect("Encryption failed");
let result = decrypt(&ciphertext, &nonce, &key2);
assert!(result.is_err(), "Decryption with wrong key should fail");
}
#[test]
fn test_decrypt_with_wrong_nonce_fails() {
let key = derive_key().expect("Failed to derive key");
let plaintext = "secret-data";
let (_, ciphertext) = encrypt(plaintext, &key).expect("Encryption failed");
let (wrong_nonce, _) = encrypt("other", &key).expect("Encryption failed");
let result = decrypt(&ciphertext, &wrong_nonce, &key);
assert!(result.is_err(), "Decryption with wrong nonce should fail");
}
#[test]
fn test_encrypted_credentials_serialization() {
let mut creds = EncryptedCredentials {
salt: "test-salt".to_string(),
..Default::default()
};
creds.credentials.insert(
"account1".to_string(),
EncryptedToken {
nonce: "nonce-b64".to_string(),
ciphertext: "cipher-b64".to_string(),
},
);
let json = serde_json::to_string(&creds).expect("Serialization failed");
let deserialized: EncryptedCredentials =
serde_json::from_str(&json).expect("Deserialization failed");
assert_eq!(deserialized.version, 1);
assert_eq!(deserialized.salt, "test-salt");
assert_eq!(deserialized.credentials.len(), 1);
}
}