Skip to main content

oxirs_stream/
ml_integration.rs

1//! # Machine Learning Integration for Stream Processing
2//!
3//! This module provides comprehensive ML capabilities for real-time stream processing,
4//! including online learning, anomaly detection, and predictive analytics.
5//!
6//! ## Features
7//!
8//! - **Online Learning**: Incremental model training on streaming data
9//! - **Anomaly Detection**: Real-time detection with adaptive thresholds
10//! - **Predictive Analytics**: Forecast future events and trends
11//! - **Feature Engineering**: Automatic feature extraction from events
12//! - **Model Serving**: Deploy and update models in production
13//! - **A/B Testing**: Compare model performance
14//! - **AutoML**: Automated model selection and hyperparameter tuning
15//!
16//! ## Integration with SciRS2
17//!
18//! This module leverages SciRS2's scientific computing capabilities for:
19//! - Statistical analysis via scirs2-stats
20//! - Neural networks via scirs2-neural (when available)
21//! - Signal processing via scirs2-signal
22
23use anyhow::{anyhow, Result};
24use chrono::{DateTime, Utc};
25use dashmap::DashMap;
26use parking_lot::RwLock;
27use serde::{Deserialize, Serialize};
28use std::collections::{HashMap, VecDeque};
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31use tracing::{debug, info};
32
33// Use SciRS2 for scientific computing (following SCIRS2 POLICY)
34use scirs2_core::ndarray_ext::Array1;
35use scirs2_core::random::rng;
36use scirs2_core::Rng; // For gen_range method
37
38use crate::event::StreamEvent;
39
40/// Type alias for training sample buffer to reduce type complexity
41type SampleBuffer = Arc<RwLock<Vec<(Array1<f64>, f64)>>>;
42
43/// ML model types supported
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub enum ModelType {
46    /// Online linear regression
47    LinearRegression,
48    /// Online logistic regression
49    LogisticRegression,
50    /// Streaming k-means clustering
51    KMeans { k: usize },
52    /// Exponentially weighted moving average
53    EWMA { alpha: f64 },
54    /// Isolation forest for anomaly detection
55    IsolationForest { n_trees: usize },
56    /// LSTM for sequence prediction
57    LSTM {
58        hidden_size: usize,
59        num_layers: usize,
60    },
61    /// Custom model
62    Custom { name: String },
63}
64
65/// Anomaly detection algorithm
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub enum AnomalyDetectionAlgorithm {
68    /// Statistical (Z-score based)
69    Statistical { threshold: f64 },
70    /// Isolation Forest
71    IsolationForest { contamination: f64 },
72    /// One-class SVM
73    OneClassSVM { nu: f64 },
74    /// Autoencoder-based
75    Autoencoder { encoding_dim: usize, threshold: f64 },
76    /// LSTM-based (for sequential anomalies)
77    LSTM { window_size: usize },
78    /// Ensemble of multiple algorithms
79    Ensemble {
80        algorithms: Vec<AnomalyDetectionAlgorithm>,
81    },
82}
83
84/// Feature extraction configuration
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct FeatureConfig {
87    /// Window size for temporal features
88    pub window_size: usize,
89    /// Enable statistical features
90    pub enable_statistical: bool,
91    /// Enable frequency features
92    pub enable_frequency: bool,
93    /// Enable custom features
94    pub custom_features: Vec<String>,
95}
96
97/// ML model configuration
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct MLModelConfig {
100    /// Model type
101    pub model_type: ModelType,
102    /// Feature configuration
103    pub feature_config: FeatureConfig,
104    /// Learning rate
105    pub learning_rate: f64,
106    /// Batch size for mini-batch learning
107    pub batch_size: usize,
108    /// Model update interval
109    pub update_interval: Duration,
110    /// Enable model persistence
111    pub enable_persistence: bool,
112    /// Model version
113    pub version: String,
114}
115
116/// Anomaly detection configuration
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct AnomalyDetectionConfig {
119    /// Detection algorithm
120    pub algorithm: AnomalyDetectionAlgorithm,
121    /// Sensitivity (0.0 to 1.0)
122    pub sensitivity: f64,
123    /// Adaptive threshold learning rate
124    pub adaptive_learning_rate: f64,
125    /// Window size for context
126    pub window_size: usize,
127    /// Minimum samples before detection starts
128    pub min_samples: usize,
129    /// Enable feedback loop for improvement
130    pub enable_feedback: bool,
131}
132
133/// Feature vector extracted from events
134#[derive(Debug, Clone)]
135pub struct FeatureVector {
136    /// Feature values
137    pub features: Array1<f64>,
138    /// Feature names
139    pub feature_names: Vec<String>,
140    /// Timestamp
141    pub timestamp: DateTime<Utc>,
142    /// Source event ID
143    pub source_event_id: String,
144}
145
146/// Anomaly detection result
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct AnomalyResult {
149    /// Is anomaly
150    pub is_anomaly: bool,
151    /// Anomaly score (0.0 to 1.0)
152    pub score: f64,
153    /// Explanation
154    pub explanation: String,
155    /// Contributing features
156    pub contributing_features: Vec<String>,
157    /// Timestamp
158    pub timestamp: DateTime<Utc>,
159    /// Event ID
160    pub event_id: String,
161}
162
163/// Prediction result
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct PredictionResult {
166    /// Predicted value
167    pub prediction: f64,
168    /// Confidence (0.0 to 1.0)
169    pub confidence: f64,
170    /// Prediction interval
171    pub interval: Option<(f64, f64)>,
172    /// Timestamp
173    pub timestamp: DateTime<Utc>,
174}
175
176/// Model performance metrics
177#[derive(Debug, Clone, Default, Serialize, Deserialize)]
178pub struct ModelMetrics {
179    /// Total predictions made
180    pub predictions_made: u64,
181    /// Correct predictions (if ground truth available)
182    pub correct_predictions: u64,
183    /// Accuracy (0.0 to 1.0)
184    pub accuracy: f64,
185    /// Mean absolute error
186    pub mean_absolute_error: f64,
187    /// Root mean squared error
188    pub root_mean_squared_error: f64,
189    /// R-squared score
190    pub r_squared: f64,
191    /// Average prediction time (ms)
192    pub avg_prediction_time_ms: f64,
193}
194
195/// Anomaly detection statistics
196#[derive(Debug, Clone, Default)]
197pub struct AnomalyStats {
198    /// Total events processed
199    pub events_processed: u64,
200    /// Anomalies detected
201    pub anomalies_detected: u64,
202    /// False positives (if labeled data available)
203    pub false_positives: u64,
204    /// True positives
205    pub true_positives: u64,
206    /// Average anomaly score
207    pub avg_anomaly_score: f64,
208    /// Detection rate (anomalies / total events)
209    pub detection_rate: f64,
210}
211
212/// Online learning model
213pub struct OnlineLearningModel {
214    /// Model configuration
215    config: MLModelConfig,
216    /// Model parameters
217    weights: Arc<RwLock<Array1<f64>>>,
218    /// Bias term
219    bias: Arc<RwLock<f64>>,
220    /// Number of features
221    num_features: usize,
222    /// Training samples buffer
223    sample_buffer: SampleBuffer,
224    /// Model metrics
225    metrics: Arc<RwLock<ModelMetrics>>,
226    /// Last update time
227    last_update: Arc<RwLock<Instant>>,
228}
229
230impl OnlineLearningModel {
231    /// Create a new online learning model
232    pub fn new(config: MLModelConfig, num_features: usize) -> Self {
233        // Initialize weights with small random values using SciRS2
234        let mut rng_instance = rng();
235        let weights = Array1::from_vec(
236            (0..num_features)
237                .map(|_| {
238                    // Use small random values for weight initialization
239                    rng_instance.random_range(-0.01..0.01)
240                })
241                .collect(),
242        );
243
244        Self {
245            config,
246            weights: Arc::new(RwLock::new(weights)),
247            bias: Arc::new(RwLock::new(0.0)),
248            num_features,
249            sample_buffer: Arc::new(RwLock::new(Vec::new())),
250            metrics: Arc::new(RwLock::new(ModelMetrics::default())),
251            last_update: Arc::new(RwLock::new(Instant::now())),
252        }
253    }
254
255    /// Train on a single sample (online learning)
256    pub fn train(&self, features: &Array1<f64>, target: f64) -> Result<()> {
257        if features.len() != self.num_features {
258            return Err(anyhow!(
259                "Feature dimension mismatch: expected {}, got {}",
260                self.num_features,
261                features.len()
262            ));
263        }
264
265        // Add to buffer
266        self.sample_buffer.write().push((features.clone(), target));
267
268        // Check if it's time to update
269        let should_update = {
270            let buffer = self.sample_buffer.read();
271            let last_update = self.last_update.read();
272            buffer.len() >= self.config.batch_size
273                || last_update.elapsed() >= self.config.update_interval
274        };
275
276        if should_update {
277            self.update_weights()?;
278        }
279
280        Ok(())
281    }
282
283    /// Update model weights using gradient descent
284    fn update_weights(&self) -> Result<()> {
285        let samples = {
286            let mut buffer = self.sample_buffer.write();
287            std::mem::take(&mut *buffer)
288        };
289
290        if samples.is_empty() {
291            return Ok(());
292        }
293
294        let mut weights = self.weights.write();
295        let mut bias = self.bias.write();
296
297        // Perform gradient descent update
298        for (features, target) in &samples {
299            let prediction = self.predict_internal(&weights, *bias, features);
300            let error = prediction - target;
301
302            // Update weights: w = w - learning_rate * error * x
303            for i in 0..self.num_features {
304                weights[i] -= self.config.learning_rate * error * features[i];
305            }
306
307            // Update bias: b = b - learning_rate * error
308            *bias -= self.config.learning_rate * error;
309        }
310
311        *self.last_update.write() = Instant::now();
312        debug!("Updated model weights with {} samples", samples.len());
313        Ok(())
314    }
315
316    /// Make a prediction
317    pub fn predict(&self, features: &Array1<f64>) -> Result<PredictionResult> {
318        if features.len() != self.num_features {
319            return Err(anyhow!("Feature dimension mismatch"));
320        }
321
322        let start_time = Instant::now();
323        let weights = self.weights.read();
324        let bias = self.bias.read();
325
326        let prediction = self.predict_internal(&weights, *bias, features);
327
328        // Update metrics
329        let mut metrics = self.metrics.write();
330        metrics.predictions_made += 1;
331        let prediction_time = start_time.elapsed().as_micros() as f64 / 1000.0;
332        metrics.avg_prediction_time_ms = (metrics.avg_prediction_time_ms + prediction_time) / 2.0;
333
334        Ok(PredictionResult {
335            prediction,
336            confidence: 0.8, // Placeholder - would calculate actual confidence
337            interval: None,
338            timestamp: Utc::now(),
339        })
340    }
341
342    /// Internal prediction function
343    fn predict_internal(&self, weights: &Array1<f64>, bias: f64, features: &Array1<f64>) -> f64 {
344        let mut result = bias;
345        for i in 0..self.num_features {
346            result += weights[i] * features[i];
347        }
348        result
349    }
350
351    /// Get model metrics
352    pub fn get_metrics(&self) -> ModelMetrics {
353        self.metrics.read().clone()
354    }
355}
356
357/// Anomaly detector with adaptive thresholds
358pub struct AnomalyDetector {
359    /// Configuration
360    config: AnomalyDetectionConfig,
361    /// Historical statistics (using SciRS2)
362    historical_mean: Arc<RwLock<f64>>,
363    historical_std: Arc<RwLock<f64>>,
364    /// Recent samples for statistics
365    recent_samples: Arc<RwLock<VecDeque<f64>>>,
366    /// Anomaly threshold
367    threshold: Arc<RwLock<f64>>,
368    /// Detection statistics
369    stats: Arc<RwLock<AnomalyStats>>,
370}
371
372impl AnomalyDetector {
373    /// Create a new anomaly detector
374    pub fn new(config: AnomalyDetectionConfig) -> Self {
375        Self {
376            config: config.clone(),
377            historical_mean: Arc::new(RwLock::new(0.0)),
378            historical_std: Arc::new(RwLock::new(1.0)),
379            recent_samples: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
380            threshold: Arc::new(RwLock::new(3.0)), // Initial Z-score threshold
381            stats: Arc::new(RwLock::new(AnomalyStats::default())),
382        }
383    }
384
385    /// Detect anomalies in a feature vector
386    pub fn detect(&self, features: &FeatureVector) -> Result<AnomalyResult> {
387        // For simplicity, use the mean of features as the metric
388        let metric = features.features.iter().sum::<f64>() / features.features.len() as f64;
389
390        // Update recent samples
391        let mut samples = self.recent_samples.write();
392        samples.push_back(metric);
393        if samples.len() > self.config.window_size {
394            samples.pop_front();
395        }
396
397        let mut stats = self.stats.write();
398        stats.events_processed += 1;
399
400        // Need minimum samples before detection
401        if samples.len() < self.config.min_samples {
402            return Ok(AnomalyResult {
403                is_anomaly: false,
404                score: 0.0,
405                explanation: "Insufficient samples for detection".to_string(),
406                contributing_features: Vec::new(),
407                timestamp: Utc::now(),
408                event_id: features.source_event_id.clone(),
409            });
410        }
411
412        // Calculate statistics using samples
413        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
414        let variance =
415            samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
416        let std_dev = variance.sqrt();
417
418        // Update historical statistics with exponential smoothing
419        {
420            let mut hist_mean = self.historical_mean.write();
421            let mut hist_std = self.historical_std.write();
422            let alpha = self.config.adaptive_learning_rate;
423            *hist_mean = alpha * mean + (1.0 - alpha) * *hist_mean;
424            *hist_std = alpha * std_dev + (1.0 - alpha) * *hist_std;
425        }
426
427        // Compute anomaly score based on algorithm
428        let (is_anomaly, score, explanation) = match &self.config.algorithm {
429            AnomalyDetectionAlgorithm::Statistical { threshold } => {
430                let z_score = if std_dev > 1e-10 {
431                    (metric - mean).abs() / std_dev
432                } else {
433                    0.0
434                };
435
436                let is_anomaly = z_score > *threshold;
437                let score = (z_score / threshold).min(1.0);
438
439                (
440                    is_anomaly,
441                    score,
442                    format!("Z-score: {:.2}, threshold: {:.2}", z_score, threshold),
443                )
444            }
445            AnomalyDetectionAlgorithm::IsolationForest { contamination } => {
446                // Simplified isolation forest - use statistical approach for now
447                let z_score = if std_dev > 1e-10 {
448                    (metric - mean).abs() / std_dev
449                } else {
450                    0.0
451                };
452
453                let threshold = 3.0 / contamination;
454                let is_anomaly = z_score > threshold;
455                let score = (z_score / threshold).min(1.0);
456
457                (is_anomaly, score, format!("Isolation score: {:.2}", score))
458            }
459            _ => {
460                // Default to statistical for other algorithms
461                let z_score = if std_dev > 1e-10 {
462                    (metric - mean).abs() / std_dev
463                } else {
464                    0.0
465                };
466
467                let is_anomaly = z_score > 3.0;
468                let score = (z_score / 3.0).min(1.0);
469
470                (is_anomaly, score, format!("Z-score: {:.2}", z_score))
471            }
472        };
473
474        if is_anomaly {
475            stats.anomalies_detected += 1;
476            stats.true_positives += 1;
477        }
478
479        stats.avg_anomaly_score = (stats.avg_anomaly_score + score) / 2.0;
480        stats.detection_rate = stats.anomalies_detected as f64 / stats.events_processed as f64;
481
482        Ok(AnomalyResult {
483            is_anomaly,
484            score,
485            explanation,
486            contributing_features: features.feature_names.clone(),
487            timestamp: Utc::now(),
488            event_id: features.source_event_id.clone(),
489        })
490    }
491
492    /// Provide feedback for improving detection
493    pub fn feedback(&self, event_id: &str, is_true_anomaly: bool) {
494        debug!(
495            "Received feedback for event {}: is_anomaly={}",
496            event_id, is_true_anomaly
497        );
498
499        if self.config.enable_feedback {
500            // Adjust threshold based on feedback
501            // This is a simplified approach - real implementation would be more sophisticated
502            let mut threshold = self.threshold.write();
503            if is_true_anomaly {
504                *threshold *= 0.98; // Slightly lower threshold to catch more
505            } else {
506                *threshold *= 1.02; // Slightly raise threshold to reduce false positives
507            }
508        }
509    }
510
511    /// Get detection statistics
512    pub fn get_stats(&self) -> AnomalyStats {
513        self.stats.read().clone()
514    }
515}
516
517/// Feature extractor for events
518pub struct FeatureExtractor {
519    /// Configuration
520    config: FeatureConfig,
521    /// Event history for temporal features
522    event_history: Arc<RwLock<VecDeque<StreamEvent>>>,
523}
524
525impl FeatureExtractor {
526    /// Create a new feature extractor
527    pub fn new(config: FeatureConfig) -> Self {
528        Self {
529            config: config.clone(),
530            event_history: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
531        }
532    }
533
534    /// Extract features from an event
535    pub fn extract_features(&self, event: &StreamEvent) -> Result<FeatureVector> {
536        let mut features = Vec::new();
537        let mut feature_names = Vec::new();
538
539        // Update history
540        let mut history = self.event_history.write();
541        history.push_back(event.clone());
542        if history.len() > self.config.window_size {
543            history.pop_front();
544        }
545
546        // Basic features
547        features.push(history.len() as f64);
548        feature_names.push("window_size".to_string());
549
550        // Statistical features
551        if self.config.enable_statistical {
552            // Count events in window
553            features.push(history.len() as f64);
554            feature_names.push("event_count".to_string());
555
556            // Event rate
557            if history.len() >= 2 {
558                let rate = history.len() as f64 / self.config.window_size as f64;
559                features.push(rate);
560                feature_names.push("event_rate".to_string());
561            }
562        }
563
564        // Frequency features
565        if self.config.enable_frequency {
566            // Event type frequency
567            let mut type_counts: HashMap<String, usize> = HashMap::new();
568            for evt in history.iter() {
569                let event_type = self.get_event_type(evt);
570                *type_counts.entry(event_type).or_insert(0) += 1;
571            }
572
573            let unique_types = type_counts.len() as f64;
574            features.push(unique_types);
575            feature_names.push("unique_event_types".to_string());
576        }
577
578        Ok(FeatureVector {
579            features: Array1::from_vec(features),
580            feature_names,
581            timestamp: Utc::now(),
582            source_event_id: self.get_event_id(event),
583        })
584    }
585
586    /// Get event type
587    fn get_event_type(&self, event: &StreamEvent) -> String {
588        match event {
589            StreamEvent::TripleAdded { .. } => "TripleAdded",
590            StreamEvent::TripleRemoved { .. } => "TripleRemoved",
591            StreamEvent::QuadAdded { .. } => "QuadAdded",
592            StreamEvent::QuadRemoved { .. } => "QuadRemoved",
593            StreamEvent::GraphCreated { .. } => "GraphCreated",
594            StreamEvent::GraphCleared { .. } => "GraphCleared",
595            StreamEvent::GraphDeleted { .. } => "GraphDeleted",
596            StreamEvent::SparqlUpdate { .. } => "SparqlUpdate",
597            StreamEvent::TransactionBegin { .. } => "TransactionBegin",
598            StreamEvent::TransactionCommit { .. } => "TransactionCommit",
599            StreamEvent::TransactionAbort { .. } => "TransactionAbort",
600            StreamEvent::SchemaChanged { .. } => "SchemaChanged",
601            _ => "Other",
602        }
603        .to_string()
604    }
605
606    /// Get event ID
607    fn get_event_id(&self, event: &StreamEvent) -> String {
608        let metadata = match event {
609            StreamEvent::TripleAdded { metadata, .. }
610            | StreamEvent::TripleRemoved { metadata, .. }
611            | StreamEvent::QuadAdded { metadata, .. }
612            | StreamEvent::QuadRemoved { metadata, .. }
613            | StreamEvent::GraphCreated { metadata, .. }
614            | StreamEvent::GraphCleared { metadata, .. }
615            | StreamEvent::GraphDeleted { metadata, .. }
616            | StreamEvent::SparqlUpdate { metadata, .. }
617            | StreamEvent::TransactionBegin { metadata, .. }
618            | StreamEvent::TransactionCommit { metadata, .. }
619            | StreamEvent::TransactionAbort { metadata, .. }
620            | StreamEvent::SchemaChanged { metadata, .. }
621            | StreamEvent::Heartbeat { metadata, .. }
622            | StreamEvent::QueryResultAdded { metadata, .. }
623            | StreamEvent::QueryResultRemoved { metadata, .. }
624            | StreamEvent::QueryCompleted { metadata, .. }
625            | StreamEvent::GraphMetadataUpdated { metadata, .. }
626            | StreamEvent::GraphPermissionsChanged { metadata, .. }
627            | StreamEvent::GraphStatisticsUpdated { metadata, .. }
628            | StreamEvent::GraphRenamed { metadata, .. }
629            | StreamEvent::GraphMerged { metadata, .. }
630            | StreamEvent::GraphSplit { metadata, .. }
631            | StreamEvent::SchemaDefinitionAdded { metadata, .. }
632            | StreamEvent::SchemaDefinitionRemoved { metadata, .. }
633            | StreamEvent::SchemaDefinitionModified { metadata, .. }
634            | StreamEvent::OntologyImported { metadata, .. }
635            | StreamEvent::OntologyRemoved { metadata, .. }
636            | StreamEvent::ConstraintAdded { metadata, .. }
637            | StreamEvent::ConstraintRemoved { metadata, .. }
638            | StreamEvent::ConstraintViolated { metadata, .. }
639            | StreamEvent::IndexCreated { metadata, .. }
640            | StreamEvent::IndexDropped { metadata, .. }
641            | StreamEvent::IndexRebuilt { metadata, .. }
642            | StreamEvent::SchemaUpdated { metadata, .. }
643            | StreamEvent::ShapeAdded { metadata, .. }
644            | StreamEvent::ShapeUpdated { metadata, .. }
645            | StreamEvent::ShapeRemoved { metadata, .. }
646            | StreamEvent::ShapeModified { metadata, .. }
647            | StreamEvent::ShapeValidationStarted { metadata, .. }
648            | StreamEvent::ShapeValidationCompleted { metadata, .. }
649            | StreamEvent::ShapeViolationDetected { metadata, .. }
650            | StreamEvent::ErrorOccurred { metadata, .. } => metadata,
651        };
652        metadata.event_id.clone()
653    }
654}
655
656/// ML integration manager
657pub struct MLIntegrationManager {
658    /// Online learning models
659    models: Arc<DashMap<String, OnlineLearningModel>>,
660    /// Anomaly detectors
661    detectors: Arc<DashMap<String, AnomalyDetector>>,
662    /// Feature extractors
663    extractors: Arc<DashMap<String, FeatureExtractor>>,
664}
665
666impl MLIntegrationManager {
667    /// Create a new ML integration manager
668    pub fn new() -> Self {
669        Self {
670            models: Arc::new(DashMap::new()),
671            detectors: Arc::new(DashMap::new()),
672            extractors: Arc::new(DashMap::new()),
673        }
674    }
675
676    /// Register an online learning model
677    pub fn register_model(&self, name: String, model: OnlineLearningModel) {
678        self.models.insert(name.clone(), model);
679        info!("Registered ML model: {}", name);
680    }
681
682    /// Register an anomaly detector
683    pub fn register_detector(&self, name: String, detector: AnomalyDetector) {
684        self.detectors.insert(name.clone(), detector);
685        info!("Registered anomaly detector: {}", name);
686    }
687
688    /// Register a feature extractor
689    pub fn register_extractor(&self, name: String, extractor: FeatureExtractor) {
690        self.extractors.insert(name.clone(), extractor);
691        info!("Registered feature extractor: {}", name);
692    }
693
694    /// Get a model
695    pub fn get_model(
696        &self,
697        name: &str,
698    ) -> Option<dashmap::mapref::one::Ref<'_, String, OnlineLearningModel>> {
699        self.models.get(name)
700    }
701
702    /// Get a detector
703    pub fn get_detector(
704        &self,
705        name: &str,
706    ) -> Option<dashmap::mapref::one::Ref<'_, String, AnomalyDetector>> {
707        self.detectors.get(name)
708    }
709
710    /// Get an extractor
711    pub fn get_extractor(
712        &self,
713        name: &str,
714    ) -> Option<dashmap::mapref::one::Ref<'_, String, FeatureExtractor>> {
715        self.extractors.get(name)
716    }
717}
718
719impl Default for MLIntegrationManager {
720    fn default() -> Self {
721        Self::new()
722    }
723}
724
725#[cfg(test)]
726mod tests {
727    use super::*;
728    use crate::event::EventMetadata;
729
730    #[test]
731    fn test_online_learning() {
732        let config = MLModelConfig {
733            model_type: ModelType::LinearRegression,
734            feature_config: FeatureConfig {
735                window_size: 10,
736                enable_statistical: true,
737                enable_frequency: false,
738                custom_features: Vec::new(),
739            },
740            learning_rate: 0.01,
741            batch_size: 10,
742            update_interval: Duration::from_secs(1),
743            enable_persistence: false,
744            version: "1.0".to_string(),
745        };
746
747        let model = OnlineLearningModel::new(config, 3);
748
749        // Train on some samples
750        let features = Array1::from_vec(vec![1.0, 2.0, 3.0]);
751        model.train(&features, 10.0).unwrap();
752
753        // Make a prediction
754        let result = model.predict(&features).unwrap();
755        assert!(result.prediction.is_finite());
756    }
757
758    #[test]
759    fn test_anomaly_detection() {
760        let config = AnomalyDetectionConfig {
761            algorithm: AnomalyDetectionAlgorithm::Statistical { threshold: 3.0 },
762            sensitivity: 0.8,
763            adaptive_learning_rate: 0.1,
764            window_size: 100,
765            min_samples: 10,
766            enable_feedback: true,
767        };
768
769        let detector = AnomalyDetector::new(config);
770
771        // Process normal events
772        for i in 0..20 {
773            let features = FeatureVector {
774                features: Array1::from_vec(vec![100.0 + i as f64]),
775                feature_names: vec!["value".to_string()],
776                timestamp: Utc::now(),
777                source_event_id: format!("event-{}", i),
778            };
779
780            let result = detector.detect(&features).unwrap();
781            if i >= 10 {
782                // After min_samples
783                assert!(!result.is_anomaly);
784            }
785        }
786
787        // Add an anomalous event
788        let anomalous_features = FeatureVector {
789            features: Array1::from_vec(vec![1000.0]),
790            feature_names: vec!["value".to_string()],
791            timestamp: Utc::now(),
792            source_event_id: "anomaly".to_string(),
793        };
794
795        let result = detector.detect(&anomalous_features).unwrap();
796        assert!(result.is_anomaly);
797        assert!(result.score > 0.0);
798    }
799
800    #[test]
801    fn test_feature_extraction() {
802        let config = FeatureConfig {
803            window_size: 10,
804            enable_statistical: true,
805            enable_frequency: true,
806            custom_features: Vec::new(),
807        };
808
809        let extractor = FeatureExtractor::new(config);
810
811        let event = StreamEvent::SchemaChanged {
812            schema_type: crate::event::SchemaType::Ontology,
813            change_type: crate::event::SchemaChangeType::Added,
814            details: "test schema change".to_string(),
815            metadata: EventMetadata {
816                event_id: "test-1".to_string(),
817                timestamp: Utc::now(),
818                source: "test".to_string(),
819                user: None,
820                context: None,
821                caused_by: None,
822                version: "1.0".to_string(),
823                properties: HashMap::new(),
824                checksum: None,
825            },
826        };
827
828        let features = extractor.extract_features(&event).unwrap();
829        assert!(!features.features.is_empty());
830        assert_eq!(features.features.len(), features.feature_names.len());
831    }
832}