use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OnlineLearner {
pub id: Uuid,
pub name: String,
pub parameters: Vec<f32>,
pub fisher_diagonal: Vec<f32>,
pub parameter_history: VecDeque<ParameterSnapshot>,
pub learning_rate: f32,
pub ewc_lambda: f32,
pub update_count: u64,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl OnlineLearner {
pub fn new(name: impl Into<String>, num_parameters: usize) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4(),
name: name.into(),
parameters: vec![0.0; num_parameters],
fisher_diagonal: vec![1.0; num_parameters], parameter_history: VecDeque::with_capacity(10),
learning_rate: 0.01,
ewc_lambda: 0.5,
update_count: 0,
created_at: now,
updated_at: now,
}
}
pub fn with_learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn with_ewc_lambda(mut self, lambda: f32) -> Self {
self.ewc_lambda = lambda;
self
}
pub fn update(&mut self, features: &[f32], target: f32) -> f32 {
if features.len() != self.parameters.len() {
return 0.0;
}
let prediction: f32 = features
.iter()
.zip(self.parameters.iter())
.map(|(f, p)| f * p)
.sum();
let error = prediction - target;
let loss = error * error;
#[allow(clippy::needless_range_loop)]
for i in 0..self.parameters.len() {
let base_grad = 2.0 * error * features[i];
let mut ewc_grad = 0.0;
for snapshot in &self.parameter_history {
let delta = self.parameters[i] - snapshot.parameters[i];
let importance = snapshot.fisher[i];
ewc_grad += 2.0 * self.ewc_lambda * importance * delta;
}
self.parameters[i] -= self.learning_rate * (base_grad + ewc_grad);
}
self.update_fisher(features, error);
self.update_count += 1;
self.updated_at = Utc::now();
loss
}
fn update_fisher(&mut self, features: &[f32], error: f32) {
let decay = 0.99;
for (fisher, &feat) in self.fisher_diagonal.iter_mut().zip(features.iter()) {
let grad_sq = (2.0 * error * feat).powi(2);
*fisher = decay * *fisher + (1.0 - decay) * grad_sq;
}
}
pub fn consolidate(&mut self) {
let snapshot = ParameterSnapshot {
parameters: self.parameters.clone(),
fisher: self.fisher_diagonal.clone(),
timestamp: Utc::now(),
update_count: self.update_count,
};
self.parameter_history.push_back(snapshot);
while self.parameter_history.len() > 10 {
self.parameter_history.pop_front();
}
}
pub fn predict(&self, features: &[f32]) -> f32 {
if features.len() != self.parameters.len() {
return 0.0;
}
features
.iter()
.zip(self.parameters.iter())
.map(|(f, p)| f * p)
.sum()
}
pub fn get_parameters(&self) -> &[f32] {
&self.parameters
}
pub fn get_importance(&self) -> &[f32] {
&self.fisher_diagonal
}
pub fn num_snapshots(&self) -> usize {
self.parameter_history.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParameterSnapshot {
pub parameters: Vec<f32>,
pub fisher: Vec<f32>,
pub timestamp: DateTime<Utc>,
pub update_count: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExperienceWindow {
experiences: VecDeque<Experience>,
capacity: usize,
max_age: Duration,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Experience {
pub features: Vec<f32>,
pub target: f32,
pub timestamp: DateTime<Utc>,
pub task_id: Option<String>,
}
impl ExperienceWindow {
pub fn new(capacity: usize) -> Self {
Self {
experiences: VecDeque::with_capacity(capacity),
capacity,
max_age: Duration::hours(24),
}
}
pub fn with_max_age(mut self, hours: i64) -> Self {
self.max_age = Duration::hours(hours);
self
}
pub fn add(&mut self, features: Vec<f32>, target: f32, task_id: Option<String>) {
let exp = Experience {
features,
target,
timestamp: Utc::now(),
task_id,
};
self.experiences.push_back(exp);
while self.experiences.len() > self.capacity {
self.experiences.pop_front();
}
self.prune_old();
}
pub fn sample(&self, count: usize) -> Vec<&Experience> {
use rand::Rng;
if self.experiences.is_empty() || count == 0 {
return Vec::new();
}
let mut rng = rand::thread_rng();
let mut result: Vec<&Experience> = Vec::with_capacity(count.min(self.experiences.len()));
for (i, exp) in self.experiences.iter().enumerate() {
if result.len() < count {
result.push(exp);
} else {
let j = rng.gen_range(0..=i);
if j < count {
result[j] = exp;
}
}
}
result
}
pub fn by_task(&self, task_id: &str) -> Vec<&Experience> {
self.experiences
.iter()
.filter(|e| e.task_id.as_deref() == Some(task_id))
.collect()
}
fn prune_old(&mut self) {
let cutoff = Utc::now() - self.max_age;
while let Some(front) = self.experiences.front() {
if front.timestamp < cutoff {
self.experiences.pop_front();
} else {
break;
}
}
}
pub fn len(&self) -> usize {
self.experiences.len()
}
pub fn is_empty(&self) -> bool {
self.experiences.is_empty()
}
pub fn capacity(&self) -> usize {
self.capacity
}
}
impl Default for ExperienceWindow {
fn default() -> Self {
Self::new(1000)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriftDetector {
running_mean: Vec<f32>,
running_var: Vec<f32>,
count: u64,
shift_scores: VecDeque<f32>,
threshold: f32,
}
impl DriftDetector {
pub fn new(num_features: usize) -> Self {
Self {
running_mean: vec![0.0; num_features],
running_var: vec![1.0; num_features],
count: 0,
shift_scores: VecDeque::with_capacity(100),
threshold: 2.0, }
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self
}
pub fn update(&mut self, features: &[f32]) -> bool {
if features.len() != self.running_mean.len() {
return false;
}
let shift_score: f32 = features
.iter()
.zip(self.running_mean.iter())
.zip(self.running_var.iter())
.map(|((f, m), v)| ((f - m).powi(2)) / v.max(1e-6))
.sum::<f32>()
.sqrt()
/ (features.len() as f32).sqrt();
self.shift_scores.push_back(shift_score);
while self.shift_scores.len() > 100 {
self.shift_scores.pop_front();
}
self.count += 1;
let n = self.count as f32;
for ((mean, var), &feat) in self
.running_mean
.iter_mut()
.zip(self.running_var.iter_mut())
.zip(features.iter())
{
let delta = feat - *mean;
*mean += delta / n;
let delta2 = feat - *mean;
*var += (delta * delta2 - *var) / n;
}
shift_score > self.threshold
}
pub fn average_shift(&self) -> f32 {
if self.shift_scores.is_empty() {
return 0.0;
}
self.shift_scores.iter().sum::<f32>() / self.shift_scores.len() as f32
}
pub fn is_drifting(&self) -> bool {
self.average_shift() > self.threshold
}
pub fn reset(&mut self) {
self.running_mean.fill(0.0);
self.running_var.fill(1.0);
self.count = 0;
self.shift_scores.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_online_learning() {
let mut learner = OnlineLearner::new("linear", 2).with_learning_rate(0.1);
for _ in 0..100 {
let x1 = rand::random::<f32>();
let x2 = rand::random::<f32>();
let y = 2.0 * x1 + 3.0 * x2;
learner.update(&[x1, x2], y);
}
let params = learner.get_parameters();
assert!(
(params[0] - 2.0).abs() < 0.3,
"Expected ~2.0, got {}",
params[0]
);
assert!(
(params[1] - 3.0).abs() < 0.3,
"Expected ~3.0, got {}",
params[1]
);
}
#[test]
fn test_ewc_consolidation() {
let mut learner = OnlineLearner::new("ewc_test", 1)
.with_learning_rate(0.1)
.with_ewc_lambda(1.0);
for _ in 0..50 {
let x = rand::random::<f32>();
let y = 2.0 * x;
learner.update(&[x], y);
}
let task_a_param = learner.parameters[0];
learner.consolidate();
assert_eq!(learner.num_snapshots(), 1);
for _ in 0..50 {
let x = rand::random::<f32>();
let y = -x;
learner.update(&[x], y);
}
let final_param = learner.parameters[0];
assert!(
final_param > -0.5,
"EWC should prevent full forgetting: {}",
final_param
);
assert!(
final_param < task_a_param,
"Should have adapted to Task B: {}",
final_param
);
}
#[test]
fn test_experience_window() {
let mut window = ExperienceWindow::new(10);
for i in 0..15 {
window.add(vec![i as f32], i as f32, Some("task1".to_string()));
}
assert_eq!(window.len(), 10);
let sample = window.sample(5);
assert_eq!(sample.len(), 5);
let task1 = window.by_task("task1");
assert!(!task1.is_empty());
}
#[test]
fn test_drift_detection() {
let mut detector = DriftDetector::new(2).with_threshold(3.0);
for _ in 0..200 {
let x1 = rand::random::<f32>();
let x2 = rand::random::<f32>();
detector.update(&[x1, x2]);
}
detector.shift_scores.clear();
for _ in 0..50 {
let x1 = rand::random::<f32>();
let x2 = rand::random::<f32>();
detector.update(&[x1, x2]);
}
let baseline_shift = detector.average_shift();
let mut _drift_detected = false;
for _ in 0..20 {
let x1 = rand::random::<f32>() + 9.5;
let x2 = rand::random::<f32>() + 9.5;
if detector.update(&[x1, x2]) {
_drift_detected = true;
}
}
let drift_shift = detector.average_shift();
assert!(
drift_shift > baseline_shift,
"Drift shift {} should be > baseline {}",
drift_shift,
baseline_shift
);
}
}