use aes_gcm::aead::{Aead, KeyInit, OsRng};
use aes_gcm::{AeadCore, Aes256Gcm, Key, Nonce};
use base64::{Engine as _, engine::general_purpose};
#[derive(Debug, thiserror::Error)]
pub enum EncryptionError {
#[error("encryption key must be exactly 32 bytes (got {0})")]
InvalidKeyLength(usize),
#[error("encryption failed")]
EncryptionFailed,
#[error("decryption failed — wrong key or corrupted ciphertext")]
DecryptionFailed,
#[error("ciphertext too short to contain nonce + tag")]
CiphertextTooShort,
#[error("base64 decode failed: {0}")]
Base64Decode(#[from] base64::DecodeError),
}
pub fn encrypt(key_bytes: &[u8], plaintext: &[u8]) -> Result<Vec<u8>, EncryptionError> {
let key = parse_key(key_bytes)?;
let cipher = Aes256Gcm::new(&key);
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = cipher.encrypt(&nonce, plaintext).map_err(|_| EncryptionError::EncryptionFailed)?;
let mut out = Vec::with_capacity(12 + ciphertext.len());
out.extend_from_slice(&nonce);
out.extend_from_slice(&ciphertext);
Ok(out)
}
pub fn decrypt(key_bytes: &[u8], blob: &[u8]) -> Result<Vec<u8>, EncryptionError> {
if blob.len() < 12 + 16 {
return Err(EncryptionError::CiphertextTooShort);
}
let key = parse_key(key_bytes)?;
let cipher = Aes256Gcm::new(&key);
let (nonce_bytes, ciphertext) = blob.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
cipher.decrypt(nonce, ciphertext).map_err(|_| EncryptionError::DecryptionFailed)
}
pub fn encrypt_json(key_bytes: &[u8], value: &serde_json::Value) -> Result<Vec<u8>, EncryptionError> {
let plaintext = serde_json::to_vec(value).map_err(|_| EncryptionError::EncryptionFailed)?;
encrypt(key_bytes, &plaintext)
}
pub fn decrypt_json(key_bytes: &[u8], blob: &[u8]) -> Result<serde_json::Value, EncryptionError> {
let plaintext = decrypt(key_bytes, blob)?;
serde_json::from_slice(&plaintext).map_err(|_| EncryptionError::DecryptionFailed)
}
pub fn derive_encryption_key(secret: &str) -> Vec<u8> {
if let Ok(bytes) = general_purpose::STANDARD.decode(secret)
&& bytes.len() == 32
{
return bytes;
}
let bytes = secret.as_bytes();
if bytes.len() == 32 {
return bytes.to_vec();
}
use sha2::{Digest, Sha256};
Sha256::digest(bytes).to_vec()
}
fn parse_key(key_bytes: &[u8]) -> Result<Key<Aes256Gcm>, EncryptionError> {
if key_bytes.len() != 32 {
return Err(EncryptionError::InvalidKeyLength(key_bytes.len()));
}
Ok(*Key::<Aes256Gcm>::from_slice(key_bytes))
}
#[cfg(test)]
mod tests {
use super::*;
fn test_key() -> Vec<u8> {
vec![0u8; 32]
}
#[test]
fn roundtrip() {
let plaintext = b"hello world";
let blob = encrypt(&test_key(), plaintext).unwrap();
let decrypted = decrypt(&test_key(), &blob).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn roundtrip_json() {
let value = serde_json::json!({"bucket": "my-bucket", "secret": "s3cr3t"});
let blob = encrypt_json(&test_key(), &value).unwrap();
let decrypted = decrypt_json(&test_key(), &blob).unwrap();
assert_eq!(decrypted, value);
}
#[test]
fn wrong_key_fails() {
let blob = encrypt(&test_key(), b"secret").unwrap();
let wrong_key = vec![1u8; 32];
assert!(decrypt(&wrong_key, &blob).is_err());
}
#[test]
fn corrupted_ciphertext_fails() {
let mut blob = encrypt(&test_key(), b"secret").unwrap();
if let Some(last) = blob.last_mut() {
*last ^= 0xFF;
}
assert!(decrypt(&test_key(), &blob).is_err());
}
#[test]
fn too_short_fails() {
assert!(decrypt(&test_key(), &[0u8; 10]).is_err());
}
}