1use crate::registry::ModelVariant;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::time::Duration;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct MLConfig {
73 pub enabled: bool,
75
76 pub model_variant: ModelVariant,
78
79 pub threshold: f32,
84
85 pub fallback_to_heuristic: bool,
87
88 pub cache_enabled: bool,
90
91 pub cache_config: CacheSettings,
93
94 #[serde(flatten)]
96 pub extra: HashMap<String, serde_json::Value>,
97}
98
99impl Default for MLConfig {
100 fn default() -> Self {
101 Self {
102 enabled: false, model_variant: ModelVariant::FP16,
104 threshold: 0.5,
105 fallback_to_heuristic: true,
106 cache_enabled: true,
107 cache_config: CacheSettings::default(),
108 extra: HashMap::new(),
109 }
110 }
111}
112
113impl MLConfig {
114 pub fn new(model_variant: ModelVariant, threshold: f32) -> Self {
116 Self {
117 enabled: true,
118 model_variant,
119 threshold,
120 ..Default::default()
121 }
122 }
123
124 pub fn production() -> Self {
131 Self {
132 enabled: true,
133 model_variant: ModelVariant::FP16,
134 threshold: 0.5,
135 fallback_to_heuristic: true,
136 cache_enabled: true,
137 cache_config: CacheSettings::production(),
138 extra: HashMap::new(),
139 }
140 }
141
142 pub fn edge() -> Self {
149 Self {
150 enabled: true,
151 model_variant: ModelVariant::INT8,
152 threshold: 0.6,
153 fallback_to_heuristic: true,
154 cache_enabled: true,
155 cache_config: CacheSettings::edge(),
156 extra: HashMap::new(),
157 }
158 }
159
160 pub fn high_accuracy() -> Self {
167 Self {
168 enabled: true,
169 model_variant: ModelVariant::FP32,
170 threshold: 0.3,
171 fallback_to_heuristic: false,
172 cache_enabled: true,
173 cache_config: CacheSettings::aggressive(),
174 extra: HashMap::new(),
175 }
176 }
177
178 pub fn disabled() -> Self {
180 Self {
181 enabled: false,
182 ..Default::default()
183 }
184 }
185
186 pub fn validate(&self) -> Result<(), String> {
188 if !(0.0..=1.0).contains(&self.threshold) {
189 return Err(format!(
190 "Threshold must be between 0.0 and 1.0, got {}",
191 self.threshold
192 ));
193 }
194 Ok(())
195 }
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct CacheSettings {
201 pub max_size: usize,
203
204 #[serde(with = "duration_serde")]
206 pub ttl: Duration,
207}
208
209impl Default for CacheSettings {
210 fn default() -> Self {
211 Self {
212 max_size: 1000,
213 ttl: Duration::from_secs(3600), }
215 }
216}
217
218impl CacheSettings {
219 pub fn production() -> Self {
223 Self {
224 max_size: 1000,
225 ttl: Duration::from_secs(3600),
226 }
227 }
228
229 pub fn edge() -> Self {
233 Self {
234 max_size: 100,
235 ttl: Duration::from_secs(600),
236 }
237 }
238
239 pub fn aggressive() -> Self {
243 Self {
244 max_size: 10000,
245 ttl: Duration::from_secs(7200),
246 }
247 }
248
249 pub fn minimal() -> Self {
253 Self {
254 max_size: 10,
255 ttl: Duration::from_secs(60),
256 }
257 }
258
259 pub fn disabled() -> Self {
261 Self {
262 max_size: 0,
263 ttl: Duration::from_secs(0),
264 }
265 }
266}
267
268#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
290#[serde(rename_all = "snake_case")]
291pub enum HybridMode {
292 HeuristicOnly,
294
295 MLOnly,
297
298 Hybrid,
300
301 Both,
303}
304
305impl Default for HybridMode {
306 fn default() -> Self {
307 Self::Hybrid
308 }
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
316pub enum DetectionMethod {
317 #[serde(rename = "heuristic")]
319 Heuristic,
320
321 #[serde(rename = "ml")]
323 ML,
324
325 #[serde(rename = "heuristic_short_circuit")]
327 HeuristicShortCircuit,
328
329 #[serde(rename = "ml_fallback_to_heuristic")]
331 MLFallbackToHeuristic,
332
333 #[serde(rename = "hybrid_both")]
335 HybridBoth,
336}
337
338#[derive(Debug, Clone, Default, Serialize, Deserialize)]
348pub struct InferenceMetrics {
349 pub total_calls: u64,
351
352 pub ml_calls: u64,
354
355 pub heuristic_calls: u64,
357
358 pub cache_hits: u64,
360
361 pub heuristic_short_circuits: u64,
363
364 pub total_inference_time_ms: u64,
366
367 pub ml_errors: u64,
369
370 pub fallback_count: u64,
372}
373
374impl InferenceMetrics {
375 pub fn cache_hit_rate(&self) -> f32 {
377 if self.total_calls == 0 {
378 0.0
379 } else {
380 self.cache_hits as f32 / self.total_calls as f32
381 }
382 }
383
384 pub fn heuristic_filter_rate(&self) -> f32 {
386 if self.total_calls == 0 {
387 0.0
388 } else {
389 self.heuristic_short_circuits as f32 / self.total_calls as f32
390 }
391 }
392
393 pub fn avg_inference_time_ms(&self) -> f32 {
395 if self.total_calls == 0 {
396 0.0
397 } else {
398 self.total_inference_time_ms as f32 / self.total_calls as f32
399 }
400 }
401
402 pub fn ml_error_rate(&self) -> f32 {
404 if self.ml_calls == 0 {
405 0.0
406 } else {
407 self.ml_errors as f32 / self.ml_calls as f32
408 }
409 }
410}
411
412mod duration_serde {
414 use serde::{Deserialize, Deserializer, Serialize, Serializer};
415 use std::time::Duration;
416
417 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
418 where
419 S: Serializer,
420 {
421 duration.as_secs().serialize(serializer)
422 }
423
424 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
425 where
426 D: Deserializer<'de>,
427 {
428 let secs = u64::deserialize(deserializer)?;
429 Ok(Duration::from_secs(secs))
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_ml_config_default() {
439 let config = MLConfig::default();
440 assert!(!config.enabled);
441 assert_eq!(config.model_variant, ModelVariant::FP16);
442 assert_eq!(config.threshold, 0.5);
443 assert!(config.fallback_to_heuristic);
444 assert!(config.cache_enabled);
445 }
446
447 #[test]
448 fn test_ml_config_production() {
449 let config = MLConfig::production();
450 assert!(config.enabled);
451 assert_eq!(config.model_variant, ModelVariant::FP16);
452 assert_eq!(config.threshold, 0.5);
453 assert!(config.fallback_to_heuristic);
454 assert!(config.cache_enabled);
455 }
456
457 #[test]
458 fn test_ml_config_edge() {
459 let config = MLConfig::edge();
460 assert!(config.enabled);
461 assert_eq!(config.model_variant, ModelVariant::INT8);
462 assert_eq!(config.threshold, 0.6);
463 assert_eq!(config.cache_config.max_size, 100);
464 }
465
466 #[test]
467 fn test_ml_config_high_accuracy() {
468 let config = MLConfig::high_accuracy();
469 assert!(config.enabled);
470 assert_eq!(config.model_variant, ModelVariant::FP32);
471 assert_eq!(config.threshold, 0.3);
472 assert!(!config.fallback_to_heuristic);
473 }
474
475 #[test]
476 fn test_ml_config_disabled() {
477 let config = MLConfig::disabled();
478 assert!(!config.enabled);
479 }
480
481 #[test]
482 fn test_ml_config_validation() {
483 let mut config = MLConfig::default();
484 assert!(config.validate().is_ok());
485
486 config.threshold = 1.5;
487 assert!(config.validate().is_err());
488
489 config.threshold = -0.1;
490 assert!(config.validate().is_err());
491
492 config.threshold = 0.0;
493 assert!(config.validate().is_ok());
494
495 config.threshold = 1.0;
496 assert!(config.validate().is_ok());
497 }
498
499 #[test]
500 fn test_cache_settings_default() {
501 let settings = CacheSettings::default();
502 assert_eq!(settings.max_size, 1000);
503 assert_eq!(settings.ttl, Duration::from_secs(3600));
504 }
505
506 #[test]
507 fn test_cache_settings_production() {
508 let settings = CacheSettings::production();
509 assert_eq!(settings.max_size, 1000);
510 assert_eq!(settings.ttl, Duration::from_secs(3600));
511 }
512
513 #[test]
514 fn test_cache_settings_edge() {
515 let settings = CacheSettings::edge();
516 assert_eq!(settings.max_size, 100);
517 assert_eq!(settings.ttl, Duration::from_secs(600));
518 }
519
520 #[test]
521 fn test_cache_settings_aggressive() {
522 let settings = CacheSettings::aggressive();
523 assert_eq!(settings.max_size, 10000);
524 assert_eq!(settings.ttl, Duration::from_secs(7200));
525 }
526
527 #[test]
528 fn test_cache_settings_minimal() {
529 let settings = CacheSettings::minimal();
530 assert_eq!(settings.max_size, 10);
531 assert_eq!(settings.ttl, Duration::from_secs(60));
532 }
533
534 #[test]
535 fn test_cache_settings_disabled() {
536 let settings = CacheSettings::disabled();
537 assert_eq!(settings.max_size, 0);
538 assert_eq!(settings.ttl, Duration::from_secs(0));
539 }
540
541 #[test]
542 fn test_hybrid_mode_default() {
543 assert_eq!(HybridMode::default(), HybridMode::Hybrid);
544 }
545
546 #[test]
547 fn test_inference_metrics_default() {
548 let metrics = InferenceMetrics::default();
549 assert_eq!(metrics.total_calls, 0);
550 assert_eq!(metrics.cache_hit_rate(), 0.0);
551 assert_eq!(metrics.heuristic_filter_rate(), 0.0);
552 assert_eq!(metrics.avg_inference_time_ms(), 0.0);
553 assert_eq!(metrics.ml_error_rate(), 0.0);
554 }
555
556 #[test]
557 fn test_inference_metrics_calculations() {
558 let metrics = InferenceMetrics {
559 total_calls: 100,
560 ml_calls: 40,
561 heuristic_calls: 100,
562 cache_hits: 30,
563 heuristic_short_circuits: 60,
564 total_inference_time_ms: 5000,
565 ml_errors: 4,
566 fallback_count: 4,
567 };
568
569 assert_eq!(metrics.cache_hit_rate(), 0.3);
570 assert_eq!(metrics.heuristic_filter_rate(), 0.6);
571 assert_eq!(metrics.avg_inference_time_ms(), 50.0);
572 assert_eq!(metrics.ml_error_rate(), 0.1);
573 }
574
575 #[test]
576 fn test_ml_config_serialization() {
577 let config = MLConfig::production();
578 let json = serde_json::to_string(&config).unwrap();
579 let deserialized: MLConfig = serde_json::from_str(&json).unwrap();
580
581 assert_eq!(config.enabled, deserialized.enabled);
582 assert_eq!(config.threshold, deserialized.threshold);
583 assert_eq!(config.cache_config.max_size, deserialized.cache_config.max_size);
584 }
585
586 #[test]
587 fn test_detection_method_serialization() {
588 let method = DetectionMethod::ML;
589 let json = serde_json::to_string(&method).unwrap();
590 assert_eq!(json, "\"ml\"");
591
592 let deserialized: DetectionMethod = serde_json::from_str(&json).unwrap();
593 assert_eq!(method, deserialized);
594 }
595
596 #[test]
597 fn test_hybrid_mode_serialization() {
598 let mode = HybridMode::Hybrid;
599 let json = serde_json::to_string(&mode).unwrap();
600 assert_eq!(json, "\"hybrid\"");
601
602 let deserialized: HybridMode = serde_json::from_str(&json).unwrap();
603 assert_eq!(mode, deserialized);
604 }
605}