use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::time::{Duration, SystemTime};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Operation {
Sign,
Verify,
Encrypt,
Decrypt,
KeyExchange,
DeriveKey,
WrapKey,
UnwrapKey,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PolicyViolation {
OperationDenied(Operation),
UsageLimitExceeded { limit: u64, current: u64 },
KeyExpired { expired_at: SystemTime },
KeyNotYetValid { valid_from: SystemTime },
MissingContext(String),
InvalidContext(String),
PolicyNotFound,
}
impl std::fmt::Display for PolicyViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PolicyViolation::OperationDenied(op) => {
write!(f, "Operation {:?} denied by policy", op)
}
PolicyViolation::UsageLimitExceeded { limit, current } => {
write!(f, "Usage limit exceeded: {}/{}", current, limit)
}
PolicyViolation::KeyExpired { expired_at } => {
write!(f, "Key expired at {:?}", expired_at)
}
PolicyViolation::KeyNotYetValid { valid_from } => {
write!(f, "Key not yet valid (valid from {:?})", valid_from)
}
PolicyViolation::MissingContext(ctx) => write!(f, "Missing required context: {}", ctx),
PolicyViolation::InvalidContext(msg) => write!(f, "Invalid context: {}", msg),
PolicyViolation::PolicyNotFound => write!(f, "Policy not found for key"),
}
}
}
impl std::error::Error for PolicyViolation {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyPolicy {
allowed_operations: Option<HashSet<Operation>>,
denied_operations: HashSet<Operation>,
max_uses: Option<u64>,
valid_from: Option<SystemTime>,
valid_until: Option<SystemTime>,
required_context: HashSet<String>,
metadata: HashMap<String, String>,
}
impl Default for KeyPolicy {
fn default() -> Self {
Self::new()
}
}
impl KeyPolicy {
pub fn new() -> Self {
Self {
allowed_operations: None,
denied_operations: HashSet::new(),
max_uses: None,
valid_from: None,
valid_until: None,
required_context: HashSet::new(),
metadata: HashMap::new(),
}
}
pub fn restrictive() -> Self {
Self {
allowed_operations: Some(HashSet::new()),
denied_operations: HashSet::new(),
max_uses: None,
valid_from: None,
valid_until: None,
required_context: HashSet::new(),
metadata: HashMap::new(),
}
}
pub fn allow_operation(mut self, op: Operation) -> Self {
if self.allowed_operations.is_none() {
self.allowed_operations = Some(HashSet::new());
}
self.allowed_operations.as_mut().unwrap().insert(op);
self
}
pub fn deny_operation(mut self, op: Operation) -> Self {
self.denied_operations.insert(op);
self
}
pub fn max_uses(mut self, limit: u64) -> Self {
self.max_uses = Some(limit);
self
}
pub fn valid_for(mut self, duration: Duration) -> Self {
let now = SystemTime::now();
self.valid_from = Some(now);
self.valid_until = Some(now + duration);
self
}
pub fn valid_from(mut self, time: SystemTime) -> Self {
self.valid_from = Some(time);
self
}
pub fn valid_until(mut self, time: SystemTime) -> Self {
self.valid_until = Some(time);
self
}
pub fn require_context(mut self, key: String) -> Self {
self.required_context.insert(key);
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
pub fn allows_operation(&self, op: Operation) -> bool {
if self.denied_operations.contains(&op) {
return false;
}
match &self.allowed_operations {
None => true, Some(allowed) => allowed.contains(&op),
}
}
pub fn check_usage_limit(&self, current_uses: u64) -> Result<(), PolicyViolation> {
if let Some(limit) = self.max_uses {
if current_uses >= limit {
return Err(PolicyViolation::UsageLimitExceeded {
limit,
current: current_uses,
});
}
}
Ok(())
}
pub fn check_validity(&self) -> Result<(), PolicyViolation> {
let now = SystemTime::now();
if let Some(valid_from) = self.valid_from {
if now < valid_from {
return Err(PolicyViolation::KeyNotYetValid { valid_from });
}
}
if let Some(valid_until) = self.valid_until {
if now > valid_until {
return Err(PolicyViolation::KeyExpired {
expired_at: valid_until,
});
}
}
Ok(())
}
pub fn check_context(
&self,
context: Option<&HashMap<String, String>>,
) -> Result<(), PolicyViolation> {
if self.required_context.is_empty() {
return Ok(());
}
let context = context
.ok_or_else(|| PolicyViolation::MissingContext("context required".to_string()))?;
for required_key in &self.required_context {
if !context.contains_key(required_key) {
return Err(PolicyViolation::MissingContext(required_key.clone()));
}
}
Ok(())
}
}
pub struct PolicyEngine {
policies: HashMap<[u8; 32], KeyPolicy>,
usage_counts: HashMap<[u8; 32], u64>,
violations: Vec<(SystemTime, [u8; 32], PolicyViolation)>,
}
impl Default for PolicyEngine {
fn default() -> Self {
Self::new()
}
}
impl PolicyEngine {
pub fn new() -> Self {
Self {
policies: HashMap::new(),
usage_counts: HashMap::new(),
violations: Vec::new(),
}
}
pub fn register_policy(&mut self, key_id: [u8; 32], policy: KeyPolicy) {
self.policies.insert(key_id, policy);
self.usage_counts.insert(key_id, 0);
}
pub fn update_policy(
&mut self,
key_id: &[u8; 32],
policy: KeyPolicy,
) -> Result<(), PolicyViolation> {
if !self.policies.contains_key(key_id) {
return Err(PolicyViolation::PolicyNotFound);
}
self.policies.insert(*key_id, policy);
Ok(())
}
pub fn remove_policy(&mut self, key_id: &[u8; 32]) {
self.policies.remove(key_id);
self.usage_counts.remove(key_id);
}
pub fn check_policy(
&mut self,
key_id: &[u8; 32],
operation: Operation,
context: Option<&HashMap<String, String>>,
) -> Result<(), PolicyViolation> {
let policy = self
.policies
.get(key_id)
.ok_or(PolicyViolation::PolicyNotFound)?;
if !policy.allows_operation(operation) {
let violation = PolicyViolation::OperationDenied(operation);
self.log_violation(*key_id, violation.clone());
return Err(violation);
}
if let Err(violation) = policy.check_validity() {
self.log_violation(*key_id, violation.clone());
return Err(violation);
}
let current_uses = *self.usage_counts.get(key_id).unwrap_or(&0);
if let Err(violation) = policy.check_usage_limit(current_uses) {
self.log_violation(*key_id, violation.clone());
return Err(violation);
}
if let Err(violation) = policy.check_context(context) {
self.log_violation(*key_id, violation.clone());
return Err(violation);
}
*self.usage_counts.entry(*key_id).or_insert(0) += 1;
Ok(())
}
pub fn get_usage_count(&self, key_id: &[u8; 32]) -> u64 {
*self.usage_counts.get(key_id).unwrap_or(&0)
}
pub fn reset_usage_count(&mut self, key_id: &[u8; 32]) {
if let Some(count) = self.usage_counts.get_mut(key_id) {
*count = 0;
}
}
pub fn get_policy(&self, key_id: &[u8; 32]) -> Option<&KeyPolicy> {
self.policies.get(key_id)
}
fn log_violation(&mut self, key_id: [u8; 32], violation: PolicyViolation) {
self.violations.push((SystemTime::now(), key_id, violation));
}
pub fn get_violations(&self) -> &[(SystemTime, [u8; 32], PolicyViolation)] {
&self.violations
}
pub fn get_key_violations(
&self,
key_id: &[u8; 32],
) -> Vec<&(SystemTime, [u8; 32], PolicyViolation)> {
self.violations
.iter()
.filter(|(_, kid, _)| kid == key_id)
.collect()
}
pub fn clear_violations(&mut self) {
self.violations.clear();
}
}
pub trait KeyUsagePolicy {
fn check_key_usage(
&mut self,
key_id: &[u8; 32],
operation: Operation,
context: Option<&HashMap<String, String>>,
) -> Result<(), PolicyViolation>;
}
impl KeyUsagePolicy for PolicyEngine {
fn check_key_usage(
&mut self,
key_id: &[u8; 32],
operation: Operation,
context: Option<&HashMap<String, String>>,
) -> Result<(), PolicyViolation> {
self.check_policy(key_id, operation, context)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_policy_allows_all_by_default() {
let policy = KeyPolicy::new();
assert!(policy.allows_operation(Operation::Sign));
assert!(policy.allows_operation(Operation::Encrypt));
assert!(policy.allows_operation(Operation::Decrypt));
}
#[test]
fn test_policy_restrictive() {
let policy = KeyPolicy::restrictive();
assert!(!policy.allows_operation(Operation::Sign));
assert!(!policy.allows_operation(Operation::Encrypt));
}
#[test]
fn test_policy_allow_operation() {
let policy = KeyPolicy::restrictive().allow_operation(Operation::Sign);
assert!(policy.allows_operation(Operation::Sign));
assert!(!policy.allows_operation(Operation::Encrypt));
}
#[test]
fn test_policy_deny_operation() {
let policy = KeyPolicy::new().deny_operation(Operation::Decrypt);
assert!(policy.allows_operation(Operation::Sign));
assert!(!policy.allows_operation(Operation::Decrypt));
}
#[test]
fn test_policy_deny_takes_precedence() {
let policy = KeyPolicy::new()
.allow_operation(Operation::Sign)
.deny_operation(Operation::Sign);
assert!(!policy.allows_operation(Operation::Sign));
}
#[test]
fn test_usage_limit() {
let policy = KeyPolicy::new().max_uses(5);
assert!(policy.check_usage_limit(0).is_ok());
assert!(policy.check_usage_limit(4).is_ok());
assert!(policy.check_usage_limit(5).is_err());
assert!(policy.check_usage_limit(10).is_err());
}
#[test]
fn test_validity_period() {
let now = SystemTime::now();
let past = now - Duration::from_secs(3600);
let future = now + Duration::from_secs(3600);
let policy = KeyPolicy::new().valid_from(future);
assert!(policy.check_validity().is_err());
let policy = KeyPolicy::new().valid_until(past);
assert!(policy.check_validity().is_err());
let policy = KeyPolicy::new().valid_from(past).valid_until(future);
assert!(policy.check_validity().is_ok());
}
#[test]
fn test_valid_for() {
let policy = KeyPolicy::new().valid_for(Duration::from_secs(3600));
assert!(policy.check_validity().is_ok());
}
#[test]
fn test_required_context() {
let policy = KeyPolicy::new().require_context("user_id".to_string());
assert!(policy.check_context(None).is_err());
let mut context = HashMap::new();
context.insert("session_id".to_string(), "123".to_string());
assert!(policy.check_context(Some(&context)).is_err());
context.insert("user_id".to_string(), "alice".to_string());
assert!(policy.check_context(Some(&context)).is_ok());
}
#[test]
fn test_policy_engine_register() {
let mut engine = PolicyEngine::new();
let key_id = [1u8; 32];
let policy = KeyPolicy::new();
engine.register_policy(key_id, policy);
assert!(engine.get_policy(&key_id).is_some());
assert_eq!(engine.get_usage_count(&key_id), 0);
}
#[test]
fn test_policy_engine_check() {
let mut engine = PolicyEngine::new();
let key_id = [1u8; 32];
let policy = KeyPolicy::new().allow_operation(Operation::Sign);
engine.register_policy(key_id, policy);
assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
assert_eq!(engine.get_usage_count(&key_id), 1);
assert!(
engine
.check_policy(&key_id, Operation::Decrypt, None)
.is_err()
);
assert_eq!(engine.get_usage_count(&key_id), 1); }
#[test]
fn test_policy_engine_usage_limit() {
let mut engine = PolicyEngine::new();
let key_id = [1u8; 32];
let policy = KeyPolicy::new().max_uses(3);
engine.register_policy(key_id, policy);
assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
assert!(engine.check_policy(&key_id, Operation::Sign, None).is_err());
}
#[test]
fn test_policy_engine_reset_usage() {
let mut engine = PolicyEngine::new();
let key_id = [1u8; 32];
let policy = KeyPolicy::new().max_uses(2);
engine.register_policy(key_id, policy);
assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
assert_eq!(engine.get_usage_count(&key_id), 2);
engine.reset_usage_count(&key_id);
assert_eq!(engine.get_usage_count(&key_id), 0);
assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
}
#[test]
fn test_policy_engine_violations() {
let mut engine = PolicyEngine::new();
let key_id = [1u8; 32];
let policy = KeyPolicy::new().deny_operation(Operation::Decrypt);
engine.register_policy(key_id, policy);
let _ = engine.check_policy(&key_id, Operation::Decrypt, None);
let violations = engine.get_violations();
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].1, key_id);
let key_violations = engine.get_key_violations(&key_id);
assert_eq!(key_violations.len(), 1);
engine.clear_violations();
assert_eq!(engine.get_violations().len(), 0);
}
#[test]
fn test_policy_engine_update_policy() {
let mut engine = PolicyEngine::new();
let key_id = [1u8; 32];
let policy1 = KeyPolicy::new().allow_operation(Operation::Sign);
engine.register_policy(key_id, policy1);
let policy2 = KeyPolicy::restrictive();
assert!(engine.update_policy(&key_id, policy2).is_ok());
assert!(engine.check_policy(&key_id, Operation::Sign, None).is_err());
}
#[test]
fn test_policy_engine_remove_policy() {
let mut engine = PolicyEngine::new();
let key_id = [1u8; 32];
let policy = KeyPolicy::new();
engine.register_policy(key_id, policy);
assert!(engine.get_policy(&key_id).is_some());
engine.remove_policy(&key_id);
assert!(engine.get_policy(&key_id).is_none());
}
#[test]
fn test_policy_metadata() {
let policy = KeyPolicy::new()
.with_metadata("purpose".to_string(), "signing".to_string())
.with_metadata("owner".to_string(), "alice".to_string());
assert_eq!(policy.metadata.get("purpose").unwrap(), "signing");
assert_eq!(policy.metadata.get("owner").unwrap(), "alice");
}
#[test]
fn test_policy_serialization() {
let policy = KeyPolicy::new()
.allow_operation(Operation::Sign)
.deny_operation(Operation::Decrypt)
.max_uses(100)
.require_context("user_id".to_string());
let serialized = crate::codec::encode(&policy).unwrap();
let deserialized: KeyPolicy = crate::codec::decode(&serialized).unwrap();
assert!(deserialized.allows_operation(Operation::Sign));
assert!(!deserialized.allows_operation(Operation::Decrypt));
assert_eq!(deserialized.max_uses, Some(100));
assert!(deserialized.required_context.contains("user_id"));
}
}