Skip to main content

oxirs_stream/
ml_integration.rs

1//! # Machine Learning Integration for Stream Processing
2//!
3//! This module provides comprehensive ML capabilities for real-time stream processing,
4//! including online learning, anomaly detection, and predictive analytics.
5//!
6//! ## Features
7//!
8//! - **Online Learning**: Incremental model training on streaming data
9//! - **Anomaly Detection**: Real-time detection with adaptive thresholds
10//! - **Predictive Analytics**: Forecast future events and trends
11//! - **Feature Engineering**: Automatic feature extraction from events
12//! - **Model Serving**: Deploy and update models in production
13//! - **A/B Testing**: Compare model performance
14//! - **AutoML**: Automated model selection and hyperparameter tuning
15//!
16//! ## Integration with SciRS2
17//!
18//! This module leverages SciRS2's scientific computing capabilities for:
19//! - Statistical analysis via scirs2-stats
20//! - Neural networks via scirs2-neural (when available)
21//! - Signal processing via scirs2-signal
22
23use anyhow::{anyhow, Result};
24use chrono::{DateTime, Utc};
25use dashmap::DashMap;
26use parking_lot::RwLock;
27use serde::{Deserialize, Serialize};
28use std::collections::{HashMap, VecDeque};
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31use tracing::{debug, info};
32
33// Use SciRS2 for scientific computing (following SCIRS2 POLICY)
34use scirs2_core::ndarray_ext::Array1;
35use scirs2_core::random::{rng, RngExt};
36
37use crate::event::StreamEvent;
38
39/// Type alias for training sample buffer to reduce type complexity
40type SampleBuffer = Arc<RwLock<Vec<(Array1<f64>, f64)>>>;
41
42/// ML model types supported
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum ModelType {
45    /// Online linear regression
46    LinearRegression,
47    /// Online logistic regression
48    LogisticRegression,
49    /// Streaming k-means clustering
50    KMeans { k: usize },
51    /// Exponentially weighted moving average
52    EWMA { alpha: f64 },
53    /// Isolation forest for anomaly detection
54    IsolationForest { n_trees: usize },
55    /// LSTM for sequence prediction
56    LSTM {
57        hidden_size: usize,
58        num_layers: usize,
59    },
60    /// Custom model
61    Custom { name: String },
62}
63
64/// Anomaly detection algorithm
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum AnomalyDetectionAlgorithm {
67    /// Statistical (Z-score based)
68    Statistical { threshold: f64 },
69    /// Isolation Forest
70    IsolationForest { contamination: f64 },
71    /// One-class SVM
72    OneClassSVM { nu: f64 },
73    /// Autoencoder-based
74    Autoencoder { encoding_dim: usize, threshold: f64 },
75    /// LSTM-based (for sequential anomalies)
76    LSTM { window_size: usize },
77    /// Ensemble of multiple algorithms
78    Ensemble {
79        algorithms: Vec<AnomalyDetectionAlgorithm>,
80    },
81}
82
83/// Feature extraction configuration
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct FeatureConfig {
86    /// Window size for temporal features
87    pub window_size: usize,
88    /// Enable statistical features
89    pub enable_statistical: bool,
90    /// Enable frequency features
91    pub enable_frequency: bool,
92    /// Enable custom features
93    pub custom_features: Vec<String>,
94}
95
96/// ML model configuration
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct MLModelConfig {
99    /// Model type
100    pub model_type: ModelType,
101    /// Feature configuration
102    pub feature_config: FeatureConfig,
103    /// Learning rate
104    pub learning_rate: f64,
105    /// Batch size for mini-batch learning
106    pub batch_size: usize,
107    /// Model update interval
108    pub update_interval: Duration,
109    /// Enable model persistence
110    pub enable_persistence: bool,
111    /// Model version
112    pub version: String,
113}
114
115/// Anomaly detection configuration
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct AnomalyDetectionConfig {
118    /// Detection algorithm
119    pub algorithm: AnomalyDetectionAlgorithm,
120    /// Sensitivity (0.0 to 1.0)
121    pub sensitivity: f64,
122    /// Adaptive threshold learning rate
123    pub adaptive_learning_rate: f64,
124    /// Window size for context
125    pub window_size: usize,
126    /// Minimum samples before detection starts
127    pub min_samples: usize,
128    /// Enable feedback loop for improvement
129    pub enable_feedback: bool,
130}
131
132/// Feature vector extracted from events
133#[derive(Debug, Clone)]
134pub struct FeatureVector {
135    /// Feature values
136    pub features: Array1<f64>,
137    /// Feature names
138    pub feature_names: Vec<String>,
139    /// Timestamp
140    pub timestamp: DateTime<Utc>,
141    /// Source event ID
142    pub source_event_id: String,
143}
144
145/// Anomaly detection result
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct AnomalyResult {
148    /// Is anomaly
149    pub is_anomaly: bool,
150    /// Anomaly score (0.0 to 1.0)
151    pub score: f64,
152    /// Explanation
153    pub explanation: String,
154    /// Contributing features
155    pub contributing_features: Vec<String>,
156    /// Timestamp
157    pub timestamp: DateTime<Utc>,
158    /// Event ID
159    pub event_id: String,
160}
161
162/// Prediction result
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct PredictionResult {
165    /// Predicted value
166    pub prediction: f64,
167    /// Confidence (0.0 to 1.0)
168    pub confidence: f64,
169    /// Prediction interval
170    pub interval: Option<(f64, f64)>,
171    /// Timestamp
172    pub timestamp: DateTime<Utc>,
173}
174
175/// Model performance metrics
176#[derive(Debug, Clone, Default, Serialize, Deserialize)]
177pub struct ModelMetrics {
178    /// Total predictions made
179    pub predictions_made: u64,
180    /// Correct predictions (if ground truth available)
181    pub correct_predictions: u64,
182    /// Accuracy (0.0 to 1.0)
183    pub accuracy: f64,
184    /// Mean absolute error
185    pub mean_absolute_error: f64,
186    /// Root mean squared error
187    pub root_mean_squared_error: f64,
188    /// R-squared score
189    pub r_squared: f64,
190    /// Average prediction time (ms)
191    pub avg_prediction_time_ms: f64,
192}
193
194/// Anomaly detection statistics
195#[derive(Debug, Clone, Default)]
196pub struct AnomalyStats {
197    /// Total events processed
198    pub events_processed: u64,
199    /// Anomalies detected
200    pub anomalies_detected: u64,
201    /// False positives (if labeled data available)
202    pub false_positives: u64,
203    /// True positives
204    pub true_positives: u64,
205    /// Average anomaly score
206    pub avg_anomaly_score: f64,
207    /// Detection rate (anomalies / total events)
208    pub detection_rate: f64,
209}
210
211/// Online learning model
212pub struct OnlineLearningModel {
213    /// Model configuration
214    config: MLModelConfig,
215    /// Model parameters
216    weights: Arc<RwLock<Array1<f64>>>,
217    /// Bias term
218    bias: Arc<RwLock<f64>>,
219    /// Number of features
220    num_features: usize,
221    /// Training samples buffer
222    sample_buffer: SampleBuffer,
223    /// Model metrics
224    metrics: Arc<RwLock<ModelMetrics>>,
225    /// Last update time
226    last_update: Arc<RwLock<Instant>>,
227}
228
229impl OnlineLearningModel {
230    /// Create a new online learning model
231    pub fn new(config: MLModelConfig, num_features: usize) -> Self {
232        // Initialize weights with small random values using SciRS2
233        let mut rng_instance = rng();
234        let weights = Array1::from_vec(
235            (0..num_features)
236                .map(|_| {
237                    // Use small random values for weight initialization
238                    rng_instance.random_range(-0.01..0.01)
239                })
240                .collect(),
241        );
242
243        Self {
244            config,
245            weights: Arc::new(RwLock::new(weights)),
246            bias: Arc::new(RwLock::new(0.0)),
247            num_features,
248            sample_buffer: Arc::new(RwLock::new(Vec::new())),
249            metrics: Arc::new(RwLock::new(ModelMetrics::default())),
250            last_update: Arc::new(RwLock::new(Instant::now())),
251        }
252    }
253
254    /// Train on a single sample (online learning)
255    pub fn train(&self, features: &Array1<f64>, target: f64) -> Result<()> {
256        if features.len() != self.num_features {
257            return Err(anyhow!(
258                "Feature dimension mismatch: expected {}, got {}",
259                self.num_features,
260                features.len()
261            ));
262        }
263
264        // Add to buffer
265        self.sample_buffer.write().push((features.clone(), target));
266
267        // Check if it's time to update
268        let should_update = {
269            let buffer = self.sample_buffer.read();
270            let last_update = self.last_update.read();
271            buffer.len() >= self.config.batch_size
272                || last_update.elapsed() >= self.config.update_interval
273        };
274
275        if should_update {
276            self.update_weights()?;
277        }
278
279        Ok(())
280    }
281
282    /// Update model weights using gradient descent
283    fn update_weights(&self) -> Result<()> {
284        let samples = {
285            let mut buffer = self.sample_buffer.write();
286            std::mem::take(&mut *buffer)
287        };
288
289        if samples.is_empty() {
290            return Ok(());
291        }
292
293        let mut weights = self.weights.write();
294        let mut bias = self.bias.write();
295
296        // Perform gradient descent update
297        for (features, target) in &samples {
298            let prediction = self.predict_internal(&weights, *bias, features);
299            let error = prediction - target;
300
301            // Update weights: w = w - learning_rate * error * x
302            for i in 0..self.num_features {
303                weights[i] -= self.config.learning_rate * error * features[i];
304            }
305
306            // Update bias: b = b - learning_rate * error
307            *bias -= self.config.learning_rate * error;
308        }
309
310        *self.last_update.write() = Instant::now();
311        debug!("Updated model weights with {} samples", samples.len());
312        Ok(())
313    }
314
315    /// Make a prediction
316    pub fn predict(&self, features: &Array1<f64>) -> Result<PredictionResult> {
317        if features.len() != self.num_features {
318            return Err(anyhow!("Feature dimension mismatch"));
319        }
320
321        let start_time = Instant::now();
322        let weights = self.weights.read();
323        let bias = self.bias.read();
324
325        let prediction = self.predict_internal(&weights, *bias, features);
326
327        // Update metrics
328        let mut metrics = self.metrics.write();
329        metrics.predictions_made += 1;
330        let prediction_time = start_time.elapsed().as_micros() as f64 / 1000.0;
331        metrics.avg_prediction_time_ms = (metrics.avg_prediction_time_ms + prediction_time) / 2.0;
332
333        Ok(PredictionResult {
334            prediction,
335            confidence: 0.8, // Placeholder - would calculate actual confidence
336            interval: None,
337            timestamp: Utc::now(),
338        })
339    }
340
341    /// Internal prediction function
342    fn predict_internal(&self, weights: &Array1<f64>, bias: f64, features: &Array1<f64>) -> f64 {
343        let mut result = bias;
344        for i in 0..self.num_features {
345            result += weights[i] * features[i];
346        }
347        result
348    }
349
350    /// Get model metrics
351    pub fn get_metrics(&self) -> ModelMetrics {
352        self.metrics.read().clone()
353    }
354}
355
356/// Anomaly detector with adaptive thresholds
357pub struct AnomalyDetector {
358    /// Configuration
359    config: AnomalyDetectionConfig,
360    /// Historical statistics (using SciRS2)
361    historical_mean: Arc<RwLock<f64>>,
362    historical_std: Arc<RwLock<f64>>,
363    /// Recent samples for statistics
364    recent_samples: Arc<RwLock<VecDeque<f64>>>,
365    /// Anomaly threshold
366    threshold: Arc<RwLock<f64>>,
367    /// Detection statistics
368    stats: Arc<RwLock<AnomalyStats>>,
369}
370
371impl AnomalyDetector {
372    /// Create a new anomaly detector
373    pub fn new(config: AnomalyDetectionConfig) -> Self {
374        Self {
375            config: config.clone(),
376            historical_mean: Arc::new(RwLock::new(0.0)),
377            historical_std: Arc::new(RwLock::new(1.0)),
378            recent_samples: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
379            threshold: Arc::new(RwLock::new(3.0)), // Initial Z-score threshold
380            stats: Arc::new(RwLock::new(AnomalyStats::default())),
381        }
382    }
383
384    /// Detect anomalies in a feature vector
385    pub fn detect(&self, features: &FeatureVector) -> Result<AnomalyResult> {
386        // For simplicity, use the mean of features as the metric
387        let metric = features.features.iter().sum::<f64>() / features.features.len() as f64;
388
389        // Update recent samples
390        let mut samples = self.recent_samples.write();
391        samples.push_back(metric);
392        if samples.len() > self.config.window_size {
393            samples.pop_front();
394        }
395
396        let mut stats = self.stats.write();
397        stats.events_processed += 1;
398
399        // Need minimum samples before detection
400        if samples.len() < self.config.min_samples {
401            return Ok(AnomalyResult {
402                is_anomaly: false,
403                score: 0.0,
404                explanation: "Insufficient samples for detection".to_string(),
405                contributing_features: Vec::new(),
406                timestamp: Utc::now(),
407                event_id: features.source_event_id.clone(),
408            });
409        }
410
411        // Calculate statistics using samples
412        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
413        let variance =
414            samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
415        let std_dev = variance.sqrt();
416
417        // Update historical statistics with exponential smoothing
418        {
419            let mut hist_mean = self.historical_mean.write();
420            let mut hist_std = self.historical_std.write();
421            let alpha = self.config.adaptive_learning_rate;
422            *hist_mean = alpha * mean + (1.0 - alpha) * *hist_mean;
423            *hist_std = alpha * std_dev + (1.0 - alpha) * *hist_std;
424        }
425
426        // Compute anomaly score based on algorithm
427        let (is_anomaly, score, explanation) = match &self.config.algorithm {
428            AnomalyDetectionAlgorithm::Statistical { threshold } => {
429                let z_score = if std_dev > 1e-10 {
430                    (metric - mean).abs() / std_dev
431                } else {
432                    0.0
433                };
434
435                let is_anomaly = z_score > *threshold;
436                let score = (z_score / threshold).min(1.0);
437
438                (
439                    is_anomaly,
440                    score,
441                    format!("Z-score: {:.2}, threshold: {:.2}", z_score, threshold),
442                )
443            }
444            AnomalyDetectionAlgorithm::IsolationForest { contamination } => {
445                // Simplified isolation forest - use statistical approach for now
446                let z_score = if std_dev > 1e-10 {
447                    (metric - mean).abs() / std_dev
448                } else {
449                    0.0
450                };
451
452                let threshold = 3.0 / contamination;
453                let is_anomaly = z_score > threshold;
454                let score = (z_score / threshold).min(1.0);
455
456                (is_anomaly, score, format!("Isolation score: {:.2}", score))
457            }
458            _ => {
459                // Default to statistical for other algorithms
460                let z_score = if std_dev > 1e-10 {
461                    (metric - mean).abs() / std_dev
462                } else {
463                    0.0
464                };
465
466                let is_anomaly = z_score > 3.0;
467                let score = (z_score / 3.0).min(1.0);
468
469                (is_anomaly, score, format!("Z-score: {:.2}", z_score))
470            }
471        };
472
473        if is_anomaly {
474            stats.anomalies_detected += 1;
475            stats.true_positives += 1;
476        }
477
478        stats.avg_anomaly_score = (stats.avg_anomaly_score + score) / 2.0;
479        stats.detection_rate = stats.anomalies_detected as f64 / stats.events_processed as f64;
480
481        Ok(AnomalyResult {
482            is_anomaly,
483            score,
484            explanation,
485            contributing_features: features.feature_names.clone(),
486            timestamp: Utc::now(),
487            event_id: features.source_event_id.clone(),
488        })
489    }
490
491    /// Provide feedback for improving detection
492    pub fn feedback(&self, event_id: &str, is_true_anomaly: bool) {
493        debug!(
494            "Received feedback for event {}: is_anomaly={}",
495            event_id, is_true_anomaly
496        );
497
498        if self.config.enable_feedback {
499            // Adjust threshold based on feedback
500            // This is a simplified approach - real implementation would be more sophisticated
501            let mut threshold = self.threshold.write();
502            if is_true_anomaly {
503                *threshold *= 0.98; // Slightly lower threshold to catch more
504            } else {
505                *threshold *= 1.02; // Slightly raise threshold to reduce false positives
506            }
507        }
508    }
509
510    /// Get detection statistics
511    pub fn get_stats(&self) -> AnomalyStats {
512        self.stats.read().clone()
513    }
514}
515
516/// Feature extractor for events
517pub struct FeatureExtractor {
518    /// Configuration
519    config: FeatureConfig,
520    /// Event history for temporal features
521    event_history: Arc<RwLock<VecDeque<StreamEvent>>>,
522}
523
524impl FeatureExtractor {
525    /// Create a new feature extractor
526    pub fn new(config: FeatureConfig) -> Self {
527        Self {
528            config: config.clone(),
529            event_history: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
530        }
531    }
532
533    /// Extract features from an event
534    pub fn extract_features(&self, event: &StreamEvent) -> Result<FeatureVector> {
535        let mut features = Vec::new();
536        let mut feature_names = Vec::new();
537
538        // Update history
539        let mut history = self.event_history.write();
540        history.push_back(event.clone());
541        if history.len() > self.config.window_size {
542            history.pop_front();
543        }
544
545        // Basic features
546        features.push(history.len() as f64);
547        feature_names.push("window_size".to_string());
548
549        // Statistical features
550        if self.config.enable_statistical {
551            // Count events in window
552            features.push(history.len() as f64);
553            feature_names.push("event_count".to_string());
554
555            // Event rate
556            if history.len() >= 2 {
557                let rate = history.len() as f64 / self.config.window_size as f64;
558                features.push(rate);
559                feature_names.push("event_rate".to_string());
560            }
561        }
562
563        // Frequency features
564        if self.config.enable_frequency {
565            // Event type frequency
566            let mut type_counts: HashMap<String, usize> = HashMap::new();
567            for evt in history.iter() {
568                let event_type = self.get_event_type(evt);
569                *type_counts.entry(event_type).or_insert(0) += 1;
570            }
571
572            let unique_types = type_counts.len() as f64;
573            features.push(unique_types);
574            feature_names.push("unique_event_types".to_string());
575        }
576
577        Ok(FeatureVector {
578            features: Array1::from_vec(features),
579            feature_names,
580            timestamp: Utc::now(),
581            source_event_id: self.get_event_id(event),
582        })
583    }
584
585    /// Get event type
586    fn get_event_type(&self, event: &StreamEvent) -> String {
587        match event {
588            StreamEvent::TripleAdded { .. } => "TripleAdded",
589            StreamEvent::TripleRemoved { .. } => "TripleRemoved",
590            StreamEvent::QuadAdded { .. } => "QuadAdded",
591            StreamEvent::QuadRemoved { .. } => "QuadRemoved",
592            StreamEvent::GraphCreated { .. } => "GraphCreated",
593            StreamEvent::GraphCleared { .. } => "GraphCleared",
594            StreamEvent::GraphDeleted { .. } => "GraphDeleted",
595            StreamEvent::SparqlUpdate { .. } => "SparqlUpdate",
596            StreamEvent::TransactionBegin { .. } => "TransactionBegin",
597            StreamEvent::TransactionCommit { .. } => "TransactionCommit",
598            StreamEvent::TransactionAbort { .. } => "TransactionAbort",
599            StreamEvent::SchemaChanged { .. } => "SchemaChanged",
600            _ => "Other",
601        }
602        .to_string()
603    }
604
605    /// Get event ID
606    fn get_event_id(&self, event: &StreamEvent) -> String {
607        let metadata = match event {
608            StreamEvent::TripleAdded { metadata, .. }
609            | StreamEvent::TripleRemoved { metadata, .. }
610            | StreamEvent::QuadAdded { metadata, .. }
611            | StreamEvent::QuadRemoved { metadata, .. }
612            | StreamEvent::GraphCreated { metadata, .. }
613            | StreamEvent::GraphCleared { metadata, .. }
614            | StreamEvent::GraphDeleted { metadata, .. }
615            | StreamEvent::SparqlUpdate { metadata, .. }
616            | StreamEvent::TransactionBegin { metadata, .. }
617            | StreamEvent::TransactionCommit { metadata, .. }
618            | StreamEvent::TransactionAbort { metadata, .. }
619            | StreamEvent::SchemaChanged { metadata, .. }
620            | StreamEvent::Heartbeat { metadata, .. }
621            | StreamEvent::QueryResultAdded { metadata, .. }
622            | StreamEvent::QueryResultRemoved { metadata, .. }
623            | StreamEvent::QueryCompleted { metadata, .. }
624            | StreamEvent::GraphMetadataUpdated { metadata, .. }
625            | StreamEvent::GraphPermissionsChanged { metadata, .. }
626            | StreamEvent::GraphStatisticsUpdated { metadata, .. }
627            | StreamEvent::GraphRenamed { metadata, .. }
628            | StreamEvent::GraphMerged { metadata, .. }
629            | StreamEvent::GraphSplit { metadata, .. }
630            | StreamEvent::SchemaDefinitionAdded { metadata, .. }
631            | StreamEvent::SchemaDefinitionRemoved { metadata, .. }
632            | StreamEvent::SchemaDefinitionModified { metadata, .. }
633            | StreamEvent::OntologyImported { metadata, .. }
634            | StreamEvent::OntologyRemoved { metadata, .. }
635            | StreamEvent::ConstraintAdded { metadata, .. }
636            | StreamEvent::ConstraintRemoved { metadata, .. }
637            | StreamEvent::ConstraintViolated { metadata, .. }
638            | StreamEvent::IndexCreated { metadata, .. }
639            | StreamEvent::IndexDropped { metadata, .. }
640            | StreamEvent::IndexRebuilt { metadata, .. }
641            | StreamEvent::SchemaUpdated { metadata, .. }
642            | StreamEvent::ShapeAdded { metadata, .. }
643            | StreamEvent::ShapeUpdated { metadata, .. }
644            | StreamEvent::ShapeRemoved { metadata, .. }
645            | StreamEvent::ShapeModified { metadata, .. }
646            | StreamEvent::ShapeValidationStarted { metadata, .. }
647            | StreamEvent::ShapeValidationCompleted { metadata, .. }
648            | StreamEvent::ShapeViolationDetected { metadata, .. }
649            | StreamEvent::ErrorOccurred { metadata, .. } => metadata,
650        };
651        metadata.event_id.clone()
652    }
653}
654
655/// ML integration manager
656pub struct MLIntegrationManager {
657    /// Online learning models
658    models: Arc<DashMap<String, OnlineLearningModel>>,
659    /// Anomaly detectors
660    detectors: Arc<DashMap<String, AnomalyDetector>>,
661    /// Feature extractors
662    extractors: Arc<DashMap<String, FeatureExtractor>>,
663}
664
665impl MLIntegrationManager {
666    /// Create a new ML integration manager
667    pub fn new() -> Self {
668        Self {
669            models: Arc::new(DashMap::new()),
670            detectors: Arc::new(DashMap::new()),
671            extractors: Arc::new(DashMap::new()),
672        }
673    }
674
675    /// Register an online learning model
676    pub fn register_model(&self, name: String, model: OnlineLearningModel) {
677        self.models.insert(name.clone(), model);
678        info!("Registered ML model: {}", name);
679    }
680
681    /// Register an anomaly detector
682    pub fn register_detector(&self, name: String, detector: AnomalyDetector) {
683        self.detectors.insert(name.clone(), detector);
684        info!("Registered anomaly detector: {}", name);
685    }
686
687    /// Register a feature extractor
688    pub fn register_extractor(&self, name: String, extractor: FeatureExtractor) {
689        self.extractors.insert(name.clone(), extractor);
690        info!("Registered feature extractor: {}", name);
691    }
692
693    /// Get a model
694    pub fn get_model(
695        &self,
696        name: &str,
697    ) -> Option<dashmap::mapref::one::Ref<'_, String, OnlineLearningModel>> {
698        self.models.get(name)
699    }
700
701    /// Get a detector
702    pub fn get_detector(
703        &self,
704        name: &str,
705    ) -> Option<dashmap::mapref::one::Ref<'_, String, AnomalyDetector>> {
706        self.detectors.get(name)
707    }
708
709    /// Get an extractor
710    pub fn get_extractor(
711        &self,
712        name: &str,
713    ) -> Option<dashmap::mapref::one::Ref<'_, String, FeatureExtractor>> {
714        self.extractors.get(name)
715    }
716}
717
718impl Default for MLIntegrationManager {
719    fn default() -> Self {
720        Self::new()
721    }
722}
723
724#[cfg(test)]
725mod tests {
726    use super::*;
727    use crate::event::EventMetadata;
728
729    #[test]
730    fn test_online_learning() {
731        let config = MLModelConfig {
732            model_type: ModelType::LinearRegression,
733            feature_config: FeatureConfig {
734                window_size: 10,
735                enable_statistical: true,
736                enable_frequency: false,
737                custom_features: Vec::new(),
738            },
739            learning_rate: 0.01,
740            batch_size: 10,
741            update_interval: Duration::from_secs(1),
742            enable_persistence: false,
743            version: "1.0".to_string(),
744        };
745
746        let model = OnlineLearningModel::new(config, 3);
747
748        // Train on some samples
749        let features = Array1::from_vec(vec![1.0, 2.0, 3.0]);
750        model.train(&features, 10.0).unwrap();
751
752        // Make a prediction
753        let result = model.predict(&features).unwrap();
754        assert!(result.prediction.is_finite());
755    }
756
757    #[test]
758    fn test_anomaly_detection() {
759        let config = AnomalyDetectionConfig {
760            algorithm: AnomalyDetectionAlgorithm::Statistical { threshold: 3.0 },
761            sensitivity: 0.8,
762            adaptive_learning_rate: 0.1,
763            window_size: 100,
764            min_samples: 10,
765            enable_feedback: true,
766        };
767
768        let detector = AnomalyDetector::new(config);
769
770        // Process normal events
771        for i in 0..20 {
772            let features = FeatureVector {
773                features: Array1::from_vec(vec![100.0 + i as f64]),
774                feature_names: vec!["value".to_string()],
775                timestamp: Utc::now(),
776                source_event_id: format!("event-{}", i),
777            };
778
779            let result = detector.detect(&features).unwrap();
780            if i >= 10 {
781                // After min_samples
782                assert!(!result.is_anomaly);
783            }
784        }
785
786        // Add an anomalous event
787        let anomalous_features = FeatureVector {
788            features: Array1::from_vec(vec![1000.0]),
789            feature_names: vec!["value".to_string()],
790            timestamp: Utc::now(),
791            source_event_id: "anomaly".to_string(),
792        };
793
794        let result = detector.detect(&anomalous_features).unwrap();
795        assert!(result.is_anomaly);
796        assert!(result.score > 0.0);
797    }
798
799    #[test]
800    fn test_feature_extraction() {
801        let config = FeatureConfig {
802            window_size: 10,
803            enable_statistical: true,
804            enable_frequency: true,
805            custom_features: Vec::new(),
806        };
807
808        let extractor = FeatureExtractor::new(config);
809
810        let event = StreamEvent::SchemaChanged {
811            schema_type: crate::event::SchemaType::Ontology,
812            change_type: crate::event::SchemaChangeType::Added,
813            details: "test schema change".to_string(),
814            metadata: EventMetadata {
815                event_id: "test-1".to_string(),
816                timestamp: Utc::now(),
817                source: "test".to_string(),
818                user: None,
819                context: None,
820                caused_by: None,
821                version: "1.0".to_string(),
822                properties: HashMap::new(),
823                checksum: None,
824            },
825        };
826
827        let features = extractor.extract_features(&event).unwrap();
828        assert!(!features.features.is_empty());
829        assert_eq!(features.features.len(), features.feature_names.len());
830    }
831}