use super::CredentialBackend;
use super::types::SecretPasswordSource;
use crate::error::{Error, Result};
use crate::utils::sync::RwLockExt;
use aes_gcm::{
Aes256Gcm, Nonce,
aead::{Aead, KeyInit},
};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use log::debug;
use rand::RngExt;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::{Mutex, RwLock};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct EncryptedEntry {
nonce: String,
ciphertext: String,
}
#[derive(Debug, Default, Serialize, Deserialize)]
struct EncryptedStore {
version: u32,
#[serde(default)]
salt: Option<String>,
entries: HashMap<String, EncryptedEntry>,
}
pub struct EncryptedFileBackend {
path: PathBuf,
cipher: Aes256Gcm,
salt: [u8; 16],
cache: RwLock<HashMap<String, String>>,
write_lock: Mutex<()>,
}
impl EncryptedFileBackend {
pub(crate) fn new(path: PathBuf, key: &[u8; 32], salt: [u8; 16]) -> Result<Self> {
Ok(Self {
path,
cipher: Aes256Gcm::new_from_slice(key)
.map_err(|_| Error::Credential("Invalid encryption key length".into()))?,
salt,
cache: RwLock::new(HashMap::new()),
write_lock: Mutex::new(()),
})
}
pub fn with_source(path: PathBuf, source: &SecretPasswordSource) -> Result<Self> {
Self::with_password(path, &source.resolve()?)
}
pub fn with_password(path: PathBuf, password: &str) -> Result<Self> {
let salt = Self::read_salt(&path)?.unwrap_or_else(Self::generate_salt);
let key = Self::derive_key(password, &salt)?;
Self::new(path, &key, salt)
}
pub fn read_salt(path: &Path) -> Result<Option<[u8; 16]>> {
if !path.exists() {
return Ok(None);
}
let content = fs::read_to_string(path).map_err(|e| Error::FileRead {
path: path.to_path_buf(),
source: e,
})?;
let store: EncryptedStore = serde_json::from_str(&content)
.map_err(|e| Error::Credential(format!("Failed to parse encrypted store: {e}")))?;
let Some(salt_b64) = store.salt else {
return Ok(None);
};
let salt_vec = BASE64
.decode(&salt_b64)
.map_err(|e| Error::Credential(format!("Invalid salt encoding: {e}")))?;
if salt_vec.len() != 16 {
return Err(Error::Credential(format!(
"Invalid salt length: expected 16, got {}",
salt_vec.len()
)));
}
let mut salt = [0u8; 16];
salt.copy_from_slice(&salt_vec);
Ok(Some(salt))
}
#[must_use]
pub fn generate_key() -> [u8; 32] {
rand::rng().random()
}
#[must_use]
pub fn generate_salt() -> [u8; 16] {
rand::rng().random()
}
pub fn derive_key(password: &str, salt: &[u8]) -> Result<[u8; 32]> {
use argon2::{
Argon2,
password_hash::{PasswordHasher, SaltString},
};
let salt_string = SaltString::encode_b64(salt)
.map_err(|e| Error::Credential(format!("Invalid salt bytes: {e}")))?;
let password_hash = Argon2::default()
.hash_password(password.as_bytes(), &salt_string)
.map_err(|e| Error::Credential(format!("Argon2 hashing failed: {e}")))?;
let output = password_hash
.hash
.ok_or_else(|| Error::Credential("Argon2 hash output missing".into()))?;
let bytes = output.as_bytes();
if bytes.len() < 32 {
return Err(Error::Credential(format!(
"Argon2 output too short: {}",
bytes.len()
)));
}
let mut key = [0u8; 32];
key.copy_from_slice(&bytes[..32]);
Ok(key)
}
fn load_store(&self) -> Result<EncryptedStore> {
if !self.path.exists() {
return Ok(EncryptedStore::default());
}
let content = fs::read_to_string(&self.path).map_err(|e| Error::FileRead {
path: self.path.clone(),
source: e,
})?;
serde_json::from_str(&content)
.map_err(|e| Error::Credential(format!("Failed to parse encrypted store: {e}")))
}
fn save_store(&self, entries: HashMap<String, EncryptedEntry>) -> Result<()> {
let store = EncryptedStore {
version: 1,
salt: Some(BASE64.encode(self.salt)),
entries,
};
let content = serde_json::to_string_pretty(&store)
.map_err(|e| Error::Credential(format!("Failed to serialize encrypted store: {e}")))?;
if let Some(parent) = self.path.parent() {
crate::utils::security::ensure_secure_dir(parent)?;
}
let mut temp_path = self.path.clone();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let file_name = self.path.file_name().unwrap_or_default().to_string_lossy();
temp_path.set_file_name(format!("{file_name}.{now}.tmp"));
let mut temp_file = fs::File::create(&temp_path).map_err(|e| Error::FileWrite {
path: temp_path.clone(),
source: e,
})?;
temp_file
.write_all(content.as_bytes())
.map_err(|e| Error::FileWrite {
path: temp_path.clone(),
source: e,
})?;
temp_file.sync_all().map_err(|e| Error::FileWrite {
path: temp_path.clone(),
source: e,
})?;
crate::utils::security::set_secure_file_permissions(&temp_path)?;
fs::rename(&temp_path, &self.path).map_err(|e| Error::FileWrite {
path: self.path.clone(),
source: e,
})?;
crate::utils::security::set_secure_file_permissions(&self.path)?;
Ok(())
}
fn encrypt(&self, plaintext: &str) -> Result<EncryptedEntry> {
let nonce_bytes: [u8; 12] = rand::rng().random();
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = self
.cipher
.encrypt(nonce, plaintext.as_bytes())
.map_err(|e| Error::Credential(format!("Encryption failed: {e}")))?;
Ok(EncryptedEntry {
nonce: BASE64.encode(nonce_bytes),
ciphertext: BASE64.encode(&ciphertext),
})
}
fn decrypt(&self, entry: &EncryptedEntry) -> Result<String> {
let nonce_bytes = BASE64
.decode(&entry.nonce)
.map_err(|e| Error::Credential(format!("Invalid nonce encoding: {e}")))?;
let ciphertext = BASE64
.decode(&entry.ciphertext)
.map_err(|e| Error::Credential(format!("Invalid ciphertext encoding: {e}")))?;
let nonce = Nonce::from_slice(&nonce_bytes);
let plaintext = self
.cipher
.decrypt(nonce, ciphertext.as_ref())
.map_err(|_| Error::Credential("Decryption failed (wrong key?)".into()))?;
String::from_utf8(plaintext)
.map_err(|e| Error::Credential(format!("Decrypted data is not valid UTF-8: {e}")))
}
}
impl CredentialBackend for EncryptedFileBackend {
fn store(&self, key: &str, value: &str) -> Result<()> {
let _guard = self
.write_lock
.lock()
.map_err(|_| Error::Credential("Encrypted file write lock poisoned".into()))?;
let mut store = self.load_store()?;
let encrypted = self.encrypt(value)?;
store.entries.insert(key.to_string(), encrypted);
self.save_store(store.entries)?;
self.cache
.write_recovered()?
.insert(key.to_string(), value.to_string());
debug!("Credential stored in encrypted file: {key}");
Ok(())
}
fn get(&self, key: &str) -> Result<Option<String>> {
{
let cache = self.cache.read_recovered()?;
if let Some(value) = cache.get(key) {
return Ok(Some(value.clone()));
}
}
let store = self.load_store()?;
let Some(entry) = store.entries.get(key) else {
return Ok(None);
};
let value = self.decrypt(entry)?;
self.cache
.write_recovered()?
.insert(key.to_string(), value.clone());
debug!("Credential retrieved from encrypted file: {key}");
Ok(Some(value))
}
fn remove(&self, key: &str) -> Result<()> {
let _guard = self
.write_lock
.lock()
.map_err(|_| Error::Credential("Encrypted file write lock poisoned".into()))?;
let mut store = self.load_store()?;
store.entries.remove(key);
self.save_store(store.entries)?;
self.cache.write_recovered()?.remove(key);
debug!("Credential removed from encrypted file: {key}");
Ok(())
}
fn list_keys(&self) -> Result<Vec<String>> {
Ok(self.load_store()?.entries.into_keys().collect())
}
fn backend_name(&self) -> &'static str {
"encrypted_file"
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_encrypted_store_and_get() {
let temp = tempdir().unwrap();
let path = temp.path().join("credentials.enc.json");
let salt = EncryptedFileBackend::generate_salt();
let key = EncryptedFileBackend::generate_key();
let backend = EncryptedFileBackend::new(path.clone(), &key, salt).unwrap();
backend.store("api_key", "secret123").unwrap();
backend.store("password", "hunter2").unwrap();
let backend2 = EncryptedFileBackend::new(path, &key, salt).unwrap();
assert_eq!(
backend2.get("api_key").unwrap(),
Some("secret123".to_string())
);
assert_eq!(
backend2.get("password").unwrap(),
Some("hunter2".to_string())
);
}
#[test]
fn test_encrypted_wrong_key() {
let temp = tempdir().unwrap();
let path = temp.path().join("credentials.enc.json");
let salt = EncryptedFileBackend::generate_salt();
let key1 = EncryptedFileBackend::generate_key();
let key2 = EncryptedFileBackend::generate_key();
let backend1 = EncryptedFileBackend::new(path.clone(), &key1, salt).unwrap();
backend1.store("secret", "value").unwrap();
let backend2 = EncryptedFileBackend::new(path, &key2, salt).unwrap();
let result = backend2.get("secret");
assert!(result.is_err());
}
#[test]
fn test_with_password() {
let temp = tempdir().unwrap();
let path = temp.path().join("credentials.enc.json");
let backend = EncryptedFileBackend::with_password(path.clone(), "test_password").unwrap();
backend.store("api_key", "secret123").unwrap();
let backend2 = EncryptedFileBackend::with_password(path.clone(), "test_password").unwrap();
assert_eq!(
backend2.get("api_key").unwrap(),
Some("secret123".to_string())
);
let backend3 = EncryptedFileBackend::with_password(path, "wrong_password").unwrap();
assert!(backend3.get("api_key").is_err());
}
#[test]
fn test_derive_key() {
let salt = EncryptedFileBackend::generate_salt();
let key1 = EncryptedFileBackend::derive_key("password123", &salt).unwrap();
let key2 = EncryptedFileBackend::derive_key("password123", &salt).unwrap();
assert_eq!(key1, key2);
let key3 = EncryptedFileBackend::derive_key("different", &salt).unwrap();
assert_ne!(key1, key3);
let salt2 = EncryptedFileBackend::generate_salt();
let key4 = EncryptedFileBackend::derive_key("password123", &salt2).unwrap();
assert_ne!(key1, key4);
}
#[test]
fn test_concurrent_store_no_lost_writes() {
use std::sync::Arc;
use std::thread;
let temp = tempdir().unwrap();
let path = temp.path().join("credentials.enc.json");
let backend = Arc::new(EncryptedFileBackend::with_password(path, "password").unwrap());
let b1 = Arc::clone(&backend);
let b2 = Arc::clone(&backend);
let t1 = thread::spawn(move || b1.store("key_a", "value_a").unwrap());
let t2 = thread::spawn(move || b2.store("key_b", "value_b").unwrap());
t1.join().unwrap();
t2.join().unwrap();
assert_eq!(backend.get("key_a").unwrap(), Some("value_a".to_string()));
assert_eq!(backend.get("key_b").unwrap(), Some("value_b".to_string()));
}
}