Skip to main content

oxigdal_ml/
monitoring.rs

1//! Model monitoring and performance tracking
2//!
3//! This module provides comprehensive monitoring capabilities for ML models
4//! in production, including performance metrics, drift detection, and alerting.
5
6use std::collections::VecDeque;
7use std::time::{Duration, Instant};
8use tracing::{debug, info, warn};
9
10/// Model performance metrics
11#[derive(Debug, Clone)]
12pub struct PerformanceMetrics {
13    /// Total inference count
14    pub total_inferences: u64,
15    /// Average latency in milliseconds
16    pub avg_latency_ms: f32,
17    /// P50 latency in milliseconds
18    pub p50_latency_ms: f32,
19    /// P95 latency in milliseconds
20    pub p95_latency_ms: f32,
21    /// P99 latency in milliseconds
22    pub p99_latency_ms: f32,
23    /// Throughput (inferences per second)
24    pub throughput: f32,
25    /// Error rate (0.0 to 1.0)
26    pub error_rate: f32,
27}
28
29/// Model drift metrics
30#[derive(Debug, Clone)]
31pub struct DriftMetrics {
32    /// Input distribution drift score
33    pub input_drift: f32,
34    /// Output distribution drift score
35    pub output_drift: f32,
36    /// Concept drift detected
37    pub concept_drift: bool,
38    /// Data quality score (0.0 to 1.0)
39    pub data_quality: f32,
40}
41
42/// Alert severity level
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum AlertSeverity {
45    /// Informational
46    Info,
47    /// Warning - needs attention
48    Warning,
49    /// Critical - immediate action required
50    Critical,
51}
52
53/// Model alert
54#[derive(Debug, Clone)]
55pub struct ModelAlert {
56    /// Alert timestamp
57    pub timestamp: Instant,
58    /// Alert severity
59    pub severity: AlertSeverity,
60    /// Alert message
61    pub message: String,
62    /// Metric name that triggered the alert
63    pub metric: String,
64    /// Threshold value
65    pub threshold: f32,
66    /// Actual value
67    pub actual: f32,
68}
69
70/// Monitoring configuration
71#[derive(Debug, Clone)]
72pub struct MonitoringConfig {
73    /// Enable performance monitoring
74    pub enable_performance: bool,
75    /// Enable drift detection
76    pub enable_drift: bool,
77    /// Enable alerting
78    pub enable_alerting: bool,
79    /// Latency alert threshold (ms)
80    pub latency_threshold_ms: f32,
81    /// Error rate alert threshold
82    pub error_rate_threshold: f32,
83    /// Drift alert threshold
84    pub drift_threshold: f32,
85    /// Metric retention period
86    pub retention_period: Duration,
87}
88
89impl Default for MonitoringConfig {
90    fn default() -> Self {
91        Self {
92            enable_performance: true,
93            enable_drift: true,
94            enable_alerting: true,
95            latency_threshold_ms: 1000.0,
96            error_rate_threshold: 0.05,
97            drift_threshold: 0.2,
98            retention_period: Duration::from_secs(3600), // 1 hour
99        }
100    }
101}
102
103/// Model monitor
104pub struct ModelMonitor {
105    config: MonitoringConfig,
106    latencies: VecDeque<f32>,
107    errors: VecDeque<bool>,
108    alerts: Vec<ModelAlert>,
109    start_time: Instant,
110}
111
112impl ModelMonitor {
113    /// Creates a new model monitor
114    #[must_use]
115    pub fn new(config: MonitoringConfig) -> Self {
116        info!("Initializing model monitor");
117        Self {
118            config,
119            latencies: VecDeque::new(),
120            errors: VecDeque::new(),
121            alerts: Vec::new(),
122            start_time: Instant::now(),
123        }
124    }
125
126    /// Records an inference latency
127    pub fn record_latency(&mut self, latency_ms: f32) {
128        if !self.config.enable_performance {
129            return;
130        }
131
132        self.latencies.push_back(latency_ms);
133        self.trim_old_metrics();
134
135        // Check for latency alerts
136        if self.config.enable_alerting && latency_ms > self.config.latency_threshold_ms {
137            self.add_alert(ModelAlert {
138                timestamp: Instant::now(),
139                severity: if latency_ms > self.config.latency_threshold_ms * 2.0 {
140                    AlertSeverity::Critical
141                } else {
142                    AlertSeverity::Warning
143                },
144                message: format!("High latency detected: {:.1}ms", latency_ms),
145                metric: "latency_ms".to_string(),
146                threshold: self.config.latency_threshold_ms,
147                actual: latency_ms,
148            });
149        }
150    }
151
152    /// Records an inference error
153    pub fn record_error(&mut self, is_error: bool) {
154        if !self.config.enable_performance {
155            return;
156        }
157
158        self.errors.push_back(is_error);
159        self.trim_old_metrics();
160
161        // Check for error rate alerts
162        if self.config.enable_alerting && is_error {
163            let error_rate = self.calculate_error_rate();
164            if error_rate > self.config.error_rate_threshold {
165                self.add_alert(ModelAlert {
166                    timestamp: Instant::now(),
167                    severity: AlertSeverity::Critical,
168                    message: format!("High error rate: {:.1}%", error_rate * 100.0),
169                    metric: "error_rate".to_string(),
170                    threshold: self.config.error_rate_threshold,
171                    actual: error_rate,
172                });
173            }
174        }
175    }
176
177    /// Records drift metrics
178    pub fn record_drift(&mut self, metrics: DriftMetrics) {
179        if !self.config.enable_drift {
180            return;
181        }
182
183        debug!(
184            "Drift metrics: input={:.3}, output={:.3}, concept={}",
185            metrics.input_drift, metrics.output_drift, metrics.concept_drift
186        );
187
188        // Check for drift alerts
189        if self.config.enable_alerting {
190            if metrics.input_drift > self.config.drift_threshold {
191                self.add_alert(ModelAlert {
192                    timestamp: Instant::now(),
193                    severity: AlertSeverity::Warning,
194                    message: "Input distribution drift detected".to_string(),
195                    metric: "input_drift".to_string(),
196                    threshold: self.config.drift_threshold,
197                    actual: metrics.input_drift,
198                });
199            }
200
201            if metrics.output_drift > self.config.drift_threshold {
202                self.add_alert(ModelAlert {
203                    timestamp: Instant::now(),
204                    severity: AlertSeverity::Warning,
205                    message: "Output distribution drift detected".to_string(),
206                    metric: "output_drift".to_string(),
207                    threshold: self.config.drift_threshold,
208                    actual: metrics.output_drift,
209                });
210            }
211
212            if metrics.concept_drift {
213                self.add_alert(ModelAlert {
214                    timestamp: Instant::now(),
215                    severity: AlertSeverity::Critical,
216                    message: "Concept drift detected - model retraining recommended".to_string(),
217                    metric: "concept_drift".to_string(),
218                    threshold: 0.0,
219                    actual: 1.0,
220                });
221            }
222        }
223    }
224
225    /// Calculates performance metrics
226    #[must_use]
227    pub fn performance_metrics(&self) -> PerformanceMetrics {
228        let total_inferences = self.latencies.len() as u64;
229        let avg_latency = self.calculate_average_latency();
230        let percentiles = self.calculate_latency_percentiles();
231        let throughput = self.calculate_throughput();
232        let error_rate = self.calculate_error_rate();
233
234        PerformanceMetrics {
235            total_inferences,
236            avg_latency_ms: avg_latency,
237            p50_latency_ms: percentiles.0,
238            p95_latency_ms: percentiles.1,
239            p99_latency_ms: percentiles.2,
240            throughput,
241            error_rate,
242        }
243    }
244
245    /// Returns all alerts
246    #[must_use]
247    pub fn alerts(&self) -> &[ModelAlert] {
248        &self.alerts
249    }
250
251    /// Returns alerts by severity
252    #[must_use]
253    pub fn alerts_by_severity(&self, severity: AlertSeverity) -> Vec<&ModelAlert> {
254        self.alerts
255            .iter()
256            .filter(|a| a.severity == severity)
257            .collect()
258    }
259
260    /// Clears all alerts
261    pub fn clear_alerts(&mut self) {
262        info!("Clearing {} alerts", self.alerts.len());
263        self.alerts.clear();
264    }
265
266    /// Resets all metrics
267    pub fn reset(&mut self) {
268        info!("Resetting monitor metrics");
269        self.latencies.clear();
270        self.errors.clear();
271        self.alerts.clear();
272        self.start_time = Instant::now();
273    }
274
275    // Private helper methods
276
277    fn trim_old_metrics(&mut self) {
278        let max_samples = 10000; // Keep last 10k samples
279        while self.latencies.len() > max_samples {
280            self.latencies.pop_front();
281        }
282        while self.errors.len() > max_samples {
283            self.errors.pop_front();
284        }
285    }
286
287    fn add_alert(&mut self, alert: ModelAlert) {
288        match alert.severity {
289            AlertSeverity::Info => debug!("Alert: {}", alert.message),
290            AlertSeverity::Warning => warn!("Alert: {}", alert.message),
291            AlertSeverity::Critical => {
292                warn!("CRITICAL Alert: {}", alert.message);
293            }
294        }
295        self.alerts.push(alert);
296    }
297
298    fn calculate_average_latency(&self) -> f32 {
299        if self.latencies.is_empty() {
300            return 0.0;
301        }
302        let sum: f32 = self.latencies.iter().sum();
303        sum / self.latencies.len() as f32
304    }
305
306    fn calculate_latency_percentiles(&self) -> (f32, f32, f32) {
307        if self.latencies.is_empty() {
308            return (0.0, 0.0, 0.0);
309        }
310
311        let mut sorted: Vec<_> = self.latencies.iter().copied().collect();
312        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
313
314        let p50 = sorted[sorted.len() * 50 / 100];
315        let p95 = sorted[sorted.len() * 95 / 100];
316        let p99 = sorted[sorted.len() * 99 / 100];
317
318        (p50, p95, p99)
319    }
320
321    fn calculate_throughput(&self) -> f32 {
322        let elapsed = self.start_time.elapsed().as_secs_f32();
323        if elapsed > 0.0 {
324            self.latencies.len() as f32 / elapsed
325        } else {
326            0.0
327        }
328    }
329
330    fn calculate_error_rate(&self) -> f32 {
331        if self.errors.is_empty() {
332            return 0.0;
333        }
334        let error_count = self.errors.iter().filter(|&&e| e).count();
335        error_count as f32 / self.errors.len() as f32
336    }
337}
338
339/// Calculates input drift using KL divergence
340#[must_use]
341pub fn calculate_input_drift(reference_distribution: &[f32], current_distribution: &[f32]) -> f32 {
342    if reference_distribution.len() != current_distribution.len() {
343        return 1.0; // Maximum drift
344    }
345
346    let mut divergence = 0.0;
347    for (p, q) in reference_distribution
348        .iter()
349        .zip(current_distribution.iter())
350    {
351        if *p > 0.0 && *q > 0.0 {
352            divergence += p * (p / q).ln();
353        }
354    }
355
356    divergence
357}
358
359/// Calculates output drift using distribution shift
360#[must_use]
361pub fn calculate_output_drift(reference_predictions: &[f32], current_predictions: &[f32]) -> f32 {
362    if reference_predictions.len() != current_predictions.len() {
363        return 1.0;
364    }
365
366    let ref_mean = reference_predictions.iter().sum::<f32>() / reference_predictions.len() as f32;
367    let cur_mean = current_predictions.iter().sum::<f32>() / current_predictions.len() as f32;
368
369    (ref_mean - cur_mean).abs() / ref_mean.max(1e-6)
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    #[test]
377    fn test_monitoring_config_default() {
378        let config = MonitoringConfig::default();
379        assert!(config.enable_performance);
380        assert!(config.enable_drift);
381        assert!(config.enable_alerting);
382    }
383
384    #[test]
385    fn test_model_monitor_latency() {
386        let config = MonitoringConfig::default();
387        let mut monitor = ModelMonitor::new(config);
388
389        monitor.record_latency(100.0);
390        monitor.record_latency(150.0);
391        monitor.record_latency(120.0);
392
393        let metrics = monitor.performance_metrics();
394        assert_eq!(metrics.total_inferences, 3);
395        assert!((metrics.avg_latency_ms - 123.33).abs() < 1.0);
396    }
397
398    #[test]
399    fn test_model_monitor_errors() {
400        let config = MonitoringConfig::default();
401        let mut monitor = ModelMonitor::new(config);
402
403        monitor.record_error(false);
404        monitor.record_error(false);
405        monitor.record_error(true);
406        monitor.record_error(false);
407
408        let metrics = monitor.performance_metrics();
409        assert!((metrics.error_rate - 0.25).abs() < 0.01);
410    }
411
412    #[test]
413    fn test_alert_filtering() {
414        let config = MonitoringConfig {
415            latency_threshold_ms: 100.0,
416            ..Default::default()
417        };
418        let mut monitor = ModelMonitor::new(config);
419
420        monitor.record_latency(150.0); // Should trigger warning
421        monitor.record_latency(250.0); // Should trigger critical
422
423        let warnings = monitor.alerts_by_severity(AlertSeverity::Warning);
424        let criticals = monitor.alerts_by_severity(AlertSeverity::Critical);
425
426        assert_eq!(warnings.len(), 1);
427        assert_eq!(criticals.len(), 1);
428    }
429
430    #[test]
431    fn test_input_drift_calculation() {
432        let reference = vec![0.25, 0.25, 0.25, 0.25];
433        let current = vec![0.3, 0.2, 0.3, 0.2];
434
435        let drift = calculate_input_drift(&reference, &current);
436        assert!(drift > 0.0);
437        assert!(drift < 1.0);
438    }
439
440    #[test]
441    fn test_output_drift_calculation() {
442        let reference = vec![0.8, 0.7, 0.9, 0.75];
443        let current = vec![0.6, 0.5, 0.7, 0.55]; // 20% lower
444
445        let drift = calculate_output_drift(&reference, &current);
446        assert!(drift > 0.15);
447        assert!(drift < 0.30);
448    }
449}