use aes_gcm::{
Aes256Gcm, KeyInit, Nonce,
aead::{Aead, AeadCore, OsRng},
};
use hkdf::Hkdf;
use secrecy::{ExposeSecret, SecretString};
use sha2::Sha256;
use crate::secrets::types::{DecryptedSecret, SecretError};
const KEY_SIZE: usize = 32;
const NONCE_SIZE: usize = 12;
const SALT_SIZE: usize = 32;
const TAG_SIZE: usize = 16;
pub struct SecretsCrypto {
master_key: SecretString,
}
impl SecretsCrypto {
pub fn new(master_key: SecretString) -> Result<Self, SecretError> {
if master_key.expose_secret().len() < KEY_SIZE {
return Err(SecretError::InvalidMasterKey);
}
Ok(Self { master_key })
}
pub fn generate_salt() -> Vec<u8> {
let mut salt = vec![0u8; SALT_SIZE];
rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut salt);
salt
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<(Vec<u8>, Vec<u8>), SecretError> {
let salt = Self::generate_salt();
let derived_key = self.derive_key(&salt)?;
let cipher = Aes256Gcm::new_from_slice(&derived_key).map_err(|e| {
SecretError::EncryptionFailed(format!("Failed to create cipher: {}", e))
})?;
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = cipher
.encrypt(&nonce, plaintext)
.map_err(|e| SecretError::EncryptionFailed(format!("Encryption failed: {}", e)))?;
let mut encrypted = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
encrypted.extend_from_slice(&nonce);
encrypted.extend_from_slice(&ciphertext);
Ok((encrypted, salt))
}
pub fn decrypt(
&self,
encrypted_value: &[u8],
salt: &[u8],
) -> Result<DecryptedSecret, SecretError> {
if encrypted_value.len() < NONCE_SIZE + TAG_SIZE {
return Err(SecretError::DecryptionFailed(
"Encrypted value too short".to_string(),
));
}
let derived_key = self.derive_key(salt)?;
let cipher = Aes256Gcm::new_from_slice(&derived_key).map_err(|e| {
SecretError::DecryptionFailed(format!("Failed to create cipher: {}", e))
})?;
let (nonce_bytes, ciphertext) = encrypted_value.split_at(NONCE_SIZE);
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = cipher
.decrypt(nonce, ciphertext)
.map_err(|e| SecretError::DecryptionFailed(format!("Decryption failed: {}", e)))?;
DecryptedSecret::from_bytes(plaintext)
}
fn derive_key(&self, salt: &[u8]) -> Result<[u8; KEY_SIZE], SecretError> {
let master_bytes = self.master_key.expose_secret().as_bytes();
let hk = Hkdf::<Sha256>::new(Some(salt), master_bytes);
let mut derived = [0u8; KEY_SIZE];
hk.expand(b"near-agent-secrets-v1", &mut derived)
.map_err(|_| SecretError::EncryptionFailed("HKDF expansion failed".to_string()))?;
Ok(derived)
}
}
impl std::fmt::Debug for SecretsCrypto {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SecretsCrypto")
.field("master_key", &"[REDACTED]")
.finish()
}
}
#[cfg(test)]
mod tests {
use secrecy::SecretString;
use crate::secrets::crypto::SecretsCrypto;
fn test_crypto() -> SecretsCrypto {
let key = "0123456789abcdef0123456789abcdef";
SecretsCrypto::new(SecretString::from(key.to_string())).unwrap()
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let crypto = test_crypto();
let plaintext = b"my_super_secret_api_key_12345";
let (encrypted, salt) = crypto.encrypt(plaintext).unwrap();
assert!(encrypted.len() > plaintext.len());
let decrypted = crypto.decrypt(&encrypted, &salt).unwrap();
assert_eq!(decrypted.expose().as_bytes(), plaintext);
}
#[test]
fn test_different_salts_different_ciphertext() {
let crypto = test_crypto();
let plaintext = b"same_secret";
let (encrypted1, salt1) = crypto.encrypt(plaintext).unwrap();
let (encrypted2, salt2) = crypto.encrypt(plaintext).unwrap();
assert_ne!(salt1, salt2);
assert_ne!(encrypted1, encrypted2);
let decrypted1 = crypto.decrypt(&encrypted1, &salt1).unwrap();
let decrypted2 = crypto.decrypt(&encrypted2, &salt2).unwrap();
assert_eq!(decrypted1.expose(), decrypted2.expose());
}
#[test]
fn test_wrong_salt_fails() {
let crypto = test_crypto();
let plaintext = b"secret";
let (encrypted, _salt) = crypto.encrypt(plaintext).unwrap();
let wrong_salt = SecretsCrypto::generate_salt();
let result = crypto.decrypt(&encrypted, &wrong_salt);
assert!(result.is_err());
}
#[test]
fn test_tampered_ciphertext_fails() {
let crypto = test_crypto();
let plaintext = b"secret";
let (mut encrypted, salt) = crypto.encrypt(plaintext).unwrap();
if let Some(byte) = encrypted.last_mut() {
*byte ^= 0xFF;
}
let result = crypto.decrypt(&encrypted, &salt);
assert!(result.is_err());
}
#[test]
fn test_master_key_too_short() {
let short_key = "tooshort";
let result = SecretsCrypto::new(SecretString::from(short_key.to_string()));
assert!(result.is_err());
}
#[test]
fn test_empty_plaintext() {
let crypto = test_crypto();
let plaintext = b"";
let (encrypted, salt) = crypto.encrypt(plaintext).unwrap();
let decrypted = crypto.decrypt(&encrypted, &salt).unwrap();
assert!(decrypted.is_empty());
}
#[test]
fn test_large_plaintext() {
let crypto = test_crypto();
let plaintext = vec![0x42u8; 1024 * 1024];
let (encrypted, salt) = crypto.encrypt(&plaintext).unwrap();
let decrypted = crypto.decrypt(&encrypted, &salt).unwrap();
assert_eq!(decrypted.expose().as_bytes(), plaintext.as_slice());
}
}