use std::collections::BTreeMap;
use crate::encryption::LogEncryptionKey;
pub type KeyVersion = u32;
pub const LEGACY_KEY_VERSION: KeyVersion = 0;
pub struct KeyManager {
current_version: KeyVersion,
current: LogEncryptionKey,
history: BTreeMap<KeyVersion, LogEncryptionKey>,
retention: usize,
}
impl KeyManager {
pub fn new(initial: LogEncryptionKey, retention: usize) -> Self {
let retention = retention.max(1);
Self {
current_version: 1,
current: initial,
history: BTreeMap::new(),
retention,
}
}
pub fn rotate(&mut self, new_key: LogEncryptionKey) -> KeyVersion {
let prev_version = self.current_version;
let prev_key = std::mem::replace(&mut self.current, new_key);
self.history.insert(prev_version, prev_key);
self.current_version = self
.current_version
.checked_add(1)
.unwrap_or(KeyVersion::MAX);
let max_history = self.retention.saturating_sub(1);
while self.history.len() > max_history {
if let Some((&oldest_version, _)) = self.history.iter().next() {
self.history.remove(&oldest_version);
} else {
break;
}
}
self.current_version
}
pub fn current(&self) -> (KeyVersion, &LogEncryptionKey) {
(self.current_version, &self.current)
}
pub fn lookup(&self, version: KeyVersion) -> Option<&LogEncryptionKey> {
if version == self.current_version {
return Some(&self.current);
}
self.history.get(&version)
}
pub fn version_count(&self) -> usize {
1 + self.history.len()
}
pub fn retention(&self) -> usize {
self.retention
}
}
#[cfg(test)]
mod tests {
use super::*;
fn key(byte: u8) -> LogEncryptionKey {
LogEncryptionKey::new([byte; 32])
}
#[test]
fn test_key_manager_rotation_advances_version() {
let mut mgr = KeyManager::new(key(0x01), 3);
assert_eq!(mgr.current().0, 1, "initial current version is 1");
let v2 = mgr.rotate(key(0x02));
assert_eq!(v2, 2, "rotation increments to version 2");
assert_eq!(mgr.current().0, 2);
let v3 = mgr.rotate(key(0x03));
assert_eq!(v3, 3);
assert_eq!(mgr.current().0, 3);
}
#[test]
fn test_key_manager_lookup_returns_current_and_history() {
let mut mgr = KeyManager::new(key(0xaa), 3);
let _ = mgr.rotate(key(0xbb));
let _ = mgr.rotate(key(0xcc));
assert!(mgr.lookup(3).is_some());
assert!(mgr.lookup(2).is_some());
assert!(mgr.lookup(1).is_some());
assert!(
mgr.lookup(99).is_none(),
"non-existent version returns None"
);
}
#[test]
fn test_key_manager_retention_drops_oldest() {
let mut mgr = KeyManager::new(key(0xaa), 2);
let _ = mgr.rotate(key(0xbb)); let _ = mgr.rotate(key(0xcc));
assert!(mgr.lookup(3).is_some(), "current v3 retained");
assert!(mgr.lookup(2).is_some(), "previous v2 retained");
assert!(mgr.lookup(1).is_none(), "oldest v1 dropped past retention");
let _ = mgr.rotate(key(0xdd)); assert!(mgr.lookup(4).is_some());
assert!(mgr.lookup(3).is_some());
assert!(mgr.lookup(2).is_none());
}
#[test]
fn test_key_manager_retention_clamped_to_one() {
let mut mgr = KeyManager::new(key(0x10), 0); assert_eq!(mgr.retention(), 1);
let _ = mgr.rotate(key(0x20));
assert!(mgr.lookup(2).is_some(), "current v2 retained");
assert!(mgr.lookup(1).is_none(), "v1 dropped immediately");
}
#[test]
fn test_key_manager_version_count_grows_then_caps() {
let mut mgr = KeyManager::new(key(0x01), 3);
assert_eq!(mgr.version_count(), 1, "single key after construction");
let _ = mgr.rotate(key(0x02));
assert_eq!(mgr.version_count(), 2);
let _ = mgr.rotate(key(0x03));
assert_eq!(mgr.version_count(), 3);
let _ = mgr.rotate(key(0x04));
assert_eq!(mgr.version_count(), 3);
let _ = mgr.rotate(key(0x05));
assert_eq!(mgr.version_count(), 3);
}
}