use crate::{
aead::{Ciphertext, CrabAead},
errors::{CrabError, CrabResult},
secrets::SecretVec,
};
use std::collections::HashMap;
use zeroize::Zeroize;
pub trait RotatableAead: CrabAead {
fn generate_key() -> CrabResult<SecretVec>;
fn from_key(key: &SecretVec) -> CrabResult<Self>
where
Self: Sized;
}
#[derive(Clone, Zeroize)]
#[zeroize(drop)]
pub struct VersionedKey {
pub version: u32,
pub key: SecretVec,
}
impl VersionedKey {
pub fn new(version: u32, key: SecretVec) -> CrabResult<Self> {
if version == 0 {
return Err(CrabError::invalid_input("Key version must be >= 1"));
}
Ok(Self { version, key })
}
pub fn version(&self) -> u32 {
self.version
}
pub fn key(&self) -> &SecretVec {
&self.key
}
}
pub struct KeyRotationManager<C: RotatableAead> {
keys: HashMap<u32, SecretVec>,
current_version: u32,
max_versions: usize,
_phantom: std::marker::PhantomData<C>,
}
impl<C: RotatableAead> KeyRotationManager<C> {
pub fn new() -> CrabResult<Self> {
Self::with_max_versions(256)
}
pub fn with_max_versions(max_versions: usize) -> CrabResult<Self> {
if max_versions == 0 {
return Err(CrabError::invalid_input("max_versions must be >= 1"));
}
let key = C::generate_key()?;
let mut keys = HashMap::new();
keys.insert(1, key);
Ok(Self {
keys,
current_version: 1,
max_versions,
_phantom: std::marker::PhantomData,
})
}
pub fn from_key(key: SecretVec) -> CrabResult<Self> {
let mut keys = HashMap::new();
keys.insert(1, key);
Ok(Self {
keys,
current_version: 1,
max_versions: 256,
_phantom: std::marker::PhantomData,
})
}
pub fn rotate(&mut self) -> CrabResult<()> {
let new_version = self
.current_version
.checked_add(1)
.ok_or_else(|| CrabError::invalid_input("Version number overflow"))?;
let new_key = C::generate_key()?;
self.keys.insert(new_version, new_key);
self.current_version = new_version;
if self.keys.len() > self.max_versions {
let oldest_version = self.current_version - self.max_versions as u32;
self.keys.remove(&oldest_version);
}
Ok(())
}
pub fn rotate_with_key(&mut self, key: SecretVec) -> CrabResult<()> {
let new_version = self
.current_version
.checked_add(1)
.ok_or_else(|| CrabError::invalid_input("Version number overflow"))?;
self.keys.insert(new_version, key);
self.current_version = new_version;
if self.keys.len() > self.max_versions {
let oldest_version = self.current_version - self.max_versions as u32;
self.keys.remove(&oldest_version);
}
Ok(())
}
pub fn current_version(&self) -> u32 {
self.current_version
}
pub fn version_count(&self) -> usize {
self.keys.len()
}
pub fn has_version(&self, version: u32) -> bool {
self.keys.contains_key(&version)
}
pub fn encrypt(&self, plaintext: &[u8], aad: Option<&[u8]>) -> CrabResult<(u32, Ciphertext)> {
let key = self
.keys
.get(&self.current_version)
.ok_or_else(|| CrabError::invalid_input("Current key version not found"))?;
let cipher = C::from_key(key)?;
let ciphertext = cipher.encrypt(plaintext, aad)?;
Ok((self.current_version, ciphertext))
}
pub fn decrypt(
&self,
version: u32,
ciphertext: &Ciphertext,
aad: Option<&[u8]>,
) -> CrabResult<Vec<u8>> {
let key = self.keys.get(&version).ok_or_else(|| {
CrabError::invalid_input(format!("Key version {} not found", version))
})?;
let cipher = C::from_key(key)?;
cipher.decrypt(ciphertext, aad)
}
pub fn re_encrypt(
&self,
old_version: u32,
old_ciphertext: &Ciphertext,
aad: Option<&[u8]>,
) -> CrabResult<(u32, Ciphertext)> {
let plaintext = self.decrypt(old_version, old_ciphertext, aad)?;
self.encrypt(&plaintext, aad)
}
pub fn remove_version(&mut self, version: u32) -> CrabResult<()> {
if version == self.current_version {
return Err(CrabError::invalid_input("Cannot remove current key version"));
}
self.keys.remove(&version);
Ok(())
}
pub fn available_versions(&self) -> Vec<u32> {
let mut versions: Vec<u32> = self.keys.keys().copied().collect();
versions.sort_unstable();
versions
}
}
impl<C: RotatableAead> Drop for KeyRotationManager<C> {
fn drop(&mut self) {
self.keys.clear();
}
}
impl RotatableAead for crate::aead::AesGcm256 {
fn generate_key() -> CrabResult<SecretVec> {
let key_bytes = crate::aead::AesGcm256::generate_key()?;
Ok(SecretVec::new(key_bytes))
}
fn from_key(key: &SecretVec) -> CrabResult<Self> {
crate::aead::AesGcm256::new(key.as_ref())
}
}
impl RotatableAead for crate::aead::ChaCha20Poly1305 {
fn generate_key() -> CrabResult<SecretVec> {
let key_bytes = crate::aead::ChaCha20Poly1305::generate_key()?;
Ok(SecretVec::new(key_bytes))
}
fn from_key(key: &SecretVec) -> CrabResult<Self> {
crate::aead::ChaCha20Poly1305::new(key.as_ref())
}
}
impl RotatableAead for crate::aead::AesGcm128 {
fn generate_key() -> CrabResult<SecretVec> {
let key_bytes = crate::aead::AesGcm128::generate_key()?;
Ok(SecretVec::new(key_bytes))
}
fn from_key(key: &SecretVec) -> CrabResult<Self> {
crate::aead::AesGcm128::new(key.as_ref())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::aead::{AesGcm256, ChaCha20Poly1305};
#[test]
fn test_versioned_key_creation() {
let key = SecretVec::new(vec![1u8; 32]);
let vkey = VersionedKey::new(1, key).unwrap();
assert_eq!(vkey.version(), 1);
assert_eq!(vkey.key().as_ref().len(), 32);
}
#[test]
fn test_versioned_key_zero_version_fails() {
let key = SecretVec::new(vec![1u8; 32]);
let result = VersionedKey::new(0, key);
assert!(result.is_err());
}
#[test]
fn test_key_rotation_manager_new() {
let manager = KeyRotationManager::<AesGcm256>::new().unwrap();
assert_eq!(manager.current_version(), 1);
assert_eq!(manager.version_count(), 1);
assert!(manager.has_version(1));
}
#[test]
fn test_rotation() {
let mut manager = KeyRotationManager::<AesGcm256>::new().unwrap();
assert_eq!(manager.current_version(), 1);
manager.rotate().unwrap();
assert_eq!(manager.current_version(), 2);
assert_eq!(manager.version_count(), 2);
manager.rotate().unwrap();
assert_eq!(manager.current_version(), 3);
assert_eq!(manager.version_count(), 3);
}
#[test]
fn test_encrypt_decrypt_with_version() {
let manager = KeyRotationManager::<AesGcm256>::new().unwrap();
let plaintext = b"Secret message";
let (version, ciphertext) = manager.encrypt(plaintext, None).unwrap();
assert_eq!(version, 1);
let decrypted = manager.decrypt(version, &ciphertext, None).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_decrypt_with_aad() {
let manager = KeyRotationManager::<ChaCha20Poly1305>::new().unwrap();
let plaintext = b"Secret message";
let aad = b"metadata";
let (version, ciphertext) = manager.encrypt(plaintext, Some(aad)).unwrap();
let decrypted = manager.decrypt(version, &ciphertext, Some(aad)).unwrap();
assert_eq!(decrypted, plaintext);
let result = manager.decrypt(version, &ciphertext, Some(b"wrong"));
assert!(result.is_err());
}
#[test]
fn test_decrypt_with_old_version() {
let mut manager = KeyRotationManager::<AesGcm256>::new().unwrap();
let plaintext1 = b"Message 1";
let (v1, ct1) = manager.encrypt(plaintext1, None).unwrap();
assert_eq!(v1, 1);
manager.rotate().unwrap();
let plaintext2 = b"Message 2";
let (v2, ct2) = manager.encrypt(plaintext2, None).unwrap();
assert_eq!(v2, 2);
assert_eq!(manager.decrypt(v1, &ct1, None).unwrap(), plaintext1);
assert_eq!(manager.decrypt(v2, &ct2, None).unwrap(), plaintext2);
}
#[test]
fn test_re_encrypt() {
let mut manager = KeyRotationManager::<AesGcm256>::new().unwrap();
let plaintext = b"Important data";
let (v1, ct1) = manager.encrypt(plaintext, None).unwrap();
assert_eq!(v1, 1);
manager.rotate().unwrap();
let (v2, ct2) = manager.re_encrypt(v1, &ct1, None).unwrap();
assert_eq!(v2, 2);
let decrypted = manager.decrypt(v2, &ct2, None).unwrap();
assert_eq!(decrypted, plaintext);
let decrypted_old = manager.decrypt(v1, &ct1, None).unwrap();
assert_eq!(decrypted_old, plaintext);
}
#[test]
fn test_re_encrypt_with_aad() {
let mut manager = KeyRotationManager::<ChaCha20Poly1305>::new().unwrap();
let plaintext = b"Data";
let aad = b"context";
let (v1, ct1) = manager.encrypt(plaintext, Some(aad)).unwrap();
manager.rotate().unwrap();
let (v2, ct2) = manager.re_encrypt(v1, &ct1, Some(aad)).unwrap();
assert_eq!(v2, 2);
let decrypted = manager.decrypt(v2, &ct2, Some(aad)).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_remove_version() {
let mut manager = KeyRotationManager::<AesGcm256>::new().unwrap();
manager.rotate().unwrap();
assert!(manager.has_version(1));
manager.remove_version(1).unwrap();
assert!(!manager.has_version(1));
assert!(manager.has_version(2));
}
#[test]
fn test_cannot_remove_current_version() {
let mut manager = KeyRotationManager::<AesGcm256>::new().unwrap();
let result = manager.remove_version(1);
assert!(result.is_err());
}
#[test]
fn test_max_versions() {
let mut manager = KeyRotationManager::<AesGcm256>::with_max_versions(3).unwrap();
manager.rotate().unwrap(); manager.rotate().unwrap(); assert_eq!(manager.version_count(), 3);
manager.rotate().unwrap(); assert_eq!(manager.version_count(), 3);
assert!(!manager.has_version(1));
assert!(manager.has_version(2));
assert!(manager.has_version(3));
assert!(manager.has_version(4));
}
#[test]
fn test_available_versions() {
let mut manager = KeyRotationManager::<AesGcm256>::new().unwrap();
manager.rotate().unwrap();
manager.rotate().unwrap();
let versions = manager.available_versions();
assert_eq!(versions, vec![1, 2, 3]);
}
#[test]
fn test_from_existing_key() {
let key_bytes = AesGcm256::generate_key().unwrap();
let key = SecretVec::new(key_bytes);
let manager = KeyRotationManager::<AesGcm256>::from_key(key).unwrap();
assert_eq!(manager.current_version(), 1);
let (version, ciphertext) = manager.encrypt(b"test", None).unwrap();
assert_eq!(version, 1);
assert!(manager.decrypt(version, &ciphertext, None).is_ok());
}
#[test]
fn test_decrypt_nonexistent_version() {
let manager = KeyRotationManager::<AesGcm256>::new().unwrap();
let (_, ciphertext) = manager.encrypt(b"test", None).unwrap();
let result = manager.decrypt(999, &ciphertext, None);
assert!(result.is_err());
}
#[test]
fn test_rotate_with_key() {
let mut manager = KeyRotationManager::<AesGcm256>::new().unwrap();
let custom_key_bytes = AesGcm256::generate_key().unwrap();
let custom_key = SecretVec::new(custom_key_bytes);
manager.rotate_with_key(custom_key).unwrap();
assert_eq!(manager.current_version(), 2);
let (version, ciphertext) = manager.encrypt(b"test", None).unwrap();
assert_eq!(version, 2);
assert!(manager.decrypt(version, &ciphertext, None).is_ok());
}
#[test]
fn test_multiple_rotations() {
let mut manager = KeyRotationManager::<ChaCha20Poly1305>::new().unwrap();
let mut plaintexts = vec![];
let mut ciphertexts = vec![];
for i in 0..5 {
let plaintext = format!("Message {}", i);
plaintexts.push(plaintext.clone());
let (version, ciphertext) = manager.encrypt(plaintext.as_bytes(), None).unwrap();
ciphertexts.push((version, ciphertext));
if i < 4 {
manager.rotate().unwrap();
}
}
for (i, (version, ciphertext)) in ciphertexts.iter().enumerate() {
let decrypted = manager.decrypt(*version, ciphertext, None).unwrap();
assert_eq!(decrypted, plaintexts[i].as_bytes());
}
}
}