llm_shield_models/
types.rs

1//! Common Types for ML Model Integration
2//!
3//! ## SPARC Phase 1: Specification
4//!
5//! This module defines common types for ML model configuration and hybrid
6//! detection mode that combines heuristic and ML approaches.
7//!
8//! ## Design Principles
9//!
10//! - **Flexibility**: Support multiple model variants (FP32, FP16, INT8)
11//! - **Graceful Degradation**: Fallback to heuristics if ML fails
12//! - **Performance**: Cache-aware configuration
13//! - **Observability**: Rich metadata for monitoring
14
15use crate::registry::ModelVariant;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::time::Duration;
19
20/// ML detection configuration for scanners
21///
22/// ## Fields
23///
24/// - `enabled`: Whether ML detection is enabled
25/// - `model_variant`: Model precision (FP32, FP16, INT8)
26/// - `threshold`: Detection threshold (0.0 to 1.0)
27/// - `fallback_to_heuristic`: Use heuristic if ML fails
28/// - `cache_enabled`: Enable result caching
29/// - `cache_config`: Cache configuration
30///
31/// ## Recommended Configurations
32///
33/// **Production (balanced)**:
34/// ```rust,ignore
35/// MLConfig {
36///     enabled: true,
37///     model_variant: ModelVariant::FP16,
38///     threshold: 0.5,
39///     fallback_to_heuristic: true,
40///     cache_enabled: true,
41///     cache_config: CacheSettings::default(),
42/// }
43/// ```
44///
45/// **High accuracy**:
46/// ```rust,ignore
47/// MLConfig {
48///     enabled: true,
49///     model_variant: ModelVariant::FP32,
50///     threshold: 0.3,  // Lower threshold = more sensitive
51///     fallback_to_heuristic: false,  // Require ML
52///     cache_enabled: true,
53///     cache_config: CacheSettings::default(),
54/// }
55/// ```
56///
57/// **Edge/Mobile**:
58/// ```rust,ignore
59/// MLConfig {
60///     enabled: true,
61///     model_variant: ModelVariant::INT8,
62///     threshold: 0.6,
63///     fallback_to_heuristic: true,
64///     cache_enabled: true,
65///     cache_config: CacheSettings {
66///         max_size: 100,  // Smaller cache
67///         ttl: Duration::from_secs(600),
68///     },
69/// }
70/// ```
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct MLConfig {
73    /// Whether ML detection is enabled
74    pub enabled: bool,
75
76    /// Model variant to use (FP32, FP16, INT8)
77    pub model_variant: ModelVariant,
78
79    /// Detection threshold (0.0 to 1.0)
80    /// - Higher = fewer false positives, more false negatives
81    /// - Lower = fewer false negatives, more false positives
82    /// - Recommended: 0.5 for balanced results
83    pub threshold: f32,
84
85    /// Use heuristic detection if ML fails or is unavailable
86    pub fallback_to_heuristic: bool,
87
88    /// Enable result caching for repeated inputs
89    pub cache_enabled: bool,
90
91    /// Cache settings
92    pub cache_config: CacheSettings,
93
94    /// Additional model-specific configuration
95    #[serde(flatten)]
96    pub extra: HashMap<String, serde_json::Value>,
97}
98
99impl Default for MLConfig {
100    fn default() -> Self {
101        Self {
102            enabled: false, // Opt-in for ML
103            model_variant: ModelVariant::FP16,
104            threshold: 0.5,
105            fallback_to_heuristic: true,
106            cache_enabled: true,
107            cache_config: CacheSettings::default(),
108            extra: HashMap::new(),
109        }
110    }
111}
112
113impl MLConfig {
114    /// Create a new ML configuration
115    pub fn new(model_variant: ModelVariant, threshold: f32) -> Self {
116        Self {
117            enabled: true,
118            model_variant,
119            threshold,
120            ..Default::default()
121        }
122    }
123
124    /// Create ML configuration for production use
125    ///
126    /// - FP16 model (balanced speed/accuracy)
127    /// - 0.5 threshold (balanced sensitivity)
128    /// - Heuristic fallback enabled
129    /// - Caching enabled with 1000 entries, 1 hour TTL
130    pub fn production() -> Self {
131        Self {
132            enabled: true,
133            model_variant: ModelVariant::FP16,
134            threshold: 0.5,
135            fallback_to_heuristic: true,
136            cache_enabled: true,
137            cache_config: CacheSettings::production(),
138            extra: HashMap::new(),
139        }
140    }
141
142    /// Create ML configuration for edge/mobile deployment
143    ///
144    /// - INT8 model (smallest size)
145    /// - 0.6 threshold (fewer false positives)
146    /// - Heuristic fallback enabled
147    /// - Smaller cache (100 entries, 10 minutes TTL)
148    pub fn edge() -> Self {
149        Self {
150            enabled: true,
151            model_variant: ModelVariant::INT8,
152            threshold: 0.6,
153            fallback_to_heuristic: true,
154            cache_enabled: true,
155            cache_config: CacheSettings::edge(),
156            extra: HashMap::new(),
157        }
158    }
159
160    /// Create ML configuration for high accuracy
161    ///
162    /// - FP32 model (highest accuracy)
163    /// - 0.3 threshold (very sensitive)
164    /// - No heuristic fallback
165    /// - Aggressive caching
166    pub fn high_accuracy() -> Self {
167        Self {
168            enabled: true,
169            model_variant: ModelVariant::FP32,
170            threshold: 0.3,
171            fallback_to_heuristic: false,
172            cache_enabled: true,
173            cache_config: CacheSettings::aggressive(),
174            extra: HashMap::new(),
175        }
176    }
177
178    /// Disable ML detection (heuristic-only mode)
179    pub fn disabled() -> Self {
180        Self {
181            enabled: false,
182            ..Default::default()
183        }
184    }
185
186    /// Validate configuration
187    pub fn validate(&self) -> Result<(), String> {
188        if !(0.0..=1.0).contains(&self.threshold) {
189            return Err(format!(
190                "Threshold must be between 0.0 and 1.0, got {}",
191                self.threshold
192            ));
193        }
194        Ok(())
195    }
196}
197
198/// Cache settings for ML result caching
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct CacheSettings {
201    /// Maximum number of cached entries (LRU eviction)
202    pub max_size: usize,
203
204    /// Time-to-live for cache entries
205    #[serde(with = "duration_serde")]
206    pub ttl: Duration,
207}
208
209impl Default for CacheSettings {
210    fn default() -> Self {
211        Self {
212            max_size: 1000,
213            ttl: Duration::from_secs(3600), // 1 hour
214        }
215    }
216}
217
218impl CacheSettings {
219    /// Production cache settings
220    /// - 1000 entries
221    /// - 1 hour TTL
222    pub fn production() -> Self {
223        Self {
224            max_size: 1000,
225            ttl: Duration::from_secs(3600),
226        }
227    }
228
229    /// Edge/mobile cache settings (smaller)
230    /// - 100 entries
231    /// - 10 minutes TTL
232    pub fn edge() -> Self {
233        Self {
234            max_size: 100,
235            ttl: Duration::from_secs(600),
236        }
237    }
238
239    /// Aggressive caching for high-traffic scenarios
240    /// - 10000 entries
241    /// - 2 hours TTL
242    pub fn aggressive() -> Self {
243        Self {
244            max_size: 10000,
245            ttl: Duration::from_secs(7200),
246        }
247    }
248
249    /// Minimal caching (for testing or memory-constrained environments)
250    /// - 10 entries
251    /// - 1 minute TTL
252    pub fn minimal() -> Self {
253        Self {
254            max_size: 10,
255            ttl: Duration::from_secs(60),
256        }
257    }
258
259    /// Disable caching
260    pub fn disabled() -> Self {
261        Self {
262            max_size: 0,
263            ttl: Duration::from_secs(0),
264        }
265    }
266}
267
268/// Hybrid detection mode
269///
270/// ## Specification
271///
272/// Hybrid mode combines fast heuristic detection with accurate ML detection:
273///
274/// 1. **Heuristic Pre-filter**: Fast pattern-based detection (0.01ms)
275///    - If obviously safe → return safe (60-70% of inputs)
276///    - If obviously malicious → return malicious (5-10% of inputs)
277///    - If ambiguous → proceed to ML (20-30% of inputs)
278///
279/// 2. **ML Detection**: Accurate but slower (50-150ms)
280///    - Only called for ambiguous cases
281///    - Results are cached
282///    - Falls back to heuristic on error (if enabled)
283///
284/// ## Performance Impact
285///
286/// - Pure heuristic: ~15,500 req/sec
287/// - Pure ML: ~150 req/sec
288/// - Hybrid: ~2,000 req/sec (10x faster than pure ML)
289#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
290#[serde(rename_all = "snake_case")]
291pub enum HybridMode {
292    /// Only use heuristic detection (no ML)
293    HeuristicOnly,
294
295    /// Only use ML detection (no heuristic pre-filter)
296    MLOnly,
297
298    /// Use heuristic pre-filter, then ML for ambiguous cases
299    Hybrid,
300
301    /// Use both and combine results (max risk score)
302    Both,
303}
304
305impl Default for HybridMode {
306    fn default() -> Self {
307        Self::Hybrid
308    }
309}
310
311/// Detection method used for a scan result
312///
313/// Tracks which method(s) were used to generate the result.
314/// Useful for monitoring and debugging.
315#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
316pub enum DetectionMethod {
317    /// Only heuristic pattern matching was used
318    #[serde(rename = "heuristic")]
319    Heuristic,
320
321    /// Only ML model inference was used
322    #[serde(rename = "ml")]
323    ML,
324
325    /// Heuristic pre-filter detected safe/malicious
326    #[serde(rename = "heuristic_short_circuit")]
327    HeuristicShortCircuit,
328
329    /// ML was attempted but failed, fell back to heuristic
330    #[serde(rename = "ml_fallback_to_heuristic")]
331    MLFallbackToHeuristic,
332
333    /// Both heuristic and ML were used, results combined
334    #[serde(rename = "hybrid_both")]
335    HybridBoth,
336}
337
338/// Inference performance metrics
339///
340/// ## Specification
341///
342/// These metrics should be collected and reported for monitoring:
343/// - Latency (p50, p95, p99)
344/// - Throughput
345/// - Cache hit rate
346/// - Heuristic filter rate
347#[derive(Debug, Clone, Default, Serialize, Deserialize)]
348pub struct InferenceMetrics {
349    /// Total inference calls
350    pub total_calls: u64,
351
352    /// ML inference calls (not cached)
353    pub ml_calls: u64,
354
355    /// Heuristic pre-filter calls
356    pub heuristic_calls: u64,
357
358    /// Cache hits
359    pub cache_hits: u64,
360
361    /// Heuristic short-circuits (didn't need ML)
362    pub heuristic_short_circuits: u64,
363
364    /// Total inference time (milliseconds)
365    pub total_inference_time_ms: u64,
366
367    /// ML inference errors
368    pub ml_errors: u64,
369
370    /// Fallback to heuristic count
371    pub fallback_count: u64,
372}
373
374impl InferenceMetrics {
375    /// Calculate cache hit rate (0.0 to 1.0)
376    pub fn cache_hit_rate(&self) -> f32 {
377        if self.total_calls == 0 {
378            0.0
379        } else {
380            self.cache_hits as f32 / self.total_calls as f32
381        }
382    }
383
384    /// Calculate heuristic filter rate (% of inputs filtered by heuristic)
385    pub fn heuristic_filter_rate(&self) -> f32 {
386        if self.total_calls == 0 {
387            0.0
388        } else {
389            self.heuristic_short_circuits as f32 / self.total_calls as f32
390        }
391    }
392
393    /// Calculate average inference time (milliseconds)
394    pub fn avg_inference_time_ms(&self) -> f32 {
395        if self.total_calls == 0 {
396            0.0
397        } else {
398            self.total_inference_time_ms as f32 / self.total_calls as f32
399        }
400    }
401
402    /// Calculate ML error rate
403    pub fn ml_error_rate(&self) -> f32 {
404        if self.ml_calls == 0 {
405            0.0
406        } else {
407            self.ml_errors as f32 / self.ml_calls as f32
408        }
409    }
410}
411
412/// Serialization helper for Duration
413mod duration_serde {
414    use serde::{Deserialize, Deserializer, Serialize, Serializer};
415    use std::time::Duration;
416
417    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
418    where
419        S: Serializer,
420    {
421        duration.as_secs().serialize(serializer)
422    }
423
424    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
425    where
426        D: Deserializer<'de>,
427    {
428        let secs = u64::deserialize(deserializer)?;
429        Ok(Duration::from_secs(secs))
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_ml_config_default() {
439        let config = MLConfig::default();
440        assert!(!config.enabled);
441        assert_eq!(config.model_variant, ModelVariant::FP16);
442        assert_eq!(config.threshold, 0.5);
443        assert!(config.fallback_to_heuristic);
444        assert!(config.cache_enabled);
445    }
446
447    #[test]
448    fn test_ml_config_production() {
449        let config = MLConfig::production();
450        assert!(config.enabled);
451        assert_eq!(config.model_variant, ModelVariant::FP16);
452        assert_eq!(config.threshold, 0.5);
453        assert!(config.fallback_to_heuristic);
454        assert!(config.cache_enabled);
455    }
456
457    #[test]
458    fn test_ml_config_edge() {
459        let config = MLConfig::edge();
460        assert!(config.enabled);
461        assert_eq!(config.model_variant, ModelVariant::INT8);
462        assert_eq!(config.threshold, 0.6);
463        assert_eq!(config.cache_config.max_size, 100);
464    }
465
466    #[test]
467    fn test_ml_config_high_accuracy() {
468        let config = MLConfig::high_accuracy();
469        assert!(config.enabled);
470        assert_eq!(config.model_variant, ModelVariant::FP32);
471        assert_eq!(config.threshold, 0.3);
472        assert!(!config.fallback_to_heuristic);
473    }
474
475    #[test]
476    fn test_ml_config_disabled() {
477        let config = MLConfig::disabled();
478        assert!(!config.enabled);
479    }
480
481    #[test]
482    fn test_ml_config_validation() {
483        let mut config = MLConfig::default();
484        assert!(config.validate().is_ok());
485
486        config.threshold = 1.5;
487        assert!(config.validate().is_err());
488
489        config.threshold = -0.1;
490        assert!(config.validate().is_err());
491
492        config.threshold = 0.0;
493        assert!(config.validate().is_ok());
494
495        config.threshold = 1.0;
496        assert!(config.validate().is_ok());
497    }
498
499    #[test]
500    fn test_cache_settings_default() {
501        let settings = CacheSettings::default();
502        assert_eq!(settings.max_size, 1000);
503        assert_eq!(settings.ttl, Duration::from_secs(3600));
504    }
505
506    #[test]
507    fn test_cache_settings_production() {
508        let settings = CacheSettings::production();
509        assert_eq!(settings.max_size, 1000);
510        assert_eq!(settings.ttl, Duration::from_secs(3600));
511    }
512
513    #[test]
514    fn test_cache_settings_edge() {
515        let settings = CacheSettings::edge();
516        assert_eq!(settings.max_size, 100);
517        assert_eq!(settings.ttl, Duration::from_secs(600));
518    }
519
520    #[test]
521    fn test_cache_settings_aggressive() {
522        let settings = CacheSettings::aggressive();
523        assert_eq!(settings.max_size, 10000);
524        assert_eq!(settings.ttl, Duration::from_secs(7200));
525    }
526
527    #[test]
528    fn test_cache_settings_minimal() {
529        let settings = CacheSettings::minimal();
530        assert_eq!(settings.max_size, 10);
531        assert_eq!(settings.ttl, Duration::from_secs(60));
532    }
533
534    #[test]
535    fn test_cache_settings_disabled() {
536        let settings = CacheSettings::disabled();
537        assert_eq!(settings.max_size, 0);
538        assert_eq!(settings.ttl, Duration::from_secs(0));
539    }
540
541    #[test]
542    fn test_hybrid_mode_default() {
543        assert_eq!(HybridMode::default(), HybridMode::Hybrid);
544    }
545
546    #[test]
547    fn test_inference_metrics_default() {
548        let metrics = InferenceMetrics::default();
549        assert_eq!(metrics.total_calls, 0);
550        assert_eq!(metrics.cache_hit_rate(), 0.0);
551        assert_eq!(metrics.heuristic_filter_rate(), 0.0);
552        assert_eq!(metrics.avg_inference_time_ms(), 0.0);
553        assert_eq!(metrics.ml_error_rate(), 0.0);
554    }
555
556    #[test]
557    fn test_inference_metrics_calculations() {
558        let metrics = InferenceMetrics {
559            total_calls: 100,
560            ml_calls: 40,
561            heuristic_calls: 100,
562            cache_hits: 30,
563            heuristic_short_circuits: 60,
564            total_inference_time_ms: 5000,
565            ml_errors: 4,
566            fallback_count: 4,
567        };
568
569        assert_eq!(metrics.cache_hit_rate(), 0.3);
570        assert_eq!(metrics.heuristic_filter_rate(), 0.6);
571        assert_eq!(metrics.avg_inference_time_ms(), 50.0);
572        assert_eq!(metrics.ml_error_rate(), 0.1);
573    }
574
575    #[test]
576    fn test_ml_config_serialization() {
577        let config = MLConfig::production();
578        let json = serde_json::to_string(&config).unwrap();
579        let deserialized: MLConfig = serde_json::from_str(&json).unwrap();
580
581        assert_eq!(config.enabled, deserialized.enabled);
582        assert_eq!(config.threshold, deserialized.threshold);
583        assert_eq!(config.cache_config.max_size, deserialized.cache_config.max_size);
584    }
585
586    #[test]
587    fn test_detection_method_serialization() {
588        let method = DetectionMethod::ML;
589        let json = serde_json::to_string(&method).unwrap();
590        assert_eq!(json, "\"ml\"");
591
592        let deserialized: DetectionMethod = serde_json::from_str(&json).unwrap();
593        assert_eq!(method, deserialized);
594    }
595
596    #[test]
597    fn test_hybrid_mode_serialization() {
598        let mode = HybridMode::Hybrid;
599        let json = serde_json::to_string(&mode).unwrap();
600        assert_eq!(json, "\"hybrid\"");
601
602        let deserialized: HybridMode = serde_json::from_str(&json).unwrap();
603        assert_eq!(mode, deserialized);
604    }
605}