use aes_gcm::aead::{Aead, KeyInit};
use aes_gcm::{Aes256Gcm, Key, Nonce};
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine as _;
use rand::RngCore;
pub const SECRET_ENCRYPTION_KEY_ENV: &str = "LLMTRACE_SECRET_ENCRYPTION_KEY";
const KEY_LEN: usize = 32;
const NONCE_LEN: usize = 12;
#[derive(Debug, thiserror::Error)]
pub enum SecretBoxError {
#[error("{SECRET_ENCRYPTION_KEY_ENV} is not set")]
KeyMissing,
#[error(
"{SECRET_ENCRYPTION_KEY_ENV} must decode to {KEY_LEN} bytes (64-char hex or base64); {0}"
)]
KeyInvalid(String),
#[error("encryption failed")]
EncryptFailed,
#[error("ciphertext is malformed: {0}")]
CiphertextMalformed(String),
#[error("decryption failed: ciphertext could not be authenticated")]
DecryptFailed,
}
pub struct SecretBox {
cipher: Aes256Gcm,
}
impl std::fmt::Debug for SecretBox {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SecretBox").finish_non_exhaustive()
}
}
impl SecretBox {
pub fn from_env() -> Result<Self, SecretBoxError> {
let raw =
std::env::var(SECRET_ENCRYPTION_KEY_ENV).map_err(|_| SecretBoxError::KeyMissing)?;
Self::from_master_key_str(&raw)
}
pub fn from_master_key_str(raw: &str) -> Result<Self, SecretBoxError> {
let key_bytes = decode_master_key(raw)?;
let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
Ok(Self {
cipher: Aes256Gcm::new(key),
})
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<String, SecretBoxError> {
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = self
.cipher
.encrypt(nonce, plaintext)
.map_err(|_| SecretBoxError::EncryptFailed)?;
let mut combined = Vec::with_capacity(NONCE_LEN + ciphertext.len());
combined.extend_from_slice(&nonce_bytes);
combined.extend_from_slice(&ciphertext);
Ok(BASE64.encode(combined))
}
pub fn decrypt(&self, encoded: &str) -> Result<Vec<u8>, SecretBoxError> {
let combined = BASE64
.decode(encoded.as_bytes())
.map_err(|e| SecretBoxError::CiphertextMalformed(e.to_string()))?;
if combined.len() <= NONCE_LEN {
return Err(SecretBoxError::CiphertextMalformed(
"shorter than nonce".to_string(),
));
}
let (nonce_bytes, ciphertext) = combined.split_at(NONCE_LEN);
let nonce = Nonce::from_slice(nonce_bytes);
self.cipher
.decrypt(nonce, ciphertext)
.map_err(|_| SecretBoxError::DecryptFailed)
}
}
fn decode_master_key(raw: &str) -> Result<[u8; KEY_LEN], SecretBoxError> {
let trimmed = raw.trim();
if let Ok(bytes) = hex::decode(trimmed) {
return to_key_array(bytes);
}
if let Ok(bytes) = BASE64.decode(trimmed.as_bytes()) {
return to_key_array(bytes);
}
Err(SecretBoxError::KeyInvalid(
"not valid hex or base64".to_string(),
))
}
fn to_key_array(bytes: Vec<u8>) -> Result<[u8; KEY_LEN], SecretBoxError> {
let len = bytes.len();
<[u8; KEY_LEN]>::try_from(bytes.as_slice())
.map_err(|_| SecretBoxError::KeyInvalid(format!("decoded to {len} bytes, need {KEY_LEN}")))
}
#[cfg(test)]
mod tests {
use super::*;
const HEX_KEY: &str = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f";
fn box_from_hex() -> SecretBox {
SecretBox::from_master_key_str(HEX_KEY).unwrap()
}
#[test]
fn test_round_trip_hex_key() {
let sb = box_from_hex();
let secret = b"sk-super-secret-provider-key";
let ct = sb.encrypt(secret).unwrap();
let pt = sb.decrypt(&ct).unwrap();
assert_eq!(pt, secret);
}
#[test]
fn test_round_trip_base64_key() {
let b64_key = BASE64.encode([0u8; KEY_LEN]);
let sb = SecretBox::from_master_key_str(&b64_key).unwrap();
let secret = b"another-secret";
let ct = sb.encrypt(secret).unwrap();
assert_eq!(sb.decrypt(&ct).unwrap(), secret);
}
#[test]
fn test_ciphertext_is_not_plaintext() {
let sb = box_from_hex();
let secret = b"do-not-leak";
let ct = sb.encrypt(secret).unwrap();
assert!(!ct.as_bytes().windows(secret.len()).any(|w| w == secret));
}
#[test]
fn test_nonce_is_random_per_encryption() {
let sb = box_from_hex();
let secret = b"same-input";
let a = sb.encrypt(secret).unwrap();
let b = sb.encrypt(secret).unwrap();
assert_ne!(a, b, "two encryptions of the same plaintext must differ");
}
#[test]
fn test_wrong_key_fails_to_decrypt() {
let sb = box_from_hex();
let ct = sb.encrypt(b"secret").unwrap();
let other = SecretBox::from_master_key_str(&BASE64.encode([0xAAu8; KEY_LEN])).unwrap();
assert!(matches!(
other.decrypt(&ct),
Err(SecretBoxError::DecryptFailed)
));
}
#[test]
fn test_tampered_ciphertext_fails() {
let sb = box_from_hex();
let ct = sb.encrypt(b"secret").unwrap();
let mut bytes = BASE64.decode(ct.as_bytes()).unwrap();
let last = bytes.len() - 1;
bytes[last] ^= 0xFF;
let tampered = BASE64.encode(bytes);
assert!(matches!(
sb.decrypt(&tampered),
Err(SecretBoxError::DecryptFailed)
));
}
#[test]
fn test_short_key_rejected() {
let err = SecretBox::from_master_key_str("00112233445566778899aabbccddeeff").unwrap_err();
assert!(matches!(err, SecretBoxError::KeyInvalid(_)));
}
#[test]
fn test_garbage_key_rejected() {
let err = SecretBox::from_master_key_str("not-a-key!!!").unwrap_err();
assert!(matches!(err, SecretBoxError::KeyInvalid(_)));
}
#[test]
fn test_malformed_ciphertext_rejected() {
let sb = box_from_hex();
assert!(matches!(
sb.decrypt("not-base64-???"),
Err(SecretBoxError::CiphertextMalformed(_))
));
let short = BASE64.encode([0u8; 4]);
assert!(matches!(
sb.decrypt(&short),
Err(SecretBoxError::CiphertextMalformed(_))
));
}
}