Skip to main content

oxirs_stream/ml/
mod.rs

1//! # ML Model Integration for Stream Processing
2//!
3//! Provides ML inference capabilities embedded in the stream processing pipeline:
4//!
5//! - [`StreamingModelRunner`]: Runs ML inference on stream events with batching
6//! - [`StreamAnomalyDetector`]: Z-score based streaming anomaly detection with sliding window
7//! - [`StreamFeatureExtractor`]: Configurable feature extraction from RDF stream events
8//! - [`regression::StreamRegressor`]: Online regression (linear, GBT-streaming)
9//! - [`classification::StreamClassifier`]: Online classification (logistic, kNN streaming)
10
11pub mod classification;
12pub mod regression;
13
14pub use classification::{
15    ClassPrediction, ClassificationError, ClassificationResult, KnnConfig, LogisticConfig,
16    OnlineLogisticClassifier, StreamClassifier, StreamingKnnClassifier,
17};
18pub use regression::{
19    GbtConfig, LinearConfig, OnlineLinearRegressor, RegressionError, RegressionResult,
20    StreamRegressor, StreamingGradientBoostedRegressor,
21};
22
23use std::collections::{HashMap, VecDeque};
24use std::sync::Arc;
25use std::time::Instant;
26
27use chrono::{DateTime, Utc};
28use parking_lot::RwLock;
29use serde::{Deserialize, Serialize};
30use tracing::{debug, info, warn};
31
32use scirs2_core::ndarray_ext::Array1;
33
34use crate::event::StreamEvent;
35
36// ─── Model Configuration ─────────────────────────────────────────────────────
37
38/// Configuration for a streaming model runner
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ModelConfig {
41    /// Path or identifier for the model
42    pub model_path: String,
43    /// Maximum batch size before forcing inference
44    pub batch_size: usize,
45    /// Maximum latency before forcing inference (even if batch is not full)
46    pub max_latency_ms: u64,
47    /// Number of input features expected
48    pub input_features: usize,
49    /// Model name for logging
50    pub model_name: String,
51}
52
53impl Default for ModelConfig {
54    fn default() -> Self {
55        Self {
56            model_path: "default".to_string(),
57            batch_size: 32,
58            max_latency_ms: 100,
59            input_features: 4,
60            model_name: "default-model".to_string(),
61        }
62    }
63}
64
65/// A single prediction from the model
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct Prediction {
68    /// Predicted value or class
69    pub value: f64,
70    /// Confidence score (0.0 to 1.0)
71    pub confidence: f64,
72    /// Source event identifier
73    pub source_event_id: String,
74    /// Timestamp of the prediction
75    pub predicted_at: DateTime<Utc>,
76    /// Model that produced the prediction
77    pub model_name: String,
78}
79
80/// Statistics for the streaming model runner
81#[derive(Debug, Clone, Serialize, Deserialize, Default)]
82pub struct ModelRunnerStats {
83    /// Total events processed
84    pub events_processed: u64,
85    /// Total batches executed
86    pub batches_executed: u64,
87    /// Total predictions produced
88    pub predictions_produced: u64,
89    /// Average batch size
90    pub avg_batch_size: f64,
91    /// Average inference time per batch (milliseconds)
92    pub avg_inference_time_ms: f64,
93    /// Batches triggered by size threshold
94    pub size_triggered_batches: u64,
95    /// Batches triggered by latency threshold
96    pub latency_triggered_batches: u64,
97}
98
99/// A pending event waiting to be included in a batch
100#[derive(Debug, Clone)]
101struct PendingEvent {
102    features: Array1<f64>,
103    event_id: String,
104    queued_at: Instant,
105}
106
107/// Model weights for a simple linear model (more models can be added)
108#[derive(Debug, Clone)]
109struct LinearModelWeights {
110    weights: Array1<f64>,
111    bias: f64,
112}
113
114/// Runs ML inference on stream events with automatic batching.
115///
116/// Events are collected until either `batch_size` events accumulate
117/// or `max_latency_ms` elapses, then inference is run on the batch.
118pub struct StreamingModelRunner {
119    config: ModelConfig,
120    /// Pending events waiting for batch inference
121    pending: Arc<RwLock<Vec<PendingEvent>>>,
122    /// Model weights (simple linear model for now)
123    model: Arc<RwLock<LinearModelWeights>>,
124    /// Runner statistics
125    stats: Arc<RwLock<ModelRunnerStats>>,
126    /// When the oldest pending event was queued
127    batch_start: Arc<RwLock<Option<Instant>>>,
128}
129
130impl StreamingModelRunner {
131    /// Creates a new streaming model runner.
132    pub fn new(config: ModelConfig) -> Self {
133        // Initialize with small default weights
134        let weights = Array1::from_vec(vec![0.1; config.input_features]);
135        Self {
136            config: config.clone(),
137            pending: Arc::new(RwLock::new(Vec::with_capacity(config.batch_size))),
138            model: Arc::new(RwLock::new(LinearModelWeights { weights, bias: 0.0 })),
139            stats: Arc::new(RwLock::new(ModelRunnerStats::default())),
140            batch_start: Arc::new(RwLock::new(None)),
141        }
142    }
143
144    /// Enqueues an event for prediction.
145    ///
146    /// Returns predictions if a batch was triggered.
147    pub fn enqueue(&self, features: Array1<f64>, event_id: String) -> Option<Vec<Prediction>> {
148        if features.len() != self.config.input_features {
149            warn!(
150                "Feature dimension mismatch: expected {}, got {}",
151                self.config.input_features,
152                features.len()
153            );
154            return None;
155        }
156
157        let mut pending = self.pending.write();
158        if pending.is_empty() {
159            *self.batch_start.write() = Some(Instant::now());
160        }
161        pending.push(PendingEvent {
162            features,
163            event_id,
164            queued_at: Instant::now(),
165        });
166
167        // Check if batch should be triggered
168        if pending.len() >= self.config.batch_size {
169            let events: Vec<PendingEvent> = std::mem::take(&mut *pending);
170            drop(pending);
171            *self.batch_start.write() = None;
172            self.stats.write().size_triggered_batches += 1;
173            Some(self.run_inference(events))
174        } else {
175            None
176        }
177    }
178
179    /// Flushes any pending events if the latency threshold has been exceeded.
180    ///
181    /// Returns predictions if flush was needed.
182    pub fn flush_if_due(&self) -> Option<Vec<Prediction>> {
183        let should_flush = {
184            let batch_start = self.batch_start.read();
185            match *batch_start {
186                Some(start) => start.elapsed().as_millis() as u64 >= self.config.max_latency_ms,
187                None => false,
188            }
189        };
190
191        if should_flush {
192            let mut pending = self.pending.write();
193            if pending.is_empty() {
194                return None;
195            }
196            let events: Vec<PendingEvent> = std::mem::take(&mut *pending);
197            drop(pending);
198            *self.batch_start.write() = None;
199            self.stats.write().latency_triggered_batches += 1;
200            Some(self.run_inference(events))
201        } else {
202            None
203        }
204    }
205
206    /// Forces inference on all pending events regardless of thresholds.
207    pub fn flush(&self) -> Vec<Prediction> {
208        let mut pending = self.pending.write();
209        if pending.is_empty() {
210            return Vec::new();
211        }
212        let events: Vec<PendingEvent> = std::mem::take(&mut *pending);
213        drop(pending);
214        *self.batch_start.write() = None;
215        self.run_inference(events)
216    }
217
218    /// Runs batched inference directly on a slice of stream events.
219    pub fn predict(&self, events: &[(Array1<f64>, String)]) -> Vec<Prediction> {
220        let pending_events: Vec<PendingEvent> = events
221            .iter()
222            .map(|(features, event_id)| PendingEvent {
223                features: features.clone(),
224                event_id: event_id.clone(),
225                queued_at: Instant::now(),
226            })
227            .collect();
228        self.run_inference(pending_events)
229    }
230
231    /// Updates the model weights.
232    pub fn update_weights(&self, weights: Array1<f64>, bias: f64) {
233        let mut model = self.model.write();
234        model.weights = weights;
235        model.bias = bias;
236        info!("Model {} weights updated", self.config.model_name);
237    }
238
239    /// Returns runner statistics.
240    pub fn stats(&self) -> ModelRunnerStats {
241        self.stats.read().clone()
242    }
243
244    /// Returns the number of pending events.
245    pub fn pending_count(&self) -> usize {
246        self.pending.read().len()
247    }
248
249    /// Internal inference function.
250    fn run_inference(&self, events: Vec<PendingEvent>) -> Vec<Prediction> {
251        let start = Instant::now();
252        let model = self.model.read();
253        let batch_size = events.len();
254
255        let predictions: Vec<Prediction> = events
256            .iter()
257            .map(|event| {
258                let mut value = model.bias;
259                let n = model.weights.len().min(event.features.len());
260                for i in 0..n {
261                    value += model.weights[i] * event.features[i];
262                }
263                // Sigmoid for confidence
264                let confidence = 1.0 / (1.0 + (-value).exp());
265
266                Prediction {
267                    value,
268                    confidence: confidence.clamp(0.0, 1.0),
269                    source_event_id: event.event_id.clone(),
270                    predicted_at: Utc::now(),
271                    model_name: self.config.model_name.clone(),
272                }
273            })
274            .collect();
275
276        let elapsed_ms = start.elapsed().as_micros() as f64 / 1000.0;
277
278        let mut stats = self.stats.write();
279        stats.events_processed += batch_size as u64;
280        stats.batches_executed += 1;
281        stats.predictions_produced += predictions.len() as u64;
282        let total_batches = stats.batches_executed as f64;
283        stats.avg_batch_size =
284            (stats.avg_batch_size * (total_batches - 1.0) + batch_size as f64) / total_batches;
285        stats.avg_inference_time_ms =
286            (stats.avg_inference_time_ms * (total_batches - 1.0) + elapsed_ms) / total_batches;
287
288        debug!(
289            "Inference batch: {} events, {:.2}ms",
290            batch_size, elapsed_ms
291        );
292
293        predictions
294    }
295}
296
297// ─── Streaming Anomaly Detector ──────────────────────────────────────────────
298
299/// Configuration for the streaming anomaly detector
300#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct AnomalyDetectorConfig {
302    /// Z-score threshold for anomaly detection
303    pub sigma_threshold: f64,
304    /// Sliding window size for statistics computation
305    pub window_size: usize,
306    /// Minimum samples before detection starts
307    pub min_samples: usize,
308    /// Adaptive threshold learning rate (0.0 = fixed, 1.0 = fully adaptive)
309    pub adaptive_rate: f64,
310}
311
312impl Default for AnomalyDetectorConfig {
313    fn default() -> Self {
314        Self {
315            sigma_threshold: 3.0,
316            window_size: 100,
317            min_samples: 10,
318            adaptive_rate: 0.0,
319        }
320    }
321}
322
323/// Result of anomaly detection on a single value
324#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct AnomalyCheckResult {
326    /// Whether the value is anomalous
327    pub is_anomaly: bool,
328    /// The Z-score of the value
329    pub z_score: f64,
330    /// The current mean of the sliding window
331    pub window_mean: f64,
332    /// The current standard deviation of the sliding window
333    pub window_stddev: f64,
334    /// The effective threshold used
335    pub threshold: f64,
336    /// Number of samples in the window
337    pub window_samples: usize,
338}
339
340/// Statistics for the anomaly detector
341#[derive(Debug, Clone, Serialize, Deserialize, Default)]
342pub struct AnomalyDetectorStats {
343    /// Total values processed
344    pub values_processed: u64,
345    /// Total anomalies detected
346    pub anomalies_detected: u64,
347    /// Current window mean
348    pub current_mean: f64,
349    /// Current window stddev
350    pub current_stddev: f64,
351    /// Detection rate
352    pub detection_rate: f64,
353}
354
355/// Z-score based streaming anomaly detector with a sliding window.
356///
357/// Maintains a sliding window of recent values, computes running mean and
358/// standard deviation, and flags values whose Z-score exceeds the configured
359/// sigma threshold.
360pub struct StreamAnomalyDetector {
361    config: AnomalyDetectorConfig,
362    /// Sliding window of recent values
363    window: Arc<RwLock<VecDeque<f64>>>,
364    /// Running sum for incremental mean computation
365    running_sum: Arc<RwLock<f64>>,
366    /// Running sum of squares for incremental stddev
367    running_sum_sq: Arc<RwLock<f64>>,
368    /// Effective threshold (may be adapted over time)
369    effective_threshold: Arc<RwLock<f64>>,
370    /// Statistics
371    stats: Arc<RwLock<AnomalyDetectorStats>>,
372}
373
374impl StreamAnomalyDetector {
375    /// Creates a new anomaly detector.
376    pub fn new(config: AnomalyDetectorConfig) -> Self {
377        let threshold = config.sigma_threshold;
378        Self {
379            config: config.clone(),
380            window: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
381            running_sum: Arc::new(RwLock::new(0.0)),
382            running_sum_sq: Arc::new(RwLock::new(0.0)),
383            effective_threshold: Arc::new(RwLock::new(threshold)),
384            stats: Arc::new(RwLock::new(AnomalyDetectorStats::default())),
385        }
386    }
387
388    /// Checks whether a value is anomalous.
389    pub fn is_anomaly(&self, value: f64) -> AnomalyCheckResult {
390        let mut window = self.window.write();
391        let mut sum = self.running_sum.write();
392        let mut sum_sq = self.running_sum_sq.write();
393
394        // Add value to window
395        if window.len() >= self.config.window_size {
396            if let Some(removed) = window.pop_front() {
397                *sum -= removed;
398                *sum_sq -= removed * removed;
399            }
400        }
401        window.push_back(value);
402        *sum += value;
403        *sum_sq += value * value;
404
405        let n = window.len();
406
407        let mut stats = self.stats.write();
408        stats.values_processed += 1;
409
410        // Need minimum samples
411        if n < self.config.min_samples {
412            return AnomalyCheckResult {
413                is_anomaly: false,
414                z_score: 0.0,
415                window_mean: if n > 0 { *sum / n as f64 } else { 0.0 },
416                window_stddev: 0.0,
417                threshold: *self.effective_threshold.read(),
418                window_samples: n,
419            };
420        }
421
422        let mean = *sum / n as f64;
423        let variance = (*sum_sq / n as f64) - (mean * mean);
424        let stddev = if variance > 0.0 { variance.sqrt() } else { 0.0 };
425
426        let z_score = if stddev > 1e-10 {
427            (value - mean).abs() / stddev
428        } else {
429            0.0
430        };
431
432        let threshold = *self.effective_threshold.read();
433        let is_anomaly = z_score > threshold;
434
435        if is_anomaly {
436            stats.anomalies_detected += 1;
437        }
438        stats.current_mean = mean;
439        stats.current_stddev = stddev;
440        stats.detection_rate = if stats.values_processed > 0 {
441            stats.anomalies_detected as f64 / stats.values_processed as f64
442        } else {
443            0.0
444        };
445
446        AnomalyCheckResult {
447            is_anomaly,
448            z_score,
449            window_mean: mean,
450            window_stddev: stddev,
451            threshold,
452            window_samples: n,
453        }
454    }
455
456    /// Provides feedback to adapt the threshold.
457    pub fn feedback(&self, was_true_anomaly: bool) {
458        if self.config.adaptive_rate <= 0.0 {
459            return;
460        }
461        let mut threshold = self.effective_threshold.write();
462        if was_true_anomaly {
463            // Lower threshold slightly to catch more
464            *threshold *= 1.0 - (self.config.adaptive_rate * 0.02);
465        } else {
466            // Raise threshold slightly to reduce false positives
467            *threshold *= 1.0 + (self.config.adaptive_rate * 0.02);
468        }
469        // Clamp to reasonable range
470        *threshold = threshold.clamp(1.0, 10.0);
471    }
472
473    /// Returns detector statistics.
474    pub fn stats(&self) -> AnomalyDetectorStats {
475        self.stats.read().clone()
476    }
477
478    /// Returns the current effective threshold.
479    pub fn effective_threshold(&self) -> f64 {
480        *self.effective_threshold.read()
481    }
482
483    /// Resets the detector state.
484    pub fn reset(&self) {
485        self.window.write().clear();
486        *self.running_sum.write() = 0.0;
487        *self.running_sum_sq.write() = 0.0;
488        *self.stats.write() = AnomalyDetectorStats::default();
489    }
490}
491
492// ─── Feature Extractor ───────────────────────────────────────────────────────
493
494/// A feature definition describing how to extract a numeric feature from an event
495#[derive(Debug, Clone, Serialize, Deserialize)]
496pub struct FeatureDefinition {
497    /// Feature name
498    pub name: String,
499    /// Predicate selector: if the event's predicate contains this string, extract
500    pub predicate_selector: Option<String>,
501    /// Aggregation type for window-based features
502    pub aggregation: FeatureAggregation,
503}
504
505/// Aggregation type for a feature
506#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
507pub enum FeatureAggregation {
508    /// Use the latest value
509    Latest,
510    /// Count occurrences in window
511    Count,
512    /// Sum values in window
513    Sum,
514    /// Compute mean over window
515    Mean,
516}
517
518/// Configuration for the feature extractor
519#[derive(Debug, Clone, Serialize, Deserialize)]
520pub struct FeatureExtractorConfig {
521    /// Feature definitions
522    pub features: Vec<FeatureDefinition>,
523    /// Window size for aggregation-based features
524    pub window_size: usize,
525}
526
527impl Default for FeatureExtractorConfig {
528    fn default() -> Self {
529        Self {
530            features: vec![
531                FeatureDefinition {
532                    name: "event_count".to_string(),
533                    predicate_selector: None,
534                    aggregation: FeatureAggregation::Count,
535                },
536                FeatureDefinition {
537                    name: "event_rate".to_string(),
538                    predicate_selector: None,
539                    aggregation: FeatureAggregation::Mean,
540                },
541            ],
542            window_size: 50,
543        }
544    }
545}
546
547/// Extracted feature vector
548#[derive(Debug, Clone)]
549pub struct ExtractedFeatures {
550    /// Feature values as a numeric array
551    pub values: Array1<f64>,
552    /// Feature names (corresponding to values)
553    pub names: Vec<String>,
554    /// Timestamp of extraction
555    pub extracted_at: DateTime<Utc>,
556    /// Source event ID
557    pub event_id: String,
558}
559
560/// Configurable feature extractor for RDF stream events.
561///
562/// Extracts numeric features from stream events based on configured
563/// feature definitions with predicate selectors and aggregation types.
564pub struct StreamFeatureExtractor {
565    config: FeatureExtractorConfig,
566    /// History window of events for aggregation features
567    history: Arc<RwLock<VecDeque<EventSnapshot>>>,
568    /// Per-feature running values for aggregation
569    running_values: Arc<RwLock<HashMap<String, VecDeque<f64>>>>,
570}
571
572/// Lightweight snapshot of an event for windowed features
573#[derive(Debug, Clone)]
574struct EventSnapshot {
575    event_type: String,
576    predicate: Option<String>,
577    timestamp: Instant,
578}
579
580impl StreamFeatureExtractor {
581    /// Creates a new feature extractor.
582    pub fn new(config: FeatureExtractorConfig) -> Self {
583        Self {
584            config: config.clone(),
585            history: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
586            running_values: Arc::new(RwLock::new(HashMap::new())),
587        }
588    }
589
590    /// Extracts features from a stream event.
591    pub fn extract(&self, event: &StreamEvent, event_id: &str) -> ExtractedFeatures {
592        let event_type = Self::event_type_name(event);
593        let predicate = Self::extract_predicate(event);
594
595        // Update history
596        let mut history = self.history.write();
597        history.push_back(EventSnapshot {
598            event_type: event_type.clone(),
599            predicate: predicate.clone(),
600            timestamp: Instant::now(),
601        });
602        while history.len() > self.config.window_size {
603            history.pop_front();
604        }
605        let history_len = history.len();
606
607        // Compute features
608        let mut values = Vec::with_capacity(self.config.features.len());
609        let mut names = Vec::with_capacity(self.config.features.len());
610
611        for feature_def in &self.config.features {
612            let matched = match &feature_def.predicate_selector {
613                Some(selector) => predicate
614                    .as_ref()
615                    .map(|p| p.contains(selector))
616                    .unwrap_or(false),
617                None => true, // No selector means match all events
618            };
619
620            let value = match feature_def.aggregation {
621                FeatureAggregation::Count => {
622                    // Count matching events in the window (regardless of current event)
623                    match &feature_def.predicate_selector {
624                        Some(selector) => history
625                            .iter()
626                            .filter(|e| {
627                                e.predicate
628                                    .as_ref()
629                                    .map(|p| p.contains(selector))
630                                    .unwrap_or(false)
631                            })
632                            .count() as f64,
633                        None => history_len as f64,
634                    }
635                }
636                FeatureAggregation::Latest => {
637                    if matched {
638                        1.0
639                    } else {
640                        0.0
641                    }
642                }
643                FeatureAggregation::Sum => {
644                    let running = self.running_values.read();
645                    running
646                        .get(&feature_def.name)
647                        .map(|v| v.iter().sum())
648                        .unwrap_or(0.0)
649                }
650                FeatureAggregation::Mean => {
651                    if history_len > 0 {
652                        match &feature_def.predicate_selector {
653                            Some(selector) => {
654                                let count = history
655                                    .iter()
656                                    .filter(|e| {
657                                        e.predicate
658                                            .as_ref()
659                                            .map(|p| p.contains(selector))
660                                            .unwrap_or(false)
661                                    })
662                                    .count();
663                                count as f64 / history_len as f64
664                            }
665                            None => 1.0, // All events match
666                        }
667                    } else {
668                        0.0
669                    }
670                }
671            };
672
673            values.push(value);
674            names.push(feature_def.name.clone());
675        }
676
677        // Update running values for matched event
678        {
679            let mut running = self.running_values.write();
680            for feature_def in &self.config.features {
681                let entry = running.entry(feature_def.name.clone()).or_default();
682                let matched = match &feature_def.predicate_selector {
683                    Some(selector) => predicate
684                        .as_ref()
685                        .map(|p| p.contains(selector))
686                        .unwrap_or(false),
687                    None => true,
688                };
689                entry.push_back(if matched { 1.0 } else { 0.0 });
690                while entry.len() > self.config.window_size {
691                    entry.pop_front();
692                }
693            }
694        }
695
696        ExtractedFeatures {
697            values: Array1::from_vec(values),
698            names,
699            extracted_at: Utc::now(),
700            event_id: event_id.to_string(),
701        }
702    }
703
704    /// Resets the extractor state.
705    pub fn reset(&self) {
706        self.history.write().clear();
707        self.running_values.write().clear();
708    }
709
710    /// Returns the current window size.
711    pub fn current_window_size(&self) -> usize {
712        self.history.read().len()
713    }
714
715    /// Returns the event type name.
716    fn event_type_name(event: &StreamEvent) -> String {
717        match event {
718            StreamEvent::TripleAdded { .. } => "TripleAdded",
719            StreamEvent::TripleRemoved { .. } => "TripleRemoved",
720            StreamEvent::QuadAdded { .. } => "QuadAdded",
721            StreamEvent::QuadRemoved { .. } => "QuadRemoved",
722            StreamEvent::GraphCreated { .. } => "GraphCreated",
723            StreamEvent::GraphCleared { .. } => "GraphCleared",
724            StreamEvent::GraphDeleted { .. } => "GraphDeleted",
725            StreamEvent::SparqlUpdate { .. } => "SparqlUpdate",
726            StreamEvent::TransactionBegin { .. } => "TransactionBegin",
727            StreamEvent::TransactionCommit { .. } => "TransactionCommit",
728            StreamEvent::TransactionAbort { .. } => "TransactionAbort",
729            StreamEvent::SchemaChanged { .. } => "SchemaChanged",
730            _ => "Other",
731        }
732        .to_string()
733    }
734
735    /// Extracts the predicate from a stream event, if it has one.
736    fn extract_predicate(event: &StreamEvent) -> Option<String> {
737        match event {
738            StreamEvent::TripleAdded { predicate, .. }
739            | StreamEvent::TripleRemoved { predicate, .. }
740            | StreamEvent::QuadAdded { predicate, .. }
741            | StreamEvent::QuadRemoved { predicate, .. } => Some(predicate.clone()),
742            _ => None,
743        }
744    }
745}
746
747// ─── Tests ───────────────────────────────────────────────────────────────────
748
749#[cfg(test)]
750mod tests {
751    use super::*;
752    use crate::event::EventMetadata;
753    use std::time::Duration;
754
755    fn make_metadata(id: &str) -> EventMetadata {
756        EventMetadata {
757            event_id: id.to_string(),
758            timestamp: Utc::now(),
759            source: "test".to_string(),
760            user: None,
761            context: None,
762            caused_by: None,
763            version: "1.0".to_string(),
764            properties: HashMap::new(),
765            checksum: None,
766        }
767    }
768
769    fn make_triple_event(id: &str, predicate: &str) -> StreamEvent {
770        StreamEvent::TripleAdded {
771            subject: "http://example.org/s".to_string(),
772            predicate: predicate.to_string(),
773            object: "http://example.org/o".to_string(),
774            graph: None,
775            metadata: make_metadata(id),
776        }
777    }
778
779    // ── StreamingModelRunner Tests ───────────────────────────────────────────
780
781    #[test]
782    fn test_model_runner_basic_predict() {
783        let config = ModelConfig {
784            input_features: 3,
785            batch_size: 10,
786            max_latency_ms: 1000,
787            ..Default::default()
788        };
789        let runner = StreamingModelRunner::new(config);
790
791        let events = vec![
792            (Array1::from_vec(vec![1.0, 2.0, 3.0]), "evt-1".to_string()),
793            (Array1::from_vec(vec![4.0, 5.0, 6.0]), "evt-2".to_string()),
794        ];
795        let predictions = runner.predict(&events);
796        assert_eq!(predictions.len(), 2);
797        assert!(predictions[0].value.is_finite());
798        assert!(predictions[0].confidence >= 0.0 && predictions[0].confidence <= 1.0);
799    }
800
801    #[test]
802    fn test_model_runner_batch_trigger_by_size() {
803        let config = ModelConfig {
804            input_features: 2,
805            batch_size: 3,
806            max_latency_ms: 60_000,
807            ..Default::default()
808        };
809        let runner = StreamingModelRunner::new(config);
810
811        // Enqueue 2 events: no batch yet
812        let result1 = runner.enqueue(Array1::from_vec(vec![1.0, 2.0]), "e1".to_string());
813        assert!(result1.is_none());
814        let result2 = runner.enqueue(Array1::from_vec(vec![3.0, 4.0]), "e2".to_string());
815        assert!(result2.is_none());
816        assert_eq!(runner.pending_count(), 2);
817
818        // Third event triggers batch
819        let result3 = runner.enqueue(Array1::from_vec(vec![5.0, 6.0]), "e3".to_string());
820        assert!(result3.is_some());
821        let predictions = result3.expect("should have predictions");
822        assert_eq!(predictions.len(), 3);
823        assert_eq!(runner.pending_count(), 0);
824    }
825
826    #[test]
827    fn test_model_runner_flush() {
828        let config = ModelConfig {
829            input_features: 2,
830            batch_size: 100,
831            max_latency_ms: 60_000,
832            ..Default::default()
833        };
834        let runner = StreamingModelRunner::new(config);
835
836        runner.enqueue(Array1::from_vec(vec![1.0, 2.0]), "e1".to_string());
837        runner.enqueue(Array1::from_vec(vec![3.0, 4.0]), "e2".to_string());
838
839        let predictions = runner.flush();
840        assert_eq!(predictions.len(), 2);
841        assert_eq!(runner.pending_count(), 0);
842    }
843
844    #[test]
845    fn test_model_runner_flush_if_due() {
846        let config = ModelConfig {
847            input_features: 2,
848            batch_size: 100,
849            max_latency_ms: 10, // 10ms
850            ..Default::default()
851        };
852        let runner = StreamingModelRunner::new(config);
853
854        runner.enqueue(Array1::from_vec(vec![1.0, 2.0]), "e1".to_string());
855        std::thread::sleep(Duration::from_millis(20));
856
857        let result = runner.flush_if_due();
858        assert!(result.is_some());
859    }
860
861    #[test]
862    fn test_model_runner_wrong_dimensions_ignored() {
863        let config = ModelConfig {
864            input_features: 3,
865            ..Default::default()
866        };
867        let runner = StreamingModelRunner::new(config);
868        let result = runner.enqueue(Array1::from_vec(vec![1.0, 2.0]), "bad".to_string());
869        assert!(result.is_none());
870        assert_eq!(runner.pending_count(), 0);
871    }
872
873    #[test]
874    fn test_model_runner_update_weights() {
875        let config = ModelConfig {
876            input_features: 2,
877            ..Default::default()
878        };
879        let runner = StreamingModelRunner::new(config);
880        runner.update_weights(Array1::from_vec(vec![1.0, 2.0]), 0.5);
881
882        let predictions = runner.predict(&[(Array1::from_vec(vec![1.0, 1.0]), "e1".to_string())]);
883        // value = 0.5 + 1.0*1.0 + 2.0*1.0 = 3.5
884        assert!((predictions[0].value - 3.5).abs() < 1e-6);
885    }
886
887    #[test]
888    fn test_model_runner_stats() {
889        let config = ModelConfig {
890            input_features: 2,
891            batch_size: 2,
892            ..Default::default()
893        };
894        let runner = StreamingModelRunner::new(config);
895        runner.enqueue(Array1::from_vec(vec![1.0, 2.0]), "e1".to_string());
896        runner.enqueue(Array1::from_vec(vec![3.0, 4.0]), "e2".to_string());
897
898        let stats = runner.stats();
899        assert_eq!(stats.events_processed, 2);
900        assert_eq!(stats.batches_executed, 1);
901        assert_eq!(stats.size_triggered_batches, 1);
902    }
903
904    // ── StreamAnomalyDetector Tests ──────────────────────────────────────────
905
906    #[test]
907    fn test_anomaly_detector_normal_values() {
908        let config = AnomalyDetectorConfig {
909            sigma_threshold: 3.0,
910            window_size: 50,
911            min_samples: 5,
912            adaptive_rate: 0.0,
913        };
914        let detector = StreamAnomalyDetector::new(config);
915
916        // Feed normal values
917        for i in 0..20 {
918            let result = detector.is_anomaly(100.0 + (i as f64 * 0.1));
919            if i >= 5 {
920                assert!(
921                    !result.is_anomaly,
922                    "normal value should not be anomaly at i={}",
923                    i
924                );
925            }
926        }
927    }
928
929    #[test]
930    fn test_anomaly_detector_detects_outlier() {
931        let config = AnomalyDetectorConfig {
932            sigma_threshold: 3.0,
933            window_size: 100,
934            min_samples: 10,
935            adaptive_rate: 0.0,
936        };
937        let detector = StreamAnomalyDetector::new(config);
938
939        // Feed stable values
940        for _ in 0..30 {
941            detector.is_anomaly(100.0);
942        }
943
944        // Feed a huge outlier
945        let result = detector.is_anomaly(10000.0);
946        assert!(result.is_anomaly);
947        assert!(result.z_score > 3.0);
948    }
949
950    #[test]
951    fn test_anomaly_detector_insufficient_samples() {
952        let config = AnomalyDetectorConfig {
953            min_samples: 20,
954            ..Default::default()
955        };
956        let detector = StreamAnomalyDetector::new(config);
957
958        // Not enough samples yet
959        let result = detector.is_anomaly(999999.0);
960        assert!(!result.is_anomaly);
961        assert_eq!(result.window_samples, 1);
962    }
963
964    #[test]
965    fn test_anomaly_detector_sliding_window() {
966        let config = AnomalyDetectorConfig {
967            window_size: 10,
968            min_samples: 5,
969            sigma_threshold: 3.0,
970            adaptive_rate: 0.0,
971        };
972        let detector = StreamAnomalyDetector::new(config);
973
974        // Fill window with values around 100
975        for _ in 0..10 {
976            detector.is_anomaly(100.0);
977        }
978
979        // Now shift to values around 200 to fill the window
980        for _ in 0..10 {
981            detector.is_anomaly(200.0);
982        }
983
984        // After window shift, 200 should be normal
985        let result = detector.is_anomaly(200.0);
986        assert!(!result.is_anomaly);
987        assert!((result.window_mean - 200.0).abs() < 1.0);
988    }
989
990    #[test]
991    fn test_anomaly_detector_adaptive_threshold() {
992        let config = AnomalyDetectorConfig {
993            sigma_threshold: 3.0,
994            adaptive_rate: 1.0,
995            ..Default::default()
996        };
997        let detector = StreamAnomalyDetector::new(config);
998
999        let initial_threshold = detector.effective_threshold();
1000        detector.feedback(false); // false positive -> raise threshold
1001        let new_threshold = detector.effective_threshold();
1002        assert!(new_threshold > initial_threshold);
1003
1004        detector.feedback(true); // true positive -> lower threshold
1005        let final_threshold = detector.effective_threshold();
1006        assert!(final_threshold < new_threshold);
1007    }
1008
1009    #[test]
1010    fn test_anomaly_detector_stats() {
1011        let config = AnomalyDetectorConfig {
1012            sigma_threshold: 2.0,
1013            min_samples: 3,
1014            window_size: 20,
1015            adaptive_rate: 0.0,
1016        };
1017        let detector = StreamAnomalyDetector::new(config);
1018
1019        for _ in 0..10 {
1020            detector.is_anomaly(50.0);
1021        }
1022        detector.is_anomaly(9999.0); // anomaly
1023
1024        let stats = detector.stats();
1025        assert_eq!(stats.values_processed, 11);
1026        assert!(stats.anomalies_detected >= 1);
1027    }
1028
1029    #[test]
1030    fn test_anomaly_detector_reset() {
1031        let detector = StreamAnomalyDetector::new(AnomalyDetectorConfig::default());
1032        for _ in 0..20 {
1033            detector.is_anomaly(100.0);
1034        }
1035        detector.reset();
1036        let stats = detector.stats();
1037        assert_eq!(stats.values_processed, 0);
1038    }
1039
1040    // ── StreamFeatureExtractor Tests ─────────────────────────────────────────
1041
1042    #[test]
1043    fn test_feature_extractor_basic() {
1044        let config = FeatureExtractorConfig::default();
1045        let extractor = StreamFeatureExtractor::new(config);
1046
1047        let event = make_triple_event("e1", "http://example.org/name");
1048        let features = extractor.extract(&event, "e1");
1049        assert!(!features.values.is_empty());
1050        assert_eq!(features.values.len(), features.names.len());
1051    }
1052
1053    #[test]
1054    fn test_feature_extractor_predicate_selector() {
1055        let config = FeatureExtractorConfig {
1056            features: vec![
1057                FeatureDefinition {
1058                    name: "name_events".to_string(),
1059                    predicate_selector: Some("name".to_string()),
1060                    aggregation: FeatureAggregation::Count,
1061                },
1062                FeatureDefinition {
1063                    name: "age_events".to_string(),
1064                    predicate_selector: Some("age".to_string()),
1065                    aggregation: FeatureAggregation::Count,
1066                },
1067            ],
1068            window_size: 100,
1069        };
1070        let extractor = StreamFeatureExtractor::new(config);
1071
1072        // Add events with "name" predicate
1073        for i in 0..3 {
1074            let event = make_triple_event(&format!("n{}", i), "http://example.org/name");
1075            extractor.extract(&event, &format!("n{}", i));
1076        }
1077
1078        // Add events with "age" predicate
1079        let event = make_triple_event("a1", "http://example.org/age");
1080        let features = extractor.extract(&event, "a1");
1081
1082        // name_events should be 3, age_events should be 1
1083        assert_eq!(features.names[0], "name_events");
1084        assert!((features.values[0] - 3.0).abs() < 0.01);
1085        assert_eq!(features.names[1], "age_events");
1086        assert!((features.values[1] - 1.0).abs() < 0.01);
1087    }
1088
1089    #[test]
1090    fn test_feature_extractor_mean_aggregation() {
1091        let config = FeatureExtractorConfig {
1092            features: vec![FeatureDefinition {
1093                name: "ratio".to_string(),
1094                predicate_selector: Some("type".to_string()),
1095                aggregation: FeatureAggregation::Mean,
1096            }],
1097            window_size: 10,
1098        };
1099        let extractor = StreamFeatureExtractor::new(config);
1100
1101        // 2 matching out of 4 total
1102        extractor.extract(&make_triple_event("e1", "http://ex/type"), "e1");
1103        extractor.extract(&make_triple_event("e2", "http://ex/name"), "e2");
1104        extractor.extract(&make_triple_event("e3", "http://ex/type"), "e3");
1105        let features = extractor.extract(&make_triple_event("e4", "http://ex/name"), "e4");
1106
1107        // 2 matching out of 4 = 0.5 ratio
1108        assert!((features.values[0] - 0.5).abs() < 0.01);
1109    }
1110
1111    #[test]
1112    fn test_feature_extractor_window_eviction() {
1113        let config = FeatureExtractorConfig {
1114            features: vec![FeatureDefinition {
1115                name: "count".to_string(),
1116                predicate_selector: None,
1117                aggregation: FeatureAggregation::Count,
1118            }],
1119            window_size: 3,
1120        };
1121        let extractor = StreamFeatureExtractor::new(config);
1122
1123        for i in 0..5 {
1124            extractor.extract(
1125                &make_triple_event(&format!("e{}", i), "http://ex/p"),
1126                &format!("e{}", i),
1127            );
1128        }
1129
1130        assert_eq!(extractor.current_window_size(), 3);
1131    }
1132
1133    #[test]
1134    fn test_feature_extractor_reset() {
1135        let extractor = StreamFeatureExtractor::new(FeatureExtractorConfig::default());
1136        extractor.extract(&make_triple_event("e1", "http://ex/p"), "e1");
1137        extractor.reset();
1138        assert_eq!(extractor.current_window_size(), 0);
1139    }
1140
1141    #[test]
1142    fn test_feature_extractor_non_triple_events() {
1143        let config = FeatureExtractorConfig::default();
1144        let extractor = StreamFeatureExtractor::new(config);
1145
1146        let event = StreamEvent::SchemaChanged {
1147            schema_type: crate::event::SchemaType::Ontology,
1148            change_type: crate::event::SchemaChangeType::Added,
1149            details: "test".to_string(),
1150            metadata: make_metadata("schema-1"),
1151        };
1152        let features = extractor.extract(&event, "schema-1");
1153        assert!(!features.values.is_empty());
1154    }
1155}