use aes_gcm_siv::Aes256GcmSiv;
use aes_gcm_siv::aead::generic_array::GenericArray;
use aes_gcm_siv::aead::{Aead, KeyInit};
use chacha20poly1305::ChaCha20Poly1305;
use hkdf::Hkdf;
use rand::RngCore;
use sha2::Sha256;
use crate::error::{DbxError, DbxResult};
const NONCE_SIZE: usize = 12;
const KEY_SIZE: usize = 32;
const HKDF_INFO: &[u8] = b"dbx-encryption-v1";
const HKDF_SALT: &[u8] = b"dbx-default-salt-v1";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum EncryptionAlgorithm {
#[default]
Aes256GcmSiv,
ChaCha20Poly1305,
}
impl std::fmt::Display for EncryptionAlgorithm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Aes256GcmSiv => write!(f, "AES-256-GCM-SIV"),
Self::ChaCha20Poly1305 => write!(f, "ChaCha20-Poly1305"),
}
}
}
impl EncryptionAlgorithm {
pub const ALL: &'static [EncryptionAlgorithm] = &[
EncryptionAlgorithm::Aes256GcmSiv,
EncryptionAlgorithm::ChaCha20Poly1305,
];
}
#[derive(Clone)]
pub struct EncryptionConfig {
algorithm: EncryptionAlgorithm,
key: [u8; KEY_SIZE],
}
impl std::fmt::Debug for EncryptionConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptionConfig")
.field("algorithm", &self.algorithm)
.field("key", &"[REDACTED]")
.finish()
}
}
impl EncryptionConfig {
pub fn from_password(password: &str) -> Self {
Self::from_password_with_algorithm(password, EncryptionAlgorithm::default())
}
pub fn from_password_with_algorithm(password: &str, algorithm: EncryptionAlgorithm) -> Self {
let key = Self::derive_key(password.as_bytes());
Self { algorithm, key }
}
pub fn from_key(key: [u8; KEY_SIZE]) -> Self {
Self {
algorithm: EncryptionAlgorithm::default(),
key,
}
}
pub fn from_key_with_algorithm(key: [u8; KEY_SIZE], algorithm: EncryptionAlgorithm) -> Self {
Self { algorithm, key }
}
pub fn with_algorithm(mut self, algorithm: EncryptionAlgorithm) -> Self {
self.algorithm = algorithm;
self
}
pub fn algorithm(&self) -> EncryptionAlgorithm {
self.algorithm
}
pub fn encrypt(&self, plaintext: &[u8]) -> DbxResult<Vec<u8>> {
let mut nonce_bytes = [0u8; NONCE_SIZE];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = GenericArray::from_slice(&nonce_bytes);
let ciphertext = match self.algorithm {
EncryptionAlgorithm::Aes256GcmSiv => {
let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&self.key));
cipher.encrypt(nonce, plaintext).map_err(|e| {
DbxError::Encryption(format!("AES-GCM-SIV encrypt failed: {}", e))
})?
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
let cipher = ChaCha20Poly1305::new(GenericArray::from_slice(&self.key));
cipher
.encrypt(nonce, plaintext)
.map_err(|e| DbxError::Encryption(format!("ChaCha20 encrypt failed: {}", e)))?
}
};
let mut output = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
output.extend_from_slice(&nonce_bytes);
output.extend_from_slice(&ciphertext);
Ok(output)
}
pub fn decrypt(&self, encrypted: &[u8]) -> DbxResult<Vec<u8>> {
if encrypted.len() < NONCE_SIZE {
return Err(DbxError::Encryption(
"encrypted data too short (missing nonce)".to_string(),
));
}
let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_SIZE);
let nonce = GenericArray::from_slice(nonce_bytes);
match self.algorithm {
EncryptionAlgorithm::Aes256GcmSiv => {
let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&self.key));
cipher
.decrypt(nonce, ciphertext)
.map_err(|e| DbxError::Encryption(format!("AES-GCM-SIV decrypt failed: {}", e)))
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
let cipher = ChaCha20Poly1305::new(GenericArray::from_slice(&self.key));
cipher
.decrypt(nonce, ciphertext)
.map_err(|e| DbxError::Encryption(format!("ChaCha20 decrypt failed: {}", e)))
}
}
}
pub fn encrypt_with_aad(&self, plaintext: &[u8], aad: &[u8]) -> DbxResult<Vec<u8>> {
use aes_gcm_siv::aead::Payload;
let mut nonce_bytes = [0u8; NONCE_SIZE];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = GenericArray::from_slice(&nonce_bytes);
let payload = Payload {
msg: plaintext,
aad,
};
let ciphertext = match self.algorithm {
EncryptionAlgorithm::Aes256GcmSiv => {
let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&self.key));
cipher.encrypt(nonce, payload).map_err(|e| {
DbxError::Encryption(format!("AES-GCM-SIV encrypt+AAD failed: {}", e))
})?
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
let cipher = ChaCha20Poly1305::new(GenericArray::from_slice(&self.key));
cipher.encrypt(nonce, payload).map_err(|e| {
DbxError::Encryption(format!("ChaCha20 encrypt+AAD failed: {}", e))
})?
}
};
let mut output = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
output.extend_from_slice(&nonce_bytes);
output.extend_from_slice(&ciphertext);
Ok(output)
}
pub fn decrypt_with_aad(&self, encrypted: &[u8], aad: &[u8]) -> DbxResult<Vec<u8>> {
use aes_gcm_siv::aead::Payload;
if encrypted.len() < NONCE_SIZE {
return Err(DbxError::Encryption(
"encrypted data too short (missing nonce)".to_string(),
));
}
let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_SIZE);
let nonce = GenericArray::from_slice(nonce_bytes);
let payload = Payload {
msg: ciphertext,
aad,
};
match self.algorithm {
EncryptionAlgorithm::Aes256GcmSiv => {
let cipher = Aes256GcmSiv::new(GenericArray::from_slice(&self.key));
cipher.decrypt(nonce, payload).map_err(|e| {
DbxError::Encryption(format!("AES-GCM-SIV decrypt+AAD failed: {}", e))
})
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
let cipher = ChaCha20Poly1305::new(GenericArray::from_slice(&self.key));
cipher.decrypt(nonce, payload).map_err(|e| {
DbxError::Encryption(format!("ChaCha20 decrypt+AAD failed: {}", e))
})
}
}
}
fn derive_key(input: &[u8]) -> [u8; KEY_SIZE] {
let hk = Hkdf::<Sha256>::new(Some(HKDF_SALT), input);
let mut key = [0u8; KEY_SIZE];
hk.expand(HKDF_INFO, &mut key)
.expect("HKDF expand should never fail for 32-byte output");
key
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_algorithm_is_aes_gcm_siv() {
let config = EncryptionConfig::from_password("test");
assert_eq!(config.algorithm(), EncryptionAlgorithm::Aes256GcmSiv);
}
#[test]
fn round_trip_aes_gcm_siv() {
let config = EncryptionConfig::from_password("test-password");
let plaintext = b"Hello, DBX encryption!";
let encrypted = config.encrypt(plaintext).unwrap();
assert_ne!(encrypted, plaintext);
assert!(encrypted.len() > plaintext.len());
let decrypted = config.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn round_trip_chacha20() {
let config = EncryptionConfig::from_password("test-password")
.with_algorithm(EncryptionAlgorithm::ChaCha20Poly1305);
let plaintext = b"Hello, ChaCha20!";
let encrypted = config.encrypt(plaintext).unwrap();
let decrypted = config.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn round_trip_all_algorithms() {
let plaintext = b"Testing all algorithms";
for algo in EncryptionAlgorithm::ALL {
let config = EncryptionConfig::from_password("pw").with_algorithm(*algo);
let encrypted = config.encrypt(plaintext).unwrap();
let decrypted = config.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext, "Round-trip failed for {:?}", algo);
}
}
#[test]
fn from_raw_key() {
let key = [0xABu8; KEY_SIZE];
let config = EncryptionConfig::from_key(key);
let plaintext = b"raw key test";
let encrypted = config.encrypt(plaintext).unwrap();
let decrypted = config.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn wrong_password_fails() {
let config1 = EncryptionConfig::from_password("correct-password");
let config2 = EncryptionConfig::from_password("wrong-password");
let plaintext = b"secret data";
let encrypted = config1.encrypt(plaintext).unwrap();
let result = config2.decrypt(&encrypted);
assert!(result.is_err(), "Decryption with wrong key should fail");
}
#[test]
fn wrong_algorithm_fails() {
let config_aes = EncryptionConfig::from_password("same-password");
let config_chacha = EncryptionConfig::from_password("same-password")
.with_algorithm(EncryptionAlgorithm::ChaCha20Poly1305);
let plaintext = b"algorithm mismatch test";
let encrypted = config_aes.encrypt(plaintext).unwrap();
let result = config_chacha.decrypt(&encrypted);
assert!(
result.is_err(),
"Decryption with wrong algorithm should fail"
);
}
#[test]
fn tampered_data_fails() {
let config = EncryptionConfig::from_password("test");
let plaintext = b"tamper test";
let mut encrypted = config.encrypt(plaintext).unwrap();
let last = encrypted.len() - 1;
encrypted[last] ^= 0xFF;
let result = config.decrypt(&encrypted);
assert!(result.is_err(), "Tampered data should fail authentication");
}
#[test]
fn too_short_data_fails() {
let config = EncryptionConfig::from_password("test");
let result = config.decrypt(&[0u8; 5]);
assert!(result.is_err());
}
#[test]
fn empty_plaintext() {
let config = EncryptionConfig::from_password("test");
let plaintext = b"";
let encrypted = config.encrypt(plaintext).unwrap();
let decrypted = config.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn large_data_round_trip() {
let config = EncryptionConfig::from_password("test");
let plaintext: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
let encrypted = config.encrypt(&plaintext).unwrap();
let decrypted = config.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn each_encrypt_produces_different_output() {
let config = EncryptionConfig::from_password("test");
let plaintext = b"same input";
let enc1 = config.encrypt(plaintext).unwrap();
let enc2 = config.encrypt(plaintext).unwrap();
assert_ne!(enc1, enc2, "Each encryption should use a fresh nonce");
assert_eq!(config.decrypt(&enc1).unwrap(), plaintext);
assert_eq!(config.decrypt(&enc2).unwrap(), plaintext);
}
#[test]
fn aad_round_trip() {
let config = EncryptionConfig::from_password("test");
let plaintext = b"sensitive data";
let aad = b"table:users,column:email";
let encrypted = config.encrypt_with_aad(plaintext, aad).unwrap();
let decrypted = config.decrypt_with_aad(&encrypted, aad).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn aad_mismatch_fails() {
let config = EncryptionConfig::from_password("test");
let plaintext = b"sensitive data";
let aad = b"table:users";
let encrypted = config.encrypt_with_aad(plaintext, aad).unwrap();
let result = config.decrypt_with_aad(&encrypted, b"table:orders");
assert!(result.is_err(), "Wrong AAD should fail authentication");
}
#[test]
fn display_names() {
assert_eq!(
format!("{}", EncryptionAlgorithm::Aes256GcmSiv),
"AES-256-GCM-SIV"
);
assert_eq!(
format!("{}", EncryptionAlgorithm::ChaCha20Poly1305),
"ChaCha20-Poly1305"
);
}
#[test]
fn all_algorithms_count() {
assert_eq!(EncryptionAlgorithm::ALL.len(), 2);
}
#[test]
fn debug_redacts_key() {
let config = EncryptionConfig::from_password("secret");
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("REDACTED"));
assert!(!debug_str.contains("secret"));
}
#[test]
fn wire_format_structure() {
let config = EncryptionConfig::from_password("test");
let plaintext = b"hello";
let encrypted = config.encrypt(plaintext).unwrap();
assert_eq!(
encrypted.len(),
NONCE_SIZE + plaintext.len() + 16, "Wire format should be nonce + plaintext + tag"
);
}
}