1use dashmap::DashMap;
36use parking_lot::RwLock;
37use std::collections::HashMap;
38use std::sync::Arc;
39use std::time::{Duration, Instant};
40use thiserror::Error;
41
42#[derive(Debug, Error)]
44pub enum PolicyError {
45 #[error("Policy not found: {0}")]
47 NotFound(String),
48
49 #[error("Invalid policy: {0}")]
51 Invalid(String),
52
53 #[error("Policy conflict: {0}")]
55 Conflict(String),
56
57 #[error("Evaluation error: {0}")]
59 Evaluation(String),
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum PolicyAction {
65 Allow,
67 Deny,
69 RateLimit,
71 Log,
73 Verify,
75}
76
77#[derive(Debug, Clone)]
79pub struct PolicyResult {
80 pub allowed: bool,
82
83 pub action: PolicyAction,
85
86 pub policy_name: Option<String>,
88
89 pub reason: String,
91
92 pub confidence: f64,
94}
95
96impl PolicyResult {
97 pub fn allow(reason: String) -> Self {
99 Self {
100 allowed: true,
101 action: PolicyAction::Allow,
102 policy_name: None,
103 reason,
104 confidence: 1.0,
105 }
106 }
107
108 pub fn deny(reason: String, policy_name: Option<String>) -> Self {
110 Self {
111 allowed: false,
112 action: PolicyAction::Deny,
113 policy_name,
114 reason,
115 confidence: 1.0,
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct ConnectionPolicy {
123 pub name: String,
125
126 pub action: PolicyAction,
128
129 pub priority: u32,
131
132 pub peer_whitelist: Vec<String>,
134
135 pub peer_blacklist: Vec<String>,
137
138 pub max_connections_per_peer: Option<usize>,
140
141 pub rate_limit: Option<f64>,
143
144 pub enabled: bool,
146}
147
148impl ConnectionPolicy {
149 pub fn new(name: impl Into<String>) -> Self {
151 Self {
152 name: name.into(),
153 action: PolicyAction::Allow,
154 priority: 50,
155 peer_whitelist: Vec::new(),
156 peer_blacklist: Vec::new(),
157 max_connections_per_peer: None,
158 rate_limit: None,
159 enabled: true,
160 }
161 }
162
163 pub fn with_action(mut self, action: PolicyAction) -> Self {
165 self.action = action;
166 self
167 }
168
169 pub fn with_priority(mut self, priority: u32) -> Self {
171 self.priority = priority;
172 self
173 }
174
175 pub fn with_whitelist_peer(mut self, peer: impl Into<String>) -> Self {
177 self.peer_whitelist.push(peer.into());
178 self
179 }
180
181 pub fn with_blacklist_peer(mut self, peer: impl Into<String>) -> Self {
183 self.peer_blacklist.push(peer.into());
184 self
185 }
186
187 pub fn with_max_connections(mut self, max: usize) -> Self {
189 self.max_connections_per_peer = Some(max);
190 self
191 }
192
193 pub fn with_rate_limit(mut self, rate: f64) -> Self {
195 self.rate_limit = Some(rate);
196 self
197 }
198
199 pub fn evaluate(&self, peer_id: &str) -> Option<PolicyResult> {
201 if !self.enabled {
202 return None;
203 }
204
205 if self.peer_blacklist.iter().any(|p| p == peer_id) {
207 return Some(PolicyResult::deny(
208 format!("Peer {} is blacklisted", peer_id),
209 Some(self.name.clone()),
210 ));
211 }
212
213 if !self.peer_whitelist.is_empty() && !self.peer_whitelist.iter().any(|p| p == peer_id) {
215 return Some(PolicyResult::deny(
216 format!("Peer {} not in whitelist", peer_id),
217 Some(self.name.clone()),
218 ));
219 }
220
221 Some(PolicyResult {
223 allowed: self.action == PolicyAction::Allow,
224 action: self.action,
225 policy_name: Some(self.name.clone()),
226 reason: format!("Policy {} matched", self.name),
227 confidence: 1.0,
228 })
229 }
230}
231
232#[derive(Debug, Clone)]
234pub struct BandwidthPolicy {
235 pub name: String,
237
238 pub max_upload_bps: Option<u64>,
240
241 pub max_download_bps: Option<u64>,
243
244 pub per_peer_limit_bps: Option<u64>,
246
247 pub priority: u32,
249
250 pub enabled: bool,
252}
253
254impl BandwidthPolicy {
255 pub fn new(name: impl Into<String>) -> Self {
257 Self {
258 name: name.into(),
259 max_upload_bps: None,
260 max_download_bps: None,
261 per_peer_limit_bps: None,
262 priority: 50,
263 enabled: true,
264 }
265 }
266
267 pub fn with_max_upload(mut self, bps: u64) -> Self {
269 self.max_upload_bps = Some(bps);
270 self
271 }
272
273 pub fn with_max_download(mut self, bps: u64) -> Self {
275 self.max_download_bps = Some(bps);
276 self
277 }
278
279 pub fn with_per_peer_limit(mut self, bps: u64) -> Self {
281 self.per_peer_limit_bps = Some(bps);
282 self
283 }
284}
285
286#[derive(Debug, Clone)]
288pub struct ContentPolicy {
289 pub name: String,
291
292 pub action: PolicyAction,
294
295 pub allowed_patterns: Vec<String>,
297
298 pub blocked_patterns: Vec<String>,
300
301 pub max_size: Option<u64>,
303
304 pub priority: u32,
306
307 pub enabled: bool,
309}
310
311impl ContentPolicy {
312 pub fn new(name: impl Into<String>) -> Self {
314 Self {
315 name: name.into(),
316 action: PolicyAction::Allow,
317 allowed_patterns: Vec::new(),
318 blocked_patterns: Vec::new(),
319 max_size: None,
320 priority: 50,
321 enabled: true,
322 }
323 }
324
325 pub fn with_allowed_pattern(mut self, pattern: impl Into<String>) -> Self {
327 self.allowed_patterns.push(pattern.into());
328 self
329 }
330
331 pub fn with_blocked_pattern(mut self, pattern: impl Into<String>) -> Self {
333 self.blocked_patterns.push(pattern.into());
334 self
335 }
336
337 pub fn with_max_size(mut self, size: u64) -> Self {
339 self.max_size = Some(size);
340 self
341 }
342}
343
344#[derive(Debug, Clone)]
346pub struct PolicyConfig {
347 pub enabled: bool,
349
350 pub default_action: PolicyAction,
352
353 pub log_decisions: bool,
355
356 pub max_policies_per_type: usize,
358
359 pub evaluation_timeout: Duration,
361}
362
363impl Default for PolicyConfig {
364 fn default() -> Self {
365 Self {
366 enabled: true,
367 default_action: PolicyAction::Allow,
368 log_decisions: true,
369 max_policies_per_type: 100,
370 evaluation_timeout: Duration::from_millis(100),
371 }
372 }
373}
374
375impl PolicyConfig {
376 pub fn strict() -> Self {
378 Self {
379 enabled: true,
380 default_action: PolicyAction::Deny,
381 log_decisions: true,
382 max_policies_per_type: 200,
383 evaluation_timeout: Duration::from_millis(50),
384 }
385 }
386
387 pub fn permissive() -> Self {
389 Self {
390 enabled: true,
391 default_action: PolicyAction::Allow,
392 log_decisions: false,
393 max_policies_per_type: 50,
394 evaluation_timeout: Duration::from_millis(200),
395 }
396 }
397}
398
399#[derive(Debug, Clone, Default)]
401pub struct PolicyStats {
402 pub evaluations: u64,
404
405 pub allowed: u64,
407
408 pub denied: u64,
410
411 pub rate_limited: u64,
413
414 pub policy_hits: HashMap<String, u64>,
416
417 pub avg_eval_time_ms: f64,
419}
420
421pub struct PolicyEngine {
423 config: PolicyConfig,
424 connection_policies: Arc<RwLock<Vec<ConnectionPolicy>>>,
425 bandwidth_policies: Arc<RwLock<Vec<BandwidthPolicy>>>,
426 content_policies: Arc<RwLock<Vec<ContentPolicy>>>,
427 stats: Arc<RwLock<PolicyStats>>,
428 connection_counts: Arc<DashMap<String, usize>>,
429}
430
431impl PolicyEngine {
432 pub fn new(config: PolicyConfig) -> Self {
434 Self {
435 config,
436 connection_policies: Arc::new(RwLock::new(Vec::new())),
437 bandwidth_policies: Arc::new(RwLock::new(Vec::new())),
438 content_policies: Arc::new(RwLock::new(Vec::new())),
439 stats: Arc::new(RwLock::new(PolicyStats::default())),
440 connection_counts: Arc::new(DashMap::new()),
441 }
442 }
443
444 pub fn add_connection_policy(&self, policy: ConnectionPolicy) -> Result<(), PolicyError> {
446 let mut policies = self.connection_policies.write();
447
448 if policies.len() >= self.config.max_policies_per_type {
449 return Err(PolicyError::Invalid(
450 "Maximum connection policies reached".to_string(),
451 ));
452 }
453
454 policies.push(policy);
455 policies.sort_by(|a, b| b.priority.cmp(&a.priority));
456
457 Ok(())
458 }
459
460 pub fn add_bandwidth_policy(&self, policy: BandwidthPolicy) -> Result<(), PolicyError> {
462 let mut policies = self.bandwidth_policies.write();
463
464 if policies.len() >= self.config.max_policies_per_type {
465 return Err(PolicyError::Invalid(
466 "Maximum bandwidth policies reached".to_string(),
467 ));
468 }
469
470 policies.push(policy);
471 policies.sort_by(|a, b| b.priority.cmp(&a.priority));
472
473 Ok(())
474 }
475
476 pub fn add_content_policy(&self, policy: ContentPolicy) -> Result<(), PolicyError> {
478 let mut policies = self.content_policies.write();
479
480 if policies.len() >= self.config.max_policies_per_type {
481 return Err(PolicyError::Invalid(
482 "Maximum content policies reached".to_string(),
483 ));
484 }
485
486 policies.push(policy);
487 policies.sort_by(|a, b| b.priority.cmp(&a.priority));
488
489 Ok(())
490 }
491
492 pub async fn evaluate_connection(&self, peer_id: &str) -> Result<bool, PolicyError> {
494 let start = Instant::now();
495
496 if !self.config.enabled {
497 return Ok(true);
498 }
499
500 let policies = self.connection_policies.read();
501
502 for policy in policies.iter() {
504 if let Some(result) = policy.evaluate(peer_id) {
505 self.record_evaluation(&result, start.elapsed());
506 return Ok(result.allowed);
507 }
508 }
509
510 let allowed = self.config.default_action == PolicyAction::Allow;
512 self.record_default_evaluation(allowed, start.elapsed());
513
514 Ok(allowed)
515 }
516
517 pub fn can_connect(&self, peer_id: &str) -> bool {
519 let count = self.connection_counts.get(peer_id).map(|c| *c).unwrap_or(0);
520
521 let policies = self.connection_policies.read();
522
523 for policy in policies.iter() {
524 if let Some(max) = policy.max_connections_per_peer {
525 if count >= max {
526 return false;
527 }
528 }
529 }
530
531 true
532 }
533
534 pub fn record_connection(&self, peer_id: &str) {
536 self.connection_counts
537 .entry(peer_id.to_string())
538 .and_modify(|c| *c += 1)
539 .or_insert(1);
540 }
541
542 pub fn record_disconnection(&self, peer_id: &str) {
544 if let Some(mut count) = self.connection_counts.get_mut(peer_id) {
545 if *count > 0 {
546 *count -= 1;
547 }
548 }
549 }
550
551 pub fn remove_connection_policy(&self, name: &str) -> Result<(), PolicyError> {
553 let mut policies = self.connection_policies.write();
554 let len_before = policies.len();
555 policies.retain(|p| p.name != name);
556
557 if policies.len() == len_before {
558 Err(PolicyError::NotFound(name.to_string()))
559 } else {
560 Ok(())
561 }
562 }
563
564 pub fn connection_policies(&self) -> Vec<ConnectionPolicy> {
566 self.connection_policies.read().clone()
567 }
568
569 pub fn bandwidth_policies(&self) -> Vec<BandwidthPolicy> {
571 self.bandwidth_policies.read().clone()
572 }
573
574 pub fn content_policies(&self) -> Vec<ContentPolicy> {
576 self.content_policies.read().clone()
577 }
578
579 pub fn stats(&self) -> PolicyStats {
581 self.stats.read().clone()
582 }
583
584 pub fn reset_stats(&self) {
586 let mut stats = self.stats.write();
587 *stats = PolicyStats::default();
588 }
589
590 fn record_evaluation(&self, result: &PolicyResult, duration: Duration) {
592 let mut stats = self.stats.write();
593 stats.evaluations += 1;
594
595 if result.allowed {
596 stats.allowed += 1;
597 } else {
598 stats.denied += 1;
599 }
600
601 if result.action == PolicyAction::RateLimit {
602 stats.rate_limited += 1;
603 }
604
605 if let Some(ref policy_name) = result.policy_name {
606 *stats.policy_hits.entry(policy_name.clone()).or_insert(0) += 1;
607 }
608
609 let eval_time_ms = duration.as_secs_f64() * 1000.0;
611 let alpha = 0.3;
612 stats.avg_eval_time_ms = alpha * eval_time_ms + (1.0 - alpha) * stats.avg_eval_time_ms;
613 }
614
615 fn record_default_evaluation(&self, allowed: bool, duration: Duration) {
617 let mut stats = self.stats.write();
618 stats.evaluations += 1;
619
620 if allowed {
621 stats.allowed += 1;
622 } else {
623 stats.denied += 1;
624 }
625
626 let eval_time_ms = duration.as_secs_f64() * 1000.0;
627 let alpha = 0.3;
628 stats.avg_eval_time_ms = alpha * eval_time_ms + (1.0 - alpha) * stats.avg_eval_time_ms;
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
637 fn test_policy_creation() {
638 let policy = ConnectionPolicy::new("test")
639 .with_action(PolicyAction::Allow)
640 .with_priority(100);
641
642 assert_eq!(policy.name, "test");
643 assert_eq!(policy.action, PolicyAction::Allow);
644 assert_eq!(policy.priority, 100);
645 }
646
647 #[test]
648 fn test_policy_engine_creation() {
649 let config = PolicyConfig::default();
650 let engine = PolicyEngine::new(config);
651
652 assert_eq!(engine.connection_policies().len(), 0);
653 }
654
655 #[tokio::test]
656 async fn test_add_connection_policy() {
657 let engine = PolicyEngine::new(PolicyConfig::default());
658
659 let policy = ConnectionPolicy::new("test");
660 assert!(engine.add_connection_policy(policy).is_ok());
661 assert_eq!(engine.connection_policies().len(), 1);
662 }
663
664 #[tokio::test]
665 async fn test_blacklist_evaluation() {
666 let engine = PolicyEngine::new(PolicyConfig::default());
667
668 let policy = ConnectionPolicy::new("blacklist")
669 .with_action(PolicyAction::Allow)
670 .with_blacklist_peer("bad_peer");
671
672 engine.add_connection_policy(policy).unwrap();
673
674 let allowed = engine.evaluate_connection("bad_peer").await.unwrap();
675 assert!(!allowed);
676
677 let allowed = engine.evaluate_connection("good_peer").await.unwrap();
678 assert!(allowed);
679 }
680
681 #[tokio::test]
682 async fn test_whitelist_evaluation() {
683 let engine = PolicyEngine::new(PolicyConfig::default());
684
685 let policy = ConnectionPolicy::new("whitelist")
686 .with_action(PolicyAction::Allow)
687 .with_whitelist_peer("good_peer");
688
689 engine.add_connection_policy(policy).unwrap();
690
691 let allowed = engine.evaluate_connection("good_peer").await.unwrap();
692 assert!(allowed);
693
694 let allowed = engine.evaluate_connection("bad_peer").await.unwrap();
695 assert!(!allowed);
696 }
697
698 #[test]
699 fn test_connection_counting() {
700 let engine = PolicyEngine::new(PolicyConfig::default());
701
702 engine.record_connection("peer1");
703 engine.record_connection("peer1");
704
705 assert!(engine.can_connect("peer1"));
706
707 engine.record_disconnection("peer1");
708 assert!(engine.can_connect("peer1"));
709 }
710
711 #[tokio::test]
712 async fn test_policy_priority() {
713 let engine = PolicyEngine::new(PolicyConfig::default());
714
715 let policy1 = ConnectionPolicy::new("low")
717 .with_action(PolicyAction::Deny)
718 .with_priority(10);
719
720 let policy2 = ConnectionPolicy::new("high")
722 .with_action(PolicyAction::Allow)
723 .with_priority(100);
724
725 engine.add_connection_policy(policy1).unwrap();
726 engine.add_connection_policy(policy2).unwrap();
727
728 let allowed = engine.evaluate_connection("test_peer").await.unwrap();
730 assert!(allowed);
731 }
732
733 #[tokio::test]
734 async fn test_remove_policy() {
735 let engine = PolicyEngine::new(PolicyConfig::default());
736
737 let policy = ConnectionPolicy::new("test");
738 engine.add_connection_policy(policy).unwrap();
739
740 assert!(engine.remove_connection_policy("test").is_ok());
741 assert_eq!(engine.connection_policies().len(), 0);
742
743 assert!(engine.remove_connection_policy("nonexistent").is_err());
744 }
745
746 #[tokio::test]
747 async fn test_statistics() {
748 let engine = PolicyEngine::new(PolicyConfig::default());
749
750 let policy = ConnectionPolicy::new("test").with_action(PolicyAction::Allow);
751 engine.add_connection_policy(policy).unwrap();
752
753 engine.evaluate_connection("peer1").await.unwrap();
754 engine.evaluate_connection("peer2").await.unwrap();
755
756 let stats = engine.stats();
757 assert_eq!(stats.evaluations, 2);
758 assert_eq!(stats.allowed, 2);
759 }
760
761 #[test]
762 fn test_bandwidth_policy() {
763 let policy = BandwidthPolicy::new("test")
764 .with_max_upload(1_000_000)
765 .with_max_download(5_000_000)
766 .with_per_peer_limit(100_000);
767
768 assert_eq!(policy.max_upload_bps, Some(1_000_000));
769 assert_eq!(policy.max_download_bps, Some(5_000_000));
770 assert_eq!(policy.per_peer_limit_bps, Some(100_000));
771 }
772
773 #[test]
774 fn test_content_policy() {
775 let policy = ContentPolicy::new("test")
776 .with_allowed_pattern("^Qm.*")
777 .with_max_size(10_000_000);
778
779 assert_eq!(policy.allowed_patterns.len(), 1);
780 assert_eq!(policy.max_size, Some(10_000_000));
781 }
782
783 #[test]
784 fn test_policy_config_presets() {
785 let strict = PolicyConfig::strict();
786 assert_eq!(strict.default_action, PolicyAction::Deny);
787
788 let permissive = PolicyConfig::permissive();
789 assert_eq!(permissive.default_action, PolicyAction::Allow);
790 }
791
792 #[tokio::test]
793 async fn test_default_action() {
794 let config = PolicyConfig {
795 default_action: PolicyAction::Deny,
796 ..Default::default()
797 };
798 let engine = PolicyEngine::new(config);
799
800 let allowed = engine.evaluate_connection("peer1").await.unwrap();
802 assert!(!allowed);
803 }
804
805 #[tokio::test]
806 async fn test_reset_stats() {
807 let engine = PolicyEngine::new(PolicyConfig::default());
808
809 let policy = ConnectionPolicy::new("test");
810 engine.add_connection_policy(policy).unwrap();
811
812 engine.evaluate_connection("peer1").await.unwrap();
813 assert_eq!(engine.stats().evaluations, 1);
814
815 engine.reset_stats();
816 assert_eq!(engine.stats().evaluations, 0);
817 }
818}