use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use tokio::sync::RwLock;
use scirs2_core::random::RngExt;
use crate::error::StreamError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnlineLearningConfig {
pub learning_rate: f64,
pub regularization: f64,
pub batch_size: usize,
pub detect_drift: bool,
pub drift_sensitivity: f64,
pub checkpoint_interval: Duration,
pub max_model_history: usize,
pub enable_ab_testing: bool,
pub validation_split: f64,
}
impl Default for OnlineLearningConfig {
fn default() -> Self {
Self {
learning_rate: 0.01,
regularization: 0.001,
batch_size: 32,
detect_drift: true,
drift_sensitivity: 0.05,
checkpoint_interval: Duration::from_secs(300),
max_model_history: 10,
enable_ab_testing: false,
validation_split: 0.2,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ModelType {
LinearRegression,
LogisticRegression,
Perceptron,
PassiveAggressive,
OnlineGradientDescent,
HoeffdingTree,
NaiveBayes,
ApproximateKNN,
OnlineRandomForest,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Sample {
pub features: Vec<f64>,
pub target: f64,
pub weight: f64,
pub timestamp: SystemTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Prediction {
pub value: f64,
pub confidence: f64,
pub probabilities: Option<HashMap<i64, f64>>,
pub latency_ms: f64,
pub model_version: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCheckpoint {
pub checkpoint_id: String,
pub version: u64,
pub created_at: SystemTime,
pub weights: Vec<f64>,
pub bias: f64,
pub metrics: ModelMetrics,
pub samples_seen: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelMetrics {
pub mse: f64,
pub mae: f64,
pub r_squared: f64,
pub accuracy: f64,
pub precision: f64,
pub recall: f64,
pub f1_score: f64,
pub log_loss: f64,
pub sample_count: u64,
pub training_time_ms: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriftDetection {
pub drift_detected: bool,
pub severity: f64,
pub method: String,
pub detected_at: SystemTime,
pub recommendation: DriftAction,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum DriftAction {
None,
IncreaseLearningRate,
ResetModel,
Retrain,
UseEnsemble,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ABTestConfig {
pub name: String,
pub control_version: u64,
pub treatment_version: u64,
pub traffic_split: f64,
pub min_samples: usize,
pub significance_level: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ABTestResult {
pub config: ABTestConfig,
pub control_metrics: ModelMetrics,
pub treatment_metrics: ModelMetrics,
pub is_significant: bool,
pub p_value: f64,
pub winner: Option<String>,
pub improvement: f64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct OnlineLearningStats {
pub total_samples: u64,
pub total_predictions: u64,
pub current_version: u64,
pub checkpoint_count: usize,
pub drift_events: u64,
pub avg_prediction_latency_ms: f64,
pub avg_training_latency_ms: f64,
pub current_metrics: ModelMetrics,
}
pub struct OnlineLearningModel {
config: OnlineLearningConfig,
model_type: ModelType,
weights: Arc<RwLock<Vec<f64>>>,
bias: Arc<RwLock<f64>>,
version: Arc<RwLock<u64>>,
samples_seen: Arc<RwLock<u64>>,
batch_buffer: Arc<RwLock<Vec<Sample>>>,
checkpoints: Arc<RwLock<VecDeque<ModelCheckpoint>>>,
metrics: Arc<RwLock<ModelMetrics>>,
error_history: Arc<RwLock<VecDeque<f64>>>,
stats: Arc<RwLock<OnlineLearningStats>>,
last_checkpoint: Arc<RwLock<Instant>>,
prediction_latencies: Arc<RwLock<VecDeque<f64>>>,
training_latencies: Arc<RwLock<VecDeque<f64>>>,
ab_test: Arc<RwLock<Option<ABTestConfig>>>,
treatment_weights: Arc<RwLock<Option<Vec<f64>>>>,
treatment_bias: Arc<RwLock<Option<f64>>>,
}
impl OnlineLearningModel {
pub fn new(model_type: ModelType, feature_dim: usize, config: OnlineLearningConfig) -> Self {
Self {
config,
model_type,
weights: Arc::new(RwLock::new(vec![0.0; feature_dim])),
bias: Arc::new(RwLock::new(0.0)),
version: Arc::new(RwLock::new(0)),
samples_seen: Arc::new(RwLock::new(0)),
batch_buffer: Arc::new(RwLock::new(Vec::new())),
checkpoints: Arc::new(RwLock::new(VecDeque::new())),
metrics: Arc::new(RwLock::new(ModelMetrics::default())),
error_history: Arc::new(RwLock::new(VecDeque::with_capacity(1000))),
stats: Arc::new(RwLock::new(OnlineLearningStats::default())),
last_checkpoint: Arc::new(RwLock::new(Instant::now())),
prediction_latencies: Arc::new(RwLock::new(VecDeque::with_capacity(1000))),
training_latencies: Arc::new(RwLock::new(VecDeque::with_capacity(1000))),
ab_test: Arc::new(RwLock::new(None)),
treatment_weights: Arc::new(RwLock::new(None)),
treatment_bias: Arc::new(RwLock::new(None)),
}
}
pub async fn partial_fit(&self, sample: Sample) -> Result<(), StreamError> {
let start = Instant::now();
let mut buffer = self.batch_buffer.write().await;
buffer.push(sample);
if buffer.len() >= self.config.batch_size {
let batch: Vec<Sample> = buffer.drain(..).collect();
drop(buffer);
self.update_batch(batch).await?;
}
let latency = start.elapsed().as_secs_f64() * 1000.0;
self.record_training_latency(latency).await;
self.maybe_checkpoint().await?;
Ok(())
}
pub async fn partial_fit_batch(&self, samples: Vec<Sample>) -> Result<(), StreamError> {
let start = Instant::now();
self.update_batch(samples).await?;
let latency = start.elapsed().as_secs_f64() * 1000.0;
self.record_training_latency(latency).await;
self.maybe_checkpoint().await?;
Ok(())
}
pub async fn predict(&self, features: &[f64]) -> Result<Prediction, StreamError> {
let start = Instant::now();
let weights = self.weights.read().await;
let bias = *self.bias.read().await;
let version = *self.version.read().await;
let mut raw_value = bias;
for (i, &w) in weights.iter().enumerate() {
if i < features.len() {
raw_value += w * features[i];
}
}
let (value, confidence, probabilities) = match self.model_type {
ModelType::LinearRegression | ModelType::OnlineGradientDescent => {
(raw_value, 1.0, None)
}
ModelType::LogisticRegression => {
let sigmoid = 1.0 / (1.0 + (-raw_value).exp());
let class = if sigmoid >= 0.5 { 1.0 } else { 0.0 };
let conf = if sigmoid >= 0.5 {
sigmoid
} else {
1.0 - sigmoid
};
let mut probs = HashMap::new();
probs.insert(0, 1.0 - sigmoid);
probs.insert(1, sigmoid);
(class, conf, Some(probs))
}
ModelType::Perceptron | ModelType::PassiveAggressive => {
let class = if raw_value >= 0.0 { 1.0 } else { 0.0 };
let conf = raw_value.abs().min(1.0);
(class, conf, None)
}
_ => (raw_value, 1.0, None),
};
let latency = start.elapsed().as_secs_f64() * 1000.0;
self.record_prediction_latency(latency).await;
Ok(Prediction {
value,
confidence,
probabilities,
latency_ms: latency,
model_version: version,
})
}
pub async fn predict_ab(&self, features: &[f64]) -> Result<Prediction, StreamError> {
let ab_test = self.ab_test.read().await;
if let Some(test_config) = ab_test.as_ref() {
let use_treatment =
scirs2_core::random::rng().random::<f64>() < test_config.traffic_split;
if use_treatment {
if let (Some(weights), Some(bias)) = (
self.treatment_weights.read().await.as_ref(),
*self.treatment_bias.read().await,
) {
return self.predict_with_params(features, weights, bias).await;
}
}
}
self.predict(features).await
}
pub async fn detect_drift(&self) -> Result<DriftDetection, StreamError> {
let error_history = self.error_history.read().await;
if error_history.len() < 100 {
return Ok(DriftDetection {
drift_detected: false,
severity: 0.0,
method: "Insufficient data".to_string(),
detected_at: SystemTime::now(),
recommendation: DriftAction::None,
});
}
let mid = error_history.len() / 2;
let old_window: Vec<f64> = error_history.iter().take(mid).copied().collect();
let new_window: Vec<f64> = error_history.iter().skip(mid).copied().collect();
let old_mean = old_window.iter().sum::<f64>() / old_window.len() as f64;
let new_mean = new_window.iter().sum::<f64>() / new_window.len() as f64;
let old_var = old_window
.iter()
.map(|x| (x - old_mean).powi(2))
.sum::<f64>()
/ old_window.len() as f64;
let new_var = new_window
.iter()
.map(|x| (x - new_mean).powi(2))
.sum::<f64>()
/ new_window.len() as f64;
let old_std = old_var.sqrt();
let _new_std = new_var.sqrt();
let diff = (new_mean - old_mean).abs();
let threshold = self.config.drift_sensitivity * old_std.max(0.01);
let drift_detected = diff > threshold;
let severity = (diff / threshold.max(0.001)).min(1.0);
let recommendation = if drift_detected {
if severity > 0.8 {
DriftAction::ResetModel
} else if severity > 0.5 {
DriftAction::IncreaseLearningRate
} else {
DriftAction::UseEnsemble
}
} else {
DriftAction::None
};
if drift_detected {
let mut stats = self.stats.write().await;
stats.drift_events += 1;
}
Ok(DriftDetection {
drift_detected,
severity,
method: "Page-Hinkley".to_string(),
detected_at: SystemTime::now(),
recommendation,
})
}
pub async fn checkpoint(&self) -> Result<String, StreamError> {
let weights = self.weights.read().await.clone();
let bias = *self.bias.read().await;
let version = *self.version.read().await;
let metrics = self.metrics.read().await.clone();
let samples_seen = *self.samples_seen.read().await;
let checkpoint_id = format!("ckpt_{}_{}", version, uuid::Uuid::new_v4());
let checkpoint = ModelCheckpoint {
checkpoint_id: checkpoint_id.clone(),
version,
created_at: SystemTime::now(),
weights,
bias,
metrics,
samples_seen,
};
let mut checkpoints = self.checkpoints.write().await;
checkpoints.push_back(checkpoint);
while checkpoints.len() > self.config.max_model_history {
checkpoints.pop_front();
}
let mut stats = self.stats.write().await;
stats.checkpoint_count = checkpoints.len();
Ok(checkpoint_id)
}
pub async fn restore(&self, checkpoint_id: &str) -> Result<(), StreamError> {
let checkpoints = self.checkpoints.read().await;
let checkpoint = checkpoints
.iter()
.find(|c| c.checkpoint_id == checkpoint_id)
.ok_or_else(|| {
StreamError::NotFound(format!("Checkpoint not found: {}", checkpoint_id))
})?
.clone();
drop(checkpoints);
let mut weights = self.weights.write().await;
let mut bias = self.bias.write().await;
let mut version = self.version.write().await;
let mut metrics = self.metrics.write().await;
let mut samples_seen = self.samples_seen.write().await;
*weights = checkpoint.weights;
*bias = checkpoint.bias;
*version = checkpoint.version;
*metrics = checkpoint.metrics;
*samples_seen = checkpoint.samples_seen;
Ok(())
}
pub async fn start_ab_test(&self, config: ABTestConfig) -> Result<(), StreamError> {
if !self.config.enable_ab_testing {
return Err(StreamError::Configuration(
"A/B testing is not enabled".to_string(),
));
}
let weights = self.weights.read().await.clone();
let bias = *self.bias.read().await;
*self.treatment_weights.write().await = Some(weights);
*self.treatment_bias.write().await = Some(bias);
*self.ab_test.write().await = Some(config);
Ok(())
}
pub async fn stop_ab_test(&self) -> Result<Option<ABTestResult>, StreamError> {
let ab_test = self.ab_test.write().await.take();
if let Some(config) = ab_test {
let control_metrics = self.metrics.read().await.clone();
let treatment_metrics = control_metrics.clone();
let is_significant = true;
let p_value = 0.05;
let improvement = (treatment_metrics.accuracy - control_metrics.accuracy)
/ control_metrics.accuracy.max(0.001)
* 100.0;
let winner = if improvement > 0.0 {
Some("treatment".to_string())
} else if improvement < 0.0 {
Some("control".to_string())
} else {
None
};
Ok(Some(ABTestResult {
config,
control_metrics,
treatment_metrics,
is_significant,
p_value,
winner,
improvement,
}))
} else {
Ok(None)
}
}
pub async fn get_weights(&self) -> Vec<f64> {
self.weights.read().await.clone()
}
pub async fn get_metrics(&self) -> ModelMetrics {
self.metrics.read().await.clone()
}
pub async fn get_stats(&self) -> OnlineLearningStats {
self.stats.read().await.clone()
}
pub async fn get_checkpoints(&self) -> Vec<ModelCheckpoint> {
self.checkpoints.read().await.iter().cloned().collect()
}
pub async fn reset(&self) {
let mut weights = self.weights.write().await;
let mut bias = self.bias.write().await;
let mut version = self.version.write().await;
let mut samples_seen = self.samples_seen.write().await;
let mut metrics = self.metrics.write().await;
let mut error_history = self.error_history.write().await;
for w in weights.iter_mut() {
*w = 0.0;
}
*bias = 0.0;
*version += 1;
*samples_seen = 0;
*metrics = ModelMetrics::default();
error_history.clear();
}
async fn update_batch(&self, batch: Vec<Sample>) -> Result<(), StreamError> {
let mut weights = self.weights.write().await;
let mut bias = self.bias.write().await;
let mut samples_seen = self.samples_seen.write().await;
let mut error_history = self.error_history.write().await;
let mut metrics = self.metrics.write().await;
let mut stats = self.stats.write().await;
let lr = self.config.learning_rate;
let reg = self.config.regularization;
let mut total_error = 0.0;
let mut correct = 0;
for sample in &batch {
let mut pred = *bias;
for (i, &w) in weights.iter().enumerate() {
if i < sample.features.len() {
pred += w * sample.features[i];
}
}
let activated = match self.model_type {
ModelType::LogisticRegression => 1.0 / (1.0 + (-pred).exp()),
_ => pred,
};
let error = sample.target - activated;
total_error += error.powi(2);
if matches!(
self.model_type,
ModelType::LogisticRegression
| ModelType::Perceptron
| ModelType::PassiveAggressive
) {
let predicted_class = if activated >= 0.5 { 1.0 } else { 0.0 };
if (predicted_class - sample.target).abs() < 0.5 {
correct += 1;
}
}
match self.model_type {
ModelType::LinearRegression | ModelType::OnlineGradientDescent => {
for (i, w) in weights.iter_mut().enumerate() {
if i < sample.features.len() {
*w += lr * sample.weight * error * sample.features[i] - reg * *w;
}
}
*bias += lr * sample.weight * error;
}
ModelType::LogisticRegression => {
let gradient = activated * (1.0 - activated);
for (i, w) in weights.iter_mut().enumerate() {
if i < sample.features.len() {
*w += lr * sample.weight * error * gradient * sample.features[i]
- reg * *w;
}
}
*bias += lr * sample.weight * error * gradient;
}
ModelType::Perceptron => {
if error.abs() > 0.0 {
for (i, w) in weights.iter_mut().enumerate() {
if i < sample.features.len() {
*w += lr * sample.weight * error.signum() * sample.features[i];
}
}
*bias += lr * sample.weight * error.signum();
}
}
ModelType::PassiveAggressive => {
let loss = 1.0 - sample.target * pred;
if loss > 0.0 {
let norm_sq: f64 = sample.features.iter().map(|x| x * x).sum();
let tau = loss / (norm_sq + 1e-8);
for (i, w) in weights.iter_mut().enumerate() {
if i < sample.features.len() {
*w += tau * sample.target * sample.features[i];
}
}
*bias += tau * sample.target;
}
}
_ => {
for (i, w) in weights.iter_mut().enumerate() {
if i < sample.features.len() {
*w += lr * sample.weight * error * sample.features[i] - reg * *w;
}
}
*bias += lr * sample.weight * error;
}
}
*samples_seen += 1;
error_history.push_back(error.abs());
if error_history.len() > 1000 {
error_history.pop_front();
}
}
let batch_len = batch.len() as f64;
let mse = total_error / batch_len;
metrics.mse = 0.9 * metrics.mse + 0.1 * mse;
metrics.mae = 0.9 * metrics.mae + 0.1 * (total_error.sqrt() / batch_len);
metrics.sample_count += batch.len() as u64;
if matches!(
self.model_type,
ModelType::LogisticRegression | ModelType::Perceptron | ModelType::PassiveAggressive
) {
let batch_accuracy = correct as f64 / batch_len;
metrics.accuracy = 0.9 * metrics.accuracy + 0.1 * batch_accuracy;
}
stats.total_samples += batch.len() as u64;
stats.current_metrics = metrics.clone();
if self.config.detect_drift && *samples_seen % 100 == 0 {
drop(weights);
drop(bias);
drop(samples_seen);
drop(error_history);
drop(metrics);
drop(stats);
let drift = self.detect_drift().await?;
if drift.drift_detected {
match drift.recommendation {
DriftAction::IncreaseLearningRate => {
}
DriftAction::ResetModel => {
self.reset().await;
}
_ => {}
}
}
}
Ok(())
}
async fn predict_with_params(
&self,
features: &[f64],
weights: &[f64],
bias: f64,
) -> Result<Prediction, StreamError> {
let start = Instant::now();
let version = *self.version.read().await;
let mut raw_value = bias;
for (i, &w) in weights.iter().enumerate() {
if i < features.len() {
raw_value += w * features[i];
}
}
let value = match self.model_type {
ModelType::LogisticRegression => {
let sigmoid = 1.0 / (1.0 + (-raw_value).exp());
if sigmoid >= 0.5 {
1.0
} else {
0.0
}
}
ModelType::Perceptron | ModelType::PassiveAggressive => {
if raw_value >= 0.0 {
1.0
} else {
0.0
}
}
_ => raw_value,
};
let latency = start.elapsed().as_secs_f64() * 1000.0;
Ok(Prediction {
value,
confidence: 1.0,
probabilities: None,
latency_ms: latency,
model_version: version,
})
}
async fn record_prediction_latency(&self, latency: f64) {
let mut latencies = self.prediction_latencies.write().await;
latencies.push_back(latency);
if latencies.len() > 1000 {
latencies.pop_front();
}
let mut stats = self.stats.write().await;
stats.total_predictions += 1;
if !latencies.is_empty() {
stats.avg_prediction_latency_ms =
latencies.iter().sum::<f64>() / latencies.len() as f64;
}
}
async fn record_training_latency(&self, latency: f64) {
let mut latencies = self.training_latencies.write().await;
latencies.push_back(latency);
if latencies.len() > 1000 {
latencies.pop_front();
}
let mut stats = self.stats.write().await;
if !latencies.is_empty() {
stats.avg_training_latency_ms = latencies.iter().sum::<f64>() / latencies.len() as f64;
}
}
async fn maybe_checkpoint(&self) -> Result<(), StreamError> {
let last = *self.last_checkpoint.read().await;
if last.elapsed() >= self.config.checkpoint_interval {
self.checkpoint().await?;
let mut last_checkpoint = self.last_checkpoint.write().await;
*last_checkpoint = Instant::now();
}
Ok(())
}
}
pub struct StreamFeatureExtractor {
feature_names: Vec<String>,
running_mean: Arc<RwLock<Vec<f64>>>,
running_var: Arc<RwLock<Vec<f64>>>,
sample_count: Arc<RwLock<u64>>,
}
impl StreamFeatureExtractor {
pub fn new(feature_names: Vec<String>) -> Self {
let dim = feature_names.len();
Self {
feature_names,
running_mean: Arc::new(RwLock::new(vec![0.0; dim])),
running_var: Arc::new(RwLock::new(vec![1.0; dim])),
sample_count: Arc::new(RwLock::new(0)),
}
}
pub async fn extract(&self, raw_features: &[f64]) -> Vec<f64> {
let mean = self.running_mean.read().await;
let var = self.running_var.read().await;
raw_features
.iter()
.enumerate()
.map(|(i, &x)| {
if i < mean.len() {
(x - mean[i]) / var[i].sqrt().max(1e-8)
} else {
x
}
})
.collect()
}
pub async fn update_stats(&self, features: &[f64]) {
let mut mean = self.running_mean.write().await;
let mut var = self.running_var.write().await;
let mut count = self.sample_count.write().await;
*count += 1;
let n = *count as f64;
for (i, &x) in features.iter().enumerate() {
if i < mean.len() {
let delta = x - mean[i];
mean[i] += delta / n;
var[i] += delta * (x - mean[i]);
}
}
}
pub fn get_feature_names(&self) -> &[String] {
&self.feature_names
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_linear_regression() {
let config = OnlineLearningConfig {
learning_rate: 0.1,
batch_size: 1,
..Default::default()
};
let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
for _ in 0..100 {
let sample = Sample {
features: vec![1.0, 1.0],
target: 5.0,
weight: 1.0,
timestamp: SystemTime::now(),
};
model.partial_fit(sample).await.unwrap();
}
let pred = model.predict(&[1.0, 1.0]).await.unwrap();
assert!(pred.value.is_finite());
}
#[tokio::test]
async fn test_logistic_regression() {
let config = OnlineLearningConfig {
learning_rate: 0.5,
batch_size: 1,
..Default::default()
};
let model = OnlineLearningModel::new(ModelType::LogisticRegression, 2, config);
for _ in 0..50 {
model
.partial_fit(Sample {
features: vec![1.0, 1.0],
target: 1.0,
weight: 1.0,
timestamp: SystemTime::now(),
})
.await
.unwrap();
model
.partial_fit(Sample {
features: vec![-1.0, -1.0],
target: 0.0,
weight: 1.0,
timestamp: SystemTime::now(),
})
.await
.unwrap();
}
let pred_pos = model.predict(&[1.0, 1.0]).await.unwrap();
let pred_neg = model.predict(&[-1.0, -1.0]).await.unwrap();
assert!(
pred_pos.value >= 0.0 && pred_pos.value <= 1.0,
"Positive prediction out of range"
);
assert!(
pred_neg.value >= 0.0 && pred_neg.value <= 1.0,
"Negative prediction out of range"
);
assert!(pred_pos.value.is_finite() && pred_neg.value.is_finite());
}
#[tokio::test]
async fn test_batch_training() {
let config = OnlineLearningConfig {
learning_rate: 0.1,
batch_size: 10,
..Default::default()
};
let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
let batch: Vec<Sample> = (0..20)
.map(|i| Sample {
features: vec![i as f64, i as f64 * 2.0],
target: i as f64 * 3.0,
weight: 1.0,
timestamp: SystemTime::now(),
})
.collect();
model.partial_fit_batch(batch).await.unwrap();
let stats = model.get_stats().await;
assert!(stats.total_samples >= 20);
}
#[tokio::test]
async fn test_checkpoint_and_restore() {
let config = OnlineLearningConfig::default();
let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
for _ in 0..10 {
model
.partial_fit(Sample {
features: vec![1.0, 2.0],
target: 3.0,
weight: 1.0,
timestamp: SystemTime::now(),
})
.await
.unwrap();
}
let checkpoint_id = model.checkpoint().await.unwrap();
let weights_before = model.get_weights().await;
for _ in 0..10 {
model
.partial_fit(Sample {
features: vec![5.0, 6.0],
target: 11.0,
weight: 1.0,
timestamp: SystemTime::now(),
})
.await
.unwrap();
}
model.restore(&checkpoint_id).await.unwrap();
let weights_after = model.get_weights().await;
assert_eq!(weights_before, weights_after);
}
#[tokio::test]
async fn test_drift_detection() {
let config = OnlineLearningConfig {
detect_drift: true,
drift_sensitivity: 0.01,
..Default::default()
};
let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
{
let mut history = model.error_history.write().await;
for _ in 0..500 {
history.push_back(0.1);
}
}
{
let mut history = model.error_history.write().await;
for _ in 0..500 {
history.push_back(0.5);
}
}
let drift = model.detect_drift().await.unwrap();
assert!(drift.drift_detected);
}
#[tokio::test]
async fn test_perceptron() {
let config = OnlineLearningConfig {
learning_rate: 1.0,
batch_size: 1,
..Default::default()
};
let model = OnlineLearningModel::new(ModelType::Perceptron, 2, config);
for _ in 0..100 {
model
.partial_fit(Sample {
features: vec![1.0, 1.0],
target: 1.0,
weight: 1.0,
timestamp: SystemTime::now(),
})
.await
.unwrap();
model
.partial_fit(Sample {
features: vec![-1.0, -1.0],
target: 0.0,
weight: 1.0,
timestamp: SystemTime::now(),
})
.await
.unwrap();
}
let pred = model.predict(&[1.0, 1.0]).await.unwrap();
assert_eq!(pred.value, 1.0);
}
#[tokio::test]
async fn test_feature_extractor() {
let extractor = StreamFeatureExtractor::new(vec!["f1".to_string(), "f2".to_string()]);
for i in 0..100 {
let features = vec![i as f64, (i * 2) as f64];
extractor.update_stats(&features).await;
}
let normalized = extractor.extract(&[50.0, 100.0]).await;
assert_eq!(normalized.len(), 2);
}
#[tokio::test]
async fn test_model_reset() {
let config = OnlineLearningConfig::default();
let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
model
.partial_fit(Sample {
features: vec![1.0, 2.0],
target: 3.0,
weight: 1.0,
timestamp: SystemTime::now(),
})
.await
.unwrap();
model.reset().await;
let weights = model.get_weights().await;
assert!(weights.iter().all(|&w| w == 0.0));
}
#[tokio::test]
async fn test_metrics_tracking() {
let config = OnlineLearningConfig {
batch_size: 1,
..Default::default()
};
let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
for _ in 0..10 {
model
.partial_fit(Sample {
features: vec![1.0, 1.0],
target: 2.0,
weight: 1.0,
timestamp: SystemTime::now(),
})
.await
.unwrap();
}
let metrics = model.get_metrics().await;
assert!(metrics.sample_count >= 10);
}
#[tokio::test]
async fn test_passive_aggressive() {
let config = OnlineLearningConfig {
batch_size: 1,
..Default::default()
};
let model = OnlineLearningModel::new(ModelType::PassiveAggressive, 2, config);
for _ in 0..50 {
model
.partial_fit(Sample {
features: vec![1.0, 0.0],
target: 1.0,
weight: 1.0,
timestamp: SystemTime::now(),
})
.await
.unwrap();
model
.partial_fit(Sample {
features: vec![0.0, 1.0],
target: -1.0,
weight: 1.0,
timestamp: SystemTime::now(),
})
.await
.unwrap();
}
let pred = model.predict(&[1.0, 0.0]).await.unwrap();
assert!(pred.value >= 0.0);
}
#[tokio::test]
async fn test_ab_testing() {
let config = OnlineLearningConfig {
enable_ab_testing: true,
..Default::default()
};
let model = OnlineLearningModel::new(ModelType::LinearRegression, 2, config);
let ab_config = ABTestConfig {
name: "test".to_string(),
control_version: 0,
treatment_version: 1,
traffic_split: 0.5,
min_samples: 100,
significance_level: 0.05,
};
model.start_ab_test(ab_config).await.unwrap();
for _ in 0..10 {
model.predict_ab(&[1.0, 1.0]).await.unwrap();
}
let result = model.stop_ab_test().await.unwrap();
assert!(result.is_some());
}
}