use crate::registry::ModelVariant;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MLConfig {
pub enabled: bool,
pub model_variant: ModelVariant,
pub threshold: f32,
pub fallback_to_heuristic: bool,
pub cache_enabled: bool,
pub cache_config: CacheSettings,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
impl Default for MLConfig {
fn default() -> Self {
Self {
enabled: false, model_variant: ModelVariant::FP16,
threshold: 0.5,
fallback_to_heuristic: true,
cache_enabled: true,
cache_config: CacheSettings::default(),
extra: HashMap::new(),
}
}
}
impl MLConfig {
pub fn new(model_variant: ModelVariant, threshold: f32) -> Self {
Self {
enabled: true,
model_variant,
threshold,
..Default::default()
}
}
pub fn production() -> Self {
Self {
enabled: true,
model_variant: ModelVariant::FP16,
threshold: 0.5,
fallback_to_heuristic: true,
cache_enabled: true,
cache_config: CacheSettings::production(),
extra: HashMap::new(),
}
}
pub fn edge() -> Self {
Self {
enabled: true,
model_variant: ModelVariant::INT8,
threshold: 0.6,
fallback_to_heuristic: true,
cache_enabled: true,
cache_config: CacheSettings::edge(),
extra: HashMap::new(),
}
}
pub fn high_accuracy() -> Self {
Self {
enabled: true,
model_variant: ModelVariant::FP32,
threshold: 0.3,
fallback_to_heuristic: false,
cache_enabled: true,
cache_config: CacheSettings::aggressive(),
extra: HashMap::new(),
}
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
pub fn validate(&self) -> Result<(), String> {
if !(0.0..=1.0).contains(&self.threshold) {
return Err(format!(
"Threshold must be between 0.0 and 1.0, got {}",
self.threshold
));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheSettings {
pub max_size: usize,
#[serde(with = "duration_serde")]
pub ttl: Duration,
}
impl Default for CacheSettings {
fn default() -> Self {
Self {
max_size: 1000,
ttl: Duration::from_secs(3600), }
}
}
impl CacheSettings {
pub fn production() -> Self {
Self {
max_size: 1000,
ttl: Duration::from_secs(3600),
}
}
pub fn edge() -> Self {
Self {
max_size: 100,
ttl: Duration::from_secs(600),
}
}
pub fn aggressive() -> Self {
Self {
max_size: 10000,
ttl: Duration::from_secs(7200),
}
}
pub fn minimal() -> Self {
Self {
max_size: 10,
ttl: Duration::from_secs(60),
}
}
pub fn disabled() -> Self {
Self {
max_size: 0,
ttl: Duration::from_secs(0),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HybridMode {
HeuristicOnly,
MLOnly,
Hybrid,
Both,
}
impl Default for HybridMode {
fn default() -> Self {
Self::Hybrid
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DetectionMethod {
#[serde(rename = "heuristic")]
Heuristic,
#[serde(rename = "ml")]
ML,
#[serde(rename = "heuristic_short_circuit")]
HeuristicShortCircuit,
#[serde(rename = "ml_fallback_to_heuristic")]
MLFallbackToHeuristic,
#[serde(rename = "hybrid_both")]
HybridBoth,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct InferenceMetrics {
pub total_calls: u64,
pub ml_calls: u64,
pub heuristic_calls: u64,
pub cache_hits: u64,
pub heuristic_short_circuits: u64,
pub total_inference_time_ms: u64,
pub ml_errors: u64,
pub fallback_count: u64,
}
impl InferenceMetrics {
pub fn cache_hit_rate(&self) -> f32 {
if self.total_calls == 0 {
0.0
} else {
self.cache_hits as f32 / self.total_calls as f32
}
}
pub fn heuristic_filter_rate(&self) -> f32 {
if self.total_calls == 0 {
0.0
} else {
self.heuristic_short_circuits as f32 / self.total_calls as f32
}
}
pub fn avg_inference_time_ms(&self) -> f32 {
if self.total_calls == 0 {
0.0
} else {
self.total_inference_time_ms as f32 / self.total_calls as f32
}
}
pub fn ml_error_rate(&self) -> f32 {
if self.ml_calls == 0 {
0.0
} else {
self.ml_errors as f32 / self.ml_calls as f32
}
}
}
mod duration_serde {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
duration.as_secs().serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let secs = u64::deserialize(deserializer)?;
Ok(Duration::from_secs(secs))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ml_config_default() {
let config = MLConfig::default();
assert!(!config.enabled);
assert_eq!(config.model_variant, ModelVariant::FP16);
assert_eq!(config.threshold, 0.5);
assert!(config.fallback_to_heuristic);
assert!(config.cache_enabled);
}
#[test]
fn test_ml_config_production() {
let config = MLConfig::production();
assert!(config.enabled);
assert_eq!(config.model_variant, ModelVariant::FP16);
assert_eq!(config.threshold, 0.5);
assert!(config.fallback_to_heuristic);
assert!(config.cache_enabled);
}
#[test]
fn test_ml_config_edge() {
let config = MLConfig::edge();
assert!(config.enabled);
assert_eq!(config.model_variant, ModelVariant::INT8);
assert_eq!(config.threshold, 0.6);
assert_eq!(config.cache_config.max_size, 100);
}
#[test]
fn test_ml_config_high_accuracy() {
let config = MLConfig::high_accuracy();
assert!(config.enabled);
assert_eq!(config.model_variant, ModelVariant::FP32);
assert_eq!(config.threshold, 0.3);
assert!(!config.fallback_to_heuristic);
}
#[test]
fn test_ml_config_disabled() {
let config = MLConfig::disabled();
assert!(!config.enabled);
}
#[test]
fn test_ml_config_validation() {
let mut config = MLConfig::default();
assert!(config.validate().is_ok());
config.threshold = 1.5;
assert!(config.validate().is_err());
config.threshold = -0.1;
assert!(config.validate().is_err());
config.threshold = 0.0;
assert!(config.validate().is_ok());
config.threshold = 1.0;
assert!(config.validate().is_ok());
}
#[test]
fn test_cache_settings_default() {
let settings = CacheSettings::default();
assert_eq!(settings.max_size, 1000);
assert_eq!(settings.ttl, Duration::from_secs(3600));
}
#[test]
fn test_cache_settings_production() {
let settings = CacheSettings::production();
assert_eq!(settings.max_size, 1000);
assert_eq!(settings.ttl, Duration::from_secs(3600));
}
#[test]
fn test_cache_settings_edge() {
let settings = CacheSettings::edge();
assert_eq!(settings.max_size, 100);
assert_eq!(settings.ttl, Duration::from_secs(600));
}
#[test]
fn test_cache_settings_aggressive() {
let settings = CacheSettings::aggressive();
assert_eq!(settings.max_size, 10000);
assert_eq!(settings.ttl, Duration::from_secs(7200));
}
#[test]
fn test_cache_settings_minimal() {
let settings = CacheSettings::minimal();
assert_eq!(settings.max_size, 10);
assert_eq!(settings.ttl, Duration::from_secs(60));
}
#[test]
fn test_cache_settings_disabled() {
let settings = CacheSettings::disabled();
assert_eq!(settings.max_size, 0);
assert_eq!(settings.ttl, Duration::from_secs(0));
}
#[test]
fn test_hybrid_mode_default() {
assert_eq!(HybridMode::default(), HybridMode::Hybrid);
}
#[test]
fn test_inference_metrics_default() {
let metrics = InferenceMetrics::default();
assert_eq!(metrics.total_calls, 0);
assert_eq!(metrics.cache_hit_rate(), 0.0);
assert_eq!(metrics.heuristic_filter_rate(), 0.0);
assert_eq!(metrics.avg_inference_time_ms(), 0.0);
assert_eq!(metrics.ml_error_rate(), 0.0);
}
#[test]
fn test_inference_metrics_calculations() {
let metrics = InferenceMetrics {
total_calls: 100,
ml_calls: 40,
heuristic_calls: 100,
cache_hits: 30,
heuristic_short_circuits: 60,
total_inference_time_ms: 5000,
ml_errors: 4,
fallback_count: 4,
};
assert_eq!(metrics.cache_hit_rate(), 0.3);
assert_eq!(metrics.heuristic_filter_rate(), 0.6);
assert_eq!(metrics.avg_inference_time_ms(), 50.0);
assert_eq!(metrics.ml_error_rate(), 0.1);
}
#[test]
fn test_ml_config_serialization() {
let config = MLConfig::production();
let json = serde_json::to_string(&config).unwrap();
let deserialized: MLConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.enabled, deserialized.enabled);
assert_eq!(config.threshold, deserialized.threshold);
assert_eq!(config.cache_config.max_size, deserialized.cache_config.max_size);
}
#[test]
fn test_detection_method_serialization() {
let method = DetectionMethod::ML;
let json = serde_json::to_string(&method).unwrap();
assert_eq!(json, "\"ml\"");
let deserialized: DetectionMethod = serde_json::from_str(&json).unwrap();
assert_eq!(method, deserialized);
}
#[test]
fn test_hybrid_mode_serialization() {
let mode = HybridMode::Hybrid;
let json = serde_json::to_string(&mode).unwrap();
assert_eq!(json, "\"hybrid\"");
let deserialized: HybridMode = serde_json::from_str(&json).unwrap();
assert_eq!(mode, deserialized);
}
}