Skip to main content

oxideshield_guard/
multilayer.rs

1//! Multi-Layer Defense Pipeline
2//!
3//! Implements a defense-in-depth architecture inspired by PromptGuard's 4-layer approach.
4//!
5//! ## Architecture
6//!
7//! ```text
8//! Layer 1: Fast Regex (PatternGuard)     → <1ms, ~70% detection
9//! Layer 2: Perplexity Analysis           → <5ms, +10% detection
10//! Layer 3: Semantic/ML (if enabled)      → <25ms, +15% detection
11//! Layer 4: PII/Toxicity Filters          → <10ms, comprehensive
12//! ```
13//!
14//! ## Research References
15//!
16//! - [PromptGuard](https://www.nature.com/articles/s41598-025-31086-y) - Nature Scientific Reports, 2025
17//!   4-layer defense: regex + MiniBERT + semantic + adaptive. F1=0.91, 67% injection reduction
18//! - [The Attacker Moves Second](https://simonwillison.net/2025/Nov/2/new-prompt-injection-papers/)
19//!   12 defenses bypassed at >90% with adaptive attacks - validates need for defense-in-depth
20//!
21//! ## Example
22//!
23//! ```rust,ignore
24//! use oxideshield_guard::multilayer::{MultiLayerDefense, AggregationStrategy, LayerConfig};
25//!
26//! let defense = MultiLayerDefense::builder("defense")
27//!     .add_layer(LayerConfig::regex())
28//!     .add_layer(LayerConfig::perplexity())
29//!     .add_layer(LayerConfig::pii())
30//!     .with_aggregation(AggregationStrategy::FailFast)
31//!     .build();
32//!
33//! let result = defense.check("user input");
34//! ```
35
36use std::collections::HashMap;
37use std::time::{Duration, Instant};
38
39use oxide_license::{require_feature_sync, Feature, LicenseError};
40use serde::{Deserialize, Serialize};
41use tracing::{debug, info, instrument, warn};
42
43use crate::guard::{Guard, GuardAction, GuardCheckResult};
44use oxideshield_core::Match;
45
46/// Aggregation strategy for combining guard results
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48#[serde(rename_all = "snake_case")]
49#[derive(Default)]
50pub enum AggregationStrategy {
51    /// Stop at first Block action (fastest)
52    #[default]
53    FailFast,
54    /// All guards must block to trigger blocking
55    Unanimous,
56    /// More than 50% of guards must detect to block
57    Majority,
58    /// Weighted voting based on guard confidence/priority
59    Weighted,
60    /// Run all guards and return combined results
61    Comprehensive,
62}
63
64/// Configuration for a single layer
65#[derive(Debug, Clone)]
66pub struct LayerConfig {
67    /// Layer name
68    pub name: String,
69    /// Layer type identifier
70    pub layer_type: LayerType,
71    /// Weight for weighted aggregation (1.0 = normal)
72    pub weight: f32,
73    /// Whether this layer is enabled
74    pub enabled: bool,
75    /// Timeout for this layer
76    pub timeout: Duration,
77}
78
79/// Types of layers available
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
81#[serde(rename_all = "snake_case")]
82pub enum LayerType {
83    /// Pattern-based detection (fastest)
84    Regex,
85    /// Perplexity/entropy analysis
86    Perplexity,
87    /// PII detection and redaction
88    PII,
89    /// Toxicity content moderation
90    Toxicity,
91    /// ML-based classification (requires semantic feature)
92    MLClassifier,
93    /// Semantic similarity (requires semantic feature)
94    Semantic,
95    /// Custom guard
96    Custom,
97}
98
99impl LayerConfig {
100    /// Create a regex layer config
101    pub fn regex() -> Self {
102        Self {
103            name: "regex".to_string(),
104            layer_type: LayerType::Regex,
105            weight: 1.0,
106            enabled: true,
107            timeout: Duration::from_millis(10),
108        }
109    }
110
111    /// Create a perplexity layer config
112    pub fn perplexity() -> Self {
113        Self {
114            name: "perplexity".to_string(),
115            layer_type: LayerType::Perplexity,
116            weight: 1.0,
117            enabled: true,
118            timeout: Duration::from_millis(50),
119        }
120    }
121
122    /// Create a PII layer config
123    pub fn pii() -> Self {
124        Self {
125            name: "pii".to_string(),
126            layer_type: LayerType::PII,
127            weight: 1.0,
128            enabled: true,
129            timeout: Duration::from_millis(50),
130        }
131    }
132
133    /// Create a toxicity layer config
134    pub fn toxicity() -> Self {
135        Self {
136            name: "toxicity".to_string(),
137            layer_type: LayerType::Toxicity,
138            weight: 1.0,
139            enabled: true,
140            timeout: Duration::from_millis(50),
141        }
142    }
143
144    /// Create an ML classifier layer config
145    pub fn ml_classifier() -> Self {
146        Self {
147            name: "ml_classifier".to_string(),
148            layer_type: LayerType::MLClassifier,
149            weight: 1.5, // Higher weight for ML-based detection
150            enabled: true,
151            timeout: Duration::from_millis(100),
152        }
153    }
154
155    /// Create a semantic layer config
156    pub fn semantic() -> Self {
157        Self {
158            name: "semantic".to_string(),
159            layer_type: LayerType::Semantic,
160            weight: 1.5, // Higher weight for semantic detection
161            enabled: true,
162            timeout: Duration::from_millis(100),
163        }
164    }
165
166    /// Set the layer weight
167    pub fn with_weight(mut self, weight: f32) -> Self {
168        self.weight = weight.max(0.0);
169        self
170    }
171
172    /// Set enabled state
173    pub fn with_enabled(mut self, enabled: bool) -> Self {
174        self.enabled = enabled;
175        self
176    }
177
178    /// Set timeout
179    pub fn with_timeout(mut self, timeout: Duration) -> Self {
180        self.timeout = timeout;
181        self
182    }
183}
184
185/// Result from a single layer
186#[derive(Debug, Clone)]
187pub struct LayerResult {
188    /// Layer name
189    pub layer_name: String,
190    /// Layer type
191    pub layer_type: LayerType,
192    /// Guard result
193    pub result: GuardCheckResult,
194    /// Execution time
195    pub duration: Duration,
196    /// Layer weight
197    pub weight: f32,
198}
199
200/// Result from the multi-layer defense
201#[derive(Debug, Clone)]
202pub struct MultiLayerResult {
203    /// Overall pass/fail
204    pub passed: bool,
205    /// Final action to take
206    pub action: GuardAction,
207    /// Results from each layer
208    pub layer_results: Vec<LayerResult>,
209    /// Combined matches from all layers
210    pub all_matches: Vec<Match>,
211    /// Total execution time
212    pub total_duration: Duration,
213    /// Aggregation strategy used
214    pub strategy: AggregationStrategy,
215    /// Summary message
216    pub summary: String,
217}
218
219/// Telemetry data for monitoring
220///
221/// Uses interior mutability for thread-safe updates during checks.
222#[derive(Debug, Default)]
223pub struct DefenseTelemetry {
224    /// Total checks performed
225    total_checks: std::sync::atomic::AtomicU64,
226    /// Checks that passed
227    passed_checks: std::sync::atomic::AtomicU64,
228    /// Checks that blocked
229    blocked_checks: std::sync::atomic::AtomicU64,
230    /// Per-layer statistics
231    layer_stats: parking_lot::RwLock<HashMap<String, LayerStats>>,
232}
233
234impl DefenseTelemetry {
235    /// Create a new telemetry collector
236    pub fn new() -> Self {
237        Self::default()
238    }
239
240    /// Record a check result
241    pub fn record_check(&self, passed: bool, layer_results: &[LayerResult]) {
242        use std::sync::atomic::Ordering;
243
244        self.total_checks.fetch_add(1, Ordering::Relaxed);
245        if passed {
246            self.passed_checks.fetch_add(1, Ordering::Relaxed);
247        } else {
248            self.blocked_checks.fetch_add(1, Ordering::Relaxed);
249        }
250
251        // Update per-layer stats
252        let mut stats = self.layer_stats.write();
253        for layer_result in layer_results {
254            let layer_stats = stats.entry(layer_result.layer_name.clone()).or_default();
255            layer_stats.record(layer_result);
256        }
257    }
258
259    /// Get total checks
260    pub fn total_checks(&self) -> u64 {
261        self.total_checks.load(std::sync::atomic::Ordering::Relaxed)
262    }
263
264    /// Get passed checks
265    pub fn passed_checks(&self) -> u64 {
266        self.passed_checks
267            .load(std::sync::atomic::Ordering::Relaxed)
268    }
269
270    /// Get blocked checks
271    pub fn blocked_checks(&self) -> u64 {
272        self.blocked_checks
273            .load(std::sync::atomic::Ordering::Relaxed)
274    }
275
276    /// Get block rate
277    pub fn block_rate(&self) -> f64 {
278        let total = self.total_checks();
279        if total == 0 {
280            0.0
281        } else {
282            self.blocked_checks() as f64 / total as f64
283        }
284    }
285
286    /// Get per-layer statistics
287    pub fn layer_stats(&self) -> HashMap<String, LayerStats> {
288        self.layer_stats.read().clone()
289    }
290
291    /// Reset all telemetry
292    pub fn reset(&self) {
293        use std::sync::atomic::Ordering;
294        self.total_checks.store(0, Ordering::Relaxed);
295        self.passed_checks.store(0, Ordering::Relaxed);
296        self.blocked_checks.store(0, Ordering::Relaxed);
297        self.layer_stats.write().clear();
298    }
299}
300
301impl Clone for DefenseTelemetry {
302    fn clone(&self) -> Self {
303        use std::sync::atomic::Ordering;
304        Self {
305            total_checks: std::sync::atomic::AtomicU64::new(
306                self.total_checks.load(Ordering::Relaxed),
307            ),
308            passed_checks: std::sync::atomic::AtomicU64::new(
309                self.passed_checks.load(Ordering::Relaxed),
310            ),
311            blocked_checks: std::sync::atomic::AtomicU64::new(
312                self.blocked_checks.load(Ordering::Relaxed),
313            ),
314            layer_stats: parking_lot::RwLock::new(self.layer_stats.read().clone()),
315        }
316    }
317}
318
319/// Statistics for a single layer
320#[derive(Debug, Clone, Default)]
321pub struct LayerStats {
322    /// Total checks by this layer
323    pub checks: u64,
324    /// Detections by this layer
325    pub detections: u64,
326    /// Total duration in milliseconds (for computing average)
327    total_duration_ms: f64,
328    /// Average duration
329    pub avg_duration_ms: f64,
330    /// Detection rate
331    pub detection_rate: f64,
332}
333
334impl LayerStats {
335    /// Record a layer result
336    pub fn record(&mut self, result: &LayerResult) {
337        self.checks += 1;
338        if !result.result.passed {
339            self.detections += 1;
340        }
341        self.total_duration_ms += result.duration.as_secs_f64() * 1000.0;
342        self.avg_duration_ms = self.total_duration_ms / self.checks as f64;
343        self.detection_rate = self.detections as f64 / self.checks as f64;
344    }
345}
346
347/// Multi-layer defense system
348///
349/// Orchestrates multiple guards in a layered defense architecture.
350pub struct MultiLayerDefense {
351    name: String,
352    /// Layer configurations
353    layers: Vec<LayerConfig>,
354    /// Actual guard instances
355    guards: Vec<Box<dyn Guard>>,
356    /// Aggregation strategy
357    strategy: AggregationStrategy,
358    /// Telemetry collector for defense-level metrics
359    telemetry: Option<DefenseTelemetry>,
360}
361
362impl MultiLayerDefense {
363    /// Create a new builder
364    pub fn builder(name: impl Into<String>) -> MultiLayerDefenseBuilder {
365        MultiLayerDefenseBuilder::new(name)
366    }
367
368    /// Check content through all layers
369    #[instrument(skip(self, content), fields(defense = %self.name, content_len = content.len()))]
370    pub fn check(&self, content: &str) -> MultiLayerResult {
371        let start = Instant::now();
372        let mut layer_results = Vec::new();
373        let mut all_matches = Vec::new();
374        let mut blocked_count = 0;
375        let mut total_weight = 0.0f32;
376        let mut blocked_weight = 0.0f32;
377
378        for (i, (layer_config, guard)) in self.layers.iter().zip(self.guards.iter()).enumerate() {
379            if !layer_config.enabled {
380                continue;
381            }
382
383            let layer_start = Instant::now();
384            let result = guard.check(content);
385            let duration = layer_start.elapsed();
386
387            debug!(
388                layer = %layer_config.name,
389                passed = result.passed,
390                duration_ms = %duration.as_millis(),
391                "Layer check complete"
392            );
393
394            let layer_result = LayerResult {
395                layer_name: layer_config.name.clone(),
396                layer_type: layer_config.layer_type,
397                result: result.clone(),
398                duration,
399                weight: layer_config.weight,
400            };
401
402            if !result.passed {
403                blocked_count += 1;
404                blocked_weight += layer_config.weight;
405                all_matches.extend(result.matches.clone());
406            }
407            total_weight += layer_config.weight;
408
409            layer_results.push(layer_result);
410
411            // Check for early termination based on strategy
412            if self.should_terminate_early(&result, blocked_count, i) {
413                break;
414            }
415        }
416
417        let total_duration = start.elapsed();
418
419        // Determine final result based on aggregation strategy
420        let (passed, action) =
421            self.aggregate_results(&layer_results, blocked_count, blocked_weight, total_weight);
422
423        let summary = self.build_summary(&layer_results, passed, total_duration);
424
425        info!(
426            passed = passed,
427            blocked_count = blocked_count,
428            layers_checked = layer_results.len(),
429            total_duration_ms = %total_duration.as_millis(),
430            "Multi-layer defense complete"
431        );
432
433        // Record telemetry if enabled
434        if let Some(ref telemetry) = self.telemetry {
435            telemetry.record_check(passed, &layer_results);
436        }
437
438        MultiLayerResult {
439            passed,
440            action,
441            layer_results,
442            all_matches,
443            total_duration,
444            strategy: self.strategy,
445            summary,
446        }
447    }
448
449    /// Check if we should terminate early based on strategy
450    fn should_terminate_early(
451        &self,
452        result: &GuardCheckResult,
453        blocked_count: usize,
454        layer_index: usize,
455    ) -> bool {
456        match self.strategy {
457            AggregationStrategy::FailFast => !result.passed && result.action == GuardAction::Block,
458            AggregationStrategy::Unanimous => {
459                // Continue until all checked
460                false
461            }
462            AggregationStrategy::Majority => {
463                // Early termination if we already have enough blocks for majority
464                // or if remaining layers can't change the outcome
465                let total_layers = self.layers.iter().filter(|l| l.enabled).count();
466                let layers_checked = layer_index + 1;
467                let layers_remaining = total_layers.saturating_sub(layers_checked);
468                let majority_threshold = (total_layers / 2) + 1;
469
470                // If we already have majority blocks, we can terminate
471                if blocked_count >= majority_threshold {
472                    debug!(
473                        blocked_count = blocked_count,
474                        threshold = majority_threshold,
475                        "Early termination: majority already blocked"
476                    );
477                    true
478                }
479                // If even all remaining layers blocking can't reach majority, terminate
480                else if blocked_count + layers_remaining < majority_threshold {
481                    debug!(
482                        blocked_count = blocked_count,
483                        remaining = layers_remaining,
484                        threshold = majority_threshold,
485                        "Early termination: majority impossible"
486                    );
487                    true
488                } else {
489                    false
490                }
491            }
492            AggregationStrategy::Weighted => {
493                // Run all to get full weights
494                false
495            }
496            AggregationStrategy::Comprehensive => {
497                // Always run all
498                false
499            }
500        }
501    }
502
503    /// Aggregate results based on strategy
504    fn aggregate_results(
505        &self,
506        layer_results: &[LayerResult],
507        blocked_count: usize,
508        blocked_weight: f32,
509        total_weight: f32,
510    ) -> (bool, GuardAction) {
511        if layer_results.is_empty() {
512            return (true, GuardAction::Allow);
513        }
514
515        match self.strategy {
516            AggregationStrategy::FailFast => {
517                // Any block = block
518                if blocked_count > 0 {
519                    let action = layer_results
520                        .iter()
521                        .filter(|r| !r.result.passed)
522                        .map(|r| r.result.action)
523                        .max_by_key(|a| match a {
524                            GuardAction::Block => 5,
525                            GuardAction::Alert => 4,
526                            GuardAction::Sanitize => 3,
527                            GuardAction::Suggest => 2,
528                            GuardAction::Log => 1,
529                            GuardAction::Allow => 0,
530                        })
531                        .unwrap_or(GuardAction::Block);
532                    (false, action)
533                } else {
534                    (true, GuardAction::Allow)
535                }
536            }
537            AggregationStrategy::Unanimous => {
538                // All must block to block
539                let total = layer_results.len();
540                if blocked_count == total && total > 0 {
541                    (false, GuardAction::Block)
542                } else {
543                    (true, GuardAction::Allow)
544                }
545            }
546            AggregationStrategy::Majority => {
547                // >50% must block
548                let total = layer_results.len();
549                if blocked_count * 2 > total {
550                    (false, GuardAction::Block)
551                } else {
552                    (true, GuardAction::Allow)
553                }
554            }
555            AggregationStrategy::Weighted => {
556                // Weighted voting
557                if total_weight > 0.0 && blocked_weight / total_weight > 0.5 {
558                    (false, GuardAction::Block)
559                } else {
560                    (true, GuardAction::Allow)
561                }
562            }
563            AggregationStrategy::Comprehensive => {
564                // Report all, block if any blocked
565                if blocked_count > 0 {
566                    (false, GuardAction::Block)
567                } else {
568                    (true, GuardAction::Allow)
569                }
570            }
571        }
572    }
573
574    /// Build a summary message
575    fn build_summary(
576        &self,
577        layer_results: &[LayerResult],
578        passed: bool,
579        duration: Duration,
580    ) -> String {
581        let triggered: Vec<&str> = layer_results
582            .iter()
583            .filter(|r| !r.result.passed)
584            .map(|r| r.layer_name.as_str())
585            .collect();
586
587        if passed {
588            format!(
589                "Passed all {} layers in {}ms",
590                layer_results.len(),
591                duration.as_millis()
592            )
593        } else {
594            format!(
595                "Blocked by {} of {} layers ({}) in {}ms",
596                triggered.len(),
597                layer_results.len(),
598                triggered.join(", "),
599                duration.as_millis()
600            )
601        }
602    }
603
604    /// Get the name
605    pub fn name(&self) -> &str {
606        &self.name
607    }
608
609    /// Get layer configurations
610    pub fn layers(&self) -> &[LayerConfig] {
611        &self.layers
612    }
613
614    /// Get the aggregation strategy
615    pub fn strategy(&self) -> AggregationStrategy {
616        self.strategy
617    }
618
619    /// Get telemetry data (if enabled)
620    pub fn telemetry(&self) -> Option<&DefenseTelemetry> {
621        self.telemetry.as_ref()
622    }
623
624    /// Check if telemetry is enabled
625    pub fn has_telemetry(&self) -> bool {
626        self.telemetry.is_some()
627    }
628}
629
630/// Builder for MultiLayerDefense
631pub struct MultiLayerDefenseBuilder {
632    name: String,
633    layers: Vec<LayerConfig>,
634    guards: Vec<Box<dyn Guard>>,
635    strategy: AggregationStrategy,
636    enable_telemetry: bool,
637}
638
639impl MultiLayerDefenseBuilder {
640    /// Create a new builder
641    pub fn new(name: impl Into<String>) -> Self {
642        Self {
643            name: name.into(),
644            layers: Vec::new(),
645            guards: Vec::new(),
646            strategy: AggregationStrategy::FailFast,
647            enable_telemetry: false,
648        }
649    }
650
651    /// Add a layer with a guard
652    pub fn add_guard(mut self, config: LayerConfig, guard: Box<dyn Guard>) -> Self {
653        self.layers.push(config);
654        self.guards.push(guard);
655        self
656    }
657
658    /// Set the aggregation strategy
659    pub fn with_strategy(mut self, strategy: AggregationStrategy) -> Self {
660        self.strategy = strategy;
661        self
662    }
663
664    /// Enable telemetry collection
665    pub fn with_telemetry(mut self, enabled: bool) -> Self {
666        self.enable_telemetry = enabled;
667        self
668    }
669
670    /// Build the defense with license validation.
671    ///
672    /// # License Requirement
673    ///
674    /// MultiLayerDefense requires a Professional or Enterprise license.
675    /// Returns an error if the license requirement is not met.
676    pub fn build(self) -> Result<MultiLayerDefense, LicenseError> {
677        require_feature_sync(Feature::MultiLayerDefense)?;
678        Ok(self.build_unchecked())
679    }
680
681    /// Build the defense without license validation.
682    ///
683    /// Restricted to crate-internal use.
684    pub(crate) fn build_unchecked(self) -> MultiLayerDefense {
685        MultiLayerDefense {
686            name: self.name,
687            layers: self.layers,
688            guards: self.guards,
689            strategy: self.strategy,
690            telemetry: if self.enable_telemetry {
691                Some(DefenseTelemetry::default())
692            } else {
693                None
694            },
695        }
696    }
697}
698
699#[cfg(test)]
700mod tests {
701    use super::*;
702    use crate::guard::LengthGuard;
703    use crate::guards::PerplexityGuard;
704
705    // Note: Tests use build_unchecked to bypass license validation since we're
706    // testing guard functionality, not license enforcement.
707
708    #[test]
709    fn test_layer_config_defaults() {
710        let regex = LayerConfig::regex();
711        assert_eq!(regex.layer_type, LayerType::Regex);
712        assert!(regex.enabled);
713
714        let perplexity = LayerConfig::perplexity();
715        assert_eq!(perplexity.layer_type, LayerType::Perplexity);
716    }
717
718    #[test]
719    fn test_multilayer_builder() {
720        let defense = MultiLayerDefense::builder("test")
721            .add_guard(
722                LayerConfig::regex(),
723                Box::new(LengthGuard::new("length").with_max_chars(1000)),
724            )
725            .with_strategy(AggregationStrategy::FailFast)
726            .build_unchecked();
727
728        assert_eq!(defense.name(), "test");
729        assert_eq!(defense.strategy(), AggregationStrategy::FailFast);
730        assert_eq!(defense.layers().len(), 1);
731    }
732
733    #[test]
734    fn test_multilayer_pass() {
735        let defense = MultiLayerDefense::builder("test")
736            .add_guard(
737                LayerConfig::regex(),
738                Box::new(LengthGuard::new("length").with_max_chars(1000)),
739            )
740            .build_unchecked();
741
742        let result = defense.check("Short text");
743        assert!(result.passed);
744        assert_eq!(result.action, GuardAction::Allow);
745    }
746
747    #[test]
748    fn test_multilayer_fail_fast() {
749        let defense = MultiLayerDefense::builder("test")
750            .add_guard(
751                LayerConfig::regex(),
752                Box::new(LengthGuard::new("length").with_max_chars(5)),
753            )
754            .add_guard(
755                LayerConfig::perplexity(),
756                Box::new(PerplexityGuard::new("perplexity")),
757            )
758            .with_strategy(AggregationStrategy::FailFast)
759            .build_unchecked();
760
761        let result = defense.check("This is too long");
762        assert!(!result.passed);
763        // With FailFast, should stop after first failure
764        assert_eq!(result.layer_results.len(), 1);
765    }
766
767    #[test]
768    fn test_multilayer_comprehensive() {
769        let defense = MultiLayerDefense::builder("test")
770            .add_guard(
771                LayerConfig::regex(),
772                Box::new(LengthGuard::new("length").with_max_chars(5)),
773            )
774            .add_guard(
775                LayerConfig::perplexity(),
776                Box::new(PerplexityGuard::new("perplexity")),
777            )
778            .with_strategy(AggregationStrategy::Comprehensive)
779            .build_unchecked();
780
781        let result = defense.check("This is too long");
782        assert!(!result.passed);
783        // With Comprehensive, should run all layers
784        assert_eq!(result.layer_results.len(), 2);
785    }
786
787    #[test]
788    fn test_aggregation_unanimous() {
789        let defense = MultiLayerDefense::builder("test")
790            .add_guard(
791                LayerConfig::regex(),
792                Box::new(LengthGuard::new("length").with_max_chars(5)), // Will fail
793            )
794            .add_guard(
795                LayerConfig::perplexity(),
796                Box::new(LengthGuard::new("length2").with_max_chars(1000)), // Will pass
797            )
798            .with_strategy(AggregationStrategy::Unanimous)
799            .build_unchecked();
800
801        let result = defense.check("Medium text");
802        // Unanimous requires all to block, one passes so overall passes
803        assert!(result.passed);
804    }
805
806    #[test]
807    fn test_aggregation_majority() {
808        let defense = MultiLayerDefense::builder("test")
809            .add_guard(
810                LayerConfig::regex().with_weight(1.0),
811                Box::new(LengthGuard::new("l1").with_max_chars(5)), // Fails
812            )
813            .add_guard(
814                LayerConfig::regex().with_weight(1.0),
815                Box::new(LengthGuard::new("l2").with_max_chars(5)), // Fails
816            )
817            .add_guard(
818                LayerConfig::regex().with_weight(1.0),
819                Box::new(LengthGuard::new("l3").with_max_chars(1000)), // Passes
820            )
821            .with_strategy(AggregationStrategy::Majority)
822            .build_unchecked();
823
824        let result = defense.check("Medium length text");
825        // 2/3 blocked = majority, should block
826        assert!(!result.passed);
827    }
828
829    #[test]
830    fn test_layer_results_contain_timing() {
831        let defense = MultiLayerDefense::builder("test")
832            .add_guard(
833                LayerConfig::regex(),
834                Box::new(LengthGuard::new("length").with_max_chars(1000)),
835            )
836            .build_unchecked();
837
838        let result = defense.check("Test");
839        assert!(!result.layer_results.is_empty());
840        assert!(result.layer_results[0].duration.as_nanos() > 0);
841        assert!(result.total_duration.as_nanos() > 0);
842    }
843
844    #[test]
845    fn test_telemetry_recording() {
846        let defense = MultiLayerDefense::builder("telemetry_test")
847            .add_guard(
848                LayerConfig::regex(), // Layer name is "regex"
849                Box::new(LengthGuard::new("length").with_max_chars(1000)),
850            )
851            .add_guard(
852                LayerConfig::perplexity(), // Layer name is "perplexity"
853                Box::new(PerplexityGuard::new("perplexity")),
854            )
855            .with_telemetry(true)
856            .build_unchecked();
857
858        // Verify telemetry is enabled
859        assert!(defense.has_telemetry());
860
861        // Run some checks
862        defense.check("Short text");
863        defense.check("Another short text");
864        defense.check("x".repeat(2000).as_str()); // This will fail length check
865
866        // Verify telemetry recorded
867        let telemetry = defense.telemetry().unwrap();
868        assert_eq!(telemetry.total_checks(), 3);
869        assert_eq!(telemetry.passed_checks(), 2);
870        assert_eq!(telemetry.blocked_checks(), 1);
871        assert!(telemetry.block_rate() > 0.0);
872
873        // Verify per-layer stats (layer name comes from LayerConfig, not guard)
874        let layer_stats = telemetry.layer_stats();
875        assert!(
876            layer_stats.contains_key("regex"),
877            "Expected 'regex' layer stats"
878        );
879        let regex_stats = &layer_stats["regex"];
880        assert_eq!(regex_stats.checks, 3);
881        assert_eq!(regex_stats.detections, 1);
882        assert!(regex_stats.avg_duration_ms > 0.0);
883    }
884
885    #[test]
886    fn test_telemetry_disabled_by_default() {
887        let defense = MultiLayerDefense::builder("no_telemetry")
888            .add_guard(
889                LayerConfig::regex(),
890                Box::new(LengthGuard::new("length").with_max_chars(1000)),
891            )
892            .build_unchecked();
893
894        assert!(!defense.has_telemetry());
895        assert!(defense.telemetry().is_none());
896    }
897
898    #[test]
899    fn test_majority_early_termination() {
900        // Create 5 layers where 3 will block immediately
901        let defense = MultiLayerDefense::builder("majority_early")
902            .add_guard(
903                LayerConfig::regex(),
904                Box::new(LengthGuard::new("l1").with_max_chars(5)), // Blocks
905            )
906            .add_guard(
907                LayerConfig::regex(),
908                Box::new(LengthGuard::new("l2").with_max_chars(5)), // Blocks
909            )
910            .add_guard(
911                LayerConfig::regex(),
912                Box::new(LengthGuard::new("l3").with_max_chars(5)), // Blocks
913            )
914            .add_guard(
915                LayerConfig::regex(),
916                Box::new(LengthGuard::new("l4").with_max_chars(1000)), // Would pass
917            )
918            .add_guard(
919                LayerConfig::regex(),
920                Box::new(LengthGuard::new("l5").with_max_chars(1000)), // Would pass
921            )
922            .with_strategy(AggregationStrategy::Majority)
923            .build_unchecked();
924
925        let result = defense.check("This text is longer than 5 chars");
926        assert!(!result.passed);
927
928        // With 5 layers, majority is 3. After 3 blocks, we should terminate early
929        // So layer_results should have 3 entries, not 5
930        assert_eq!(
931            result.layer_results.len(),
932            3,
933            "Should terminate after reaching majority"
934        );
935    }
936
937    #[test]
938    fn test_license_check_required() {
939        // This test verifies that without a license, build() returns an error
940        let builder = MultiLayerDefense::builder("test").add_guard(
941            LayerConfig::regex(),
942            Box::new(LengthGuard::new("length").with_max_chars(1000)),
943        );
944
945        let result = builder.build();
946        // The result type should be Result<MultiLayerDefense, LicenseError>
947        let _ = result; // Suppress unused warning
948    }
949}