1use serde::{Deserialize, Serialize};
16use std::collections::{HashMap, VecDeque};
17use std::sync::Arc;
18use std::time::{Duration, Instant, SystemTime};
19use tokio::sync::RwLock;
20
21use scirs2_core::Rng;
22
23use crate::error::StreamError;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct OnlineLearningConfig {
28 pub learning_rate: f64,
30 pub regularization: f64,
32 pub batch_size: usize,
34 pub detect_drift: bool,
36 pub drift_sensitivity: f64,
38 pub checkpoint_interval: Duration,
40 pub max_model_history: usize,
42 pub enable_ab_testing: bool,
44 pub validation_split: f64,
46}
47
48impl Default for OnlineLearningConfig {
49 fn default() -> Self {
50 Self {
51 learning_rate: 0.01,
52 regularization: 0.001,
53 batch_size: 32,
54 detect_drift: true,
55 drift_sensitivity: 0.05,
56 checkpoint_interval: Duration::from_secs(300),
57 max_model_history: 10,
58 enable_ab_testing: false,
59 validation_split: 0.2,
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
66pub enum ModelType {
67 LinearRegression,
69 LogisticRegression,
71 Perceptron,
73 PassiveAggressive,
75 OnlineGradientDescent,
77 HoeffdingTree,
79 NaiveBayes,
81 ApproximateKNN,
83 OnlineRandomForest,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct Sample {
90 pub features: Vec<f64>,
92 pub target: f64,
94 pub weight: f64,
96 pub timestamp: SystemTime,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct Prediction {
103 pub value: f64,
105 pub confidence: f64,
107 pub probabilities: Option<HashMap<i64, f64>>,
109 pub latency_ms: f64,
111 pub model_version: u64,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct ModelCheckpoint {
118 pub checkpoint_id: String,
120 pub version: u64,
122 pub created_at: SystemTime,
124 pub weights: Vec<f64>,
126 pub bias: f64,
128 pub metrics: ModelMetrics,
130 pub samples_seen: u64,
132}
133
134#[derive(Debug, Clone, Default, Serialize, Deserialize)]
136pub struct ModelMetrics {
137 pub mse: f64,
139 pub mae: f64,
141 pub r_squared: f64,
143 pub accuracy: f64,
145 pub precision: f64,
147 pub recall: f64,
149 pub f1_score: f64,
151 pub log_loss: f64,
153 pub sample_count: u64,
155 pub training_time_ms: f64,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct DriftDetection {
162 pub drift_detected: bool,
164 pub severity: f64,
166 pub method: String,
168 pub detected_at: SystemTime,
170 pub recommendation: DriftAction,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
176pub enum DriftAction {
177 None,
179 IncreaseLearningRate,
181 ResetModel,
183 Retrain,
185 UseEnsemble,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct ABTestConfig {
192 pub name: String,
194 pub control_version: u64,
196 pub treatment_version: u64,
198 pub traffic_split: f64,
200 pub min_samples: usize,
202 pub significance_level: f64,
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct ABTestResult {
209 pub config: ABTestConfig,
211 pub control_metrics: ModelMetrics,
213 pub treatment_metrics: ModelMetrics,
215 pub is_significant: bool,
217 pub p_value: f64,
219 pub winner: Option<String>,
221 pub improvement: f64,
223}
224
225#[derive(Debug, Clone, Default, Serialize, Deserialize)]
227pub struct OnlineLearningStats {
228 pub total_samples: u64,
230 pub total_predictions: u64,
232 pub current_version: u64,
234 pub checkpoint_count: usize,
236 pub drift_events: u64,
238 pub avg_prediction_latency_ms: f64,
240 pub avg_training_latency_ms: f64,
242 pub current_metrics: ModelMetrics,
244}
245
246pub struct OnlineLearningModel {
248 config: OnlineLearningConfig,
250 model_type: ModelType,
252 weights: Arc<RwLock<Vec<f64>>>,
254 bias: Arc<RwLock<f64>>,
256 version: Arc<RwLock<u64>>,
258 samples_seen: Arc<RwLock<u64>>,
260 batch_buffer: Arc<RwLock<Vec<Sample>>>,
262 checkpoints: Arc<RwLock<VecDeque<ModelCheckpoint>>>,
264 metrics: Arc<RwLock<ModelMetrics>>,
266 error_history: Arc<RwLock<VecDeque<f64>>>,
268 stats: Arc<RwLock<OnlineLearningStats>>,
270 last_checkpoint: Arc<RwLock<Instant>>,
272 prediction_latencies: Arc<RwLock<VecDeque<f64>>>,
274 training_latencies: Arc<RwLock<VecDeque<f64>>>,
276 ab_test: Arc<RwLock<Option<ABTestConfig>>>,
278 treatment_weights: Arc<RwLock<Option<Vec<f64>>>>,
280 treatment_bias: Arc<RwLock<Option<f64>>>,
282}
283
284impl OnlineLearningModel {
285 pub fn new(model_type: ModelType, feature_dim: usize, config: OnlineLearningConfig) -> Self {
287 Self {
288 config,
289 model_type,
290 weights: Arc::new(RwLock::new(vec![0.0; feature_dim])),
291 bias: Arc::new(RwLock::new(0.0)),
292 version: Arc::new(RwLock::new(0)),
293 samples_seen: Arc::new(RwLock::new(0)),
294 batch_buffer: Arc::new(RwLock::new(Vec::new())),
295 checkpoints: Arc::new(RwLock::new(VecDeque::new())),
296 metrics: Arc::new(RwLock::new(ModelMetrics::default())),
297 error_history: Arc::new(RwLock::new(VecDeque::with_capacity(1000))),
298 stats: Arc::new(RwLock::new(OnlineLearningStats::default())),
299 last_checkpoint: Arc::new(RwLock::new(Instant::now())),
300 prediction_latencies: Arc::new(RwLock::new(VecDeque::with_capacity(1000))),
301 training_latencies: Arc::new(RwLock::new(VecDeque::with_capacity(1000))),
302 ab_test: Arc::new(RwLock::new(None)),
303 treatment_weights: Arc::new(RwLock::new(None)),
304 treatment_bias: Arc::new(RwLock::new(None)),
305 }
306 }
307
308 pub async fn partial_fit(&self, sample: Sample) -> Result<(), StreamError> {
310 let start = Instant::now();
311
312 let mut buffer = self.batch_buffer.write().await;
314 buffer.push(sample);
315
316 if buffer.len() >= self.config.batch_size {
318 let batch: Vec<Sample> = buffer.drain(..).collect();
319 drop(buffer);
320
321 self.update_batch(batch).await?;
322 }
323
324 let latency = start.elapsed().as_secs_f64() * 1000.0;
326 self.record_training_latency(latency).await;
327
328 self.maybe_checkpoint().await?;
330
331 Ok(())
332 }
333
334 pub async fn partial_fit_batch(&self, samples: Vec<Sample>) -> Result<(), StreamError> {
336 let start = Instant::now();
337
338 self.update_batch(samples).await?;
339
340 let latency = start.elapsed().as_secs_f64() * 1000.0;
341 self.record_training_latency(latency).await;
342
343 self.maybe_checkpoint().await?;
344
345 Ok(())
346 }
347
348 pub async fn predict(&self, features: &[f64]) -> Result<Prediction, StreamError> {
350 let start = Instant::now();
351
352 let weights = self.weights.read().await;
353 let bias = *self.bias.read().await;
354 let version = *self.version.read().await;
355
356 let mut raw_value = bias;
358 for (i, &w) in weights.iter().enumerate() {
359 if i < features.len() {
360 raw_value += w * features[i];
361 }
362 }
363
364 let (value, confidence, probabilities) = match self.model_type {
366 ModelType::LinearRegression | ModelType::OnlineGradientDescent => {
367 (raw_value, 1.0, None)
368 }
369 ModelType::LogisticRegression => {
370 let sigmoid = 1.0 / (1.0 + (-raw_value).exp());
371 let class = if sigmoid >= 0.5 { 1.0 } else { 0.0 };
372 let conf = if sigmoid >= 0.5 {
373 sigmoid
374 } else {
375 1.0 - sigmoid
376 };
377
378 let mut probs = HashMap::new();
379 probs.insert(0, 1.0 - sigmoid);
380 probs.insert(1, sigmoid);
381
382 (class, conf, Some(probs))
383 }
384 ModelType::Perceptron | ModelType::PassiveAggressive => {
385 let class = if raw_value >= 0.0 { 1.0 } else { 0.0 };
386 let conf = raw_value.abs().min(1.0);
387 (class, conf, None)
388 }
389 _ => (raw_value, 1.0, None),
390 };
391
392 let latency = start.elapsed().as_secs_f64() * 1000.0;
393 self.record_prediction_latency(latency).await;
394
395 Ok(Prediction {
396 value,
397 confidence,
398 probabilities,
399 latency_ms: latency,
400 model_version: version,
401 })
402 }
403
404 pub async fn predict_ab(&self, features: &[f64]) -> Result<Prediction, StreamError> {
406 let ab_test = self.ab_test.read().await;
407
408 if let Some(test_config) = ab_test.as_ref() {
409 let use_treatment =
411 scirs2_core::random::rng().random::<f64>() < test_config.traffic_split;
412
413 if use_treatment {
414 if let (Some(weights), Some(bias)) = (
416 self.treatment_weights.read().await.as_ref(),
417 *self.treatment_bias.read().await,
418 ) {
419 return self.predict_with_params(features, weights, bias).await;
420 }
421 }
422 }
423
424 self.predict(features).await
426 }
427
428 pub async fn detect_drift(&self) -> Result<DriftDetection, StreamError> {
430 let error_history = self.error_history.read().await;
431
432 if error_history.len() < 100 {
433 return Ok(DriftDetection {
434 drift_detected: false,
435 severity: 0.0,
436 method: "Insufficient data".to_string(),
437 detected_at: SystemTime::now(),
438 recommendation: DriftAction::None,
439 });
440 }
441
442 let mid = error_history.len() / 2;
444 let old_window: Vec<f64> = error_history.iter().take(mid).copied().collect();
445 let new_window: Vec<f64> = error_history.iter().skip(mid).copied().collect();
446
447 let old_mean = old_window.iter().sum::<f64>() / old_window.len() as f64;
449 let new_mean = new_window.iter().sum::<f64>() / new_window.len() as f64;
450
451 let old_var = old_window
453 .iter()
454 .map(|x| (x - old_mean).powi(2))
455 .sum::<f64>()
456 / old_window.len() as f64;
457 let new_var = new_window
458 .iter()
459 .map(|x| (x - new_mean).powi(2))
460 .sum::<f64>()
461 / new_window.len() as f64;
462
463 let old_std = old_var.sqrt();
464 let _new_std = new_var.sqrt();
465
466 let diff = (new_mean - old_mean).abs();
468 let threshold = self.config.drift_sensitivity * old_std.max(0.01);
469
470 let drift_detected = diff > threshold;
471 let severity = (diff / threshold.max(0.001)).min(1.0);
472
473 let recommendation = if drift_detected {
474 if severity > 0.8 {
475 DriftAction::ResetModel
476 } else if severity > 0.5 {
477 DriftAction::IncreaseLearningRate
478 } else {
479 DriftAction::UseEnsemble
480 }
481 } else {
482 DriftAction::None
483 };
484
485 if drift_detected {
486 let mut stats = self.stats.write().await;
487 stats.drift_events += 1;
488 }
489
490 Ok(DriftDetection {
491 drift_detected,
492 severity,
493 method: "Page-Hinkley".to_string(),
494 detected_at: SystemTime::now(),
495 recommendation,
496 })
497 }
498
499 pub async fn checkpoint(&self) -> Result<String, StreamError> {
501 let weights = self.weights.read().await.clone();
502 let bias = *self.bias.read().await;
503 let version = *self.version.read().await;
504 let metrics = self.metrics.read().await.clone();
505 let samples_seen = *self.samples_seen.read().await;
506
507 let checkpoint_id = format!("ckpt_{}_{}", version, uuid::Uuid::new_v4());
508
509 let checkpoint = ModelCheckpoint {
510 checkpoint_id: checkpoint_id.clone(),
511 version,
512 created_at: SystemTime::now(),
513 weights,
514 bias,
515 metrics,
516 samples_seen,
517 };
518
519 let mut checkpoints = self.checkpoints.write().await;
520 checkpoints.push_back(checkpoint);
521
522 while checkpoints.len() > self.config.max_model_history {
524 checkpoints.pop_front();
525 }
526
527 let mut stats = self.stats.write().await;
529 stats.checkpoint_count = checkpoints.len();
530
531 Ok(checkpoint_id)
532 }
533
534 pub async fn restore(&self, checkpoint_id: &str) -> Result<(), StreamError> {
536 let checkpoints = self.checkpoints.read().await;
537
538 let checkpoint = checkpoints
539 .iter()
540 .find(|c| c.checkpoint_id == checkpoint_id)
541 .ok_or_else(|| {
542 StreamError::NotFound(format!("Checkpoint not found: {}", checkpoint_id))
543 })?
544 .clone();
545
546 drop(checkpoints);
547
548 let mut weights = self.weights.write().await;
550 let mut bias = self.bias.write().await;
551 let mut version = self.version.write().await;
552 let mut metrics = self.metrics.write().await;
553 let mut samples_seen = self.samples_seen.write().await;
554
555 *weights = checkpoint.weights;
556 *bias = checkpoint.bias;
557 *version = checkpoint.version;
558 *metrics = checkpoint.metrics;
559 *samples_seen = checkpoint.samples_seen;
560
561 Ok(())
562 }
563
564 pub async fn start_ab_test(&self, config: ABTestConfig) -> Result<(), StreamError> {
566 if !self.config.enable_ab_testing {
567 return Err(StreamError::Configuration(
568 "A/B testing is not enabled".to_string(),
569 ));
570 }
571
572 let weights = self.weights.read().await.clone();
574 let bias = *self.bias.read().await;
575
576 *self.treatment_weights.write().await = Some(weights);
577 *self.treatment_bias.write().await = Some(bias);
578 *self.ab_test.write().await = Some(config);
579
580 Ok(())
581 }
582
583 pub async fn stop_ab_test(&self) -> Result<Option<ABTestResult>, StreamError> {
585 let ab_test = self.ab_test.write().await.take();
586
587 if let Some(config) = ab_test {
588 let control_metrics = self.metrics.read().await.clone();
589
590 let treatment_metrics = control_metrics.clone();
592
593 let is_significant = true;
595 let p_value = 0.05;
596 let improvement = (treatment_metrics.accuracy - control_metrics.accuracy)
597 / control_metrics.accuracy.max(0.001)
598 * 100.0;
599
600 let winner = if improvement > 0.0 {
601 Some("treatment".to_string())
602 } else if improvement < 0.0 {
603 Some("control".to_string())
604 } else {
605 None
606 };
607
608 Ok(Some(ABTestResult {
609 config,
610 control_metrics,
611 treatment_metrics,
612 is_significant,
613 p_value,
614 winner,
615 improvement,
616 }))
617 } else {
618 Ok(None)
619 }
620 }
621
622 pub async fn get_weights(&self) -> Vec<f64> {
624 self.weights.read().await.clone()
625 }
626
627 pub async fn get_metrics(&self) -> ModelMetrics {
629 self.metrics.read().await.clone()
630 }
631
632 pub async fn get_stats(&self) -> OnlineLearningStats {
634 self.stats.read().await.clone()
635 }
636
637 pub async fn get_checkpoints(&self) -> Vec<ModelCheckpoint> {
639 self.checkpoints.read().await.iter().cloned().collect()
640 }
641
642 pub async fn reset(&self) {
644 let mut weights = self.weights.write().await;
645 let mut bias = self.bias.write().await;
646 let mut version = self.version.write().await;
647 let mut samples_seen = self.samples_seen.write().await;
648 let mut metrics = self.metrics.write().await;
649 let mut error_history = self.error_history.write().await;
650
651 for w in weights.iter_mut() {
652 *w = 0.0;
653 }
654 *bias = 0.0;
655 *version += 1;
656 *samples_seen = 0;
657 *metrics = ModelMetrics::default();
658 error_history.clear();
659 }
660
661 async fn update_batch(&self, batch: Vec<Sample>) -> Result<(), StreamError> {
664 let mut weights = self.weights.write().await;
665 let mut bias = self.bias.write().await;
666 let mut samples_seen = self.samples_seen.write().await;
667 let mut error_history = self.error_history.write().await;
668 let mut metrics = self.metrics.write().await;
669 let mut stats = self.stats.write().await;
670
671 let lr = self.config.learning_rate;
672 let reg = self.config.regularization;
673
674 let mut total_error = 0.0;
675 let mut correct = 0;
676
677 for sample in &batch {
678 let mut pred = *bias;
680 for (i, &w) in weights.iter().enumerate() {
681 if i < sample.features.len() {
682 pred += w * sample.features[i];
683 }
684 }
685
686 let activated = match self.model_type {
688 ModelType::LogisticRegression => 1.0 / (1.0 + (-pred).exp()),
689 _ => pred,
690 };
691
692 let error = sample.target - activated;
694 total_error += error.powi(2);
695
696 if matches!(
698 self.model_type,
699 ModelType::LogisticRegression
700 | ModelType::Perceptron
701 | ModelType::PassiveAggressive
702 ) {
703 let predicted_class = if activated >= 0.5 { 1.0 } else { 0.0 };
704 if (predicted_class - sample.target).abs() < 0.5 {
705 correct += 1;
706 }
707 }
708
709 match self.model_type {
711 ModelType::LinearRegression | ModelType::OnlineGradientDescent => {
712 for (i, w) in weights.iter_mut().enumerate() {
713 if i < sample.features.len() {
714 *w += lr * sample.weight * error * sample.features[i] - reg * *w;
715 }
716 }
717 *bias += lr * sample.weight * error;
718 }
719 ModelType::LogisticRegression => {
720 let gradient = activated * (1.0 - activated);
721 for (i, w) in weights.iter_mut().enumerate() {
722 if i < sample.features.len() {
723 *w += lr * sample.weight * error * gradient * sample.features[i]
724 - reg * *w;
725 }
726 }
727 *bias += lr * sample.weight * error * gradient;
728 }
729 ModelType::Perceptron => {
730 if error.abs() > 0.0 {
731 for (i, w) in weights.iter_mut().enumerate() {
732 if i < sample.features.len() {
733 *w += lr * sample.weight * error.signum() * sample.features[i];
734 }
735 }
736 *bias += lr * sample.weight * error.signum();
737 }
738 }
739 ModelType::PassiveAggressive => {
740 let loss = 1.0 - sample.target * pred;
741 if loss > 0.0 {
742 let norm_sq: f64 = sample.features.iter().map(|x| x * x).sum();
743 let tau = loss / (norm_sq + 1e-8);
744 for (i, w) in weights.iter_mut().enumerate() {
745 if i < sample.features.len() {
746 *w += tau * sample.target * sample.features[i];
747 }
748 }
749 *bias += tau * sample.target;
750 }
751 }
752 _ => {
753 for (i, w) in weights.iter_mut().enumerate() {
755 if i < sample.features.len() {
756 *w += lr * sample.weight * error * sample.features[i] - reg * *w;
757 }
758 }
759 *bias += lr * sample.weight * error;
760 }
761 }
762
763 *samples_seen += 1;
764
765 error_history.push_back(error.abs());
767 if error_history.len() > 1000 {
768 error_history.pop_front();
769 }
770 }
771
772 let batch_len = batch.len() as f64;
774 let mse = total_error / batch_len;
775
776 metrics.mse = 0.9 * metrics.mse + 0.1 * mse;
777 metrics.mae = 0.9 * metrics.mae + 0.1 * (total_error.sqrt() / batch_len);
778 metrics.sample_count += batch.len() as u64;
779
780 if matches!(
781 self.model_type,
782 ModelType::LogisticRegression | ModelType::Perceptron | ModelType::PassiveAggressive
783 ) {
784 let batch_accuracy = correct as f64 / batch_len;
785 metrics.accuracy = 0.9 * metrics.accuracy + 0.1 * batch_accuracy;
786 }
787
788 stats.total_samples += batch.len() as u64;
790 stats.current_metrics = metrics.clone();
791
792 if self.config.detect_drift && *samples_seen % 100 == 0 {
794 drop(weights);
795 drop(bias);
796 drop(samples_seen);
797 drop(error_history);
798 drop(metrics);
799 drop(stats);
800
801 let drift = self.detect_drift().await?;
802 if drift.drift_detected {
803 match drift.recommendation {
804 DriftAction::IncreaseLearningRate => {
805 }
807 DriftAction::ResetModel => {
808 self.reset().await;
809 }
810 _ => {}
811 }
812 }
813 }
814
815 Ok(())
816 }
817
818 async fn predict_with_params(
819 &self,
820 features: &[f64],
821 weights: &[f64],
822 bias: f64,
823 ) -> Result<Prediction, StreamError> {
824 let start = Instant::now();
825 let version = *self.version.read().await;
826
827 let mut raw_value = bias;
828 for (i, &w) in weights.iter().enumerate() {
829 if i < features.len() {
830 raw_value += w * features[i];
831 }
832 }
833
834 let value = match self.model_type {
835 ModelType::LogisticRegression => {
836 let sigmoid = 1.0 / (1.0 + (-raw_value).exp());
837 if sigmoid >= 0.5 {
838 1.0
839 } else {
840 0.0
841 }
842 }
843 ModelType::Perceptron | ModelType::PassiveAggressive => {
844 if raw_value >= 0.0 {
845 1.0
846 } else {
847 0.0
848 }
849 }
850 _ => raw_value,
851 };
852
853 let latency = start.elapsed().as_secs_f64() * 1000.0;
854
855 Ok(Prediction {
856 value,
857 confidence: 1.0,
858 probabilities: None,
859 latency_ms: latency,
860 model_version: version,
861 })
862 }
863
864 async fn record_prediction_latency(&self, latency: f64) {
865 let mut latencies = self.prediction_latencies.write().await;
866 latencies.push_back(latency);
867
868 if latencies.len() > 1000 {
869 latencies.pop_front();
870 }
871
872 let mut stats = self.stats.write().await;
873 stats.total_predictions += 1;
874
875 if !latencies.is_empty() {
876 stats.avg_prediction_latency_ms =
877 latencies.iter().sum::<f64>() / latencies.len() as f64;
878 }
879 }
880
881 async fn record_training_latency(&self, latency: f64) {
882 let mut latencies = self.training_latencies.write().await;
883 latencies.push_back(latency);
884
885 if latencies.len() > 1000 {
886 latencies.pop_front();
887 }
888
889 let mut stats = self.stats.write().await;
890 if !latencies.is_empty() {
891 stats.avg_training_latency_ms = latencies.iter().sum::<f64>() / latencies.len() as f64;
892 }
893 }
894
895 async fn maybe_checkpoint(&self) -> Result<(), StreamError> {
896 let last = *self.last_checkpoint.read().await;
897
898 if last.elapsed() >= self.config.checkpoint_interval {
899 self.checkpoint().await?;
900
901 let mut last_checkpoint = self.last_checkpoint.write().await;
902 *last_checkpoint = Instant::now();
903 }
904
905 Ok(())
906 }
907}
908
909pub struct StreamFeatureExtractor {
911 feature_names: Vec<String>,
913 running_mean: Arc<RwLock<Vec<f64>>>,
915 running_var: Arc<RwLock<Vec<f64>>>,
917 sample_count: Arc<RwLock<u64>>,
919}
920
921impl StreamFeatureExtractor {
922 pub fn new(feature_names: Vec<String>) -> Self {
924 let dim = feature_names.len();
925 Self {
926 feature_names,
927 running_mean: Arc::new(RwLock::new(vec![0.0; dim])),
928 running_var: Arc::new(RwLock::new(vec![1.0; dim])),
929 sample_count: Arc::new(RwLock::new(0)),
930 }
931 }
932
933 pub async fn extract(&self, raw_features: &[f64]) -> Vec<f64> {
935 let mean = self.running_mean.read().await;
936 let var = self.running_var.read().await;
937
938 raw_features
939 .iter()
940 .enumerate()
941 .map(|(i, &x)| {
942 if i < mean.len() {
943 (x - mean[i]) / var[i].sqrt().max(1e-8)
944 } else {
945 x
946 }
947 })
948 .collect()
949 }
950
951 pub async fn update_stats(&self, features: &[f64]) {
953 let mut mean = self.running_mean.write().await;
954 let mut var = self.running_var.write().await;
955 let mut count = self.sample_count.write().await;
956
957 *count += 1;
958 let n = *count as f64;
959
960 for (i, &x) in features.iter().enumerate() {
961 if i < mean.len() {
962 let delta = x - mean[i];
963 mean[i] += delta / n;
964 var[i] += delta * (x - mean[i]);
965 }
966 }
967 }
968
969 pub fn get_feature_names(&self) -> &[String] {
971 &self.feature_names
972 }
973}
974
975#[cfg(test)]
976mod tests {
977 use super::*;
978
979 #[tokio::test]
980 async fn test_linear_regression() {
981 let config = OnlineLearningConfig {
982 learning_rate: 0.1,
983 batch_size: 1,
984 ..Default::default()
985 };
986
987 let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
988
989 for _ in 0..100 {
991 let sample = Sample {
992 features: vec![1.0, 1.0],
993 target: 5.0,
994 weight: 1.0,
995 timestamp: SystemTime::now(),
996 };
997 model.partial_fit(sample).await.unwrap();
998 }
999
1000 let pred = model.predict(&[1.0, 1.0]).await.unwrap();
1001 assert!(pred.value.is_finite());
1004 }
1005
1006 #[tokio::test]
1007 async fn test_logistic_regression() {
1008 let config = OnlineLearningConfig {
1009 learning_rate: 0.5,
1010 batch_size: 1,
1011 ..Default::default()
1012 };
1013
1014 let model = OnlineLearningModel::new(ModelType::LogisticRegression, 2, config);
1015
1016 for _ in 0..50 {
1018 model
1020 .partial_fit(Sample {
1021 features: vec![1.0, 1.0],
1022 target: 1.0,
1023 weight: 1.0,
1024 timestamp: SystemTime::now(),
1025 })
1026 .await
1027 .unwrap();
1028
1029 model
1031 .partial_fit(Sample {
1032 features: vec![-1.0, -1.0],
1033 target: 0.0,
1034 weight: 1.0,
1035 timestamp: SystemTime::now(),
1036 })
1037 .await
1038 .unwrap();
1039 }
1040
1041 let pred_pos = model.predict(&[1.0, 1.0]).await.unwrap();
1042 let pred_neg = model.predict(&[-1.0, -1.0]).await.unwrap();
1043
1044 assert!(
1046 pred_pos.value >= 0.0 && pred_pos.value <= 1.0,
1047 "Positive prediction out of range"
1048 );
1049 assert!(
1050 pred_neg.value >= 0.0 && pred_neg.value <= 1.0,
1051 "Negative prediction out of range"
1052 );
1053 assert!(pred_pos.value.is_finite() && pred_neg.value.is_finite());
1055 }
1056
1057 #[tokio::test]
1058 async fn test_batch_training() {
1059 let config = OnlineLearningConfig {
1060 learning_rate: 0.1,
1061 batch_size: 10,
1062 ..Default::default()
1063 };
1064
1065 let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
1066
1067 let batch: Vec<Sample> = (0..20)
1068 .map(|i| Sample {
1069 features: vec![i as f64, i as f64 * 2.0],
1070 target: i as f64 * 3.0,
1071 weight: 1.0,
1072 timestamp: SystemTime::now(),
1073 })
1074 .collect();
1075
1076 model.partial_fit_batch(batch).await.unwrap();
1077
1078 let stats = model.get_stats().await;
1079 assert!(stats.total_samples >= 20);
1080 }
1081
1082 #[tokio::test]
1083 async fn test_checkpoint_and_restore() {
1084 let config = OnlineLearningConfig::default();
1085 let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
1086
1087 for _ in 0..10 {
1089 model
1090 .partial_fit(Sample {
1091 features: vec![1.0, 2.0],
1092 target: 3.0,
1093 weight: 1.0,
1094 timestamp: SystemTime::now(),
1095 })
1096 .await
1097 .unwrap();
1098 }
1099
1100 let checkpoint_id = model.checkpoint().await.unwrap();
1102 let weights_before = model.get_weights().await;
1103
1104 for _ in 0..10 {
1106 model
1107 .partial_fit(Sample {
1108 features: vec![5.0, 6.0],
1109 target: 11.0,
1110 weight: 1.0,
1111 timestamp: SystemTime::now(),
1112 })
1113 .await
1114 .unwrap();
1115 }
1116
1117 model.restore(&checkpoint_id).await.unwrap();
1119 let weights_after = model.get_weights().await;
1120
1121 assert_eq!(weights_before, weights_after);
1122 }
1123
1124 #[tokio::test]
1125 async fn test_drift_detection() {
1126 let config = OnlineLearningConfig {
1127 detect_drift: true,
1128 drift_sensitivity: 0.01,
1129 ..Default::default()
1130 };
1131
1132 let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
1133
1134 {
1136 let mut history = model.error_history.write().await;
1137 for _ in 0..500 {
1138 history.push_back(0.1);
1139 }
1140 }
1141
1142 {
1144 let mut history = model.error_history.write().await;
1145 for _ in 0..500 {
1146 history.push_back(0.5);
1147 }
1148 }
1149
1150 let drift = model.detect_drift().await.unwrap();
1151 assert!(drift.drift_detected);
1152 }
1153
1154 #[tokio::test]
1155 async fn test_perceptron() {
1156 let config = OnlineLearningConfig {
1157 learning_rate: 1.0,
1158 batch_size: 1,
1159 ..Default::default()
1160 };
1161
1162 let model = OnlineLearningModel::new(ModelType::Perceptron, 2, config);
1163
1164 for _ in 0..100 {
1166 model
1167 .partial_fit(Sample {
1168 features: vec![1.0, 1.0],
1169 target: 1.0,
1170 weight: 1.0,
1171 timestamp: SystemTime::now(),
1172 })
1173 .await
1174 .unwrap();
1175
1176 model
1177 .partial_fit(Sample {
1178 features: vec![-1.0, -1.0],
1179 target: 0.0,
1180 weight: 1.0,
1181 timestamp: SystemTime::now(),
1182 })
1183 .await
1184 .unwrap();
1185 }
1186
1187 let pred = model.predict(&[1.0, 1.0]).await.unwrap();
1188 assert_eq!(pred.value, 1.0);
1189 }
1190
1191 #[tokio::test]
1192 async fn test_feature_extractor() {
1193 let extractor = StreamFeatureExtractor::new(vec!["f1".to_string(), "f2".to_string()]);
1194
1195 for i in 0..100 {
1197 let features = vec![i as f64, (i * 2) as f64];
1198 extractor.update_stats(&features).await;
1199 }
1200
1201 let normalized = extractor.extract(&[50.0, 100.0]).await;
1203 assert_eq!(normalized.len(), 2);
1204 }
1205
1206 #[tokio::test]
1207 async fn test_model_reset() {
1208 let config = OnlineLearningConfig::default();
1209 let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
1210
1211 model
1213 .partial_fit(Sample {
1214 features: vec![1.0, 2.0],
1215 target: 3.0,
1216 weight: 1.0,
1217 timestamp: SystemTime::now(),
1218 })
1219 .await
1220 .unwrap();
1221
1222 model.reset().await;
1224
1225 let weights = model.get_weights().await;
1226 assert!(weights.iter().all(|&w| w == 0.0));
1227 }
1228
1229 #[tokio::test]
1230 async fn test_metrics_tracking() {
1231 let config = OnlineLearningConfig {
1232 batch_size: 1,
1233 ..Default::default()
1234 };
1235
1236 let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
1237
1238 for _ in 0..10 {
1239 model
1240 .partial_fit(Sample {
1241 features: vec![1.0, 1.0],
1242 target: 2.0,
1243 weight: 1.0,
1244 timestamp: SystemTime::now(),
1245 })
1246 .await
1247 .unwrap();
1248 }
1249
1250 let metrics = model.get_metrics().await;
1251 assert!(metrics.sample_count >= 10);
1252 }
1253
1254 #[tokio::test]
1255 async fn test_passive_aggressive() {
1256 let config = OnlineLearningConfig {
1257 batch_size: 1,
1258 ..Default::default()
1259 };
1260
1261 let model = OnlineLearningModel::new(ModelType::PassiveAggressive, 2, config);
1262
1263 for _ in 0..50 {
1264 model
1265 .partial_fit(Sample {
1266 features: vec![1.0, 0.0],
1267 target: 1.0,
1268 weight: 1.0,
1269 timestamp: SystemTime::now(),
1270 })
1271 .await
1272 .unwrap();
1273
1274 model
1275 .partial_fit(Sample {
1276 features: vec![0.0, 1.0],
1277 target: -1.0,
1278 weight: 1.0,
1279 timestamp: SystemTime::now(),
1280 })
1281 .await
1282 .unwrap();
1283 }
1284
1285 let pred = model.predict(&[1.0, 0.0]).await.unwrap();
1286 assert!(pred.value >= 0.0);
1287 }
1288
1289 #[tokio::test]
1290 async fn test_ab_testing() {
1291 let config = OnlineLearningConfig {
1292 enable_ab_testing: true,
1293 ..Default::default()
1294 };
1295
1296 let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
1297
1298 let ab_config = ABTestConfig {
1299 name: "test".to_string(),
1300 control_version: 0,
1301 treatment_version: 1,
1302 traffic_split: 0.5,
1303 min_samples: 100,
1304 significance_level: 0.05,
1305 };
1306
1307 model.start_ab_test(ab_config).await.unwrap();
1308
1309 for _ in 0..10 {
1311 model.predict_ab(&[1.0, 1.0]).await.unwrap();
1312 }
1313
1314 let result = model.stop_ab_test().await.unwrap();
1315 assert!(result.is_some());
1316 }
1317}