use std::sync::Arc;
use aes_gcm::aead::{Aead, KeyInit};
use aes_gcm::{Aes256Gcm, Key as AesKey, Nonce as AesNonce};
use chacha20poly1305::{ChaCha20Poly1305, Key as ChaChaKey, Nonce as ChaChaNonce};
use rand_core::{OsRng, RngCore};
use crate::{Error, Result};
pub(crate) const NONCE_LEN: usize = 12;
pub(crate) const TAG_LEN: usize = 16;
pub(crate) const ENCRYPTION_OVERHEAD: usize = NONCE_LEN + TAG_LEN;
pub(crate) const SALT_LEN: usize = 16;
pub(crate) const VERIFICATION_PLAINTEXT: &[u8; 32] =
b"EMDB-ENCRYPT-OK\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum Cipher {
#[default]
Aes256Gcm,
ChaCha20Poly1305,
}
#[derive(Clone)]
enum CipherImpl {
Aes(Box<Aes256Gcm>),
ChaCha(Box<ChaCha20Poly1305>),
}
#[derive(Clone)]
pub(crate) struct EncryptionContext {
cipher: CipherImpl,
kind: Cipher,
}
impl std::fmt::Debug for EncryptionContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptionContext")
.field("key", &"<redacted>")
.field("cipher", &self.kind)
.finish()
}
}
impl EncryptionContext {
pub(crate) fn from_key(key: &[u8; 32]) -> Self {
Self::from_key_with_cipher(key, Cipher::Aes256Gcm)
}
pub(crate) fn from_key_with_cipher(key: &[u8; 32], kind: Cipher) -> Self {
let cipher = match kind {
Cipher::Aes256Gcm => CipherImpl::Aes(Box::new(Aes256Gcm::new(
AesKey::<Aes256Gcm>::from_slice(key),
))),
Cipher::ChaCha20Poly1305 => {
CipherImpl::ChaCha(Box::new(ChaCha20Poly1305::new(ChaChaKey::from_slice(key))))
}
};
Self { cipher, kind }
}
pub(crate) fn kind(&self) -> Cipher {
self.kind
}
pub(crate) fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
let mut nonce_bytes = [0_u8; NONCE_LEN];
OsRng.fill_bytes(&mut nonce_bytes);
let ciphertext = match &self.cipher {
CipherImpl::Aes(c) => c
.encrypt(AesNonce::from_slice(&nonce_bytes), plaintext)
.map_err(|_| Error::Encryption("aead encrypt failed"))?,
CipherImpl::ChaCha(c) => c
.encrypt(ChaChaNonce::from_slice(&nonce_bytes), plaintext)
.map_err(|_| Error::Encryption("aead encrypt failed"))?,
};
let mut out = Vec::with_capacity(NONCE_LEN + ciphertext.len());
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(&ciphertext);
Ok(out)
}
pub(crate) fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
if encrypted.len() < NONCE_LEN + TAG_LEN {
return Err(Error::Encryption(
"encrypted buffer too short to hold nonce + tag",
));
}
let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_LEN);
match &self.cipher {
CipherImpl::Aes(c) => c
.decrypt(AesNonce::from_slice(nonce_bytes), ciphertext)
.map_err(|_| Error::EncryptionKeyMismatch),
CipherImpl::ChaCha(c) => c
.decrypt(ChaChaNonce::from_slice(nonce_bytes), ciphertext)
.map_err(|_| Error::EncryptionKeyMismatch),
}
}
}
pub(crate) type SharedEncryption = Option<Arc<EncryptionContext>>;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum EncryptionInput {
Key([u8; 32]),
Passphrase(String),
}
pub(crate) fn random_salt() -> [u8; SALT_LEN] {
let mut out = [0_u8; SALT_LEN];
OsRng.fill_bytes(&mut out);
out
}
pub(crate) fn derive_key_from_passphrase(
passphrase: &str,
salt: &[u8; SALT_LEN],
) -> Result<[u8; 32]> {
use argon2::{Algorithm, Argon2, Params, Version};
if passphrase.is_empty() {
return Err(Error::InvalidConfig(
"encryption_passphrase must not be empty",
));
}
let params = Params::new(19_456, 2, 1, Some(32))
.map_err(|_| Error::Encryption("argon2 params construction failed"))?;
let argon = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
let mut key = [0_u8; 32];
argon
.hash_password_into(passphrase.as_bytes(), salt, &mut key)
.map_err(|_| Error::Encryption("argon2 key derivation failed"))?;
Ok(key)
}
#[cfg(test)]
mod tests {
use super::{EncryptionContext, VERIFICATION_PLAINTEXT};
use crate::Error;
fn key_a() -> [u8; 32] {
let mut k = [0_u8; 32];
for (i, b) in k.iter_mut().enumerate() {
*b = i as u8;
}
k
}
fn key_b() -> [u8; 32] {
[0xFF_u8; 32]
}
#[test]
fn round_trip_recovers_plaintext() {
let ctx = EncryptionContext::from_key(&key_a());
let plaintext = b"the quick brown fox jumps over the lazy dog";
let ct = match ctx.encrypt(plaintext) {
Ok(c) => c,
Err(err) => panic!("encrypt should succeed: {err}"),
};
assert_eq!(ct.len(), 12 + plaintext.len() + 16);
let pt = match ctx.decrypt(&ct) {
Ok(p) => p,
Err(err) => panic!("decrypt should succeed: {err}"),
};
assert_eq!(pt, plaintext);
}
#[test]
fn distinct_nonces_for_repeated_calls() {
let ctx = EncryptionContext::from_key(&key_a());
let pt = b"identical-input";
let ct1 = ctx.encrypt(pt).unwrap_or_else(|err| panic!("{err}"));
let ct2 = ctx.encrypt(pt).unwrap_or_else(|err| panic!("{err}"));
assert_ne!(ct1, ct2, "repeated encryption must use fresh nonces");
let pt1 = ctx.decrypt(&ct1).unwrap_or_else(|err| panic!("{err}"));
let pt2 = ctx.decrypt(&ct2).unwrap_or_else(|err| panic!("{err}"));
assert_eq!(pt1, pt2);
assert_eq!(pt1.as_slice(), pt);
}
#[test]
fn wrong_key_fails_with_mismatch_error() {
let producer = EncryptionContext::from_key(&key_a());
let consumer = EncryptionContext::from_key(&key_b());
let ct = producer
.encrypt(b"secret")
.unwrap_or_else(|err| panic!("{err}"));
let result = consumer.decrypt(&ct);
assert!(matches!(result, Err(Error::EncryptionKeyMismatch)));
}
#[test]
fn tampered_ciphertext_fails_with_mismatch_error() {
let ctx = EncryptionContext::from_key(&key_a());
let mut ct = ctx
.encrypt(b"do not modify")
.unwrap_or_else(|err| panic!("{err}"));
ct[15] ^= 0x01;
let result = ctx.decrypt(&ct);
assert!(
matches!(result, Err(Error::EncryptionKeyMismatch)),
"tampered ciphertext must fail authentication: {result:?}"
);
}
#[test]
fn truncated_buffer_fails_with_encryption_error() {
let ctx = EncryptionContext::from_key(&key_a());
let too_short = [0_u8; 10]; let result = ctx.decrypt(&too_short);
assert!(matches!(result, Err(Error::Encryption(_))));
}
#[test]
fn verification_plaintext_is_thirty_two_bytes() {
assert_eq!(VERIFICATION_PLAINTEXT.len(), 32);
}
#[test]
fn debug_does_not_leak_key() {
let ctx = EncryptionContext::from_key(&key_a());
let debug_str = format!("{ctx:?}");
assert!(
!debug_str.contains("\\x01\\x02"),
"Debug output must not leak key bytes: {debug_str}"
);
assert!(debug_str.contains("redacted"));
}
#[test]
fn kdf_is_deterministic_for_fixed_passphrase_and_salt() {
let salt = [0xAA_u8; super::SALT_LEN];
let k1 = match super::derive_key_from_passphrase("hunter2", &salt) {
Ok(k) => k,
Err(err) => panic!("derive should succeed: {err}"),
};
let k2 = match super::derive_key_from_passphrase("hunter2", &salt) {
Ok(k) => k,
Err(err) => panic!("derive should succeed: {err}"),
};
assert_eq!(k1, k2, "same passphrase + salt must produce same key");
}
#[test]
fn kdf_diverges_for_different_salts() {
let s1 = [0x11_u8; super::SALT_LEN];
let s2 = [0x22_u8; super::SALT_LEN];
let k1 =
super::derive_key_from_passphrase("hunter2", &s1).unwrap_or_else(|e| panic!("{e}"));
let k2 =
super::derive_key_from_passphrase("hunter2", &s2).unwrap_or_else(|e| panic!("{e}"));
assert_ne!(k1, k2, "different salts must produce different keys");
}
#[test]
fn kdf_diverges_for_different_passphrases() {
let salt = [0x33_u8; super::SALT_LEN];
let k1 =
super::derive_key_from_passphrase("alpha", &salt).unwrap_or_else(|e| panic!("{e}"));
let k2 =
super::derive_key_from_passphrase("bravo", &salt).unwrap_or_else(|e| panic!("{e}"));
assert_ne!(k1, k2, "different passphrases must produce different keys");
}
#[test]
fn kdf_rejects_empty_passphrase() {
let salt = [0x44_u8; super::SALT_LEN];
let result = super::derive_key_from_passphrase("", &salt);
assert!(matches!(result, Err(Error::InvalidConfig(_))));
}
#[test]
fn random_salt_is_fresh_each_call() {
let s1 = super::random_salt();
let s2 = super::random_salt();
assert_ne!(s1, s2, "random_salt must use the OS RNG");
}
}