#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
use std::time::{
Duration,
SystemTime,
UNIX_EPOCH,
};
use crate::error::HpkeError;
use crate::security::CryptoRng;
#[derive(Debug, Clone)]
pub struct KeyRotationPolicy {
pub max_messages: u64,
pub max_duration_secs: u64,
pub max_data_bytes: u64,
pub strict_enforcement: bool,
}
impl Default for KeyRotationPolicy {
fn default() -> Self {
Self {
max_messages: 1000, max_duration_secs: 3600, max_data_bytes: 1024 * 1024, strict_enforcement: true,
}
}
}
#[derive(Debug, Clone)]
pub struct KeyRotationState {
pub message_count: u64,
pub key_creation_time: u64,
pub data_bytes_processed: u64,
pub rotation_required: bool,
}
impl KeyRotationState {
pub fn new() -> Self {
Self {
message_count: 0,
key_creation_time: current_timestamp(),
data_bytes_processed: 0,
rotation_required: false,
}
}
pub fn update_after_message(&mut self, message_size: usize, policy: &KeyRotationPolicy) {
self.message_count += 1;
self.data_bytes_processed += message_size as u64;
self.check_rotation_needed(policy);
}
pub fn check_rotation_needed(&mut self, policy: &KeyRotationPolicy) {
let current_time = current_timestamp();
let elapsed_time = current_time.saturating_sub(self.key_creation_time);
self.rotation_required = self.message_count >= policy.max_messages ||
elapsed_time >= policy.max_duration_secs ||
self.data_bytes_processed >= policy.max_data_bytes;
}
pub fn reset_after_rotation(&mut self) {
self.message_count = 0;
self.key_creation_time = current_timestamp();
self.data_bytes_processed = 0;
self.rotation_required = false;
}
pub fn should_allow_operation(&self, policy: &KeyRotationPolicy) -> Result<(), HpkeError> {
if self.rotation_required && policy.strict_enforcement {
return Err(HpkeError::CryptoError(
"Key rotation required before processing more messages".into(),
));
}
Ok(())
}
}
impl Default for KeyRotationState {
fn default() -> Self {
Self::new()
}
}
pub struct KeyRotationManager {
policy: KeyRotationPolicy,
sender_state: KeyRotationState,
receiver_state: KeyRotationState,
}
impl KeyRotationManager {
pub fn new(policy: KeyRotationPolicy) -> Self {
Self {
policy,
sender_state: KeyRotationState::new(),
receiver_state: KeyRotationState::new(),
}
}
pub fn update_sender_state(&mut self, message_size: usize) -> Result<(), HpkeError> {
self.sender_state.should_allow_operation(&self.policy)?;
self.sender_state
.update_after_message(message_size, &self.policy);
Ok(())
}
pub fn update_receiver_state(&mut self, message_size: usize) -> Result<(), HpkeError> {
self.receiver_state.should_allow_operation(&self.policy)?;
self.receiver_state
.update_after_message(message_size, &self.policy);
Ok(())
}
pub fn is_sender_rotation_needed(&self) -> bool {
self.sender_state.rotation_required
}
pub fn is_receiver_rotation_needed(&self) -> bool {
self.receiver_state.rotation_required
}
pub fn reset_sender_state(&mut self) {
self.sender_state.reset_after_rotation();
}
pub fn reset_receiver_state(&mut self) {
self.receiver_state.reset_after_rotation();
}
pub fn policy(&self) -> &KeyRotationPolicy {
&self.policy
}
pub fn update_policy(&mut self, new_policy: KeyRotationPolicy) {
self.policy = new_policy;
self.sender_state.check_rotation_needed(&self.policy);
self.receiver_state.check_rotation_needed(&self.policy);
}
}
pub fn generate_rotation_schedule(
base_policy: &KeyRotationPolicy,
randomness_factor: f64,
rng: &mut dyn CryptoRng,
) -> Result<KeyRotationPolicy, HpkeError> {
if !(0.0..=1.0).contains(&randomness_factor) {
return Err(HpkeError::CryptoError(
"Randomness factor must be between 0.0 and 1.0".into(),
));
}
let mut random_bytes = [0u8; 12]; rng.fill_bytes(&mut random_bytes)?;
let rand1 = u32::from_le_bytes([
random_bytes[0],
random_bytes[1],
random_bytes[2],
random_bytes[3],
]);
let rand2 = u32::from_le_bytes([
random_bytes[4],
random_bytes[5],
random_bytes[6],
random_bytes[7],
]);
let rand3 = u32::from_le_bytes([
random_bytes[8],
random_bytes[9],
random_bytes[10],
random_bytes[11],
]);
let max_messages = apply_randomization(base_policy.max_messages, randomness_factor, rand1);
let max_duration_secs =
apply_randomization(base_policy.max_duration_secs, randomness_factor, rand2);
let max_data_bytes = apply_randomization(base_policy.max_data_bytes, randomness_factor, rand3);
Ok(KeyRotationPolicy {
max_messages,
max_duration_secs,
max_data_bytes,
strict_enforcement: base_policy.strict_enforcement,
})
}
fn apply_randomization(base_value: u64, randomness_factor: f64, random_u32: u32) -> u64 {
let variation = (base_value as f64 * randomness_factor) as u64;
let random_offset = (random_u32 as u64 % (2 * variation + 1)).saturating_sub(variation);
base_value.saturating_add(random_offset).max(1) }
fn current_timestamp() -> u64 {
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
{
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0))
.as_secs()
}
#[cfg(any(not(feature = "std"), target_arch = "wasm32"))]
{
0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::test_rng::TestRng;
#[test]
fn test_key_rotation_policy_default() {
let policy = KeyRotationPolicy::default();
assert_eq!(policy.max_messages, 1000);
assert_eq!(policy.max_duration_secs, 3600);
assert_eq!(policy.max_data_bytes, 1024 * 1024);
assert!(policy.strict_enforcement);
}
#[test]
fn test_key_rotation_state_update() {
let mut state = KeyRotationState::new();
let policy = KeyRotationPolicy {
max_messages: 5,
max_duration_secs: 3600,
max_data_bytes: 1000,
strict_enforcement: true,
};
assert!(!state.rotation_required);
for i in 1..=4 {
state.update_after_message(100, &policy);
assert_eq!(state.message_count, i);
assert_eq!(state.data_bytes_processed, i * 100);
}
assert!(!state.rotation_required);
state.update_after_message(100, &policy);
assert!(state.rotation_required);
}
#[test]
fn test_key_rotation_manager() {
let policy = KeyRotationPolicy {
max_messages: 3,
max_duration_secs: 3600,
max_data_bytes: 1000,
strict_enforcement: false,
};
let mut manager = KeyRotationManager::new(policy);
for _ in 0..2 {
assert!(manager.update_sender_state(100).is_ok());
assert!(!manager.is_sender_rotation_needed());
}
assert!(manager.update_sender_state(100).is_ok());
assert!(manager.is_sender_rotation_needed());
manager.reset_sender_state();
assert!(!manager.is_sender_rotation_needed());
}
#[test]
fn test_strict_enforcement() {
let policy = KeyRotationPolicy {
max_messages: 2,
max_duration_secs: 3600,
max_data_bytes: 1000,
strict_enforcement: true,
};
let mut manager = KeyRotationManager::new(policy);
assert!(manager.update_sender_state(100).is_ok());
assert!(manager.update_sender_state(100).is_ok());
assert!(manager.is_sender_rotation_needed());
assert!(manager.update_sender_state(100).is_err());
}
#[test]
fn test_generate_rotation_schedule() {
let base_policy = KeyRotationPolicy::default();
let mut rng = TestRng::new();
let randomized_policy = generate_rotation_schedule(&base_policy, 0.1, &mut rng)
.expect("Should generate randomized policy");
let variation = ((base_policy.max_messages as f64) * 0.1) as u64;
let min_messages = base_policy.max_messages.saturating_sub(variation).max(1);
let max_messages = base_policy.max_messages.saturating_add(variation).max(1);
assert!(randomized_policy.max_messages >= min_messages);
assert!(randomized_policy.max_messages <= max_messages);
assert!(randomized_policy.max_messages > 0);
assert!(randomized_policy.max_duration_secs > 0);
assert!(randomized_policy.max_data_bytes > 0);
}
}