use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{debug, info};
use scirs2_core::ndarray_ext::Array1;
use scirs2_core::random::{rng, RngExt};
use crate::event::StreamEvent;
type SampleBuffer = Arc<RwLock<Vec<(Array1<f64>, f64)>>>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelType {
LinearRegression,
LogisticRegression,
KMeans { k: usize },
EWMA { alpha: f64 },
IsolationForest { n_trees: usize },
LSTM {
hidden_size: usize,
num_layers: usize,
},
Custom { name: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AnomalyDetectionAlgorithm {
Statistical { threshold: f64 },
IsolationForest { contamination: f64 },
OneClassSVM { nu: f64 },
Autoencoder { encoding_dim: usize, threshold: f64 },
LSTM { window_size: usize },
Ensemble {
algorithms: Vec<AnomalyDetectionAlgorithm>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeatureConfig {
pub window_size: usize,
pub enable_statistical: bool,
pub enable_frequency: bool,
pub custom_features: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MLModelConfig {
pub model_type: ModelType,
pub feature_config: FeatureConfig,
pub learning_rate: f64,
pub batch_size: usize,
pub update_interval: Duration,
pub enable_persistence: bool,
pub version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnomalyDetectionConfig {
pub algorithm: AnomalyDetectionAlgorithm,
pub sensitivity: f64,
pub adaptive_learning_rate: f64,
pub window_size: usize,
pub min_samples: usize,
pub enable_feedback: bool,
}
#[derive(Debug, Clone)]
pub struct FeatureVector {
pub features: Array1<f64>,
pub feature_names: Vec<String>,
pub timestamp: DateTime<Utc>,
pub source_event_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnomalyResult {
pub is_anomaly: bool,
pub score: f64,
pub explanation: String,
pub contributing_features: Vec<String>,
pub timestamp: DateTime<Utc>,
pub event_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictionResult {
pub prediction: f64,
pub confidence: f64,
pub interval: Option<(f64, f64)>,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelMetrics {
pub predictions_made: u64,
pub correct_predictions: u64,
pub accuracy: f64,
pub mean_absolute_error: f64,
pub root_mean_squared_error: f64,
pub r_squared: f64,
pub avg_prediction_time_ms: f64,
}
#[derive(Debug, Clone, Default)]
pub struct AnomalyStats {
pub events_processed: u64,
pub anomalies_detected: u64,
pub false_positives: u64,
pub true_positives: u64,
pub avg_anomaly_score: f64,
pub detection_rate: f64,
}
pub struct OnlineLearningModel {
config: MLModelConfig,
weights: Arc<RwLock<Array1<f64>>>,
bias: Arc<RwLock<f64>>,
num_features: usize,
sample_buffer: SampleBuffer,
metrics: Arc<RwLock<ModelMetrics>>,
last_update: Arc<RwLock<Instant>>,
}
impl OnlineLearningModel {
pub fn new(config: MLModelConfig, num_features: usize) -> Self {
let mut rng_instance = rng();
let weights = Array1::from_vec(
(0..num_features)
.map(|_| {
rng_instance.random_range(-0.01..0.01)
})
.collect(),
);
Self {
config,
weights: Arc::new(RwLock::new(weights)),
bias: Arc::new(RwLock::new(0.0)),
num_features,
sample_buffer: Arc::new(RwLock::new(Vec::new())),
metrics: Arc::new(RwLock::new(ModelMetrics::default())),
last_update: Arc::new(RwLock::new(Instant::now())),
}
}
pub fn train(&self, features: &Array1<f64>, target: f64) -> Result<()> {
if features.len() != self.num_features {
return Err(anyhow!(
"Feature dimension mismatch: expected {}, got {}",
self.num_features,
features.len()
));
}
self.sample_buffer.write().push((features.clone(), target));
let should_update = {
let buffer = self.sample_buffer.read();
let last_update = self.last_update.read();
buffer.len() >= self.config.batch_size
|| last_update.elapsed() >= self.config.update_interval
};
if should_update {
self.update_weights()?;
}
Ok(())
}
fn update_weights(&self) -> Result<()> {
let samples = {
let mut buffer = self.sample_buffer.write();
std::mem::take(&mut *buffer)
};
if samples.is_empty() {
return Ok(());
}
let mut weights = self.weights.write();
let mut bias = self.bias.write();
for (features, target) in &samples {
let prediction = self.predict_internal(&weights, *bias, features);
let error = prediction - target;
for i in 0..self.num_features {
weights[i] -= self.config.learning_rate * error * features[i];
}
*bias -= self.config.learning_rate * error;
}
*self.last_update.write() = Instant::now();
debug!("Updated model weights with {} samples", samples.len());
Ok(())
}
pub fn predict(&self, features: &Array1<f64>) -> Result<PredictionResult> {
if features.len() != self.num_features {
return Err(anyhow!("Feature dimension mismatch"));
}
let start_time = Instant::now();
let weights = self.weights.read();
let bias = self.bias.read();
let prediction = self.predict_internal(&weights, *bias, features);
let mut metrics = self.metrics.write();
metrics.predictions_made += 1;
let prediction_time = start_time.elapsed().as_micros() as f64 / 1000.0;
metrics.avg_prediction_time_ms = (metrics.avg_prediction_time_ms + prediction_time) / 2.0;
Ok(PredictionResult {
prediction,
confidence: 0.8, interval: None,
timestamp: Utc::now(),
})
}
fn predict_internal(&self, weights: &Array1<f64>, bias: f64, features: &Array1<f64>) -> f64 {
let mut result = bias;
for i in 0..self.num_features {
result += weights[i] * features[i];
}
result
}
pub fn get_metrics(&self) -> ModelMetrics {
self.metrics.read().clone()
}
}
pub struct AnomalyDetector {
config: AnomalyDetectionConfig,
historical_mean: Arc<RwLock<f64>>,
historical_std: Arc<RwLock<f64>>,
recent_samples: Arc<RwLock<VecDeque<f64>>>,
threshold: Arc<RwLock<f64>>,
stats: Arc<RwLock<AnomalyStats>>,
}
impl AnomalyDetector {
pub fn new(config: AnomalyDetectionConfig) -> Self {
Self {
config: config.clone(),
historical_mean: Arc::new(RwLock::new(0.0)),
historical_std: Arc::new(RwLock::new(1.0)),
recent_samples: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
threshold: Arc::new(RwLock::new(3.0)), stats: Arc::new(RwLock::new(AnomalyStats::default())),
}
}
pub fn detect(&self, features: &FeatureVector) -> Result<AnomalyResult> {
let metric = features.features.iter().sum::<f64>() / features.features.len() as f64;
let mut samples = self.recent_samples.write();
samples.push_back(metric);
if samples.len() > self.config.window_size {
samples.pop_front();
}
let mut stats = self.stats.write();
stats.events_processed += 1;
if samples.len() < self.config.min_samples {
return Ok(AnomalyResult {
is_anomaly: false,
score: 0.0,
explanation: "Insufficient samples for detection".to_string(),
contributing_features: Vec::new(),
timestamp: Utc::now(),
event_id: features.source_event_id.clone(),
});
}
let mean = samples.iter().sum::<f64>() / samples.len() as f64;
let variance =
samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
let std_dev = variance.sqrt();
{
let mut hist_mean = self.historical_mean.write();
let mut hist_std = self.historical_std.write();
let alpha = self.config.adaptive_learning_rate;
*hist_mean = alpha * mean + (1.0 - alpha) * *hist_mean;
*hist_std = alpha * std_dev + (1.0 - alpha) * *hist_std;
}
let (is_anomaly, score, explanation) = match &self.config.algorithm {
AnomalyDetectionAlgorithm::Statistical { threshold } => {
let z_score = if std_dev > 1e-10 {
(metric - mean).abs() / std_dev
} else {
0.0
};
let is_anomaly = z_score > *threshold;
let score = (z_score / threshold).min(1.0);
(
is_anomaly,
score,
format!("Z-score: {:.2}, threshold: {:.2}", z_score, threshold),
)
}
AnomalyDetectionAlgorithm::IsolationForest { contamination } => {
let z_score = if std_dev > 1e-10 {
(metric - mean).abs() / std_dev
} else {
0.0
};
let threshold = 3.0 / contamination;
let is_anomaly = z_score > threshold;
let score = (z_score / threshold).min(1.0);
(is_anomaly, score, format!("Isolation score: {:.2}", score))
}
_ => {
let z_score = if std_dev > 1e-10 {
(metric - mean).abs() / std_dev
} else {
0.0
};
let is_anomaly = z_score > 3.0;
let score = (z_score / 3.0).min(1.0);
(is_anomaly, score, format!("Z-score: {:.2}", z_score))
}
};
if is_anomaly {
stats.anomalies_detected += 1;
stats.true_positives += 1;
}
stats.avg_anomaly_score = (stats.avg_anomaly_score + score) / 2.0;
stats.detection_rate = stats.anomalies_detected as f64 / stats.events_processed as f64;
Ok(AnomalyResult {
is_anomaly,
score,
explanation,
contributing_features: features.feature_names.clone(),
timestamp: Utc::now(),
event_id: features.source_event_id.clone(),
})
}
pub fn feedback(&self, event_id: &str, is_true_anomaly: bool) {
debug!(
"Received feedback for event {}: is_anomaly={}",
event_id, is_true_anomaly
);
if self.config.enable_feedback {
let mut threshold = self.threshold.write();
if is_true_anomaly {
*threshold *= 0.98; } else {
*threshold *= 1.02; }
}
}
pub fn get_stats(&self) -> AnomalyStats {
self.stats.read().clone()
}
}
pub struct FeatureExtractor {
config: FeatureConfig,
event_history: Arc<RwLock<VecDeque<StreamEvent>>>,
}
impl FeatureExtractor {
pub fn new(config: FeatureConfig) -> Self {
Self {
config: config.clone(),
event_history: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
}
}
pub fn extract_features(&self, event: &StreamEvent) -> Result<FeatureVector> {
let mut features = Vec::new();
let mut feature_names = Vec::new();
let mut history = self.event_history.write();
history.push_back(event.clone());
if history.len() > self.config.window_size {
history.pop_front();
}
features.push(history.len() as f64);
feature_names.push("window_size".to_string());
if self.config.enable_statistical {
features.push(history.len() as f64);
feature_names.push("event_count".to_string());
if history.len() >= 2 {
let rate = history.len() as f64 / self.config.window_size as f64;
features.push(rate);
feature_names.push("event_rate".to_string());
}
}
if self.config.enable_frequency {
let mut type_counts: HashMap<String, usize> = HashMap::new();
for evt in history.iter() {
let event_type = self.get_event_type(evt);
*type_counts.entry(event_type).or_insert(0) += 1;
}
let unique_types = type_counts.len() as f64;
features.push(unique_types);
feature_names.push("unique_event_types".to_string());
}
Ok(FeatureVector {
features: Array1::from_vec(features),
feature_names,
timestamp: Utc::now(),
source_event_id: self.get_event_id(event),
})
}
fn get_event_type(&self, event: &StreamEvent) -> String {
match event {
StreamEvent::TripleAdded { .. } => "TripleAdded",
StreamEvent::TripleRemoved { .. } => "TripleRemoved",
StreamEvent::QuadAdded { .. } => "QuadAdded",
StreamEvent::QuadRemoved { .. } => "QuadRemoved",
StreamEvent::GraphCreated { .. } => "GraphCreated",
StreamEvent::GraphCleared { .. } => "GraphCleared",
StreamEvent::GraphDeleted { .. } => "GraphDeleted",
StreamEvent::SparqlUpdate { .. } => "SparqlUpdate",
StreamEvent::TransactionBegin { .. } => "TransactionBegin",
StreamEvent::TransactionCommit { .. } => "TransactionCommit",
StreamEvent::TransactionAbort { .. } => "TransactionAbort",
StreamEvent::SchemaChanged { .. } => "SchemaChanged",
_ => "Other",
}
.to_string()
}
fn get_event_id(&self, event: &StreamEvent) -> String {
let metadata = match event {
StreamEvent::TripleAdded { metadata, .. }
| StreamEvent::TripleRemoved { metadata, .. }
| StreamEvent::QuadAdded { metadata, .. }
| StreamEvent::QuadRemoved { metadata, .. }
| StreamEvent::GraphCreated { metadata, .. }
| StreamEvent::GraphCleared { metadata, .. }
| StreamEvent::GraphDeleted { metadata, .. }
| StreamEvent::SparqlUpdate { metadata, .. }
| StreamEvent::TransactionBegin { metadata, .. }
| StreamEvent::TransactionCommit { metadata, .. }
| StreamEvent::TransactionAbort { metadata, .. }
| StreamEvent::SchemaChanged { metadata, .. }
| StreamEvent::Heartbeat { metadata, .. }
| StreamEvent::QueryResultAdded { metadata, .. }
| StreamEvent::QueryResultRemoved { metadata, .. }
| StreamEvent::QueryCompleted { metadata, .. }
| StreamEvent::GraphMetadataUpdated { metadata, .. }
| StreamEvent::GraphPermissionsChanged { metadata, .. }
| StreamEvent::GraphStatisticsUpdated { metadata, .. }
| StreamEvent::GraphRenamed { metadata, .. }
| StreamEvent::GraphMerged { metadata, .. }
| StreamEvent::GraphSplit { metadata, .. }
| StreamEvent::SchemaDefinitionAdded { metadata, .. }
| StreamEvent::SchemaDefinitionRemoved { metadata, .. }
| StreamEvent::SchemaDefinitionModified { metadata, .. }
| StreamEvent::OntologyImported { metadata, .. }
| StreamEvent::OntologyRemoved { metadata, .. }
| StreamEvent::ConstraintAdded { metadata, .. }
| StreamEvent::ConstraintRemoved { metadata, .. }
| StreamEvent::ConstraintViolated { metadata, .. }
| StreamEvent::IndexCreated { metadata, .. }
| StreamEvent::IndexDropped { metadata, .. }
| StreamEvent::IndexRebuilt { metadata, .. }
| StreamEvent::SchemaUpdated { metadata, .. }
| StreamEvent::ShapeAdded { metadata, .. }
| StreamEvent::ShapeUpdated { metadata, .. }
| StreamEvent::ShapeRemoved { metadata, .. }
| StreamEvent::ShapeModified { metadata, .. }
| StreamEvent::ShapeValidationStarted { metadata, .. }
| StreamEvent::ShapeValidationCompleted { metadata, .. }
| StreamEvent::ShapeViolationDetected { metadata, .. }
| StreamEvent::ErrorOccurred { metadata, .. } => metadata,
};
metadata.event_id.clone()
}
}
pub struct MLIntegrationManager {
models: Arc<DashMap<String, OnlineLearningModel>>,
detectors: Arc<DashMap<String, AnomalyDetector>>,
extractors: Arc<DashMap<String, FeatureExtractor>>,
}
impl MLIntegrationManager {
pub fn new() -> Self {
Self {
models: Arc::new(DashMap::new()),
detectors: Arc::new(DashMap::new()),
extractors: Arc::new(DashMap::new()),
}
}
pub fn register_model(&self, name: String, model: OnlineLearningModel) {
self.models.insert(name.clone(), model);
info!("Registered ML model: {}", name);
}
pub fn register_detector(&self, name: String, detector: AnomalyDetector) {
self.detectors.insert(name.clone(), detector);
info!("Registered anomaly detector: {}", name);
}
pub fn register_extractor(&self, name: String, extractor: FeatureExtractor) {
self.extractors.insert(name.clone(), extractor);
info!("Registered feature extractor: {}", name);
}
pub fn get_model(
&self,
name: &str,
) -> Option<dashmap::mapref::one::Ref<'_, String, OnlineLearningModel>> {
self.models.get(name)
}
pub fn get_detector(
&self,
name: &str,
) -> Option<dashmap::mapref::one::Ref<'_, String, AnomalyDetector>> {
self.detectors.get(name)
}
pub fn get_extractor(
&self,
name: &str,
) -> Option<dashmap::mapref::one::Ref<'_, String, FeatureExtractor>> {
self.extractors.get(name)
}
}
impl Default for MLIntegrationManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::event::EventMetadata;
#[test]
fn test_online_learning() {
let config = MLModelConfig {
model_type: ModelType::LinearRegression,
feature_config: FeatureConfig {
window_size: 10,
enable_statistical: true,
enable_frequency: false,
custom_features: Vec::new(),
},
learning_rate: 0.01,
batch_size: 10,
update_interval: Duration::from_secs(1),
enable_persistence: false,
version: "1.0".to_string(),
};
let model = OnlineLearningModel::new(config, 3);
let features = Array1::from_vec(vec![1.0, 2.0, 3.0]);
model.train(&features, 10.0).unwrap();
let result = model.predict(&features).unwrap();
assert!(result.prediction.is_finite());
}
#[test]
fn test_anomaly_detection() {
let config = AnomalyDetectionConfig {
algorithm: AnomalyDetectionAlgorithm::Statistical { threshold: 3.0 },
sensitivity: 0.8,
adaptive_learning_rate: 0.1,
window_size: 100,
min_samples: 10,
enable_feedback: true,
};
let detector = AnomalyDetector::new(config);
for i in 0..20 {
let features = FeatureVector {
features: Array1::from_vec(vec![100.0 + i as f64]),
feature_names: vec!["value".to_string()],
timestamp: Utc::now(),
source_event_id: format!("event-{}", i),
};
let result = detector.detect(&features).unwrap();
if i >= 10 {
assert!(!result.is_anomaly);
}
}
let anomalous_features = FeatureVector {
features: Array1::from_vec(vec![1000.0]),
feature_names: vec!["value".to_string()],
timestamp: Utc::now(),
source_event_id: "anomaly".to_string(),
};
let result = detector.detect(&anomalous_features).unwrap();
assert!(result.is_anomaly);
assert!(result.score > 0.0);
}
#[test]
fn test_feature_extraction() {
let config = FeatureConfig {
window_size: 10,
enable_statistical: true,
enable_frequency: true,
custom_features: Vec::new(),
};
let extractor = FeatureExtractor::new(config);
let event = StreamEvent::SchemaChanged {
schema_type: crate::event::SchemaType::Ontology,
change_type: crate::event::SchemaChangeType::Added,
details: "test schema change".to_string(),
metadata: EventMetadata {
event_id: "test-1".to_string(),
timestamp: Utc::now(),
source: "test".to_string(),
user: None,
context: None,
caused_by: None,
version: "1.0".to_string(),
properties: HashMap::new(),
checksum: None,
},
};
let features = extractor.extract_features(&event).unwrap();
assert!(!features.features.is_empty());
assert_eq!(features.features.len(), features.feature_names.len());
}
}