use crate::encryption::KeyProviderError;
use std::sync::RwLock;
#[derive(Debug, Clone)]
pub struct KeyVersion {
pub version: u32,
pub version_id: [u8; 16],
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RotationState {
Idle,
InProgress {
old_version: u32,
new_version: u32,
},
Complete,
}
pub struct KeyRotationManager {
current_version: RwLock<u32>,
state: RwLock<RotationState>,
}
impl KeyRotationManager {
#[must_use]
pub fn new() -> Self {
Self {
current_version: RwLock::new(1),
state: RwLock::new(RotationState::Idle),
}
}
#[must_use]
pub fn current_version(&self) -> u32 {
*self
.current_version
.read()
.unwrap_or_else(|e| e.into_inner())
}
#[must_use]
pub fn state(&self) -> RotationState {
self.state.read().unwrap_or_else(|e| e.into_inner()).clone()
}
pub fn begin_rotation(&self) -> Result<u32, KeyProviderError> {
let mut state = self.state.write().unwrap_or_else(|e| e.into_inner());
if *state != RotationState::Idle {
return Err(KeyProviderError::Unavailable(
"Key rotation already in progress".into(),
));
}
let old = self.current_version();
let new = old + 1;
*state = RotationState::InProgress {
old_version: old,
new_version: new,
};
Ok(new)
}
pub fn complete_rotation(&self) -> Result<(), KeyProviderError> {
let mut state = self.state.write().unwrap_or_else(|e| e.into_inner());
match *state {
RotationState::InProgress { new_version, .. } => {
let mut ver = self
.current_version
.write()
.unwrap_or_else(|e| e.into_inner());
*ver = new_version;
*state = RotationState::Idle;
Ok(())
}
_ => Err(KeyProviderError::Unavailable(
"No rotation in progress".into(),
)),
}
}
pub fn cancel_rotation(&self) -> Result<(), KeyProviderError> {
let mut state = self.state.write().unwrap_or_else(|e| e.into_inner());
match *state {
RotationState::InProgress { .. } => {
*state = RotationState::Idle;
Ok(())
}
_ => Err(KeyProviderError::Unavailable(
"No rotation in progress".into(),
)),
}
}
}
impl Default for KeyRotationManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_starts_at_version_1() {
let mgr = KeyRotationManager::new();
assert_eq!(mgr.current_version(), 1);
assert_eq!(mgr.state(), RotationState::Idle);
}
#[test]
fn begin_rotation_transitions_state() {
let mgr = KeyRotationManager::new();
let new_ver = mgr.begin_rotation().unwrap();
assert_eq!(new_ver, 2);
assert_eq!(
mgr.state(),
RotationState::InProgress {
old_version: 1,
new_version: 2,
}
);
assert_eq!(mgr.current_version(), 1);
}
#[test]
fn complete_rotation_advances_version() {
let mgr = KeyRotationManager::new();
mgr.begin_rotation().unwrap();
mgr.complete_rotation().unwrap();
assert_eq!(mgr.current_version(), 2);
assert_eq!(mgr.state(), RotationState::Idle);
}
#[test]
fn cancel_rotation_returns_to_idle() {
let mgr = KeyRotationManager::new();
mgr.begin_rotation().unwrap();
mgr.cancel_rotation().unwrap();
assert_eq!(mgr.current_version(), 1);
assert_eq!(mgr.state(), RotationState::Idle);
}
#[test]
fn double_begin_fails() {
let mgr = KeyRotationManager::new();
mgr.begin_rotation().unwrap();
let err = mgr.begin_rotation().unwrap_err();
assert!(
err.to_string().contains("already in progress"),
"expected 'already in progress', got: {err}"
);
}
#[test]
fn complete_without_begin_fails() {
let mgr = KeyRotationManager::new();
let err = mgr.complete_rotation().unwrap_err();
assert!(
err.to_string().contains("No rotation in progress"),
"expected 'No rotation in progress', got: {err}"
);
}
#[test]
fn cancel_without_begin_fails() {
let mgr = KeyRotationManager::new();
let err = mgr.cancel_rotation().unwrap_err();
assert!(
err.to_string().contains("No rotation in progress"),
"expected 'No rotation in progress', got: {err}"
);
}
#[test]
fn multiple_rotation_cycles() {
let mgr = KeyRotationManager::new();
mgr.begin_rotation().unwrap();
mgr.complete_rotation().unwrap();
assert_eq!(mgr.current_version(), 2);
mgr.begin_rotation().unwrap();
mgr.complete_rotation().unwrap();
assert_eq!(mgr.current_version(), 3);
mgr.begin_rotation().unwrap();
mgr.cancel_rotation().unwrap();
assert_eq!(mgr.current_version(), 3);
mgr.begin_rotation().unwrap();
mgr.complete_rotation().unwrap();
assert_eq!(mgr.current_version(), 4);
}
#[test]
fn key_version_struct() {
let kv = KeyVersion {
version: 42,
version_id: [0xAA; 16],
};
assert_eq!(kv.version, 42);
assert_eq!(kv.version_id, [0xAA; 16]);
let kv2 = kv.clone();
assert_eq!(kv2.version, kv.version);
}
#[test]
fn default_creates_same_as_new() {
let mgr = KeyRotationManager::default();
assert_eq!(mgr.current_version(), 1);
assert_eq!(mgr.state(), RotationState::Idle);
}
}