1use serde::{Deserialize, Serialize};
40use std::collections::{HashMap, HashSet};
41use std::time::{Duration, SystemTime};
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
45pub enum Operation {
46 Sign,
48 Verify,
50 Encrypt,
52 Decrypt,
54 KeyExchange,
56 DeriveKey,
58 WrapKey,
60 UnwrapKey,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
66pub enum PolicyViolation {
67 OperationDenied(Operation),
69 UsageLimitExceeded { limit: u64, current: u64 },
71 KeyExpired { expired_at: SystemTime },
73 KeyNotYetValid { valid_from: SystemTime },
75 MissingContext(String),
77 InvalidContext(String),
79 PolicyNotFound,
81}
82
83impl std::fmt::Display for PolicyViolation {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 PolicyViolation::OperationDenied(op) => {
87 write!(f, "Operation {:?} denied by policy", op)
88 }
89 PolicyViolation::UsageLimitExceeded { limit, current } => {
90 write!(f, "Usage limit exceeded: {}/{}", current, limit)
91 }
92 PolicyViolation::KeyExpired { expired_at } => {
93 write!(f, "Key expired at {:?}", expired_at)
94 }
95 PolicyViolation::KeyNotYetValid { valid_from } => {
96 write!(f, "Key not yet valid (valid from {:?})", valid_from)
97 }
98 PolicyViolation::MissingContext(ctx) => write!(f, "Missing required context: {}", ctx),
99 PolicyViolation::InvalidContext(msg) => write!(f, "Invalid context: {}", msg),
100 PolicyViolation::PolicyNotFound => write!(f, "Policy not found for key"),
101 }
102 }
103}
104
105impl std::error::Error for PolicyViolation {}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct KeyPolicy {
110 allowed_operations: Option<HashSet<Operation>>,
112 denied_operations: HashSet<Operation>,
114 max_uses: Option<u64>,
116 valid_from: Option<SystemTime>,
118 valid_until: Option<SystemTime>,
120 required_context: HashSet<String>,
122 metadata: HashMap<String, String>,
124}
125
126impl Default for KeyPolicy {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132impl KeyPolicy {
133 pub fn new() -> Self {
135 Self {
136 allowed_operations: None,
137 denied_operations: HashSet::new(),
138 max_uses: None,
139 valid_from: None,
140 valid_until: None,
141 required_context: HashSet::new(),
142 metadata: HashMap::new(),
143 }
144 }
145
146 pub fn restrictive() -> Self {
148 Self {
149 allowed_operations: Some(HashSet::new()),
150 denied_operations: HashSet::new(),
151 max_uses: None,
152 valid_from: None,
153 valid_until: None,
154 required_context: HashSet::new(),
155 metadata: HashMap::new(),
156 }
157 }
158
159 pub fn allow_operation(mut self, op: Operation) -> Self {
161 if self.allowed_operations.is_none() {
162 self.allowed_operations = Some(HashSet::new());
163 }
164 self.allowed_operations.as_mut().unwrap().insert(op);
165 self
166 }
167
168 pub fn deny_operation(mut self, op: Operation) -> Self {
170 self.denied_operations.insert(op);
171 self
172 }
173
174 pub fn max_uses(mut self, limit: u64) -> Self {
176 self.max_uses = Some(limit);
177 self
178 }
179
180 pub fn valid_for(mut self, duration: Duration) -> Self {
182 let now = SystemTime::now();
183 self.valid_from = Some(now);
184 self.valid_until = Some(now + duration);
185 self
186 }
187
188 pub fn valid_from(mut self, time: SystemTime) -> Self {
190 self.valid_from = Some(time);
191 self
192 }
193
194 pub fn valid_until(mut self, time: SystemTime) -> Self {
196 self.valid_until = Some(time);
197 self
198 }
199
200 pub fn require_context(mut self, key: String) -> Self {
202 self.required_context.insert(key);
203 self
204 }
205
206 pub fn with_metadata(mut self, key: String, value: String) -> Self {
208 self.metadata.insert(key, value);
209 self
210 }
211
212 pub fn allows_operation(&self, op: Operation) -> bool {
214 if self.denied_operations.contains(&op) {
216 return false;
217 }
218
219 match &self.allowed_operations {
221 None => true, Some(allowed) => allowed.contains(&op),
223 }
224 }
225
226 pub fn check_usage_limit(&self, current_uses: u64) -> Result<(), PolicyViolation> {
228 if let Some(limit) = self.max_uses {
229 if current_uses >= limit {
230 return Err(PolicyViolation::UsageLimitExceeded {
231 limit,
232 current: current_uses,
233 });
234 }
235 }
236 Ok(())
237 }
238
239 pub fn check_validity(&self) -> Result<(), PolicyViolation> {
241 let now = SystemTime::now();
242
243 if let Some(valid_from) = self.valid_from {
244 if now < valid_from {
245 return Err(PolicyViolation::KeyNotYetValid { valid_from });
246 }
247 }
248
249 if let Some(valid_until) = self.valid_until {
250 if now > valid_until {
251 return Err(PolicyViolation::KeyExpired {
252 expired_at: valid_until,
253 });
254 }
255 }
256
257 Ok(())
258 }
259
260 pub fn check_context(
262 &self,
263 context: Option<&HashMap<String, String>>,
264 ) -> Result<(), PolicyViolation> {
265 if self.required_context.is_empty() {
266 return Ok(());
267 }
268
269 let context = context
270 .ok_or_else(|| PolicyViolation::MissingContext("context required".to_string()))?;
271
272 for required_key in &self.required_context {
273 if !context.contains_key(required_key) {
274 return Err(PolicyViolation::MissingContext(required_key.clone()));
275 }
276 }
277
278 Ok(())
279 }
280}
281
282pub struct PolicyEngine {
284 policies: HashMap<[u8; 32], KeyPolicy>,
286 usage_counts: HashMap<[u8; 32], u64>,
288 violations: Vec<(SystemTime, [u8; 32], PolicyViolation)>,
290}
291
292impl Default for PolicyEngine {
293 fn default() -> Self {
294 Self::new()
295 }
296}
297
298impl PolicyEngine {
299 pub fn new() -> Self {
301 Self {
302 policies: HashMap::new(),
303 usage_counts: HashMap::new(),
304 violations: Vec::new(),
305 }
306 }
307
308 pub fn register_policy(&mut self, key_id: [u8; 32], policy: KeyPolicy) {
310 self.policies.insert(key_id, policy);
311 self.usage_counts.insert(key_id, 0);
312 }
313
314 pub fn update_policy(
316 &mut self,
317 key_id: &[u8; 32],
318 policy: KeyPolicy,
319 ) -> Result<(), PolicyViolation> {
320 if !self.policies.contains_key(key_id) {
321 return Err(PolicyViolation::PolicyNotFound);
322 }
323 self.policies.insert(*key_id, policy);
324 Ok(())
325 }
326
327 pub fn remove_policy(&mut self, key_id: &[u8; 32]) {
329 self.policies.remove(key_id);
330 self.usage_counts.remove(key_id);
331 }
332
333 pub fn check_policy(
335 &mut self,
336 key_id: &[u8; 32],
337 operation: Operation,
338 context: Option<&HashMap<String, String>>,
339 ) -> Result<(), PolicyViolation> {
340 let policy = self
341 .policies
342 .get(key_id)
343 .ok_or(PolicyViolation::PolicyNotFound)?;
344
345 if !policy.allows_operation(operation) {
347 let violation = PolicyViolation::OperationDenied(operation);
348 self.log_violation(*key_id, violation.clone());
349 return Err(violation);
350 }
351
352 if let Err(violation) = policy.check_validity() {
354 self.log_violation(*key_id, violation.clone());
355 return Err(violation);
356 }
357
358 let current_uses = *self.usage_counts.get(key_id).unwrap_or(&0);
360 if let Err(violation) = policy.check_usage_limit(current_uses) {
361 self.log_violation(*key_id, violation.clone());
362 return Err(violation);
363 }
364
365 if let Err(violation) = policy.check_context(context) {
367 self.log_violation(*key_id, violation.clone());
368 return Err(violation);
369 }
370
371 *self.usage_counts.entry(*key_id).or_insert(0) += 1;
373
374 Ok(())
375 }
376
377 pub fn get_usage_count(&self, key_id: &[u8; 32]) -> u64 {
379 *self.usage_counts.get(key_id).unwrap_or(&0)
380 }
381
382 pub fn reset_usage_count(&mut self, key_id: &[u8; 32]) {
384 if let Some(count) = self.usage_counts.get_mut(key_id) {
385 *count = 0;
386 }
387 }
388
389 pub fn get_policy(&self, key_id: &[u8; 32]) -> Option<&KeyPolicy> {
391 self.policies.get(key_id)
392 }
393
394 fn log_violation(&mut self, key_id: [u8; 32], violation: PolicyViolation) {
396 self.violations.push((SystemTime::now(), key_id, violation));
397 }
398
399 pub fn get_violations(&self) -> &[(SystemTime, [u8; 32], PolicyViolation)] {
401 &self.violations
402 }
403
404 pub fn get_key_violations(
406 &self,
407 key_id: &[u8; 32],
408 ) -> Vec<&(SystemTime, [u8; 32], PolicyViolation)> {
409 self.violations
410 .iter()
411 .filter(|(_, kid, _)| kid == key_id)
412 .collect()
413 }
414
415 pub fn clear_violations(&mut self) {
417 self.violations.clear();
418 }
419}
420
421pub trait KeyUsagePolicy {
423 fn check_key_usage(
425 &mut self,
426 key_id: &[u8; 32],
427 operation: Operation,
428 context: Option<&HashMap<String, String>>,
429 ) -> Result<(), PolicyViolation>;
430}
431
432impl KeyUsagePolicy for PolicyEngine {
433 fn check_key_usage(
434 &mut self,
435 key_id: &[u8; 32],
436 operation: Operation,
437 context: Option<&HashMap<String, String>>,
438 ) -> Result<(), PolicyViolation> {
439 self.check_policy(key_id, operation, context)
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn test_policy_allows_all_by_default() {
449 let policy = KeyPolicy::new();
450 assert!(policy.allows_operation(Operation::Sign));
451 assert!(policy.allows_operation(Operation::Encrypt));
452 assert!(policy.allows_operation(Operation::Decrypt));
453 }
454
455 #[test]
456 fn test_policy_restrictive() {
457 let policy = KeyPolicy::restrictive();
458 assert!(!policy.allows_operation(Operation::Sign));
459 assert!(!policy.allows_operation(Operation::Encrypt));
460 }
461
462 #[test]
463 fn test_policy_allow_operation() {
464 let policy = KeyPolicy::restrictive().allow_operation(Operation::Sign);
465 assert!(policy.allows_operation(Operation::Sign));
466 assert!(!policy.allows_operation(Operation::Encrypt));
467 }
468
469 #[test]
470 fn test_policy_deny_operation() {
471 let policy = KeyPolicy::new().deny_operation(Operation::Decrypt);
472 assert!(policy.allows_operation(Operation::Sign));
473 assert!(!policy.allows_operation(Operation::Decrypt));
474 }
475
476 #[test]
477 fn test_policy_deny_takes_precedence() {
478 let policy = KeyPolicy::new()
479 .allow_operation(Operation::Sign)
480 .deny_operation(Operation::Sign);
481 assert!(!policy.allows_operation(Operation::Sign));
482 }
483
484 #[test]
485 fn test_usage_limit() {
486 let policy = KeyPolicy::new().max_uses(5);
487 assert!(policy.check_usage_limit(0).is_ok());
488 assert!(policy.check_usage_limit(4).is_ok());
489 assert!(policy.check_usage_limit(5).is_err());
490 assert!(policy.check_usage_limit(10).is_err());
491 }
492
493 #[test]
494 fn test_validity_period() {
495 let now = SystemTime::now();
496 let past = now - Duration::from_secs(3600);
497 let future = now + Duration::from_secs(3600);
498
499 let policy = KeyPolicy::new().valid_from(future);
501 assert!(policy.check_validity().is_err());
502
503 let policy = KeyPolicy::new().valid_until(past);
505 assert!(policy.check_validity().is_err());
506
507 let policy = KeyPolicy::new().valid_from(past).valid_until(future);
509 assert!(policy.check_validity().is_ok());
510 }
511
512 #[test]
513 fn test_valid_for() {
514 let policy = KeyPolicy::new().valid_for(Duration::from_secs(3600));
515 assert!(policy.check_validity().is_ok());
516 }
517
518 #[test]
519 fn test_required_context() {
520 let policy = KeyPolicy::new().require_context("user_id".to_string());
521
522 assert!(policy.check_context(None).is_err());
524
525 let mut context = HashMap::new();
527 context.insert("session_id".to_string(), "123".to_string());
528 assert!(policy.check_context(Some(&context)).is_err());
529
530 context.insert("user_id".to_string(), "alice".to_string());
532 assert!(policy.check_context(Some(&context)).is_ok());
533 }
534
535 #[test]
536 fn test_policy_engine_register() {
537 let mut engine = PolicyEngine::new();
538 let key_id = [1u8; 32];
539 let policy = KeyPolicy::new();
540
541 engine.register_policy(key_id, policy);
542 assert!(engine.get_policy(&key_id).is_some());
543 assert_eq!(engine.get_usage_count(&key_id), 0);
544 }
545
546 #[test]
547 fn test_policy_engine_check() {
548 let mut engine = PolicyEngine::new();
549 let key_id = [1u8; 32];
550 let policy = KeyPolicy::new().allow_operation(Operation::Sign);
551
552 engine.register_policy(key_id, policy);
553
554 assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
556 assert_eq!(engine.get_usage_count(&key_id), 1);
557
558 assert!(
560 engine
561 .check_policy(&key_id, Operation::Decrypt, None)
562 .is_err()
563 );
564 assert_eq!(engine.get_usage_count(&key_id), 1); }
566
567 #[test]
568 fn test_policy_engine_usage_limit() {
569 let mut engine = PolicyEngine::new();
570 let key_id = [1u8; 32];
571 let policy = KeyPolicy::new().max_uses(3);
572
573 engine.register_policy(key_id, policy);
574
575 assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
577 assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
578 assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
579
580 assert!(engine.check_policy(&key_id, Operation::Sign, None).is_err());
582 }
583
584 #[test]
585 fn test_policy_engine_reset_usage() {
586 let mut engine = PolicyEngine::new();
587 let key_id = [1u8; 32];
588 let policy = KeyPolicy::new().max_uses(2);
589
590 engine.register_policy(key_id, policy);
591
592 assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
593 assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
594 assert_eq!(engine.get_usage_count(&key_id), 2);
595
596 engine.reset_usage_count(&key_id);
597 assert_eq!(engine.get_usage_count(&key_id), 0);
598
599 assert!(engine.check_policy(&key_id, Operation::Sign, None).is_ok());
601 }
602
603 #[test]
604 fn test_policy_engine_violations() {
605 let mut engine = PolicyEngine::new();
606 let key_id = [1u8; 32];
607 let policy = KeyPolicy::new().deny_operation(Operation::Decrypt);
608
609 engine.register_policy(key_id, policy);
610
611 let _ = engine.check_policy(&key_id, Operation::Decrypt, None);
613
614 let violations = engine.get_violations();
615 assert_eq!(violations.len(), 1);
616 assert_eq!(violations[0].1, key_id);
617
618 let key_violations = engine.get_key_violations(&key_id);
619 assert_eq!(key_violations.len(), 1);
620
621 engine.clear_violations();
622 assert_eq!(engine.get_violations().len(), 0);
623 }
624
625 #[test]
626 fn test_policy_engine_update_policy() {
627 let mut engine = PolicyEngine::new();
628 let key_id = [1u8; 32];
629 let policy1 = KeyPolicy::new().allow_operation(Operation::Sign);
630
631 engine.register_policy(key_id, policy1);
632
633 let policy2 = KeyPolicy::restrictive();
635 assert!(engine.update_policy(&key_id, policy2).is_ok());
636
637 assert!(engine.check_policy(&key_id, Operation::Sign, None).is_err());
639 }
640
641 #[test]
642 fn test_policy_engine_remove_policy() {
643 let mut engine = PolicyEngine::new();
644 let key_id = [1u8; 32];
645 let policy = KeyPolicy::new();
646
647 engine.register_policy(key_id, policy);
648 assert!(engine.get_policy(&key_id).is_some());
649
650 engine.remove_policy(&key_id);
651 assert!(engine.get_policy(&key_id).is_none());
652 }
653
654 #[test]
655 fn test_policy_metadata() {
656 let policy = KeyPolicy::new()
657 .with_metadata("purpose".to_string(), "signing".to_string())
658 .with_metadata("owner".to_string(), "alice".to_string());
659
660 assert_eq!(policy.metadata.get("purpose").unwrap(), "signing");
661 assert_eq!(policy.metadata.get("owner").unwrap(), "alice");
662 }
663
664 #[test]
665 fn test_policy_serialization() {
666 let policy = KeyPolicy::new()
667 .allow_operation(Operation::Sign)
668 .deny_operation(Operation::Decrypt)
669 .max_uses(100)
670 .require_context("user_id".to_string());
671
672 let serialized = crate::codec::encode(&policy).unwrap();
673 let deserialized: KeyPolicy = crate::codec::decode(&serialized).unwrap();
674
675 assert!(deserialized.allows_operation(Operation::Sign));
676 assert!(!deserialized.allows_operation(Operation::Decrypt));
677 assert_eq!(deserialized.max_uses, Some(100));
678 assert!(deserialized.required_context.contains("user_id"));
679 }
680}