1use anyhow::{anyhow, Result};
24use chrono::{DateTime, Utc};
25use dashmap::DashMap;
26use parking_lot::RwLock;
27use serde::{Deserialize, Serialize};
28use std::collections::{HashMap, VecDeque};
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31use tracing::{debug, info};
32
33use scirs2_core::ndarray_ext::Array1;
35use scirs2_core::random::rng;
36use scirs2_core::Rng; use crate::event::StreamEvent;
39
40type SampleBuffer = Arc<RwLock<Vec<(Array1<f64>, f64)>>>;
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub enum ModelType {
46 LinearRegression,
48 LogisticRegression,
50 KMeans { k: usize },
52 EWMA { alpha: f64 },
54 IsolationForest { n_trees: usize },
56 LSTM {
58 hidden_size: usize,
59 num_layers: usize,
60 },
61 Custom { name: String },
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub enum AnomalyDetectionAlgorithm {
68 Statistical { threshold: f64 },
70 IsolationForest { contamination: f64 },
72 OneClassSVM { nu: f64 },
74 Autoencoder { encoding_dim: usize, threshold: f64 },
76 LSTM { window_size: usize },
78 Ensemble {
80 algorithms: Vec<AnomalyDetectionAlgorithm>,
81 },
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct FeatureConfig {
87 pub window_size: usize,
89 pub enable_statistical: bool,
91 pub enable_frequency: bool,
93 pub custom_features: Vec<String>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct MLModelConfig {
100 pub model_type: ModelType,
102 pub feature_config: FeatureConfig,
104 pub learning_rate: f64,
106 pub batch_size: usize,
108 pub update_interval: Duration,
110 pub enable_persistence: bool,
112 pub version: String,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct AnomalyDetectionConfig {
119 pub algorithm: AnomalyDetectionAlgorithm,
121 pub sensitivity: f64,
123 pub adaptive_learning_rate: f64,
125 pub window_size: usize,
127 pub min_samples: usize,
129 pub enable_feedback: bool,
131}
132
133#[derive(Debug, Clone)]
135pub struct FeatureVector {
136 pub features: Array1<f64>,
138 pub feature_names: Vec<String>,
140 pub timestamp: DateTime<Utc>,
142 pub source_event_id: String,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct AnomalyResult {
149 pub is_anomaly: bool,
151 pub score: f64,
153 pub explanation: String,
155 pub contributing_features: Vec<String>,
157 pub timestamp: DateTime<Utc>,
159 pub event_id: String,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct PredictionResult {
166 pub prediction: f64,
168 pub confidence: f64,
170 pub interval: Option<(f64, f64)>,
172 pub timestamp: DateTime<Utc>,
174}
175
176#[derive(Debug, Clone, Default, Serialize, Deserialize)]
178pub struct ModelMetrics {
179 pub predictions_made: u64,
181 pub correct_predictions: u64,
183 pub accuracy: f64,
185 pub mean_absolute_error: f64,
187 pub root_mean_squared_error: f64,
189 pub r_squared: f64,
191 pub avg_prediction_time_ms: f64,
193}
194
195#[derive(Debug, Clone, Default)]
197pub struct AnomalyStats {
198 pub events_processed: u64,
200 pub anomalies_detected: u64,
202 pub false_positives: u64,
204 pub true_positives: u64,
206 pub avg_anomaly_score: f64,
208 pub detection_rate: f64,
210}
211
212pub struct OnlineLearningModel {
214 config: MLModelConfig,
216 weights: Arc<RwLock<Array1<f64>>>,
218 bias: Arc<RwLock<f64>>,
220 num_features: usize,
222 sample_buffer: SampleBuffer,
224 metrics: Arc<RwLock<ModelMetrics>>,
226 last_update: Arc<RwLock<Instant>>,
228}
229
230impl OnlineLearningModel {
231 pub fn new(config: MLModelConfig, num_features: usize) -> Self {
233 let mut rng_instance = rng();
235 let weights = Array1::from_vec(
236 (0..num_features)
237 .map(|_| {
238 rng_instance.random_range(-0.01..0.01)
240 })
241 .collect(),
242 );
243
244 Self {
245 config,
246 weights: Arc::new(RwLock::new(weights)),
247 bias: Arc::new(RwLock::new(0.0)),
248 num_features,
249 sample_buffer: Arc::new(RwLock::new(Vec::new())),
250 metrics: Arc::new(RwLock::new(ModelMetrics::default())),
251 last_update: Arc::new(RwLock::new(Instant::now())),
252 }
253 }
254
255 pub fn train(&self, features: &Array1<f64>, target: f64) -> Result<()> {
257 if features.len() != self.num_features {
258 return Err(anyhow!(
259 "Feature dimension mismatch: expected {}, got {}",
260 self.num_features,
261 features.len()
262 ));
263 }
264
265 self.sample_buffer.write().push((features.clone(), target));
267
268 let should_update = {
270 let buffer = self.sample_buffer.read();
271 let last_update = self.last_update.read();
272 buffer.len() >= self.config.batch_size
273 || last_update.elapsed() >= self.config.update_interval
274 };
275
276 if should_update {
277 self.update_weights()?;
278 }
279
280 Ok(())
281 }
282
283 fn update_weights(&self) -> Result<()> {
285 let samples = {
286 let mut buffer = self.sample_buffer.write();
287 std::mem::take(&mut *buffer)
288 };
289
290 if samples.is_empty() {
291 return Ok(());
292 }
293
294 let mut weights = self.weights.write();
295 let mut bias = self.bias.write();
296
297 for (features, target) in &samples {
299 let prediction = self.predict_internal(&weights, *bias, features);
300 let error = prediction - target;
301
302 for i in 0..self.num_features {
304 weights[i] -= self.config.learning_rate * error * features[i];
305 }
306
307 *bias -= self.config.learning_rate * error;
309 }
310
311 *self.last_update.write() = Instant::now();
312 debug!("Updated model weights with {} samples", samples.len());
313 Ok(())
314 }
315
316 pub fn predict(&self, features: &Array1<f64>) -> Result<PredictionResult> {
318 if features.len() != self.num_features {
319 return Err(anyhow!("Feature dimension mismatch"));
320 }
321
322 let start_time = Instant::now();
323 let weights = self.weights.read();
324 let bias = self.bias.read();
325
326 let prediction = self.predict_internal(&weights, *bias, features);
327
328 let mut metrics = self.metrics.write();
330 metrics.predictions_made += 1;
331 let prediction_time = start_time.elapsed().as_micros() as f64 / 1000.0;
332 metrics.avg_prediction_time_ms = (metrics.avg_prediction_time_ms + prediction_time) / 2.0;
333
334 Ok(PredictionResult {
335 prediction,
336 confidence: 0.8, interval: None,
338 timestamp: Utc::now(),
339 })
340 }
341
342 fn predict_internal(&self, weights: &Array1<f64>, bias: f64, features: &Array1<f64>) -> f64 {
344 let mut result = bias;
345 for i in 0..self.num_features {
346 result += weights[i] * features[i];
347 }
348 result
349 }
350
351 pub fn get_metrics(&self) -> ModelMetrics {
353 self.metrics.read().clone()
354 }
355}
356
357pub struct AnomalyDetector {
359 config: AnomalyDetectionConfig,
361 historical_mean: Arc<RwLock<f64>>,
363 historical_std: Arc<RwLock<f64>>,
364 recent_samples: Arc<RwLock<VecDeque<f64>>>,
366 threshold: Arc<RwLock<f64>>,
368 stats: Arc<RwLock<AnomalyStats>>,
370}
371
372impl AnomalyDetector {
373 pub fn new(config: AnomalyDetectionConfig) -> Self {
375 Self {
376 config: config.clone(),
377 historical_mean: Arc::new(RwLock::new(0.0)),
378 historical_std: Arc::new(RwLock::new(1.0)),
379 recent_samples: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
380 threshold: Arc::new(RwLock::new(3.0)), stats: Arc::new(RwLock::new(AnomalyStats::default())),
382 }
383 }
384
385 pub fn detect(&self, features: &FeatureVector) -> Result<AnomalyResult> {
387 let metric = features.features.iter().sum::<f64>() / features.features.len() as f64;
389
390 let mut samples = self.recent_samples.write();
392 samples.push_back(metric);
393 if samples.len() > self.config.window_size {
394 samples.pop_front();
395 }
396
397 let mut stats = self.stats.write();
398 stats.events_processed += 1;
399
400 if samples.len() < self.config.min_samples {
402 return Ok(AnomalyResult {
403 is_anomaly: false,
404 score: 0.0,
405 explanation: "Insufficient samples for detection".to_string(),
406 contributing_features: Vec::new(),
407 timestamp: Utc::now(),
408 event_id: features.source_event_id.clone(),
409 });
410 }
411
412 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
414 let variance =
415 samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
416 let std_dev = variance.sqrt();
417
418 {
420 let mut hist_mean = self.historical_mean.write();
421 let mut hist_std = self.historical_std.write();
422 let alpha = self.config.adaptive_learning_rate;
423 *hist_mean = alpha * mean + (1.0 - alpha) * *hist_mean;
424 *hist_std = alpha * std_dev + (1.0 - alpha) * *hist_std;
425 }
426
427 let (is_anomaly, score, explanation) = match &self.config.algorithm {
429 AnomalyDetectionAlgorithm::Statistical { threshold } => {
430 let z_score = if std_dev > 1e-10 {
431 (metric - mean).abs() / std_dev
432 } else {
433 0.0
434 };
435
436 let is_anomaly = z_score > *threshold;
437 let score = (z_score / threshold).min(1.0);
438
439 (
440 is_anomaly,
441 score,
442 format!("Z-score: {:.2}, threshold: {:.2}", z_score, threshold),
443 )
444 }
445 AnomalyDetectionAlgorithm::IsolationForest { contamination } => {
446 let z_score = if std_dev > 1e-10 {
448 (metric - mean).abs() / std_dev
449 } else {
450 0.0
451 };
452
453 let threshold = 3.0 / contamination;
454 let is_anomaly = z_score > threshold;
455 let score = (z_score / threshold).min(1.0);
456
457 (is_anomaly, score, format!("Isolation score: {:.2}", score))
458 }
459 _ => {
460 let z_score = if std_dev > 1e-10 {
462 (metric - mean).abs() / std_dev
463 } else {
464 0.0
465 };
466
467 let is_anomaly = z_score > 3.0;
468 let score = (z_score / 3.0).min(1.0);
469
470 (is_anomaly, score, format!("Z-score: {:.2}", z_score))
471 }
472 };
473
474 if is_anomaly {
475 stats.anomalies_detected += 1;
476 stats.true_positives += 1;
477 }
478
479 stats.avg_anomaly_score = (stats.avg_anomaly_score + score) / 2.0;
480 stats.detection_rate = stats.anomalies_detected as f64 / stats.events_processed as f64;
481
482 Ok(AnomalyResult {
483 is_anomaly,
484 score,
485 explanation,
486 contributing_features: features.feature_names.clone(),
487 timestamp: Utc::now(),
488 event_id: features.source_event_id.clone(),
489 })
490 }
491
492 pub fn feedback(&self, event_id: &str, is_true_anomaly: bool) {
494 debug!(
495 "Received feedback for event {}: is_anomaly={}",
496 event_id, is_true_anomaly
497 );
498
499 if self.config.enable_feedback {
500 let mut threshold = self.threshold.write();
503 if is_true_anomaly {
504 *threshold *= 0.98; } else {
506 *threshold *= 1.02; }
508 }
509 }
510
511 pub fn get_stats(&self) -> AnomalyStats {
513 self.stats.read().clone()
514 }
515}
516
517pub struct FeatureExtractor {
519 config: FeatureConfig,
521 event_history: Arc<RwLock<VecDeque<StreamEvent>>>,
523}
524
525impl FeatureExtractor {
526 pub fn new(config: FeatureConfig) -> Self {
528 Self {
529 config: config.clone(),
530 event_history: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
531 }
532 }
533
534 pub fn extract_features(&self, event: &StreamEvent) -> Result<FeatureVector> {
536 let mut features = Vec::new();
537 let mut feature_names = Vec::new();
538
539 let mut history = self.event_history.write();
541 history.push_back(event.clone());
542 if history.len() > self.config.window_size {
543 history.pop_front();
544 }
545
546 features.push(history.len() as f64);
548 feature_names.push("window_size".to_string());
549
550 if self.config.enable_statistical {
552 features.push(history.len() as f64);
554 feature_names.push("event_count".to_string());
555
556 if history.len() >= 2 {
558 let rate = history.len() as f64 / self.config.window_size as f64;
559 features.push(rate);
560 feature_names.push("event_rate".to_string());
561 }
562 }
563
564 if self.config.enable_frequency {
566 let mut type_counts: HashMap<String, usize> = HashMap::new();
568 for evt in history.iter() {
569 let event_type = self.get_event_type(evt);
570 *type_counts.entry(event_type).or_insert(0) += 1;
571 }
572
573 let unique_types = type_counts.len() as f64;
574 features.push(unique_types);
575 feature_names.push("unique_event_types".to_string());
576 }
577
578 Ok(FeatureVector {
579 features: Array1::from_vec(features),
580 feature_names,
581 timestamp: Utc::now(),
582 source_event_id: self.get_event_id(event),
583 })
584 }
585
586 fn get_event_type(&self, event: &StreamEvent) -> String {
588 match event {
589 StreamEvent::TripleAdded { .. } => "TripleAdded",
590 StreamEvent::TripleRemoved { .. } => "TripleRemoved",
591 StreamEvent::QuadAdded { .. } => "QuadAdded",
592 StreamEvent::QuadRemoved { .. } => "QuadRemoved",
593 StreamEvent::GraphCreated { .. } => "GraphCreated",
594 StreamEvent::GraphCleared { .. } => "GraphCleared",
595 StreamEvent::GraphDeleted { .. } => "GraphDeleted",
596 StreamEvent::SparqlUpdate { .. } => "SparqlUpdate",
597 StreamEvent::TransactionBegin { .. } => "TransactionBegin",
598 StreamEvent::TransactionCommit { .. } => "TransactionCommit",
599 StreamEvent::TransactionAbort { .. } => "TransactionAbort",
600 StreamEvent::SchemaChanged { .. } => "SchemaChanged",
601 _ => "Other",
602 }
603 .to_string()
604 }
605
606 fn get_event_id(&self, event: &StreamEvent) -> String {
608 let metadata = match event {
609 StreamEvent::TripleAdded { metadata, .. }
610 | StreamEvent::TripleRemoved { metadata, .. }
611 | StreamEvent::QuadAdded { metadata, .. }
612 | StreamEvent::QuadRemoved { metadata, .. }
613 | StreamEvent::GraphCreated { metadata, .. }
614 | StreamEvent::GraphCleared { metadata, .. }
615 | StreamEvent::GraphDeleted { metadata, .. }
616 | StreamEvent::SparqlUpdate { metadata, .. }
617 | StreamEvent::TransactionBegin { metadata, .. }
618 | StreamEvent::TransactionCommit { metadata, .. }
619 | StreamEvent::TransactionAbort { metadata, .. }
620 | StreamEvent::SchemaChanged { metadata, .. }
621 | StreamEvent::Heartbeat { metadata, .. }
622 | StreamEvent::QueryResultAdded { metadata, .. }
623 | StreamEvent::QueryResultRemoved { metadata, .. }
624 | StreamEvent::QueryCompleted { metadata, .. }
625 | StreamEvent::GraphMetadataUpdated { metadata, .. }
626 | StreamEvent::GraphPermissionsChanged { metadata, .. }
627 | StreamEvent::GraphStatisticsUpdated { metadata, .. }
628 | StreamEvent::GraphRenamed { metadata, .. }
629 | StreamEvent::GraphMerged { metadata, .. }
630 | StreamEvent::GraphSplit { metadata, .. }
631 | StreamEvent::SchemaDefinitionAdded { metadata, .. }
632 | StreamEvent::SchemaDefinitionRemoved { metadata, .. }
633 | StreamEvent::SchemaDefinitionModified { metadata, .. }
634 | StreamEvent::OntologyImported { metadata, .. }
635 | StreamEvent::OntologyRemoved { metadata, .. }
636 | StreamEvent::ConstraintAdded { metadata, .. }
637 | StreamEvent::ConstraintRemoved { metadata, .. }
638 | StreamEvent::ConstraintViolated { metadata, .. }
639 | StreamEvent::IndexCreated { metadata, .. }
640 | StreamEvent::IndexDropped { metadata, .. }
641 | StreamEvent::IndexRebuilt { metadata, .. }
642 | StreamEvent::SchemaUpdated { metadata, .. }
643 | StreamEvent::ShapeAdded { metadata, .. }
644 | StreamEvent::ShapeUpdated { metadata, .. }
645 | StreamEvent::ShapeRemoved { metadata, .. }
646 | StreamEvent::ShapeModified { metadata, .. }
647 | StreamEvent::ShapeValidationStarted { metadata, .. }
648 | StreamEvent::ShapeValidationCompleted { metadata, .. }
649 | StreamEvent::ShapeViolationDetected { metadata, .. }
650 | StreamEvent::ErrorOccurred { metadata, .. } => metadata,
651 };
652 metadata.event_id.clone()
653 }
654}
655
656pub struct MLIntegrationManager {
658 models: Arc<DashMap<String, OnlineLearningModel>>,
660 detectors: Arc<DashMap<String, AnomalyDetector>>,
662 extractors: Arc<DashMap<String, FeatureExtractor>>,
664}
665
666impl MLIntegrationManager {
667 pub fn new() -> Self {
669 Self {
670 models: Arc::new(DashMap::new()),
671 detectors: Arc::new(DashMap::new()),
672 extractors: Arc::new(DashMap::new()),
673 }
674 }
675
676 pub fn register_model(&self, name: String, model: OnlineLearningModel) {
678 self.models.insert(name.clone(), model);
679 info!("Registered ML model: {}", name);
680 }
681
682 pub fn register_detector(&self, name: String, detector: AnomalyDetector) {
684 self.detectors.insert(name.clone(), detector);
685 info!("Registered anomaly detector: {}", name);
686 }
687
688 pub fn register_extractor(&self, name: String, extractor: FeatureExtractor) {
690 self.extractors.insert(name.clone(), extractor);
691 info!("Registered feature extractor: {}", name);
692 }
693
694 pub fn get_model(
696 &self,
697 name: &str,
698 ) -> Option<dashmap::mapref::one::Ref<'_, String, OnlineLearningModel>> {
699 self.models.get(name)
700 }
701
702 pub fn get_detector(
704 &self,
705 name: &str,
706 ) -> Option<dashmap::mapref::one::Ref<'_, String, AnomalyDetector>> {
707 self.detectors.get(name)
708 }
709
710 pub fn get_extractor(
712 &self,
713 name: &str,
714 ) -> Option<dashmap::mapref::one::Ref<'_, String, FeatureExtractor>> {
715 self.extractors.get(name)
716 }
717}
718
719impl Default for MLIntegrationManager {
720 fn default() -> Self {
721 Self::new()
722 }
723}
724
725#[cfg(test)]
726mod tests {
727 use super::*;
728 use crate::event::EventMetadata;
729
730 #[test]
731 fn test_online_learning() {
732 let config = MLModelConfig {
733 model_type: ModelType::LinearRegression,
734 feature_config: FeatureConfig {
735 window_size: 10,
736 enable_statistical: true,
737 enable_frequency: false,
738 custom_features: Vec::new(),
739 },
740 learning_rate: 0.01,
741 batch_size: 10,
742 update_interval: Duration::from_secs(1),
743 enable_persistence: false,
744 version: "1.0".to_string(),
745 };
746
747 let model = OnlineLearningModel::new(config, 3);
748
749 let features = Array1::from_vec(vec![1.0, 2.0, 3.0]);
751 model.train(&features, 10.0).unwrap();
752
753 let result = model.predict(&features).unwrap();
755 assert!(result.prediction.is_finite());
756 }
757
758 #[test]
759 fn test_anomaly_detection() {
760 let config = AnomalyDetectionConfig {
761 algorithm: AnomalyDetectionAlgorithm::Statistical { threshold: 3.0 },
762 sensitivity: 0.8,
763 adaptive_learning_rate: 0.1,
764 window_size: 100,
765 min_samples: 10,
766 enable_feedback: true,
767 };
768
769 let detector = AnomalyDetector::new(config);
770
771 for i in 0..20 {
773 let features = FeatureVector {
774 features: Array1::from_vec(vec![100.0 + i as f64]),
775 feature_names: vec!["value".to_string()],
776 timestamp: Utc::now(),
777 source_event_id: format!("event-{}", i),
778 };
779
780 let result = detector.detect(&features).unwrap();
781 if i >= 10 {
782 assert!(!result.is_anomaly);
784 }
785 }
786
787 let anomalous_features = FeatureVector {
789 features: Array1::from_vec(vec![1000.0]),
790 feature_names: vec!["value".to_string()],
791 timestamp: Utc::now(),
792 source_event_id: "anomaly".to_string(),
793 };
794
795 let result = detector.detect(&anomalous_features).unwrap();
796 assert!(result.is_anomaly);
797 assert!(result.score > 0.0);
798 }
799
800 #[test]
801 fn test_feature_extraction() {
802 let config = FeatureConfig {
803 window_size: 10,
804 enable_statistical: true,
805 enable_frequency: true,
806 custom_features: Vec::new(),
807 };
808
809 let extractor = FeatureExtractor::new(config);
810
811 let event = StreamEvent::SchemaChanged {
812 schema_type: crate::event::SchemaType::Ontology,
813 change_type: crate::event::SchemaChangeType::Added,
814 details: "test schema change".to_string(),
815 metadata: EventMetadata {
816 event_id: "test-1".to_string(),
817 timestamp: Utc::now(),
818 source: "test".to_string(),
819 user: None,
820 context: None,
821 caused_by: None,
822 version: "1.0".to_string(),
823 properties: HashMap::new(),
824 checksum: None,
825 },
826 };
827
828 let features = extractor.extract_features(&event).unwrap();
829 assert!(!features.features.is_empty());
830 assert_eq!(features.features.len(), features.feature_names.len());
831 }
832}