1use std::collections::HashMap;
37use std::time::{Duration, Instant};
38
39use oxide_license::{require_feature_sync, Feature, LicenseError};
40use serde::{Deserialize, Serialize};
41use tracing::{debug, info, instrument, warn};
42
43use crate::guard::{Guard, GuardAction, GuardCheckResult};
44use oxideshield_core::Match;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48#[serde(rename_all = "snake_case")]
49#[derive(Default)]
50pub enum AggregationStrategy {
51 #[default]
53 FailFast,
54 Unanimous,
56 Majority,
58 Weighted,
60 Comprehensive,
62}
63
64#[derive(Debug, Clone)]
66pub struct LayerConfig {
67 pub name: String,
69 pub layer_type: LayerType,
71 pub weight: f32,
73 pub enabled: bool,
75 pub timeout: Duration,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
81#[serde(rename_all = "snake_case")]
82pub enum LayerType {
83 Regex,
85 Perplexity,
87 PII,
89 Toxicity,
91 MLClassifier,
93 Semantic,
95 Custom,
97}
98
99impl LayerConfig {
100 pub fn regex() -> Self {
102 Self {
103 name: "regex".to_string(),
104 layer_type: LayerType::Regex,
105 weight: 1.0,
106 enabled: true,
107 timeout: Duration::from_millis(10),
108 }
109 }
110
111 pub fn perplexity() -> Self {
113 Self {
114 name: "perplexity".to_string(),
115 layer_type: LayerType::Perplexity,
116 weight: 1.0,
117 enabled: true,
118 timeout: Duration::from_millis(50),
119 }
120 }
121
122 pub fn pii() -> Self {
124 Self {
125 name: "pii".to_string(),
126 layer_type: LayerType::PII,
127 weight: 1.0,
128 enabled: true,
129 timeout: Duration::from_millis(50),
130 }
131 }
132
133 pub fn toxicity() -> Self {
135 Self {
136 name: "toxicity".to_string(),
137 layer_type: LayerType::Toxicity,
138 weight: 1.0,
139 enabled: true,
140 timeout: Duration::from_millis(50),
141 }
142 }
143
144 pub fn ml_classifier() -> Self {
146 Self {
147 name: "ml_classifier".to_string(),
148 layer_type: LayerType::MLClassifier,
149 weight: 1.5, enabled: true,
151 timeout: Duration::from_millis(100),
152 }
153 }
154
155 pub fn semantic() -> Self {
157 Self {
158 name: "semantic".to_string(),
159 layer_type: LayerType::Semantic,
160 weight: 1.5, enabled: true,
162 timeout: Duration::from_millis(100),
163 }
164 }
165
166 pub fn with_weight(mut self, weight: f32) -> Self {
168 self.weight = weight.max(0.0);
169 self
170 }
171
172 pub fn with_enabled(mut self, enabled: bool) -> Self {
174 self.enabled = enabled;
175 self
176 }
177
178 pub fn with_timeout(mut self, timeout: Duration) -> Self {
180 self.timeout = timeout;
181 self
182 }
183}
184
185#[derive(Debug, Clone)]
187pub struct LayerResult {
188 pub layer_name: String,
190 pub layer_type: LayerType,
192 pub result: GuardCheckResult,
194 pub duration: Duration,
196 pub weight: f32,
198}
199
200#[derive(Debug, Clone)]
202pub struct MultiLayerResult {
203 pub passed: bool,
205 pub action: GuardAction,
207 pub layer_results: Vec<LayerResult>,
209 pub all_matches: Vec<Match>,
211 pub total_duration: Duration,
213 pub strategy: AggregationStrategy,
215 pub summary: String,
217}
218
219#[derive(Debug, Default)]
223pub struct DefenseTelemetry {
224 total_checks: std::sync::atomic::AtomicU64,
226 passed_checks: std::sync::atomic::AtomicU64,
228 blocked_checks: std::sync::atomic::AtomicU64,
230 layer_stats: parking_lot::RwLock<HashMap<String, LayerStats>>,
232}
233
234impl DefenseTelemetry {
235 pub fn new() -> Self {
237 Self::default()
238 }
239
240 pub fn record_check(&self, passed: bool, layer_results: &[LayerResult]) {
242 use std::sync::atomic::Ordering;
243
244 self.total_checks.fetch_add(1, Ordering::Relaxed);
245 if passed {
246 self.passed_checks.fetch_add(1, Ordering::Relaxed);
247 } else {
248 self.blocked_checks.fetch_add(1, Ordering::Relaxed);
249 }
250
251 let mut stats = self.layer_stats.write();
253 for layer_result in layer_results {
254 let layer_stats = stats.entry(layer_result.layer_name.clone()).or_default();
255 layer_stats.record(layer_result);
256 }
257 }
258
259 pub fn total_checks(&self) -> u64 {
261 self.total_checks.load(std::sync::atomic::Ordering::Relaxed)
262 }
263
264 pub fn passed_checks(&self) -> u64 {
266 self.passed_checks
267 .load(std::sync::atomic::Ordering::Relaxed)
268 }
269
270 pub fn blocked_checks(&self) -> u64 {
272 self.blocked_checks
273 .load(std::sync::atomic::Ordering::Relaxed)
274 }
275
276 pub fn block_rate(&self) -> f64 {
278 let total = self.total_checks();
279 if total == 0 {
280 0.0
281 } else {
282 self.blocked_checks() as f64 / total as f64
283 }
284 }
285
286 pub fn layer_stats(&self) -> HashMap<String, LayerStats> {
288 self.layer_stats.read().clone()
289 }
290
291 pub fn reset(&self) {
293 use std::sync::atomic::Ordering;
294 self.total_checks.store(0, Ordering::Relaxed);
295 self.passed_checks.store(0, Ordering::Relaxed);
296 self.blocked_checks.store(0, Ordering::Relaxed);
297 self.layer_stats.write().clear();
298 }
299}
300
301impl Clone for DefenseTelemetry {
302 fn clone(&self) -> Self {
303 use std::sync::atomic::Ordering;
304 Self {
305 total_checks: std::sync::atomic::AtomicU64::new(
306 self.total_checks.load(Ordering::Relaxed),
307 ),
308 passed_checks: std::sync::atomic::AtomicU64::new(
309 self.passed_checks.load(Ordering::Relaxed),
310 ),
311 blocked_checks: std::sync::atomic::AtomicU64::new(
312 self.blocked_checks.load(Ordering::Relaxed),
313 ),
314 layer_stats: parking_lot::RwLock::new(self.layer_stats.read().clone()),
315 }
316 }
317}
318
319#[derive(Debug, Clone, Default)]
321pub struct LayerStats {
322 pub checks: u64,
324 pub detections: u64,
326 total_duration_ms: f64,
328 pub avg_duration_ms: f64,
330 pub detection_rate: f64,
332}
333
334impl LayerStats {
335 pub fn record(&mut self, result: &LayerResult) {
337 self.checks += 1;
338 if !result.result.passed {
339 self.detections += 1;
340 }
341 self.total_duration_ms += result.duration.as_secs_f64() * 1000.0;
342 self.avg_duration_ms = self.total_duration_ms / self.checks as f64;
343 self.detection_rate = self.detections as f64 / self.checks as f64;
344 }
345}
346
347pub struct MultiLayerDefense {
351 name: String,
352 layers: Vec<LayerConfig>,
354 guards: Vec<Box<dyn Guard>>,
356 strategy: AggregationStrategy,
358 telemetry: Option<DefenseTelemetry>,
360}
361
362impl MultiLayerDefense {
363 pub fn builder(name: impl Into<String>) -> MultiLayerDefenseBuilder {
365 MultiLayerDefenseBuilder::new(name)
366 }
367
368 #[instrument(skip(self, content), fields(defense = %self.name, content_len = content.len()))]
370 pub fn check(&self, content: &str) -> MultiLayerResult {
371 let start = Instant::now();
372 let mut layer_results = Vec::new();
373 let mut all_matches = Vec::new();
374 let mut blocked_count = 0;
375 let mut total_weight = 0.0f32;
376 let mut blocked_weight = 0.0f32;
377
378 for (i, (layer_config, guard)) in self.layers.iter().zip(self.guards.iter()).enumerate() {
379 if !layer_config.enabled {
380 continue;
381 }
382
383 let layer_start = Instant::now();
384 let result = guard.check(content);
385 let duration = layer_start.elapsed();
386
387 debug!(
388 layer = %layer_config.name,
389 passed = result.passed,
390 duration_ms = %duration.as_millis(),
391 "Layer check complete"
392 );
393
394 let layer_result = LayerResult {
395 layer_name: layer_config.name.clone(),
396 layer_type: layer_config.layer_type,
397 result: result.clone(),
398 duration,
399 weight: layer_config.weight,
400 };
401
402 if !result.passed {
403 blocked_count += 1;
404 blocked_weight += layer_config.weight;
405 all_matches.extend(result.matches.clone());
406 }
407 total_weight += layer_config.weight;
408
409 layer_results.push(layer_result);
410
411 if self.should_terminate_early(&result, blocked_count, i) {
413 break;
414 }
415 }
416
417 let total_duration = start.elapsed();
418
419 let (passed, action) =
421 self.aggregate_results(&layer_results, blocked_count, blocked_weight, total_weight);
422
423 let summary = self.build_summary(&layer_results, passed, total_duration);
424
425 info!(
426 passed = passed,
427 blocked_count = blocked_count,
428 layers_checked = layer_results.len(),
429 total_duration_ms = %total_duration.as_millis(),
430 "Multi-layer defense complete"
431 );
432
433 if let Some(ref telemetry) = self.telemetry {
435 telemetry.record_check(passed, &layer_results);
436 }
437
438 MultiLayerResult {
439 passed,
440 action,
441 layer_results,
442 all_matches,
443 total_duration,
444 strategy: self.strategy,
445 summary,
446 }
447 }
448
449 fn should_terminate_early(
451 &self,
452 result: &GuardCheckResult,
453 blocked_count: usize,
454 layer_index: usize,
455 ) -> bool {
456 match self.strategy {
457 AggregationStrategy::FailFast => !result.passed && result.action == GuardAction::Block,
458 AggregationStrategy::Unanimous => {
459 false
461 }
462 AggregationStrategy::Majority => {
463 let total_layers = self.layers.iter().filter(|l| l.enabled).count();
466 let layers_checked = layer_index + 1;
467 let layers_remaining = total_layers.saturating_sub(layers_checked);
468 let majority_threshold = (total_layers / 2) + 1;
469
470 if blocked_count >= majority_threshold {
472 debug!(
473 blocked_count = blocked_count,
474 threshold = majority_threshold,
475 "Early termination: majority already blocked"
476 );
477 true
478 }
479 else if blocked_count + layers_remaining < majority_threshold {
481 debug!(
482 blocked_count = blocked_count,
483 remaining = layers_remaining,
484 threshold = majority_threshold,
485 "Early termination: majority impossible"
486 );
487 true
488 } else {
489 false
490 }
491 }
492 AggregationStrategy::Weighted => {
493 false
495 }
496 AggregationStrategy::Comprehensive => {
497 false
499 }
500 }
501 }
502
503 fn aggregate_results(
505 &self,
506 layer_results: &[LayerResult],
507 blocked_count: usize,
508 blocked_weight: f32,
509 total_weight: f32,
510 ) -> (bool, GuardAction) {
511 if layer_results.is_empty() {
512 return (true, GuardAction::Allow);
513 }
514
515 match self.strategy {
516 AggregationStrategy::FailFast => {
517 if blocked_count > 0 {
519 let action = layer_results
520 .iter()
521 .filter(|r| !r.result.passed)
522 .map(|r| r.result.action)
523 .max_by_key(|a| match a {
524 GuardAction::Block => 5,
525 GuardAction::Alert => 4,
526 GuardAction::Sanitize => 3,
527 GuardAction::Suggest => 2,
528 GuardAction::Log => 1,
529 GuardAction::Allow => 0,
530 })
531 .unwrap_or(GuardAction::Block);
532 (false, action)
533 } else {
534 (true, GuardAction::Allow)
535 }
536 }
537 AggregationStrategy::Unanimous => {
538 let total = layer_results.len();
540 if blocked_count == total && total > 0 {
541 (false, GuardAction::Block)
542 } else {
543 (true, GuardAction::Allow)
544 }
545 }
546 AggregationStrategy::Majority => {
547 let total = layer_results.len();
549 if blocked_count * 2 > total {
550 (false, GuardAction::Block)
551 } else {
552 (true, GuardAction::Allow)
553 }
554 }
555 AggregationStrategy::Weighted => {
556 if total_weight > 0.0 && blocked_weight / total_weight > 0.5 {
558 (false, GuardAction::Block)
559 } else {
560 (true, GuardAction::Allow)
561 }
562 }
563 AggregationStrategy::Comprehensive => {
564 if blocked_count > 0 {
566 (false, GuardAction::Block)
567 } else {
568 (true, GuardAction::Allow)
569 }
570 }
571 }
572 }
573
574 fn build_summary(
576 &self,
577 layer_results: &[LayerResult],
578 passed: bool,
579 duration: Duration,
580 ) -> String {
581 let triggered: Vec<&str> = layer_results
582 .iter()
583 .filter(|r| !r.result.passed)
584 .map(|r| r.layer_name.as_str())
585 .collect();
586
587 if passed {
588 format!(
589 "Passed all {} layers in {}ms",
590 layer_results.len(),
591 duration.as_millis()
592 )
593 } else {
594 format!(
595 "Blocked by {} of {} layers ({}) in {}ms",
596 triggered.len(),
597 layer_results.len(),
598 triggered.join(", "),
599 duration.as_millis()
600 )
601 }
602 }
603
604 pub fn name(&self) -> &str {
606 &self.name
607 }
608
609 pub fn layers(&self) -> &[LayerConfig] {
611 &self.layers
612 }
613
614 pub fn strategy(&self) -> AggregationStrategy {
616 self.strategy
617 }
618
619 pub fn telemetry(&self) -> Option<&DefenseTelemetry> {
621 self.telemetry.as_ref()
622 }
623
624 pub fn has_telemetry(&self) -> bool {
626 self.telemetry.is_some()
627 }
628}
629
630pub struct MultiLayerDefenseBuilder {
632 name: String,
633 layers: Vec<LayerConfig>,
634 guards: Vec<Box<dyn Guard>>,
635 strategy: AggregationStrategy,
636 enable_telemetry: bool,
637}
638
639impl MultiLayerDefenseBuilder {
640 pub fn new(name: impl Into<String>) -> Self {
642 Self {
643 name: name.into(),
644 layers: Vec::new(),
645 guards: Vec::new(),
646 strategy: AggregationStrategy::FailFast,
647 enable_telemetry: false,
648 }
649 }
650
651 pub fn add_guard(mut self, config: LayerConfig, guard: Box<dyn Guard>) -> Self {
653 self.layers.push(config);
654 self.guards.push(guard);
655 self
656 }
657
658 pub fn with_strategy(mut self, strategy: AggregationStrategy) -> Self {
660 self.strategy = strategy;
661 self
662 }
663
664 pub fn with_telemetry(mut self, enabled: bool) -> Self {
666 self.enable_telemetry = enabled;
667 self
668 }
669
670 pub fn build(self) -> Result<MultiLayerDefense, LicenseError> {
677 require_feature_sync(Feature::MultiLayerDefense)?;
678 Ok(self.build_unchecked())
679 }
680
681 pub(crate) fn build_unchecked(self) -> MultiLayerDefense {
685 MultiLayerDefense {
686 name: self.name,
687 layers: self.layers,
688 guards: self.guards,
689 strategy: self.strategy,
690 telemetry: if self.enable_telemetry {
691 Some(DefenseTelemetry::default())
692 } else {
693 None
694 },
695 }
696 }
697}
698
699#[cfg(test)]
700mod tests {
701 use super::*;
702 use crate::guard::LengthGuard;
703 use crate::guards::PerplexityGuard;
704
705 #[test]
709 fn test_layer_config_defaults() {
710 let regex = LayerConfig::regex();
711 assert_eq!(regex.layer_type, LayerType::Regex);
712 assert!(regex.enabled);
713
714 let perplexity = LayerConfig::perplexity();
715 assert_eq!(perplexity.layer_type, LayerType::Perplexity);
716 }
717
718 #[test]
719 fn test_multilayer_builder() {
720 let defense = MultiLayerDefense::builder("test")
721 .add_guard(
722 LayerConfig::regex(),
723 Box::new(LengthGuard::new("length").with_max_chars(1000)),
724 )
725 .with_strategy(AggregationStrategy::FailFast)
726 .build_unchecked();
727
728 assert_eq!(defense.name(), "test");
729 assert_eq!(defense.strategy(), AggregationStrategy::FailFast);
730 assert_eq!(defense.layers().len(), 1);
731 }
732
733 #[test]
734 fn test_multilayer_pass() {
735 let defense = MultiLayerDefense::builder("test")
736 .add_guard(
737 LayerConfig::regex(),
738 Box::new(LengthGuard::new("length").with_max_chars(1000)),
739 )
740 .build_unchecked();
741
742 let result = defense.check("Short text");
743 assert!(result.passed);
744 assert_eq!(result.action, GuardAction::Allow);
745 }
746
747 #[test]
748 fn test_multilayer_fail_fast() {
749 let defense = MultiLayerDefense::builder("test")
750 .add_guard(
751 LayerConfig::regex(),
752 Box::new(LengthGuard::new("length").with_max_chars(5)),
753 )
754 .add_guard(
755 LayerConfig::perplexity(),
756 Box::new(PerplexityGuard::new("perplexity")),
757 )
758 .with_strategy(AggregationStrategy::FailFast)
759 .build_unchecked();
760
761 let result = defense.check("This is too long");
762 assert!(!result.passed);
763 assert_eq!(result.layer_results.len(), 1);
765 }
766
767 #[test]
768 fn test_multilayer_comprehensive() {
769 let defense = MultiLayerDefense::builder("test")
770 .add_guard(
771 LayerConfig::regex(),
772 Box::new(LengthGuard::new("length").with_max_chars(5)),
773 )
774 .add_guard(
775 LayerConfig::perplexity(),
776 Box::new(PerplexityGuard::new("perplexity")),
777 )
778 .with_strategy(AggregationStrategy::Comprehensive)
779 .build_unchecked();
780
781 let result = defense.check("This is too long");
782 assert!(!result.passed);
783 assert_eq!(result.layer_results.len(), 2);
785 }
786
787 #[test]
788 fn test_aggregation_unanimous() {
789 let defense = MultiLayerDefense::builder("test")
790 .add_guard(
791 LayerConfig::regex(),
792 Box::new(LengthGuard::new("length").with_max_chars(5)), )
794 .add_guard(
795 LayerConfig::perplexity(),
796 Box::new(LengthGuard::new("length2").with_max_chars(1000)), )
798 .with_strategy(AggregationStrategy::Unanimous)
799 .build_unchecked();
800
801 let result = defense.check("Medium text");
802 assert!(result.passed);
804 }
805
806 #[test]
807 fn test_aggregation_majority() {
808 let defense = MultiLayerDefense::builder("test")
809 .add_guard(
810 LayerConfig::regex().with_weight(1.0),
811 Box::new(LengthGuard::new("l1").with_max_chars(5)), )
813 .add_guard(
814 LayerConfig::regex().with_weight(1.0),
815 Box::new(LengthGuard::new("l2").with_max_chars(5)), )
817 .add_guard(
818 LayerConfig::regex().with_weight(1.0),
819 Box::new(LengthGuard::new("l3").with_max_chars(1000)), )
821 .with_strategy(AggregationStrategy::Majority)
822 .build_unchecked();
823
824 let result = defense.check("Medium length text");
825 assert!(!result.passed);
827 }
828
829 #[test]
830 fn test_layer_results_contain_timing() {
831 let defense = MultiLayerDefense::builder("test")
832 .add_guard(
833 LayerConfig::regex(),
834 Box::new(LengthGuard::new("length").with_max_chars(1000)),
835 )
836 .build_unchecked();
837
838 let result = defense.check("Test");
839 assert!(!result.layer_results.is_empty());
840 assert!(result.layer_results[0].duration.as_nanos() > 0);
841 assert!(result.total_duration.as_nanos() > 0);
842 }
843
844 #[test]
845 fn test_telemetry_recording() {
846 let defense = MultiLayerDefense::builder("telemetry_test")
847 .add_guard(
848 LayerConfig::regex(), Box::new(LengthGuard::new("length").with_max_chars(1000)),
850 )
851 .add_guard(
852 LayerConfig::perplexity(), Box::new(PerplexityGuard::new("perplexity")),
854 )
855 .with_telemetry(true)
856 .build_unchecked();
857
858 assert!(defense.has_telemetry());
860
861 defense.check("Short text");
863 defense.check("Another short text");
864 defense.check("x".repeat(2000).as_str()); let telemetry = defense.telemetry().unwrap();
868 assert_eq!(telemetry.total_checks(), 3);
869 assert_eq!(telemetry.passed_checks(), 2);
870 assert_eq!(telemetry.blocked_checks(), 1);
871 assert!(telemetry.block_rate() > 0.0);
872
873 let layer_stats = telemetry.layer_stats();
875 assert!(
876 layer_stats.contains_key("regex"),
877 "Expected 'regex' layer stats"
878 );
879 let regex_stats = &layer_stats["regex"];
880 assert_eq!(regex_stats.checks, 3);
881 assert_eq!(regex_stats.detections, 1);
882 assert!(regex_stats.avg_duration_ms > 0.0);
883 }
884
885 #[test]
886 fn test_telemetry_disabled_by_default() {
887 let defense = MultiLayerDefense::builder("no_telemetry")
888 .add_guard(
889 LayerConfig::regex(),
890 Box::new(LengthGuard::new("length").with_max_chars(1000)),
891 )
892 .build_unchecked();
893
894 assert!(!defense.has_telemetry());
895 assert!(defense.telemetry().is_none());
896 }
897
898 #[test]
899 fn test_majority_early_termination() {
900 let defense = MultiLayerDefense::builder("majority_early")
902 .add_guard(
903 LayerConfig::regex(),
904 Box::new(LengthGuard::new("l1").with_max_chars(5)), )
906 .add_guard(
907 LayerConfig::regex(),
908 Box::new(LengthGuard::new("l2").with_max_chars(5)), )
910 .add_guard(
911 LayerConfig::regex(),
912 Box::new(LengthGuard::new("l3").with_max_chars(5)), )
914 .add_guard(
915 LayerConfig::regex(),
916 Box::new(LengthGuard::new("l4").with_max_chars(1000)), )
918 .add_guard(
919 LayerConfig::regex(),
920 Box::new(LengthGuard::new("l5").with_max_chars(1000)), )
922 .with_strategy(AggregationStrategy::Majority)
923 .build_unchecked();
924
925 let result = defense.check("This text is longer than 5 chars");
926 assert!(!result.passed);
927
928 assert_eq!(
931 result.layer_results.len(),
932 3,
933 "Should terminate after reaching majority"
934 );
935 }
936
937 #[test]
938 fn test_license_check_required() {
939 let builder = MultiLayerDefense::builder("test").add_guard(
941 LayerConfig::regex(),
942 Box::new(LengthGuard::new("length").with_max_chars(1000)),
943 );
944
945 let result = builder.build();
946 let _ = result; }
949}