1use anyhow::{anyhow, Result};
24use chrono::{DateTime, Utc};
25use dashmap::DashMap;
26use parking_lot::RwLock;
27use serde::{Deserialize, Serialize};
28use std::collections::{HashMap, VecDeque};
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31use tracing::{debug, info};
32
33use scirs2_core::ndarray_ext::Array1;
35use scirs2_core::random::{rng, RngExt};
36
37use crate::event::StreamEvent;
38
39type SampleBuffer = Arc<RwLock<Vec<(Array1<f64>, f64)>>>;
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum ModelType {
45 LinearRegression,
47 LogisticRegression,
49 KMeans { k: usize },
51 EWMA { alpha: f64 },
53 IsolationForest { n_trees: usize },
55 LSTM {
57 hidden_size: usize,
58 num_layers: usize,
59 },
60 Custom { name: String },
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum AnomalyDetectionAlgorithm {
67 Statistical { threshold: f64 },
69 IsolationForest { contamination: f64 },
71 OneClassSVM { nu: f64 },
73 Autoencoder { encoding_dim: usize, threshold: f64 },
75 LSTM { window_size: usize },
77 Ensemble {
79 algorithms: Vec<AnomalyDetectionAlgorithm>,
80 },
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct FeatureConfig {
86 pub window_size: usize,
88 pub enable_statistical: bool,
90 pub enable_frequency: bool,
92 pub custom_features: Vec<String>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct MLModelConfig {
99 pub model_type: ModelType,
101 pub feature_config: FeatureConfig,
103 pub learning_rate: f64,
105 pub batch_size: usize,
107 pub update_interval: Duration,
109 pub enable_persistence: bool,
111 pub version: String,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct AnomalyDetectionConfig {
118 pub algorithm: AnomalyDetectionAlgorithm,
120 pub sensitivity: f64,
122 pub adaptive_learning_rate: f64,
124 pub window_size: usize,
126 pub min_samples: usize,
128 pub enable_feedback: bool,
130}
131
132#[derive(Debug, Clone)]
134pub struct FeatureVector {
135 pub features: Array1<f64>,
137 pub feature_names: Vec<String>,
139 pub timestamp: DateTime<Utc>,
141 pub source_event_id: String,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct AnomalyResult {
148 pub is_anomaly: bool,
150 pub score: f64,
152 pub explanation: String,
154 pub contributing_features: Vec<String>,
156 pub timestamp: DateTime<Utc>,
158 pub event_id: String,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct PredictionResult {
165 pub prediction: f64,
167 pub confidence: f64,
169 pub interval: Option<(f64, f64)>,
171 pub timestamp: DateTime<Utc>,
173}
174
175#[derive(Debug, Clone, Default, Serialize, Deserialize)]
177pub struct ModelMetrics {
178 pub predictions_made: u64,
180 pub correct_predictions: u64,
182 pub accuracy: f64,
184 pub mean_absolute_error: f64,
186 pub root_mean_squared_error: f64,
188 pub r_squared: f64,
190 pub avg_prediction_time_ms: f64,
192}
193
194#[derive(Debug, Clone, Default)]
196pub struct AnomalyStats {
197 pub events_processed: u64,
199 pub anomalies_detected: u64,
201 pub false_positives: u64,
203 pub true_positives: u64,
205 pub avg_anomaly_score: f64,
207 pub detection_rate: f64,
209}
210
211pub struct OnlineLearningModel {
213 config: MLModelConfig,
215 weights: Arc<RwLock<Array1<f64>>>,
217 bias: Arc<RwLock<f64>>,
219 num_features: usize,
221 sample_buffer: SampleBuffer,
223 metrics: Arc<RwLock<ModelMetrics>>,
225 last_update: Arc<RwLock<Instant>>,
227}
228
229impl OnlineLearningModel {
230 pub fn new(config: MLModelConfig, num_features: usize) -> Self {
232 let mut rng_instance = rng();
234 let weights = Array1::from_vec(
235 (0..num_features)
236 .map(|_| {
237 rng_instance.random_range(-0.01..0.01)
239 })
240 .collect(),
241 );
242
243 Self {
244 config,
245 weights: Arc::new(RwLock::new(weights)),
246 bias: Arc::new(RwLock::new(0.0)),
247 num_features,
248 sample_buffer: Arc::new(RwLock::new(Vec::new())),
249 metrics: Arc::new(RwLock::new(ModelMetrics::default())),
250 last_update: Arc::new(RwLock::new(Instant::now())),
251 }
252 }
253
254 pub fn train(&self, features: &Array1<f64>, target: f64) -> Result<()> {
256 if features.len() != self.num_features {
257 return Err(anyhow!(
258 "Feature dimension mismatch: expected {}, got {}",
259 self.num_features,
260 features.len()
261 ));
262 }
263
264 self.sample_buffer.write().push((features.clone(), target));
266
267 let should_update = {
269 let buffer = self.sample_buffer.read();
270 let last_update = self.last_update.read();
271 buffer.len() >= self.config.batch_size
272 || last_update.elapsed() >= self.config.update_interval
273 };
274
275 if should_update {
276 self.update_weights()?;
277 }
278
279 Ok(())
280 }
281
282 fn update_weights(&self) -> Result<()> {
284 let samples = {
285 let mut buffer = self.sample_buffer.write();
286 std::mem::take(&mut *buffer)
287 };
288
289 if samples.is_empty() {
290 return Ok(());
291 }
292
293 let mut weights = self.weights.write();
294 let mut bias = self.bias.write();
295
296 for (features, target) in &samples {
298 let prediction = self.predict_internal(&weights, *bias, features);
299 let error = prediction - target;
300
301 for i in 0..self.num_features {
303 weights[i] -= self.config.learning_rate * error * features[i];
304 }
305
306 *bias -= self.config.learning_rate * error;
308 }
309
310 *self.last_update.write() = Instant::now();
311 debug!("Updated model weights with {} samples", samples.len());
312 Ok(())
313 }
314
315 pub fn predict(&self, features: &Array1<f64>) -> Result<PredictionResult> {
317 if features.len() != self.num_features {
318 return Err(anyhow!("Feature dimension mismatch"));
319 }
320
321 let start_time = Instant::now();
322 let weights = self.weights.read();
323 let bias = self.bias.read();
324
325 let prediction = self.predict_internal(&weights, *bias, features);
326
327 let mut metrics = self.metrics.write();
329 metrics.predictions_made += 1;
330 let prediction_time = start_time.elapsed().as_micros() as f64 / 1000.0;
331 metrics.avg_prediction_time_ms = (metrics.avg_prediction_time_ms + prediction_time) / 2.0;
332
333 Ok(PredictionResult {
334 prediction,
335 confidence: 0.8, interval: None,
337 timestamp: Utc::now(),
338 })
339 }
340
341 fn predict_internal(&self, weights: &Array1<f64>, bias: f64, features: &Array1<f64>) -> f64 {
343 let mut result = bias;
344 for i in 0..self.num_features {
345 result += weights[i] * features[i];
346 }
347 result
348 }
349
350 pub fn get_metrics(&self) -> ModelMetrics {
352 self.metrics.read().clone()
353 }
354}
355
356pub struct AnomalyDetector {
358 config: AnomalyDetectionConfig,
360 historical_mean: Arc<RwLock<f64>>,
362 historical_std: Arc<RwLock<f64>>,
363 recent_samples: Arc<RwLock<VecDeque<f64>>>,
365 threshold: Arc<RwLock<f64>>,
367 stats: Arc<RwLock<AnomalyStats>>,
369}
370
371impl AnomalyDetector {
372 pub fn new(config: AnomalyDetectionConfig) -> Self {
374 Self {
375 config: config.clone(),
376 historical_mean: Arc::new(RwLock::new(0.0)),
377 historical_std: Arc::new(RwLock::new(1.0)),
378 recent_samples: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
379 threshold: Arc::new(RwLock::new(3.0)), stats: Arc::new(RwLock::new(AnomalyStats::default())),
381 }
382 }
383
384 pub fn detect(&self, features: &FeatureVector) -> Result<AnomalyResult> {
386 let metric = features.features.iter().sum::<f64>() / features.features.len() as f64;
388
389 let mut samples = self.recent_samples.write();
391 samples.push_back(metric);
392 if samples.len() > self.config.window_size {
393 samples.pop_front();
394 }
395
396 let mut stats = self.stats.write();
397 stats.events_processed += 1;
398
399 if samples.len() < self.config.min_samples {
401 return Ok(AnomalyResult {
402 is_anomaly: false,
403 score: 0.0,
404 explanation: "Insufficient samples for detection".to_string(),
405 contributing_features: Vec::new(),
406 timestamp: Utc::now(),
407 event_id: features.source_event_id.clone(),
408 });
409 }
410
411 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
413 let variance =
414 samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
415 let std_dev = variance.sqrt();
416
417 {
419 let mut hist_mean = self.historical_mean.write();
420 let mut hist_std = self.historical_std.write();
421 let alpha = self.config.adaptive_learning_rate;
422 *hist_mean = alpha * mean + (1.0 - alpha) * *hist_mean;
423 *hist_std = alpha * std_dev + (1.0 - alpha) * *hist_std;
424 }
425
426 let (is_anomaly, score, explanation) = match &self.config.algorithm {
428 AnomalyDetectionAlgorithm::Statistical { threshold } => {
429 let z_score = if std_dev > 1e-10 {
430 (metric - mean).abs() / std_dev
431 } else {
432 0.0
433 };
434
435 let is_anomaly = z_score > *threshold;
436 let score = (z_score / threshold).min(1.0);
437
438 (
439 is_anomaly,
440 score,
441 format!("Z-score: {:.2}, threshold: {:.2}", z_score, threshold),
442 )
443 }
444 AnomalyDetectionAlgorithm::IsolationForest { contamination } => {
445 let z_score = if std_dev > 1e-10 {
447 (metric - mean).abs() / std_dev
448 } else {
449 0.0
450 };
451
452 let threshold = 3.0 / contamination;
453 let is_anomaly = z_score > threshold;
454 let score = (z_score / threshold).min(1.0);
455
456 (is_anomaly, score, format!("Isolation score: {:.2}", score))
457 }
458 _ => {
459 let z_score = if std_dev > 1e-10 {
461 (metric - mean).abs() / std_dev
462 } else {
463 0.0
464 };
465
466 let is_anomaly = z_score > 3.0;
467 let score = (z_score / 3.0).min(1.0);
468
469 (is_anomaly, score, format!("Z-score: {:.2}", z_score))
470 }
471 };
472
473 if is_anomaly {
474 stats.anomalies_detected += 1;
475 stats.true_positives += 1;
476 }
477
478 stats.avg_anomaly_score = (stats.avg_anomaly_score + score) / 2.0;
479 stats.detection_rate = stats.anomalies_detected as f64 / stats.events_processed as f64;
480
481 Ok(AnomalyResult {
482 is_anomaly,
483 score,
484 explanation,
485 contributing_features: features.feature_names.clone(),
486 timestamp: Utc::now(),
487 event_id: features.source_event_id.clone(),
488 })
489 }
490
491 pub fn feedback(&self, event_id: &str, is_true_anomaly: bool) {
493 debug!(
494 "Received feedback for event {}: is_anomaly={}",
495 event_id, is_true_anomaly
496 );
497
498 if self.config.enable_feedback {
499 let mut threshold = self.threshold.write();
502 if is_true_anomaly {
503 *threshold *= 0.98; } else {
505 *threshold *= 1.02; }
507 }
508 }
509
510 pub fn get_stats(&self) -> AnomalyStats {
512 self.stats.read().clone()
513 }
514}
515
516pub struct FeatureExtractor {
518 config: FeatureConfig,
520 event_history: Arc<RwLock<VecDeque<StreamEvent>>>,
522}
523
524impl FeatureExtractor {
525 pub fn new(config: FeatureConfig) -> Self {
527 Self {
528 config: config.clone(),
529 event_history: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
530 }
531 }
532
533 pub fn extract_features(&self, event: &StreamEvent) -> Result<FeatureVector> {
535 let mut features = Vec::new();
536 let mut feature_names = Vec::new();
537
538 let mut history = self.event_history.write();
540 history.push_back(event.clone());
541 if history.len() > self.config.window_size {
542 history.pop_front();
543 }
544
545 features.push(history.len() as f64);
547 feature_names.push("window_size".to_string());
548
549 if self.config.enable_statistical {
551 features.push(history.len() as f64);
553 feature_names.push("event_count".to_string());
554
555 if history.len() >= 2 {
557 let rate = history.len() as f64 / self.config.window_size as f64;
558 features.push(rate);
559 feature_names.push("event_rate".to_string());
560 }
561 }
562
563 if self.config.enable_frequency {
565 let mut type_counts: HashMap<String, usize> = HashMap::new();
567 for evt in history.iter() {
568 let event_type = self.get_event_type(evt);
569 *type_counts.entry(event_type).or_insert(0) += 1;
570 }
571
572 let unique_types = type_counts.len() as f64;
573 features.push(unique_types);
574 feature_names.push("unique_event_types".to_string());
575 }
576
577 Ok(FeatureVector {
578 features: Array1::from_vec(features),
579 feature_names,
580 timestamp: Utc::now(),
581 source_event_id: self.get_event_id(event),
582 })
583 }
584
585 fn get_event_type(&self, event: &StreamEvent) -> String {
587 match event {
588 StreamEvent::TripleAdded { .. } => "TripleAdded",
589 StreamEvent::TripleRemoved { .. } => "TripleRemoved",
590 StreamEvent::QuadAdded { .. } => "QuadAdded",
591 StreamEvent::QuadRemoved { .. } => "QuadRemoved",
592 StreamEvent::GraphCreated { .. } => "GraphCreated",
593 StreamEvent::GraphCleared { .. } => "GraphCleared",
594 StreamEvent::GraphDeleted { .. } => "GraphDeleted",
595 StreamEvent::SparqlUpdate { .. } => "SparqlUpdate",
596 StreamEvent::TransactionBegin { .. } => "TransactionBegin",
597 StreamEvent::TransactionCommit { .. } => "TransactionCommit",
598 StreamEvent::TransactionAbort { .. } => "TransactionAbort",
599 StreamEvent::SchemaChanged { .. } => "SchemaChanged",
600 _ => "Other",
601 }
602 .to_string()
603 }
604
605 fn get_event_id(&self, event: &StreamEvent) -> String {
607 let metadata = match event {
608 StreamEvent::TripleAdded { metadata, .. }
609 | StreamEvent::TripleRemoved { metadata, .. }
610 | StreamEvent::QuadAdded { metadata, .. }
611 | StreamEvent::QuadRemoved { metadata, .. }
612 | StreamEvent::GraphCreated { metadata, .. }
613 | StreamEvent::GraphCleared { metadata, .. }
614 | StreamEvent::GraphDeleted { metadata, .. }
615 | StreamEvent::SparqlUpdate { metadata, .. }
616 | StreamEvent::TransactionBegin { metadata, .. }
617 | StreamEvent::TransactionCommit { metadata, .. }
618 | StreamEvent::TransactionAbort { metadata, .. }
619 | StreamEvent::SchemaChanged { metadata, .. }
620 | StreamEvent::Heartbeat { metadata, .. }
621 | StreamEvent::QueryResultAdded { metadata, .. }
622 | StreamEvent::QueryResultRemoved { metadata, .. }
623 | StreamEvent::QueryCompleted { metadata, .. }
624 | StreamEvent::GraphMetadataUpdated { metadata, .. }
625 | StreamEvent::GraphPermissionsChanged { metadata, .. }
626 | StreamEvent::GraphStatisticsUpdated { metadata, .. }
627 | StreamEvent::GraphRenamed { metadata, .. }
628 | StreamEvent::GraphMerged { metadata, .. }
629 | StreamEvent::GraphSplit { metadata, .. }
630 | StreamEvent::SchemaDefinitionAdded { metadata, .. }
631 | StreamEvent::SchemaDefinitionRemoved { metadata, .. }
632 | StreamEvent::SchemaDefinitionModified { metadata, .. }
633 | StreamEvent::OntologyImported { metadata, .. }
634 | StreamEvent::OntologyRemoved { metadata, .. }
635 | StreamEvent::ConstraintAdded { metadata, .. }
636 | StreamEvent::ConstraintRemoved { metadata, .. }
637 | StreamEvent::ConstraintViolated { metadata, .. }
638 | StreamEvent::IndexCreated { metadata, .. }
639 | StreamEvent::IndexDropped { metadata, .. }
640 | StreamEvent::IndexRebuilt { metadata, .. }
641 | StreamEvent::SchemaUpdated { metadata, .. }
642 | StreamEvent::ShapeAdded { metadata, .. }
643 | StreamEvent::ShapeUpdated { metadata, .. }
644 | StreamEvent::ShapeRemoved { metadata, .. }
645 | StreamEvent::ShapeModified { metadata, .. }
646 | StreamEvent::ShapeValidationStarted { metadata, .. }
647 | StreamEvent::ShapeValidationCompleted { metadata, .. }
648 | StreamEvent::ShapeViolationDetected { metadata, .. }
649 | StreamEvent::ErrorOccurred { metadata, .. } => metadata,
650 };
651 metadata.event_id.clone()
652 }
653}
654
655pub struct MLIntegrationManager {
657 models: Arc<DashMap<String, OnlineLearningModel>>,
659 detectors: Arc<DashMap<String, AnomalyDetector>>,
661 extractors: Arc<DashMap<String, FeatureExtractor>>,
663}
664
665impl MLIntegrationManager {
666 pub fn new() -> Self {
668 Self {
669 models: Arc::new(DashMap::new()),
670 detectors: Arc::new(DashMap::new()),
671 extractors: Arc::new(DashMap::new()),
672 }
673 }
674
675 pub fn register_model(&self, name: String, model: OnlineLearningModel) {
677 self.models.insert(name.clone(), model);
678 info!("Registered ML model: {}", name);
679 }
680
681 pub fn register_detector(&self, name: String, detector: AnomalyDetector) {
683 self.detectors.insert(name.clone(), detector);
684 info!("Registered anomaly detector: {}", name);
685 }
686
687 pub fn register_extractor(&self, name: String, extractor: FeatureExtractor) {
689 self.extractors.insert(name.clone(), extractor);
690 info!("Registered feature extractor: {}", name);
691 }
692
693 pub fn get_model(
695 &self,
696 name: &str,
697 ) -> Option<dashmap::mapref::one::Ref<'_, String, OnlineLearningModel>> {
698 self.models.get(name)
699 }
700
701 pub fn get_detector(
703 &self,
704 name: &str,
705 ) -> Option<dashmap::mapref::one::Ref<'_, String, AnomalyDetector>> {
706 self.detectors.get(name)
707 }
708
709 pub fn get_extractor(
711 &self,
712 name: &str,
713 ) -> Option<dashmap::mapref::one::Ref<'_, String, FeatureExtractor>> {
714 self.extractors.get(name)
715 }
716}
717
718impl Default for MLIntegrationManager {
719 fn default() -> Self {
720 Self::new()
721 }
722}
723
724#[cfg(test)]
725mod tests {
726 use super::*;
727 use crate::event::EventMetadata;
728
729 #[test]
730 fn test_online_learning() {
731 let config = MLModelConfig {
732 model_type: ModelType::LinearRegression,
733 feature_config: FeatureConfig {
734 window_size: 10,
735 enable_statistical: true,
736 enable_frequency: false,
737 custom_features: Vec::new(),
738 },
739 learning_rate: 0.01,
740 batch_size: 10,
741 update_interval: Duration::from_secs(1),
742 enable_persistence: false,
743 version: "1.0".to_string(),
744 };
745
746 let model = OnlineLearningModel::new(config, 3);
747
748 let features = Array1::from_vec(vec![1.0, 2.0, 3.0]);
750 model.train(&features, 10.0).unwrap();
751
752 let result = model.predict(&features).unwrap();
754 assert!(result.prediction.is_finite());
755 }
756
757 #[test]
758 fn test_anomaly_detection() {
759 let config = AnomalyDetectionConfig {
760 algorithm: AnomalyDetectionAlgorithm::Statistical { threshold: 3.0 },
761 sensitivity: 0.8,
762 adaptive_learning_rate: 0.1,
763 window_size: 100,
764 min_samples: 10,
765 enable_feedback: true,
766 };
767
768 let detector = AnomalyDetector::new(config);
769
770 for i in 0..20 {
772 let features = FeatureVector {
773 features: Array1::from_vec(vec![100.0 + i as f64]),
774 feature_names: vec!["value".to_string()],
775 timestamp: Utc::now(),
776 source_event_id: format!("event-{}", i),
777 };
778
779 let result = detector.detect(&features).unwrap();
780 if i >= 10 {
781 assert!(!result.is_anomaly);
783 }
784 }
785
786 let anomalous_features = FeatureVector {
788 features: Array1::from_vec(vec![1000.0]),
789 feature_names: vec!["value".to_string()],
790 timestamp: Utc::now(),
791 source_event_id: "anomaly".to_string(),
792 };
793
794 let result = detector.detect(&anomalous_features).unwrap();
795 assert!(result.is_anomaly);
796 assert!(result.score > 0.0);
797 }
798
799 #[test]
800 fn test_feature_extraction() {
801 let config = FeatureConfig {
802 window_size: 10,
803 enable_statistical: true,
804 enable_frequency: true,
805 custom_features: Vec::new(),
806 };
807
808 let extractor = FeatureExtractor::new(config);
809
810 let event = StreamEvent::SchemaChanged {
811 schema_type: crate::event::SchemaType::Ontology,
812 change_type: crate::event::SchemaChangeType::Added,
813 details: "test schema change".to_string(),
814 metadata: EventMetadata {
815 event_id: "test-1".to_string(),
816 timestamp: Utc::now(),
817 source: "test".to_string(),
818 user: None,
819 context: None,
820 caused_by: None,
821 version: "1.0".to_string(),
822 properties: HashMap::new(),
823 checksum: None,
824 },
825 };
826
827 let features = extractor.extract_features(&event).unwrap();
828 assert!(!features.features.is_empty());
829 assert_eq!(features.features.len(), features.feature_names.len());
830 }
831}