1use crate::types::{AnomalyResult, DataMatrix};
8use rand::prelude::*;
9use rand::{Rng, rng};
10use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
11use serde::{Deserialize, Serialize};
12use std::collections::VecDeque;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct StreamingConfig {
21 pub n_trees: usize,
23 pub sample_size: usize,
25 pub window_size: usize,
27 pub rebuild_interval: usize,
29 pub contamination: f64,
31 pub use_sliding_window: bool,
33}
34
35impl Default for StreamingConfig {
36 fn default() -> Self {
37 Self {
38 n_trees: 100,
39 sample_size: 256,
40 window_size: 10000,
41 rebuild_interval: 1000,
42 contamination: 0.1,
43 use_sliding_window: true,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct StreamingState {
51 window: VecDeque<Vec<f64>>,
53 n_features: usize,
55 trees: Vec<StreamingITree>,
57 samples_since_rebuild: usize,
59 total_samples: usize,
61 score_stats: OnlineStats,
63 threshold: f64,
65}
66
67impl StreamingState {
68 pub fn new(n_features: usize) -> Self {
70 Self {
71 window: VecDeque::new(),
72 n_features,
73 trees: Vec::new(),
74 samples_since_rebuild: 0,
75 total_samples: 0,
76 score_stats: OnlineStats::new(),
77 threshold: 0.5,
78 }
79 }
80
81 pub fn window_size(&self) -> usize {
83 self.window.len()
84 }
85
86 pub fn total_samples(&self) -> usize {
88 self.total_samples
89 }
90
91 pub fn threshold(&self) -> f64 {
93 self.threshold
94 }
95}
96
97#[derive(Debug, Clone, Default)]
99struct OnlineStats {
100 count: u64,
101 mean: f64,
102 m2: f64, min: f64,
104 max: f64,
105}
106
107impl OnlineStats {
108 fn new() -> Self {
109 Self {
110 count: 0,
111 mean: 0.0,
112 m2: 0.0,
113 min: f64::MAX,
114 max: f64::MIN,
115 }
116 }
117
118 fn update(&mut self, value: f64) {
120 self.count += 1;
121 let delta = value - self.mean;
122 self.mean += delta / self.count as f64;
123 let delta2 = value - self.mean;
124 self.m2 += delta * delta2;
125 self.min = self.min.min(value);
126 self.max = self.max.max(value);
127 }
128
129 fn variance(&self) -> f64 {
130 if self.count < 2 {
131 0.0
132 } else {
133 self.m2 / (self.count - 1) as f64
134 }
135 }
136
137 fn std_dev(&self) -> f64 {
138 self.variance().sqrt()
139 }
140}
141
142#[derive(Debug, Clone)]
144enum StreamingINode {
145 Internal {
146 split_feature: usize,
147 split_value: f64,
148 left: Box<StreamingINode>,
149 right: Box<StreamingINode>,
150 },
151 External {
152 size: usize,
153 },
154}
155
156#[derive(Debug, Clone)]
158#[allow(dead_code)]
159struct StreamingITree {
160 root: StreamingINode,
161 max_depth: usize,
162}
163
164impl StreamingITree {
165 fn build(samples: &[Vec<f64>], max_depth: usize) -> Self {
167 let root = Self::build_node(samples, 0, max_depth);
168 Self { root, max_depth }
169 }
170
171 fn build_node(samples: &[Vec<f64>], depth: usize, max_depth: usize) -> StreamingINode {
172 if samples.is_empty() || depth >= max_depth || samples.len() <= 1 {
173 return StreamingINode::External {
174 size: samples.len(),
175 };
176 }
177
178 let n_features = samples[0].len();
179 if n_features == 0 {
180 return StreamingINode::External {
181 size: samples.len(),
182 };
183 }
184
185 let mut rng = rng();
186 let feature = rng.random_range(0..n_features);
187
188 let values: Vec<f64> = samples.iter().map(|s| s[feature]).collect();
190 let min_val = values.iter().cloned().fold(f64::INFINITY, f64::min);
191 let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
192
193 if (max_val - min_val).abs() < 1e-10 {
194 return StreamingINode::External {
195 size: samples.len(),
196 };
197 }
198
199 let split_value = rng.random_range(min_val..max_val);
200
201 let (left_samples, right_samples): (Vec<_>, Vec<_>) = samples
202 .iter()
203 .cloned()
204 .partition(|s| s[feature] < split_value);
205
206 StreamingINode::Internal {
207 split_feature: feature,
208 split_value,
209 left: Box::new(Self::build_node(&left_samples, depth + 1, max_depth)),
210 right: Box::new(Self::build_node(&right_samples, depth + 1, max_depth)),
211 }
212 }
213
214 fn path_length(&self, point: &[f64]) -> f64 {
216 self.path_length_node(&self.root, point, 0)
217 }
218
219 fn path_length_node(&self, node: &StreamingINode, point: &[f64], depth: usize) -> f64 {
220 match node {
221 StreamingINode::External { size } => depth as f64 + Self::c_factor(*size),
222 StreamingINode::Internal {
223 split_feature,
224 split_value,
225 left,
226 right,
227 } => {
228 if point[*split_feature] < *split_value {
229 self.path_length_node(left, point, depth + 1)
230 } else {
231 self.path_length_node(right, point, depth + 1)
232 }
233 }
234 }
235 }
236
237 fn c_factor(n: usize) -> f64 {
239 if n <= 1 {
240 0.0
241 } else if n == 2 {
242 1.0
243 } else {
244 let n_f = n as f64;
245 2.0 * ((n_f - 1.0).ln() + 0.5772156649) - 2.0 * (n_f - 1.0) / n_f
247 }
248 }
249}
250
251#[derive(Debug, Clone)]
257pub struct StreamingIsolationForest {
258 metadata: KernelMetadata,
259}
260
261impl Default for StreamingIsolationForest {
262 fn default() -> Self {
263 Self::new()
264 }
265}
266
267impl StreamingIsolationForest {
268 #[must_use]
270 pub fn new() -> Self {
271 Self {
272 metadata: KernelMetadata::batch("ml/streaming-isolation-forest", Domain::StatisticalML)
273 .with_description("Online streaming anomaly detection with sliding window")
274 .with_throughput(50_000)
275 .with_latency_us(20.0),
276 }
277 }
278
279 pub fn init(n_features: usize) -> StreamingState {
281 StreamingState::new(n_features)
282 }
283
284 pub fn process_sample(
288 state: &mut StreamingState,
289 sample: Vec<f64>,
290 config: &StreamingConfig,
291 ) -> (f64, bool) {
292 if sample.len() != state.n_features && state.n_features > 0 {
293 return (0.0, false); }
295
296 if state.n_features == 0 {
297 state.n_features = sample.len();
298 }
299
300 state.window.push_back(sample.clone());
302 if config.use_sliding_window && state.window.len() > config.window_size {
303 state.window.pop_front();
304 }
305
306 state.total_samples += 1;
307 state.samples_since_rebuild += 1;
308
309 if state.trees.is_empty()
311 || (state.samples_since_rebuild >= config.rebuild_interval
312 && state.window.len() >= config.sample_size)
313 {
314 Self::rebuild_forest(state, config);
315 state.samples_since_rebuild = 0;
316 }
317
318 let score = if state.trees.is_empty() {
320 0.5 } else {
322 Self::compute_score(&state.trees, &sample, config.sample_size)
323 };
324
325 state.score_stats.update(score);
327
328 if state.score_stats.count > 100 {
330 let k = Self::contamination_to_k(config.contamination);
333 state.threshold = state.score_stats.mean + k * state.score_stats.std_dev();
334 state.threshold = state.threshold.clamp(0.0, 1.0);
335 }
336
337 let is_anomaly = score >= state.threshold;
338 (score, is_anomaly)
339 }
340
341 pub fn process_batch(
343 state: &mut StreamingState,
344 samples: &DataMatrix,
345 config: &StreamingConfig,
346 ) -> AnomalyResult {
347 let mut scores = Vec::with_capacity(samples.n_samples);
348 let mut labels = Vec::with_capacity(samples.n_samples);
349
350 for i in 0..samples.n_samples {
351 let sample = samples.row(i).to_vec();
352 let (score, is_anomaly) = Self::process_sample(state, sample, config);
353 scores.push(score);
354 labels.push(if is_anomaly { -1 } else { 1 });
355 }
356
357 AnomalyResult {
358 scores,
359 labels,
360 threshold: state.threshold,
361 }
362 }
363
364 fn rebuild_forest(state: &mut StreamingState, config: &StreamingConfig) {
366 if state.window.is_empty() {
367 return;
368 }
369
370 let samples: Vec<Vec<f64>> = state.window.iter().cloned().collect();
371 let sample_size = config.sample_size.min(samples.len());
372 let max_depth = (sample_size as f64).log2().ceil() as usize;
373
374 let mut rng = rng();
375 state.trees = (0..config.n_trees)
376 .map(|_| {
377 let subset: Vec<Vec<f64>> = samples
378 .choose_multiple(&mut rng, sample_size)
379 .cloned()
380 .collect();
381 StreamingITree::build(&subset, max_depth)
382 })
383 .collect();
384 }
385
386 fn compute_score(trees: &[StreamingITree], point: &[f64], sample_size: usize) -> f64 {
388 if trees.is_empty() {
389 return 0.5;
390 }
391
392 let avg_path_length: f64 = trees
393 .iter()
394 .map(|tree| tree.path_length(point))
395 .sum::<f64>()
396 / trees.len() as f64;
397
398 let c_n = StreamingITree::c_factor(sample_size);
399 if c_n.abs() < 1e-10 {
400 return 0.5;
401 }
402
403 (2.0_f64).powf(-avg_path_length / c_n)
404 }
405
406 fn contamination_to_k(contamination: f64) -> f64 {
408 if contamination <= 0.01 {
411 2.33
412 } else if contamination <= 0.05 {
413 1.65
414 } else if contamination <= 0.10 {
415 1.28
416 } else if contamination <= 0.20 {
417 0.84
418 } else {
419 0.5
420 }
421 }
422}
423
424impl GpuKernel for StreamingIsolationForest {
425 fn metadata(&self) -> &KernelMetadata {
426 &self.metadata
427 }
428}
429
430#[derive(Debug, Clone, Serialize, Deserialize)]
436pub struct AdaptiveThresholdConfig {
437 pub initial_threshold: f64,
439 pub window_size: usize,
441 pub target_fpr: f64,
443 pub learning_rate: f64,
445 pub min_threshold: f64,
447 pub max_threshold: f64,
449 pub detect_drift: bool,
451 pub drift_sensitivity: f64,
453}
454
455impl Default for AdaptiveThresholdConfig {
456 fn default() -> Self {
457 Self {
458 initial_threshold: 0.5,
459 window_size: 1000,
460 target_fpr: 0.05,
461 learning_rate: 0.01,
462 min_threshold: 0.1,
463 max_threshold: 0.9,
464 detect_drift: true,
465 drift_sensitivity: 2.0,
466 }
467 }
468}
469
470#[derive(Debug, Clone)]
472pub struct AdaptiveThresholdState {
473 threshold: f64,
475 score_window: VecDeque<f64>,
477 label_window: VecDeque<Option<bool>>,
479 stats: OnlineStats,
481 prev_window_stats: Option<WindowStats>,
483 curr_window_stats: WindowStats,
485 total_samples: usize,
487 drift_detected: bool,
489 drift_count: usize,
491}
492
493#[derive(Debug, Clone, Default)]
495struct WindowStats {
496 mean: f64,
497 variance: f64,
498 count: usize,
499}
500
501impl AdaptiveThresholdState {
502 pub fn new(config: &AdaptiveThresholdConfig) -> Self {
504 Self {
505 threshold: config.initial_threshold,
506 score_window: VecDeque::new(),
507 label_window: VecDeque::new(),
508 stats: OnlineStats::new(),
509 prev_window_stats: None,
510 curr_window_stats: WindowStats::default(),
511 total_samples: 0,
512 drift_detected: false,
513 drift_count: 0,
514 }
515 }
516
517 pub fn threshold(&self) -> f64 {
519 self.threshold
520 }
521
522 pub fn total_samples(&self) -> usize {
524 self.total_samples
525 }
526
527 pub fn drift_detected(&self) -> bool {
529 self.drift_detected
530 }
531
532 pub fn drift_count(&self) -> usize {
534 self.drift_count
535 }
536}
537
538#[derive(Debug, Clone, Serialize, Deserialize)]
540pub struct ThresholdResult {
541 pub threshold: f64,
543 pub is_anomaly: bool,
545 pub estimated_fpr: f64,
547 pub drift_detected: bool,
549 pub confidence: f64,
551}
552
553#[derive(Debug, Clone)]
559pub struct AdaptiveThreshold {
560 metadata: KernelMetadata,
561}
562
563impl Default for AdaptiveThreshold {
564 fn default() -> Self {
565 Self::new()
566 }
567}
568
569impl AdaptiveThreshold {
570 #[must_use]
572 pub fn new() -> Self {
573 Self {
574 metadata: KernelMetadata::batch("ml/adaptive-threshold", Domain::StatisticalML)
575 .with_description("Self-adjusting anomaly thresholds with drift detection")
576 .with_throughput(100_000)
577 .with_latency_us(5.0),
578 }
579 }
580
581 pub fn init(config: &AdaptiveThresholdConfig) -> AdaptiveThresholdState {
583 AdaptiveThresholdState::new(config)
584 }
585
586 pub fn process_score(
588 state: &mut AdaptiveThresholdState,
589 score: f64,
590 ground_truth: Option<bool>,
591 config: &AdaptiveThresholdConfig,
592 ) -> ThresholdResult {
593 state.stats.update(score);
595 state.total_samples += 1;
596
597 state.score_window.push_back(score);
599 state.label_window.push_back(ground_truth);
600
601 if state.score_window.len() > config.window_size {
602 state.score_window.pop_front();
603 state.label_window.pop_front();
604 }
605
606 state.curr_window_stats = Self::compute_window_stats(&state.score_window);
608
609 state.drift_detected = false;
611 if config.detect_drift {
612 if let Some(prev) = &state.prev_window_stats {
613 let drift = Self::detect_drift(prev, &state.curr_window_stats, config);
614 if drift {
615 state.drift_detected = true;
616 state.drift_count += 1;
617 state.threshold = Self::estimate_threshold_from_window(
619 &state.score_window,
620 config.target_fpr,
621 );
622 }
623 }
624 }
625
626 if let Some(is_anomaly) = ground_truth {
628 Self::update_threshold_with_feedback(state, score, is_anomaly, config);
629 } else {
630 Self::update_threshold_quantile(state, config);
632 }
633
634 if state.score_window.len() == config.window_size {
637 if state.prev_window_stats.is_none() || state.drift_detected {
638 state.prev_window_stats = Some(state.curr_window_stats.clone());
639 }
640 }
641
642 let is_anomaly = score >= state.threshold;
643 let estimated_fpr = Self::estimate_fpr(state, config);
644 let confidence = Self::compute_confidence(state, config);
645
646 ThresholdResult {
647 threshold: state.threshold,
648 is_anomaly,
649 estimated_fpr,
650 drift_detected: state.drift_detected,
651 confidence,
652 }
653 }
654
655 fn compute_window_stats(window: &VecDeque<f64>) -> WindowStats {
657 if window.is_empty() {
658 return WindowStats::default();
659 }
660
661 let count = window.len();
662 let mean: f64 = window.iter().sum::<f64>() / count as f64;
663 let variance: f64 = window.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / count as f64;
664
665 WindowStats {
666 mean,
667 variance,
668 count,
669 }
670 }
671
672 fn detect_drift(
674 prev: &WindowStats,
675 curr: &WindowStats,
676 config: &AdaptiveThresholdConfig,
677 ) -> bool {
678 if prev.count < 10 || curr.count < 10 {
679 return false;
680 }
681
682 let se = ((prev.variance / prev.count as f64) + (curr.variance / curr.count as f64)).sqrt();
684 if se.abs() < 1e-10 {
685 return false;
686 }
687
688 let t_stat = (curr.mean - prev.mean).abs() / se;
689 t_stat > config.drift_sensitivity
690 }
691
692 fn estimate_threshold_from_window(window: &VecDeque<f64>, target_fpr: f64) -> f64 {
694 if window.is_empty() {
695 return 0.5;
696 }
697
698 let mut sorted: Vec<f64> = window.iter().cloned().collect();
699 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
700
701 let idx = ((1.0 - target_fpr) * sorted.len() as f64) as usize;
702 let idx = idx.min(sorted.len() - 1);
703 sorted[idx]
704 }
705
706 fn update_threshold_with_feedback(
708 state: &mut AdaptiveThresholdState,
709 score: f64,
710 is_anomaly: bool,
711 config: &AdaptiveThresholdConfig,
712 ) {
713 if score >= state.threshold && !is_anomaly {
715 state.threshold += config.learning_rate * (score - state.threshold);
717 }
718 else if score < state.threshold && is_anomaly {
720 state.threshold -= config.learning_rate * (state.threshold - score);
722 }
723
724 state.threshold = state
725 .threshold
726 .clamp(config.min_threshold, config.max_threshold);
727 }
728
729 fn update_threshold_quantile(
731 state: &mut AdaptiveThresholdState,
732 config: &AdaptiveThresholdConfig,
733 ) {
734 if state.score_window.len() < 10 {
735 return;
736 }
737
738 let target = Self::estimate_threshold_from_window(&state.score_window, config.target_fpr);
739
740 state.threshold =
742 state.threshold * (1.0 - config.learning_rate) + target * config.learning_rate;
743 state.threshold = state
744 .threshold
745 .clamp(config.min_threshold, config.max_threshold);
746 }
747
748 fn estimate_fpr(state: &AdaptiveThresholdState, _config: &AdaptiveThresholdConfig) -> f64 {
750 if state.score_window.is_empty() {
751 return 0.0;
752 }
753
754 let above_threshold = state
755 .score_window
756 .iter()
757 .filter(|&&s| s >= state.threshold)
758 .count();
759
760 above_threshold as f64 / state.score_window.len() as f64
761 }
762
763 fn compute_confidence(state: &AdaptiveThresholdState, config: &AdaptiveThresholdConfig) -> f64 {
765 let sample_factor = (state.score_window.len() as f64 / config.window_size as f64).min(1.0);
767
768 let drift_factor = if state.drift_detected { 0.5 } else { 1.0 };
770
771 let bound_factor = if (state.threshold - config.min_threshold).abs() < 0.01
773 || (state.threshold - config.max_threshold).abs() < 0.01
774 {
775 0.7
776 } else {
777 1.0
778 };
779
780 sample_factor * drift_factor * bound_factor
781 }
782
783 pub fn process_batch(
785 state: &mut AdaptiveThresholdState,
786 scores: &[f64],
787 ground_truth: Option<&[bool]>,
788 config: &AdaptiveThresholdConfig,
789 ) -> Vec<ThresholdResult> {
790 scores
791 .iter()
792 .enumerate()
793 .map(|(i, &score)| {
794 let gt = ground_truth.map(|gt| gt[i]);
795 Self::process_score(state, score, gt, config)
796 })
797 .collect()
798 }
799}
800
801impl GpuKernel for AdaptiveThreshold {
802 fn metadata(&self) -> &KernelMetadata {
803 &self.metadata
804 }
805}
806
807#[cfg(test)]
808mod tests {
809 use super::*;
810
811 #[test]
812 fn test_streaming_isolation_forest_metadata() {
813 let kernel = StreamingIsolationForest::new();
814 assert_eq!(kernel.metadata().id, "ml/streaming-isolation-forest");
815 }
816
817 #[test]
818 fn test_streaming_isolation_forest_basic() {
819 let config = StreamingConfig {
820 n_trees: 10,
821 sample_size: 50,
822 window_size: 100,
823 rebuild_interval: 20,
824 contamination: 0.1,
825 use_sliding_window: true,
826 };
827
828 let mut state = StreamingIsolationForest::init(2);
829
830 for _ in 0..50 {
832 let sample = vec![rng().random_range(0.0..1.0), rng().random_range(0.0..1.0)];
833 StreamingIsolationForest::process_sample(&mut state, sample, &config);
834 }
835
836 assert!(state.window_size() > 0);
837 assert_eq!(state.total_samples(), 50);
838
839 let (score, _is_anomaly) =
841 StreamingIsolationForest::process_sample(&mut state, vec![100.0, 100.0], &config);
842 assert!(score > 0.0);
843 }
844
845 #[test]
846 fn test_streaming_sliding_window() {
847 let config = StreamingConfig {
848 window_size: 10,
849 use_sliding_window: true,
850 ..Default::default()
851 };
852
853 let mut state = StreamingIsolationForest::init(1);
854
855 for i in 0..20 {
857 StreamingIsolationForest::process_sample(&mut state, vec![i as f64], &config);
858 }
859
860 assert_eq!(state.window_size(), 10);
862 assert_eq!(state.total_samples(), 20);
863 }
864
865 #[test]
866 fn test_adaptive_threshold_metadata() {
867 let kernel = AdaptiveThreshold::new();
868 assert_eq!(kernel.metadata().id, "ml/adaptive-threshold");
869 }
870
871 #[test]
872 fn test_adaptive_threshold_basic() {
873 let config = AdaptiveThresholdConfig {
874 initial_threshold: 0.5,
875 window_size: 100,
876 target_fpr: 0.1,
877 learning_rate: 0.1,
878 ..Default::default()
879 };
880
881 let mut state = AdaptiveThreshold::init(&config);
882
883 for _ in 0..50 {
885 let score = rng().random_range(0.0..0.4);
886 AdaptiveThreshold::process_score(&mut state, score, None, &config);
887 }
888
889 let result = AdaptiveThreshold::process_score(&mut state, 0.9, None, &config);
891 assert!(result.is_anomaly);
892 }
893
894 #[test]
895 fn test_adaptive_threshold_feedback() {
896 let config = AdaptiveThresholdConfig {
897 initial_threshold: 0.5,
898 learning_rate: 0.2,
899 ..Default::default()
900 };
901
902 let mut state = AdaptiveThreshold::init(&config);
903
904 let initial_threshold = state.threshold();
906 AdaptiveThreshold::process_score(&mut state, 0.6, Some(false), &config);
907 assert!(state.threshold() > initial_threshold);
908
909 let prev_threshold = state.threshold();
911 AdaptiveThreshold::process_score(&mut state, 0.3, Some(true), &config);
912 assert!(state.threshold() < prev_threshold);
913 }
914
915 #[test]
916 fn test_drift_detection() {
917 let config = AdaptiveThresholdConfig {
918 window_size: 10,
919 detect_drift: true,
920 drift_sensitivity: 1.5, ..Default::default()
922 };
923
924 let mut state = AdaptiveThreshold::init(&config);
925
926 for _ in 0..10 {
928 AdaptiveThreshold::process_score(&mut state, 0.15, None, &config);
929 }
930
931 let mut drift_found = false;
933 for _ in 0..15 {
934 let result = AdaptiveThreshold::process_score(&mut state, 0.85, None, &config);
935 if result.drift_detected {
936 drift_found = true;
937 }
938 }
939
940 assert!(
942 drift_found || state.drift_count() > 0,
943 "Should detect drift between 0.15 and 0.85 score ranges"
944 );
945 }
946
947 #[test]
948 fn test_batch_processing() {
949 let config = StreamingConfig::default();
950 let mut state = StreamingIsolationForest::init(2);
951
952 let data = DataMatrix::new(
953 vec![
954 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 10.0, 10.0, ],
956 4,
957 2,
958 );
959
960 let result = StreamingIsolationForest::process_batch(&mut state, &data, &config);
961 assert_eq!(result.scores.len(), 4);
962 assert_eq!(result.labels.len(), 4);
963 }
964
965 #[test]
966 fn test_online_stats() {
967 let mut stats = OnlineStats::new();
968
969 for v in [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
970 stats.update(v);
971 }
972
973 assert!((stats.mean - 5.0).abs() < 0.01);
974 assert!((stats.variance() - 4.57).abs() < 0.1);
975 assert_eq!(stats.min, 2.0);
976 assert_eq!(stats.max, 9.0);
977 }
978}