Skip to main content

oxirs_stream/ml/
mod.rs

1//! # ML Model Integration for Stream Processing
2//!
3//! Provides ML inference capabilities embedded in the stream processing pipeline:
4//!
5//! - [`StreamingModelRunner`]: Runs ML inference on stream events with batching
6//! - [`StreamAnomalyDetector`]: Z-score based streaming anomaly detection with sliding window
7//! - [`StreamFeatureExtractor`]: Configurable feature extraction from RDF stream events
8
9use std::collections::{HashMap, VecDeque};
10use std::sync::Arc;
11use std::time::Instant;
12
13use chrono::{DateTime, Utc};
14use parking_lot::RwLock;
15use serde::{Deserialize, Serialize};
16use tracing::{debug, info, warn};
17
18use scirs2_core::ndarray_ext::Array1;
19
20use crate::event::StreamEvent;
21
22// ─── Model Configuration ─────────────────────────────────────────────────────
23
24/// Configuration for a streaming model runner
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ModelConfig {
27    /// Path or identifier for the model
28    pub model_path: String,
29    /// Maximum batch size before forcing inference
30    pub batch_size: usize,
31    /// Maximum latency before forcing inference (even if batch is not full)
32    pub max_latency_ms: u64,
33    /// Number of input features expected
34    pub input_features: usize,
35    /// Model name for logging
36    pub model_name: String,
37}
38
39impl Default for ModelConfig {
40    fn default() -> Self {
41        Self {
42            model_path: "default".to_string(),
43            batch_size: 32,
44            max_latency_ms: 100,
45            input_features: 4,
46            model_name: "default-model".to_string(),
47        }
48    }
49}
50
51/// A single prediction from the model
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct Prediction {
54    /// Predicted value or class
55    pub value: f64,
56    /// Confidence score (0.0 to 1.0)
57    pub confidence: f64,
58    /// Source event identifier
59    pub source_event_id: String,
60    /// Timestamp of the prediction
61    pub predicted_at: DateTime<Utc>,
62    /// Model that produced the prediction
63    pub model_name: String,
64}
65
66/// Statistics for the streaming model runner
67#[derive(Debug, Clone, Serialize, Deserialize, Default)]
68pub struct ModelRunnerStats {
69    /// Total events processed
70    pub events_processed: u64,
71    /// Total batches executed
72    pub batches_executed: u64,
73    /// Total predictions produced
74    pub predictions_produced: u64,
75    /// Average batch size
76    pub avg_batch_size: f64,
77    /// Average inference time per batch (milliseconds)
78    pub avg_inference_time_ms: f64,
79    /// Batches triggered by size threshold
80    pub size_triggered_batches: u64,
81    /// Batches triggered by latency threshold
82    pub latency_triggered_batches: u64,
83}
84
85/// A pending event waiting to be included in a batch
86#[derive(Debug, Clone)]
87struct PendingEvent {
88    features: Array1<f64>,
89    event_id: String,
90    queued_at: Instant,
91}
92
93/// Model weights for a simple linear model (more models can be added)
94#[derive(Debug, Clone)]
95struct LinearModelWeights {
96    weights: Array1<f64>,
97    bias: f64,
98}
99
100/// Runs ML inference on stream events with automatic batching.
101///
102/// Events are collected until either `batch_size` events accumulate
103/// or `max_latency_ms` elapses, then inference is run on the batch.
104pub struct StreamingModelRunner {
105    config: ModelConfig,
106    /// Pending events waiting for batch inference
107    pending: Arc<RwLock<Vec<PendingEvent>>>,
108    /// Model weights (simple linear model for now)
109    model: Arc<RwLock<LinearModelWeights>>,
110    /// Runner statistics
111    stats: Arc<RwLock<ModelRunnerStats>>,
112    /// When the oldest pending event was queued
113    batch_start: Arc<RwLock<Option<Instant>>>,
114}
115
116impl StreamingModelRunner {
117    /// Creates a new streaming model runner.
118    pub fn new(config: ModelConfig) -> Self {
119        // Initialize with small default weights
120        let weights = Array1::from_vec(vec![0.1; config.input_features]);
121        Self {
122            config: config.clone(),
123            pending: Arc::new(RwLock::new(Vec::with_capacity(config.batch_size))),
124            model: Arc::new(RwLock::new(LinearModelWeights { weights, bias: 0.0 })),
125            stats: Arc::new(RwLock::new(ModelRunnerStats::default())),
126            batch_start: Arc::new(RwLock::new(None)),
127        }
128    }
129
130    /// Enqueues an event for prediction.
131    ///
132    /// Returns predictions if a batch was triggered.
133    pub fn enqueue(&self, features: Array1<f64>, event_id: String) -> Option<Vec<Prediction>> {
134        if features.len() != self.config.input_features {
135            warn!(
136                "Feature dimension mismatch: expected {}, got {}",
137                self.config.input_features,
138                features.len()
139            );
140            return None;
141        }
142
143        let mut pending = self.pending.write();
144        if pending.is_empty() {
145            *self.batch_start.write() = Some(Instant::now());
146        }
147        pending.push(PendingEvent {
148            features,
149            event_id,
150            queued_at: Instant::now(),
151        });
152
153        // Check if batch should be triggered
154        if pending.len() >= self.config.batch_size {
155            let events: Vec<PendingEvent> = std::mem::take(&mut *pending);
156            drop(pending);
157            *self.batch_start.write() = None;
158            self.stats.write().size_triggered_batches += 1;
159            Some(self.run_inference(events))
160        } else {
161            None
162        }
163    }
164
165    /// Flushes any pending events if the latency threshold has been exceeded.
166    ///
167    /// Returns predictions if flush was needed.
168    pub fn flush_if_due(&self) -> Option<Vec<Prediction>> {
169        let should_flush = {
170            let batch_start = self.batch_start.read();
171            match *batch_start {
172                Some(start) => start.elapsed().as_millis() as u64 >= self.config.max_latency_ms,
173                None => false,
174            }
175        };
176
177        if should_flush {
178            let mut pending = self.pending.write();
179            if pending.is_empty() {
180                return None;
181            }
182            let events: Vec<PendingEvent> = std::mem::take(&mut *pending);
183            drop(pending);
184            *self.batch_start.write() = None;
185            self.stats.write().latency_triggered_batches += 1;
186            Some(self.run_inference(events))
187        } else {
188            None
189        }
190    }
191
192    /// Forces inference on all pending events regardless of thresholds.
193    pub fn flush(&self) -> Vec<Prediction> {
194        let mut pending = self.pending.write();
195        if pending.is_empty() {
196            return Vec::new();
197        }
198        let events: Vec<PendingEvent> = std::mem::take(&mut *pending);
199        drop(pending);
200        *self.batch_start.write() = None;
201        self.run_inference(events)
202    }
203
204    /// Runs batched inference directly on a slice of stream events.
205    pub fn predict(&self, events: &[(Array1<f64>, String)]) -> Vec<Prediction> {
206        let pending_events: Vec<PendingEvent> = events
207            .iter()
208            .map(|(features, event_id)| PendingEvent {
209                features: features.clone(),
210                event_id: event_id.clone(),
211                queued_at: Instant::now(),
212            })
213            .collect();
214        self.run_inference(pending_events)
215    }
216
217    /// Updates the model weights.
218    pub fn update_weights(&self, weights: Array1<f64>, bias: f64) {
219        let mut model = self.model.write();
220        model.weights = weights;
221        model.bias = bias;
222        info!("Model {} weights updated", self.config.model_name);
223    }
224
225    /// Returns runner statistics.
226    pub fn stats(&self) -> ModelRunnerStats {
227        self.stats.read().clone()
228    }
229
230    /// Returns the number of pending events.
231    pub fn pending_count(&self) -> usize {
232        self.pending.read().len()
233    }
234
235    /// Internal inference function.
236    fn run_inference(&self, events: Vec<PendingEvent>) -> Vec<Prediction> {
237        let start = Instant::now();
238        let model = self.model.read();
239        let batch_size = events.len();
240
241        let predictions: Vec<Prediction> = events
242            .iter()
243            .map(|event| {
244                let mut value = model.bias;
245                let n = model.weights.len().min(event.features.len());
246                for i in 0..n {
247                    value += model.weights[i] * event.features[i];
248                }
249                // Sigmoid for confidence
250                let confidence = 1.0 / (1.0 + (-value).exp());
251
252                Prediction {
253                    value,
254                    confidence: confidence.clamp(0.0, 1.0),
255                    source_event_id: event.event_id.clone(),
256                    predicted_at: Utc::now(),
257                    model_name: self.config.model_name.clone(),
258                }
259            })
260            .collect();
261
262        let elapsed_ms = start.elapsed().as_micros() as f64 / 1000.0;
263
264        let mut stats = self.stats.write();
265        stats.events_processed += batch_size as u64;
266        stats.batches_executed += 1;
267        stats.predictions_produced += predictions.len() as u64;
268        let total_batches = stats.batches_executed as f64;
269        stats.avg_batch_size =
270            (stats.avg_batch_size * (total_batches - 1.0) + batch_size as f64) / total_batches;
271        stats.avg_inference_time_ms =
272            (stats.avg_inference_time_ms * (total_batches - 1.0) + elapsed_ms) / total_batches;
273
274        debug!(
275            "Inference batch: {} events, {:.2}ms",
276            batch_size, elapsed_ms
277        );
278
279        predictions
280    }
281}
282
283// ─── Streaming Anomaly Detector ──────────────────────────────────────────────
284
285/// Configuration for the streaming anomaly detector
286#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct AnomalyDetectorConfig {
288    /// Z-score threshold for anomaly detection
289    pub sigma_threshold: f64,
290    /// Sliding window size for statistics computation
291    pub window_size: usize,
292    /// Minimum samples before detection starts
293    pub min_samples: usize,
294    /// Adaptive threshold learning rate (0.0 = fixed, 1.0 = fully adaptive)
295    pub adaptive_rate: f64,
296}
297
298impl Default for AnomalyDetectorConfig {
299    fn default() -> Self {
300        Self {
301            sigma_threshold: 3.0,
302            window_size: 100,
303            min_samples: 10,
304            adaptive_rate: 0.0,
305        }
306    }
307}
308
309/// Result of anomaly detection on a single value
310#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct AnomalyCheckResult {
312    /// Whether the value is anomalous
313    pub is_anomaly: bool,
314    /// The Z-score of the value
315    pub z_score: f64,
316    /// The current mean of the sliding window
317    pub window_mean: f64,
318    /// The current standard deviation of the sliding window
319    pub window_stddev: f64,
320    /// The effective threshold used
321    pub threshold: f64,
322    /// Number of samples in the window
323    pub window_samples: usize,
324}
325
326/// Statistics for the anomaly detector
327#[derive(Debug, Clone, Serialize, Deserialize, Default)]
328pub struct AnomalyDetectorStats {
329    /// Total values processed
330    pub values_processed: u64,
331    /// Total anomalies detected
332    pub anomalies_detected: u64,
333    /// Current window mean
334    pub current_mean: f64,
335    /// Current window stddev
336    pub current_stddev: f64,
337    /// Detection rate
338    pub detection_rate: f64,
339}
340
341/// Z-score based streaming anomaly detector with a sliding window.
342///
343/// Maintains a sliding window of recent values, computes running mean and
344/// standard deviation, and flags values whose Z-score exceeds the configured
345/// sigma threshold.
346pub struct StreamAnomalyDetector {
347    config: AnomalyDetectorConfig,
348    /// Sliding window of recent values
349    window: Arc<RwLock<VecDeque<f64>>>,
350    /// Running sum for incremental mean computation
351    running_sum: Arc<RwLock<f64>>,
352    /// Running sum of squares for incremental stddev
353    running_sum_sq: Arc<RwLock<f64>>,
354    /// Effective threshold (may be adapted over time)
355    effective_threshold: Arc<RwLock<f64>>,
356    /// Statistics
357    stats: Arc<RwLock<AnomalyDetectorStats>>,
358}
359
360impl StreamAnomalyDetector {
361    /// Creates a new anomaly detector.
362    pub fn new(config: AnomalyDetectorConfig) -> Self {
363        let threshold = config.sigma_threshold;
364        Self {
365            config: config.clone(),
366            window: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
367            running_sum: Arc::new(RwLock::new(0.0)),
368            running_sum_sq: Arc::new(RwLock::new(0.0)),
369            effective_threshold: Arc::new(RwLock::new(threshold)),
370            stats: Arc::new(RwLock::new(AnomalyDetectorStats::default())),
371        }
372    }
373
374    /// Checks whether a value is anomalous.
375    pub fn is_anomaly(&self, value: f64) -> AnomalyCheckResult {
376        let mut window = self.window.write();
377        let mut sum = self.running_sum.write();
378        let mut sum_sq = self.running_sum_sq.write();
379
380        // Add value to window
381        if window.len() >= self.config.window_size {
382            if let Some(removed) = window.pop_front() {
383                *sum -= removed;
384                *sum_sq -= removed * removed;
385            }
386        }
387        window.push_back(value);
388        *sum += value;
389        *sum_sq += value * value;
390
391        let n = window.len();
392
393        let mut stats = self.stats.write();
394        stats.values_processed += 1;
395
396        // Need minimum samples
397        if n < self.config.min_samples {
398            return AnomalyCheckResult {
399                is_anomaly: false,
400                z_score: 0.0,
401                window_mean: if n > 0 { *sum / n as f64 } else { 0.0 },
402                window_stddev: 0.0,
403                threshold: *self.effective_threshold.read(),
404                window_samples: n,
405            };
406        }
407
408        let mean = *sum / n as f64;
409        let variance = (*sum_sq / n as f64) - (mean * mean);
410        let stddev = if variance > 0.0 { variance.sqrt() } else { 0.0 };
411
412        let z_score = if stddev > 1e-10 {
413            (value - mean).abs() / stddev
414        } else {
415            0.0
416        };
417
418        let threshold = *self.effective_threshold.read();
419        let is_anomaly = z_score > threshold;
420
421        if is_anomaly {
422            stats.anomalies_detected += 1;
423        }
424        stats.current_mean = mean;
425        stats.current_stddev = stddev;
426        stats.detection_rate = if stats.values_processed > 0 {
427            stats.anomalies_detected as f64 / stats.values_processed as f64
428        } else {
429            0.0
430        };
431
432        AnomalyCheckResult {
433            is_anomaly,
434            z_score,
435            window_mean: mean,
436            window_stddev: stddev,
437            threshold,
438            window_samples: n,
439        }
440    }
441
442    /// Provides feedback to adapt the threshold.
443    pub fn feedback(&self, was_true_anomaly: bool) {
444        if self.config.adaptive_rate <= 0.0 {
445            return;
446        }
447        let mut threshold = self.effective_threshold.write();
448        if was_true_anomaly {
449            // Lower threshold slightly to catch more
450            *threshold *= 1.0 - (self.config.adaptive_rate * 0.02);
451        } else {
452            // Raise threshold slightly to reduce false positives
453            *threshold *= 1.0 + (self.config.adaptive_rate * 0.02);
454        }
455        // Clamp to reasonable range
456        *threshold = threshold.clamp(1.0, 10.0);
457    }
458
459    /// Returns detector statistics.
460    pub fn stats(&self) -> AnomalyDetectorStats {
461        self.stats.read().clone()
462    }
463
464    /// Returns the current effective threshold.
465    pub fn effective_threshold(&self) -> f64 {
466        *self.effective_threshold.read()
467    }
468
469    /// Resets the detector state.
470    pub fn reset(&self) {
471        self.window.write().clear();
472        *self.running_sum.write() = 0.0;
473        *self.running_sum_sq.write() = 0.0;
474        *self.stats.write() = AnomalyDetectorStats::default();
475    }
476}
477
478// ─── Feature Extractor ───────────────────────────────────────────────────────
479
480/// A feature definition describing how to extract a numeric feature from an event
481#[derive(Debug, Clone, Serialize, Deserialize)]
482pub struct FeatureDefinition {
483    /// Feature name
484    pub name: String,
485    /// Predicate selector: if the event's predicate contains this string, extract
486    pub predicate_selector: Option<String>,
487    /// Aggregation type for window-based features
488    pub aggregation: FeatureAggregation,
489}
490
491/// Aggregation type for a feature
492#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
493pub enum FeatureAggregation {
494    /// Use the latest value
495    Latest,
496    /// Count occurrences in window
497    Count,
498    /// Sum values in window
499    Sum,
500    /// Compute mean over window
501    Mean,
502}
503
504/// Configuration for the feature extractor
505#[derive(Debug, Clone, Serialize, Deserialize)]
506pub struct FeatureExtractorConfig {
507    /// Feature definitions
508    pub features: Vec<FeatureDefinition>,
509    /// Window size for aggregation-based features
510    pub window_size: usize,
511}
512
513impl Default for FeatureExtractorConfig {
514    fn default() -> Self {
515        Self {
516            features: vec![
517                FeatureDefinition {
518                    name: "event_count".to_string(),
519                    predicate_selector: None,
520                    aggregation: FeatureAggregation::Count,
521                },
522                FeatureDefinition {
523                    name: "event_rate".to_string(),
524                    predicate_selector: None,
525                    aggregation: FeatureAggregation::Mean,
526                },
527            ],
528            window_size: 50,
529        }
530    }
531}
532
533/// Extracted feature vector
534#[derive(Debug, Clone)]
535pub struct ExtractedFeatures {
536    /// Feature values as a numeric array
537    pub values: Array1<f64>,
538    /// Feature names (corresponding to values)
539    pub names: Vec<String>,
540    /// Timestamp of extraction
541    pub extracted_at: DateTime<Utc>,
542    /// Source event ID
543    pub event_id: String,
544}
545
546/// Configurable feature extractor for RDF stream events.
547///
548/// Extracts numeric features from stream events based on configured
549/// feature definitions with predicate selectors and aggregation types.
550pub struct StreamFeatureExtractor {
551    config: FeatureExtractorConfig,
552    /// History window of events for aggregation features
553    history: Arc<RwLock<VecDeque<EventSnapshot>>>,
554    /// Per-feature running values for aggregation
555    running_values: Arc<RwLock<HashMap<String, VecDeque<f64>>>>,
556}
557
558/// Lightweight snapshot of an event for windowed features
559#[derive(Debug, Clone)]
560struct EventSnapshot {
561    event_type: String,
562    predicate: Option<String>,
563    timestamp: Instant,
564}
565
566impl StreamFeatureExtractor {
567    /// Creates a new feature extractor.
568    pub fn new(config: FeatureExtractorConfig) -> Self {
569        Self {
570            config: config.clone(),
571            history: Arc::new(RwLock::new(VecDeque::with_capacity(config.window_size))),
572            running_values: Arc::new(RwLock::new(HashMap::new())),
573        }
574    }
575
576    /// Extracts features from a stream event.
577    pub fn extract(&self, event: &StreamEvent, event_id: &str) -> ExtractedFeatures {
578        let event_type = Self::event_type_name(event);
579        let predicate = Self::extract_predicate(event);
580
581        // Update history
582        let mut history = self.history.write();
583        history.push_back(EventSnapshot {
584            event_type: event_type.clone(),
585            predicate: predicate.clone(),
586            timestamp: Instant::now(),
587        });
588        while history.len() > self.config.window_size {
589            history.pop_front();
590        }
591        let history_len = history.len();
592
593        // Compute features
594        let mut values = Vec::with_capacity(self.config.features.len());
595        let mut names = Vec::with_capacity(self.config.features.len());
596
597        for feature_def in &self.config.features {
598            let matched = match &feature_def.predicate_selector {
599                Some(selector) => predicate
600                    .as_ref()
601                    .map(|p| p.contains(selector))
602                    .unwrap_or(false),
603                None => true, // No selector means match all events
604            };
605
606            let value = match feature_def.aggregation {
607                FeatureAggregation::Count => {
608                    // Count matching events in the window (regardless of current event)
609                    match &feature_def.predicate_selector {
610                        Some(selector) => history
611                            .iter()
612                            .filter(|e| {
613                                e.predicate
614                                    .as_ref()
615                                    .map(|p| p.contains(selector))
616                                    .unwrap_or(false)
617                            })
618                            .count() as f64,
619                        None => history_len as f64,
620                    }
621                }
622                FeatureAggregation::Latest => {
623                    if matched {
624                        1.0
625                    } else {
626                        0.0
627                    }
628                }
629                FeatureAggregation::Sum => {
630                    let running = self.running_values.read();
631                    running
632                        .get(&feature_def.name)
633                        .map(|v| v.iter().sum())
634                        .unwrap_or(0.0)
635                }
636                FeatureAggregation::Mean => {
637                    if history_len > 0 {
638                        match &feature_def.predicate_selector {
639                            Some(selector) => {
640                                let count = history
641                                    .iter()
642                                    .filter(|e| {
643                                        e.predicate
644                                            .as_ref()
645                                            .map(|p| p.contains(selector))
646                                            .unwrap_or(false)
647                                    })
648                                    .count();
649                                count as f64 / history_len as f64
650                            }
651                            None => 1.0, // All events match
652                        }
653                    } else {
654                        0.0
655                    }
656                }
657            };
658
659            values.push(value);
660            names.push(feature_def.name.clone());
661        }
662
663        // Update running values for matched event
664        {
665            let mut running = self.running_values.write();
666            for feature_def in &self.config.features {
667                let entry = running.entry(feature_def.name.clone()).or_default();
668                let matched = match &feature_def.predicate_selector {
669                    Some(selector) => predicate
670                        .as_ref()
671                        .map(|p| p.contains(selector))
672                        .unwrap_or(false),
673                    None => true,
674                };
675                entry.push_back(if matched { 1.0 } else { 0.0 });
676                while entry.len() > self.config.window_size {
677                    entry.pop_front();
678                }
679            }
680        }
681
682        ExtractedFeatures {
683            values: Array1::from_vec(values),
684            names,
685            extracted_at: Utc::now(),
686            event_id: event_id.to_string(),
687        }
688    }
689
690    /// Resets the extractor state.
691    pub fn reset(&self) {
692        self.history.write().clear();
693        self.running_values.write().clear();
694    }
695
696    /// Returns the current window size.
697    pub fn current_window_size(&self) -> usize {
698        self.history.read().len()
699    }
700
701    /// Returns the event type name.
702    fn event_type_name(event: &StreamEvent) -> String {
703        match event {
704            StreamEvent::TripleAdded { .. } => "TripleAdded",
705            StreamEvent::TripleRemoved { .. } => "TripleRemoved",
706            StreamEvent::QuadAdded { .. } => "QuadAdded",
707            StreamEvent::QuadRemoved { .. } => "QuadRemoved",
708            StreamEvent::GraphCreated { .. } => "GraphCreated",
709            StreamEvent::GraphCleared { .. } => "GraphCleared",
710            StreamEvent::GraphDeleted { .. } => "GraphDeleted",
711            StreamEvent::SparqlUpdate { .. } => "SparqlUpdate",
712            StreamEvent::TransactionBegin { .. } => "TransactionBegin",
713            StreamEvent::TransactionCommit { .. } => "TransactionCommit",
714            StreamEvent::TransactionAbort { .. } => "TransactionAbort",
715            StreamEvent::SchemaChanged { .. } => "SchemaChanged",
716            _ => "Other",
717        }
718        .to_string()
719    }
720
721    /// Extracts the predicate from a stream event, if it has one.
722    fn extract_predicate(event: &StreamEvent) -> Option<String> {
723        match event {
724            StreamEvent::TripleAdded { predicate, .. }
725            | StreamEvent::TripleRemoved { predicate, .. }
726            | StreamEvent::QuadAdded { predicate, .. }
727            | StreamEvent::QuadRemoved { predicate, .. } => Some(predicate.clone()),
728            _ => None,
729        }
730    }
731}
732
733// ─── Tests ───────────────────────────────────────────────────────────────────
734
735#[cfg(test)]
736mod tests {
737    use super::*;
738    use crate::event::EventMetadata;
739    use std::time::Duration;
740
741    fn make_metadata(id: &str) -> EventMetadata {
742        EventMetadata {
743            event_id: id.to_string(),
744            timestamp: Utc::now(),
745            source: "test".to_string(),
746            user: None,
747            context: None,
748            caused_by: None,
749            version: "1.0".to_string(),
750            properties: HashMap::new(),
751            checksum: None,
752        }
753    }
754
755    fn make_triple_event(id: &str, predicate: &str) -> StreamEvent {
756        StreamEvent::TripleAdded {
757            subject: "http://example.org/s".to_string(),
758            predicate: predicate.to_string(),
759            object: "http://example.org/o".to_string(),
760            graph: None,
761            metadata: make_metadata(id),
762        }
763    }
764
765    // ── StreamingModelRunner Tests ───────────────────────────────────────────
766
767    #[test]
768    fn test_model_runner_basic_predict() {
769        let config = ModelConfig {
770            input_features: 3,
771            batch_size: 10,
772            max_latency_ms: 1000,
773            ..Default::default()
774        };
775        let runner = StreamingModelRunner::new(config);
776
777        let events = vec![
778            (Array1::from_vec(vec![1.0, 2.0, 3.0]), "evt-1".to_string()),
779            (Array1::from_vec(vec![4.0, 5.0, 6.0]), "evt-2".to_string()),
780        ];
781        let predictions = runner.predict(&events);
782        assert_eq!(predictions.len(), 2);
783        assert!(predictions[0].value.is_finite());
784        assert!(predictions[0].confidence >= 0.0 && predictions[0].confidence <= 1.0);
785    }
786
787    #[test]
788    fn test_model_runner_batch_trigger_by_size() {
789        let config = ModelConfig {
790            input_features: 2,
791            batch_size: 3,
792            max_latency_ms: 60_000,
793            ..Default::default()
794        };
795        let runner = StreamingModelRunner::new(config);
796
797        // Enqueue 2 events: no batch yet
798        let result1 = runner.enqueue(Array1::from_vec(vec![1.0, 2.0]), "e1".to_string());
799        assert!(result1.is_none());
800        let result2 = runner.enqueue(Array1::from_vec(vec![3.0, 4.0]), "e2".to_string());
801        assert!(result2.is_none());
802        assert_eq!(runner.pending_count(), 2);
803
804        // Third event triggers batch
805        let result3 = runner.enqueue(Array1::from_vec(vec![5.0, 6.0]), "e3".to_string());
806        assert!(result3.is_some());
807        let predictions = result3.expect("should have predictions");
808        assert_eq!(predictions.len(), 3);
809        assert_eq!(runner.pending_count(), 0);
810    }
811
812    #[test]
813    fn test_model_runner_flush() {
814        let config = ModelConfig {
815            input_features: 2,
816            batch_size: 100,
817            max_latency_ms: 60_000,
818            ..Default::default()
819        };
820        let runner = StreamingModelRunner::new(config);
821
822        runner.enqueue(Array1::from_vec(vec![1.0, 2.0]), "e1".to_string());
823        runner.enqueue(Array1::from_vec(vec![3.0, 4.0]), "e2".to_string());
824
825        let predictions = runner.flush();
826        assert_eq!(predictions.len(), 2);
827        assert_eq!(runner.pending_count(), 0);
828    }
829
830    #[test]
831    fn test_model_runner_flush_if_due() {
832        let config = ModelConfig {
833            input_features: 2,
834            batch_size: 100,
835            max_latency_ms: 10, // 10ms
836            ..Default::default()
837        };
838        let runner = StreamingModelRunner::new(config);
839
840        runner.enqueue(Array1::from_vec(vec![1.0, 2.0]), "e1".to_string());
841        std::thread::sleep(Duration::from_millis(20));
842
843        let result = runner.flush_if_due();
844        assert!(result.is_some());
845    }
846
847    #[test]
848    fn test_model_runner_wrong_dimensions_ignored() {
849        let config = ModelConfig {
850            input_features: 3,
851            ..Default::default()
852        };
853        let runner = StreamingModelRunner::new(config);
854        let result = runner.enqueue(Array1::from_vec(vec![1.0, 2.0]), "bad".to_string());
855        assert!(result.is_none());
856        assert_eq!(runner.pending_count(), 0);
857    }
858
859    #[test]
860    fn test_model_runner_update_weights() {
861        let config = ModelConfig {
862            input_features: 2,
863            ..Default::default()
864        };
865        let runner = StreamingModelRunner::new(config);
866        runner.update_weights(Array1::from_vec(vec![1.0, 2.0]), 0.5);
867
868        let predictions = runner.predict(&[(Array1::from_vec(vec![1.0, 1.0]), "e1".to_string())]);
869        // value = 0.5 + 1.0*1.0 + 2.0*1.0 = 3.5
870        assert!((predictions[0].value - 3.5).abs() < 1e-6);
871    }
872
873    #[test]
874    fn test_model_runner_stats() {
875        let config = ModelConfig {
876            input_features: 2,
877            batch_size: 2,
878            ..Default::default()
879        };
880        let runner = StreamingModelRunner::new(config);
881        runner.enqueue(Array1::from_vec(vec![1.0, 2.0]), "e1".to_string());
882        runner.enqueue(Array1::from_vec(vec![3.0, 4.0]), "e2".to_string());
883
884        let stats = runner.stats();
885        assert_eq!(stats.events_processed, 2);
886        assert_eq!(stats.batches_executed, 1);
887        assert_eq!(stats.size_triggered_batches, 1);
888    }
889
890    // ── StreamAnomalyDetector Tests ──────────────────────────────────────────
891
892    #[test]
893    fn test_anomaly_detector_normal_values() {
894        let config = AnomalyDetectorConfig {
895            sigma_threshold: 3.0,
896            window_size: 50,
897            min_samples: 5,
898            adaptive_rate: 0.0,
899        };
900        let detector = StreamAnomalyDetector::new(config);
901
902        // Feed normal values
903        for i in 0..20 {
904            let result = detector.is_anomaly(100.0 + (i as f64 * 0.1));
905            if i >= 5 {
906                assert!(
907                    !result.is_anomaly,
908                    "normal value should not be anomaly at i={}",
909                    i
910                );
911            }
912        }
913    }
914
915    #[test]
916    fn test_anomaly_detector_detects_outlier() {
917        let config = AnomalyDetectorConfig {
918            sigma_threshold: 3.0,
919            window_size: 100,
920            min_samples: 10,
921            adaptive_rate: 0.0,
922        };
923        let detector = StreamAnomalyDetector::new(config);
924
925        // Feed stable values
926        for _ in 0..30 {
927            detector.is_anomaly(100.0);
928        }
929
930        // Feed a huge outlier
931        let result = detector.is_anomaly(10000.0);
932        assert!(result.is_anomaly);
933        assert!(result.z_score > 3.0);
934    }
935
936    #[test]
937    fn test_anomaly_detector_insufficient_samples() {
938        let config = AnomalyDetectorConfig {
939            min_samples: 20,
940            ..Default::default()
941        };
942        let detector = StreamAnomalyDetector::new(config);
943
944        // Not enough samples yet
945        let result = detector.is_anomaly(999999.0);
946        assert!(!result.is_anomaly);
947        assert_eq!(result.window_samples, 1);
948    }
949
950    #[test]
951    fn test_anomaly_detector_sliding_window() {
952        let config = AnomalyDetectorConfig {
953            window_size: 10,
954            min_samples: 5,
955            sigma_threshold: 3.0,
956            adaptive_rate: 0.0,
957        };
958        let detector = StreamAnomalyDetector::new(config);
959
960        // Fill window with values around 100
961        for _ in 0..10 {
962            detector.is_anomaly(100.0);
963        }
964
965        // Now shift to values around 200 to fill the window
966        for _ in 0..10 {
967            detector.is_anomaly(200.0);
968        }
969
970        // After window shift, 200 should be normal
971        let result = detector.is_anomaly(200.0);
972        assert!(!result.is_anomaly);
973        assert!((result.window_mean - 200.0).abs() < 1.0);
974    }
975
976    #[test]
977    fn test_anomaly_detector_adaptive_threshold() {
978        let config = AnomalyDetectorConfig {
979            sigma_threshold: 3.0,
980            adaptive_rate: 1.0,
981            ..Default::default()
982        };
983        let detector = StreamAnomalyDetector::new(config);
984
985        let initial_threshold = detector.effective_threshold();
986        detector.feedback(false); // false positive -> raise threshold
987        let new_threshold = detector.effective_threshold();
988        assert!(new_threshold > initial_threshold);
989
990        detector.feedback(true); // true positive -> lower threshold
991        let final_threshold = detector.effective_threshold();
992        assert!(final_threshold < new_threshold);
993    }
994
995    #[test]
996    fn test_anomaly_detector_stats() {
997        let config = AnomalyDetectorConfig {
998            sigma_threshold: 2.0,
999            min_samples: 3,
1000            window_size: 20,
1001            adaptive_rate: 0.0,
1002        };
1003        let detector = StreamAnomalyDetector::new(config);
1004
1005        for _ in 0..10 {
1006            detector.is_anomaly(50.0);
1007        }
1008        detector.is_anomaly(9999.0); // anomaly
1009
1010        let stats = detector.stats();
1011        assert_eq!(stats.values_processed, 11);
1012        assert!(stats.anomalies_detected >= 1);
1013    }
1014
1015    #[test]
1016    fn test_anomaly_detector_reset() {
1017        let detector = StreamAnomalyDetector::new(AnomalyDetectorConfig::default());
1018        for _ in 0..20 {
1019            detector.is_anomaly(100.0);
1020        }
1021        detector.reset();
1022        let stats = detector.stats();
1023        assert_eq!(stats.values_processed, 0);
1024    }
1025
1026    // ── StreamFeatureExtractor Tests ─────────────────────────────────────────
1027
1028    #[test]
1029    fn test_feature_extractor_basic() {
1030        let config = FeatureExtractorConfig::default();
1031        let extractor = StreamFeatureExtractor::new(config);
1032
1033        let event = make_triple_event("e1", "http://example.org/name");
1034        let features = extractor.extract(&event, "e1");
1035        assert!(!features.values.is_empty());
1036        assert_eq!(features.values.len(), features.names.len());
1037    }
1038
1039    #[test]
1040    fn test_feature_extractor_predicate_selector() {
1041        let config = FeatureExtractorConfig {
1042            features: vec![
1043                FeatureDefinition {
1044                    name: "name_events".to_string(),
1045                    predicate_selector: Some("name".to_string()),
1046                    aggregation: FeatureAggregation::Count,
1047                },
1048                FeatureDefinition {
1049                    name: "age_events".to_string(),
1050                    predicate_selector: Some("age".to_string()),
1051                    aggregation: FeatureAggregation::Count,
1052                },
1053            ],
1054            window_size: 100,
1055        };
1056        let extractor = StreamFeatureExtractor::new(config);
1057
1058        // Add events with "name" predicate
1059        for i in 0..3 {
1060            let event = make_triple_event(&format!("n{}", i), "http://example.org/name");
1061            extractor.extract(&event, &format!("n{}", i));
1062        }
1063
1064        // Add events with "age" predicate
1065        let event = make_triple_event("a1", "http://example.org/age");
1066        let features = extractor.extract(&event, "a1");
1067
1068        // name_events should be 3, age_events should be 1
1069        assert_eq!(features.names[0], "name_events");
1070        assert!((features.values[0] - 3.0).abs() < 0.01);
1071        assert_eq!(features.names[1], "age_events");
1072        assert!((features.values[1] - 1.0).abs() < 0.01);
1073    }
1074
1075    #[test]
1076    fn test_feature_extractor_mean_aggregation() {
1077        let config = FeatureExtractorConfig {
1078            features: vec![FeatureDefinition {
1079                name: "ratio".to_string(),
1080                predicate_selector: Some("type".to_string()),
1081                aggregation: FeatureAggregation::Mean,
1082            }],
1083            window_size: 10,
1084        };
1085        let extractor = StreamFeatureExtractor::new(config);
1086
1087        // 2 matching out of 4 total
1088        extractor.extract(&make_triple_event("e1", "http://ex/type"), "e1");
1089        extractor.extract(&make_triple_event("e2", "http://ex/name"), "e2");
1090        extractor.extract(&make_triple_event("e3", "http://ex/type"), "e3");
1091        let features = extractor.extract(&make_triple_event("e4", "http://ex/name"), "e4");
1092
1093        // 2 matching out of 4 = 0.5 ratio
1094        assert!((features.values[0] - 0.5).abs() < 0.01);
1095    }
1096
1097    #[test]
1098    fn test_feature_extractor_window_eviction() {
1099        let config = FeatureExtractorConfig {
1100            features: vec![FeatureDefinition {
1101                name: "count".to_string(),
1102                predicate_selector: None,
1103                aggregation: FeatureAggregation::Count,
1104            }],
1105            window_size: 3,
1106        };
1107        let extractor = StreamFeatureExtractor::new(config);
1108
1109        for i in 0..5 {
1110            extractor.extract(
1111                &make_triple_event(&format!("e{}", i), "http://ex/p"),
1112                &format!("e{}", i),
1113            );
1114        }
1115
1116        assert_eq!(extractor.current_window_size(), 3);
1117    }
1118
1119    #[test]
1120    fn test_feature_extractor_reset() {
1121        let extractor = StreamFeatureExtractor::new(FeatureExtractorConfig::default());
1122        extractor.extract(&make_triple_event("e1", "http://ex/p"), "e1");
1123        extractor.reset();
1124        assert_eq!(extractor.current_window_size(), 0);
1125    }
1126
1127    #[test]
1128    fn test_feature_extractor_non_triple_events() {
1129        let config = FeatureExtractorConfig::default();
1130        let extractor = StreamFeatureExtractor::new(config);
1131
1132        let event = StreamEvent::SchemaChanged {
1133            schema_type: crate::event::SchemaType::Ontology,
1134            change_type: crate::event::SchemaChangeType::Added,
1135            details: "test".to_string(),
1136            metadata: make_metadata("schema-1"),
1137        };
1138        let features = extractor.extract(&event, "schema-1");
1139        assert!(!features.values.is_empty());
1140    }
1141}