use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
use axess_rng::{SecureRng, SystemRng};
use std::sync::Arc;
const NONCE_LEN: usize = 12;
#[derive(Clone)]
pub struct EncryptionKey(pub(crate) [u8; 32]);
impl Drop for EncryptionKey {
fn drop(&mut self) {
zeroize::Zeroize::zeroize(&mut self.0);
}
}
#[derive(Clone)]
pub struct SessionCrypto {
current: Arc<EncryptionKey>,
previous: Option<Arc<EncryptionKey>>,
rng: Arc<dyn SecureRng>,
}
#[derive(Debug, thiserror::Error)]
#[error("session encryption/decryption error")]
pub struct CryptoError;
impl SessionCrypto {
pub fn new(key: [u8; 32]) -> Self {
Self {
current: Arc::new(EncryptionKey(key)),
previous: None,
rng: Arc::new(SystemRng),
}
}
pub fn with_rng(mut self, rng: Arc<dyn SecureRng>) -> Self {
self.rng = rng;
self
}
pub fn with_previous_key(mut self, key: [u8; 32]) -> Self {
self.previous = Some(Arc::new(EncryptionKey(key)));
self
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
let cipher = Aes256Gcm::new_from_slice(&self.current.0).map_err(|_| CryptoError)?;
let mut nonce_bytes = [0u8; NONCE_LEN];
self.rng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher.encrypt(nonce, plaintext).map_err(|_| CryptoError)?;
let mut out = Vec::with_capacity(NONCE_LEN + ciphertext.len());
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(&ciphertext);
Ok(out)
}
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, CryptoError> {
if data.len() < NONCE_LEN {
return Err(CryptoError);
}
let (nonce_bytes, ciphertext) = data.split_at(NONCE_LEN);
let nonce = Nonce::from_slice(nonce_bytes);
let cipher = Aes256Gcm::new_from_slice(&self.current.0).map_err(|_| CryptoError)?;
if let Ok(plaintext) = cipher.decrypt(nonce, ciphertext) {
return Ok(plaintext);
}
if let Some(prev) = &self.previous {
tracing::warn!(
"session decryption failed with current key; trying previous key (rotation fallback)"
);
let old_cipher = Aes256Gcm::new_from_slice(&prev.0).map_err(|_| CryptoError)?;
if let Ok(plaintext) = old_cipher.decrypt(nonce, ciphertext) {
tracing::debug!("session decrypted with previous (rotated) key");
return Ok(plaintext);
}
tracing::warn!(
"session decryption also failed with previous key; possible data corruption or key mismatch"
);
} else {
tracing::warn!(
"session decryption failed with current key and no previous key configured"
);
}
Err(CryptoError)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encrypt_decrypt_roundtrip() {
let crypto = SessionCrypto::new([42u8; 32]);
let plaintext = b"hello session data";
let encrypted = crypto.encrypt(plaintext).unwrap();
let decrypted = crypto.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn decrypt_wrong_key_fails() {
let crypto1 = SessionCrypto::new([1u8; 32]);
let crypto2 = SessionCrypto::new([2u8; 32]);
let encrypted = crypto1.encrypt(b"secret").unwrap();
assert!(crypto2.decrypt(&encrypted).is_err());
}
#[test]
fn key_rotation_decrypt_with_previous() {
let old_key = [1u8; 32];
let new_key = [2u8; 32];
let old_crypto = SessionCrypto::new(old_key);
let encrypted = old_crypto.encrypt(b"rotated data").unwrap();
let new_crypto = SessionCrypto::new(new_key).with_previous_key(old_key);
let decrypted = new_crypto.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, b"rotated data");
}
#[test]
fn short_data_fails() {
let crypto = SessionCrypto::new([42u8; 32]);
assert!(crypto.decrypt(&[0u8; 5]).is_err());
}
#[test]
fn key_rotation_current_key_wins_when_valid() {
let current = [9u8; 32];
let prev = [1u8; 32];
let crypto = SessionCrypto::new(current).with_previous_key(prev);
let encrypted = crypto.encrypt(b"current-key data").unwrap();
let decrypted = crypto.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, b"current-key data");
}
#[test]
fn decrypt_payload_at_nonce_length_boundary_fails() {
let crypto = SessionCrypto::new([7u8; 32]);
assert!(crypto.decrypt(&[0u8; NONCE_LEN]).is_err());
}
#[test]
fn with_previous_key_does_not_replace_current_key() {
let current = [3u8; 32];
let prev = [4u8; 32];
let crypto = SessionCrypto::new(current).with_previous_key(prev);
let encrypted = crypto.encrypt(b"under-current").unwrap();
let just_current = SessionCrypto::new(current);
assert_eq!(just_current.decrypt(&encrypted).unwrap(), b"under-current");
let just_prev = SessionCrypto::new(prev);
assert!(just_prev.decrypt(&encrypted).is_err());
}
}