Skip to main content

converge_knowledge/agentic/
online.rs

1//! Online Learning - Continual Adaptation
2//!
3//! Implements online/continual learning mechanisms that allow agents to:
4//!
5//! 1. Learn incrementally from new data without forgetting
6//! 2. Adapt to distribution shifts over time
7//! 3. Use Elastic Weight Consolidation (EWC) to protect important weights
8//! 4. Maintain a sliding window of recent experiences
9//!
10//! Based on continual learning research including EWC, Progressive Networks,
11//! and experience replay strategies.
12
13use chrono::{DateTime, Duration, Utc};
14use serde::{Deserialize, Serialize};
15use std::collections::VecDeque;
16use uuid::Uuid;
17
18/// Online learner that adapts continuously to new data.
19///
20/// Key features:
21/// - Incremental updates without full retraining
22/// - EWC-style importance weighting
23/// - Forgetting prevention via rehearsal
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct OnlineLearner {
26    /// Learner identifier.
27    pub id: Uuid,
28
29    /// Name/description.
30    pub name: String,
31
32    /// Current parameter estimates (feature weights).
33    pub parameters: Vec<f32>,
34
35    /// Fisher information (importance of each parameter).
36    pub fisher_diagonal: Vec<f32>,
37
38    /// Historical parameter snapshots for EWC.
39    pub parameter_history: VecDeque<ParameterSnapshot>,
40
41    /// Learning rate.
42    pub learning_rate: f32,
43
44    /// EWC regularization strength (lambda).
45    pub ewc_lambda: f32,
46
47    /// Number of updates performed.
48    pub update_count: u64,
49
50    /// When this learner was created.
51    pub created_at: DateTime<Utc>,
52
53    /// When this learner was last updated.
54    pub updated_at: DateTime<Utc>,
55}
56
57impl OnlineLearner {
58    /// Create a new online learner.
59    pub fn new(name: impl Into<String>, num_parameters: usize) -> Self {
60        let now = Utc::now();
61        Self {
62            id: Uuid::new_v4(),
63            name: name.into(),
64            parameters: vec![0.0; num_parameters],
65            fisher_diagonal: vec![1.0; num_parameters], // Start with uniform importance
66            parameter_history: VecDeque::with_capacity(10),
67            learning_rate: 0.01,
68            ewc_lambda: 0.5,
69            update_count: 0,
70            created_at: now,
71            updated_at: now,
72        }
73    }
74
75    /// Set learning rate.
76    pub fn with_learning_rate(mut self, lr: f32) -> Self {
77        self.learning_rate = lr;
78        self
79    }
80
81    /// Set EWC lambda.
82    pub fn with_ewc_lambda(mut self, lambda: f32) -> Self {
83        self.ewc_lambda = lambda;
84        self
85    }
86
87    /// Update parameters with new observation.
88    ///
89    /// Uses gradient descent with EWC regularization to prevent forgetting.
90    pub fn update(&mut self, features: &[f32], target: f32) -> f32 {
91        if features.len() != self.parameters.len() {
92            return 0.0;
93        }
94
95        // Forward pass: linear prediction
96        let prediction: f32 = features
97            .iter()
98            .zip(self.parameters.iter())
99            .map(|(f, p)| f * p)
100            .sum();
101
102        // Compute loss
103        let error = prediction - target;
104        let loss = error * error;
105
106        // Compute gradients with EWC penalty
107        for i in 0..self.parameters.len() {
108            // Base gradient (MSE loss)
109            let base_grad = 2.0 * error * features[i];
110
111            // EWC penalty: sum over previous tasks
112            let mut ewc_grad = 0.0;
113            for snapshot in &self.parameter_history {
114                let delta = self.parameters[i] - snapshot.parameters[i];
115                let importance = snapshot.fisher[i];
116                ewc_grad += 2.0 * self.ewc_lambda * importance * delta;
117            }
118
119            // Combined update
120            let total_grad = base_grad + ewc_grad;
121            self.parameters[i] -= self.learning_rate * total_grad;
122        }
123
124        // Update Fisher diagonal based on gradient magnitude
125        self.update_fisher(features, error);
126
127        self.update_count += 1;
128        self.updated_at = Utc::now();
129
130        loss
131    }
132
133    /// Update Fisher information diagonal.
134    fn update_fisher(&mut self, features: &[f32], error: f32) {
135        // Fisher diagonal approximated by squared gradients
136        let decay = 0.99;
137        for i in 0..self.fisher_diagonal.len() {
138            let grad_sq = (2.0 * error * features[i]).powi(2);
139            self.fisher_diagonal[i] = decay * self.fisher_diagonal[i] + (1.0 - decay) * grad_sq;
140        }
141    }
142
143    /// Consolidate current knowledge (take a snapshot for EWC).
144    ///
145    /// Call this when switching to a new task/domain to remember
146    /// the current parameters and their importance.
147    pub fn consolidate(&mut self) {
148        let snapshot = ParameterSnapshot {
149            parameters: self.parameters.clone(),
150            fisher: self.fisher_diagonal.clone(),
151            timestamp: Utc::now(),
152            update_count: self.update_count,
153        };
154
155        self.parameter_history.push_back(snapshot);
156
157        // Keep only recent snapshots
158        while self.parameter_history.len() > 10 {
159            self.parameter_history.pop_front();
160        }
161    }
162
163    /// Make a prediction.
164    pub fn predict(&self, features: &[f32]) -> f32 {
165        if features.len() != self.parameters.len() {
166            return 0.0;
167        }
168
169        features
170            .iter()
171            .zip(self.parameters.iter())
172            .map(|(f, p)| f * p)
173            .sum()
174    }
175
176    /// Get current parameters.
177    pub fn get_parameters(&self) -> &[f32] {
178        &self.parameters
179    }
180
181    /// Get parameter importance.
182    pub fn get_importance(&self) -> &[f32] {
183        &self.fisher_diagonal
184    }
185
186    /// Number of consolidation snapshots.
187    pub fn num_snapshots(&self) -> usize {
188        self.parameter_history.len()
189    }
190}
191
192/// A snapshot of parameters for EWC.
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct ParameterSnapshot {
195    /// Parameter values at snapshot time.
196    pub parameters: Vec<f32>,
197
198    /// Fisher diagonal (importance) at snapshot time.
199    pub fisher: Vec<f32>,
200
201    /// When this snapshot was taken.
202    pub timestamp: DateTime<Utc>,
203
204    /// Number of updates at snapshot time.
205    pub update_count: u64,
206}
207
208/// Sliding window experience buffer for rehearsal.
209///
210/// Keeps recent experiences for periodic rehearsal to
211/// prevent catastrophic forgetting.
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct ExperienceWindow {
214    /// Buffer of recent experiences.
215    experiences: VecDeque<Experience>,
216
217    /// Maximum capacity.
218    capacity: usize,
219
220    /// How old experiences can be before removal.
221    max_age: Duration,
222}
223
224/// A single experience for rehearsal.
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct Experience {
227    /// Feature vector.
228    pub features: Vec<f32>,
229
230    /// Target value.
231    pub target: f32,
232
233    /// When this was observed.
234    pub timestamp: DateTime<Utc>,
235
236    /// Task/domain identifier.
237    pub task_id: Option<String>,
238}
239
240impl ExperienceWindow {
241    /// Create a new experience window.
242    pub fn new(capacity: usize) -> Self {
243        Self {
244            experiences: VecDeque::with_capacity(capacity),
245            capacity,
246            max_age: Duration::hours(24),
247        }
248    }
249
250    /// Set maximum age for experiences.
251    pub fn with_max_age(mut self, hours: i64) -> Self {
252        self.max_age = Duration::hours(hours);
253        self
254    }
255
256    /// Add an experience.
257    pub fn add(&mut self, features: Vec<f32>, target: f32, task_id: Option<String>) {
258        let exp = Experience {
259            features,
260            target,
261            timestamp: Utc::now(),
262            task_id,
263        };
264
265        self.experiences.push_back(exp);
266
267        // Trim if over capacity
268        while self.experiences.len() > self.capacity {
269            self.experiences.pop_front();
270        }
271
272        // Remove old experiences
273        self.prune_old();
274    }
275
276    /// Get experiences for rehearsal.
277    ///
278    /// Returns a random sample of experiences for replay.
279    pub fn sample(&self, count: usize) -> Vec<&Experience> {
280        if self.experiences.is_empty() || count == 0 {
281            return Vec::new();
282        }
283
284        // Simple reservoir sampling
285        use rand::Rng;
286        let mut rng = rand::thread_rng();
287        let mut result: Vec<&Experience> = Vec::with_capacity(count.min(self.experiences.len()));
288
289        for (i, exp) in self.experiences.iter().enumerate() {
290            if result.len() < count {
291                result.push(exp);
292            } else {
293                let j = rng.gen_range(0..=i);
294                if j < count {
295                    result[j] = exp;
296                }
297            }
298        }
299
300        result
301    }
302
303    /// Get experiences by task.
304    pub fn by_task(&self, task_id: &str) -> Vec<&Experience> {
305        self.experiences
306            .iter()
307            .filter(|e| e.task_id.as_deref() == Some(task_id))
308            .collect()
309    }
310
311    /// Prune old experiences.
312    fn prune_old(&mut self) {
313        let cutoff = Utc::now() - self.max_age;
314        while let Some(front) = self.experiences.front() {
315            if front.timestamp < cutoff {
316                self.experiences.pop_front();
317            } else {
318                break;
319            }
320        }
321    }
322
323    /// Current buffer size.
324    pub fn len(&self) -> usize {
325        self.experiences.len()
326    }
327
328    /// Check if empty.
329    pub fn is_empty(&self) -> bool {
330        self.experiences.is_empty()
331    }
332
333    /// Get capacity.
334    pub fn capacity(&self) -> usize {
335        self.capacity
336    }
337}
338
339impl Default for ExperienceWindow {
340    fn default() -> Self {
341        Self::new(1000)
342    }
343}
344
345/// Distribution shift detector.
346///
347/// Monitors for changes in input distribution that may
348/// require adaptation.
349#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct DriftDetector {
351    /// Running mean of features.
352    running_mean: Vec<f32>,
353
354    /// Running variance of features.
355    running_var: Vec<f32>,
356
357    /// Count of observations.
358    count: u64,
359
360    /// Recent shift scores.
361    shift_scores: VecDeque<f32>,
362
363    /// Threshold for drift detection.
364    threshold: f32,
365}
366
367impl DriftDetector {
368    /// Create a new drift detector.
369    pub fn new(num_features: usize) -> Self {
370        Self {
371            running_mean: vec![0.0; num_features],
372            running_var: vec![1.0; num_features],
373            count: 0,
374            shift_scores: VecDeque::with_capacity(100),
375            threshold: 2.0, // Standard deviations
376        }
377    }
378
379    /// Set detection threshold.
380    pub fn with_threshold(mut self, threshold: f32) -> Self {
381        self.threshold = threshold;
382        self
383    }
384
385    /// Update statistics and check for drift.
386    ///
387    /// Returns true if drift is detected.
388    pub fn update(&mut self, features: &[f32]) -> bool {
389        if features.len() != self.running_mean.len() {
390            return false;
391        }
392
393        // Compute shift score (Mahalanobis-like distance)
394        let shift_score: f32 = features
395            .iter()
396            .zip(self.running_mean.iter())
397            .zip(self.running_var.iter())
398            .map(|((f, m), v)| ((f - m).powi(2)) / v.max(1e-6))
399            .sum::<f32>()
400            .sqrt()
401            / (features.len() as f32).sqrt();
402
403        self.shift_scores.push_back(shift_score);
404        while self.shift_scores.len() > 100 {
405            self.shift_scores.pop_front();
406        }
407
408        // Update running statistics (Welford's algorithm)
409        self.count += 1;
410        let n = self.count as f32;
411
412        for i in 0..features.len() {
413            let delta = features[i] - self.running_mean[i];
414            self.running_mean[i] += delta / n;
415            let delta2 = features[i] - self.running_mean[i];
416            self.running_var[i] += (delta * delta2 - self.running_var[i]) / n;
417        }
418
419        shift_score > self.threshold
420    }
421
422    /// Get average recent shift score.
423    pub fn average_shift(&self) -> f32 {
424        if self.shift_scores.is_empty() {
425            return 0.0;
426        }
427        self.shift_scores.iter().sum::<f32>() / self.shift_scores.len() as f32
428    }
429
430    /// Check if drift has been detected recently.
431    pub fn is_drifting(&self) -> bool {
432        self.average_shift() > self.threshold
433    }
434
435    /// Reset the detector.
436    pub fn reset(&mut self) {
437        self.running_mean.fill(0.0);
438        self.running_var.fill(1.0);
439        self.count = 0;
440        self.shift_scores.clear();
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    /// Test: Online learning with incremental updates.
449    ///
450    /// What happens:
451    /// 1. Create an online learner for a simple linear function
452    /// 2. Feed examples one at a time
453    /// 3. Learner converges to approximate the function
454    #[test]
455    fn test_online_learning() {
456        let mut learner = OnlineLearner::new("linear", 2).with_learning_rate(0.1);
457
458        // Train on: y = 2*x1 + 3*x2
459        for _ in 0..100 {
460            let x1 = rand::random::<f32>();
461            let x2 = rand::random::<f32>();
462            let y = 2.0 * x1 + 3.0 * x2;
463
464            learner.update(&[x1, x2], y);
465        }
466
467        // Check if parameters are close to [2, 3]
468        let params = learner.get_parameters();
469        assert!(
470            (params[0] - 2.0).abs() < 0.3,
471            "Expected ~2.0, got {}",
472            params[0]
473        );
474        assert!(
475            (params[1] - 3.0).abs() < 0.3,
476            "Expected ~3.0, got {}",
477            params[1]
478        );
479    }
480
481    /// Test: EWC prevents catastrophic forgetting.
482    ///
483    /// What happens:
484    /// 1. Train on Task A (y = 2*x)
485    /// 2. Consolidate knowledge
486    /// 3. Train on Task B (y = -x)
487    /// 4. EWC preserves some Task A knowledge
488    #[test]
489    fn test_ewc_consolidation() {
490        let mut learner = OnlineLearner::new("ewc_test", 1)
491            .with_learning_rate(0.1)
492            .with_ewc_lambda(1.0);
493
494        // Task A: y = 2*x
495        for _ in 0..50 {
496            let x = rand::random::<f32>();
497            let y = 2.0 * x;
498            learner.update(&[x], y);
499        }
500
501        let task_a_param = learner.parameters[0];
502
503        // Consolidate Task A knowledge
504        learner.consolidate();
505        assert_eq!(learner.num_snapshots(), 1);
506
507        // Task B: y = -x (conflicting with Task A)
508        for _ in 0..50 {
509            let x = rand::random::<f32>();
510            let y = -1.0 * x;
511            learner.update(&[x], y);
512        }
513
514        let final_param = learner.parameters[0];
515
516        // With EWC, parameter shouldn't have fully shifted to -1
517        // It should be somewhere between 2.0 and -1.0
518        assert!(
519            final_param > -0.5,
520            "EWC should prevent full forgetting: {}",
521            final_param
522        );
523        assert!(
524            final_param < task_a_param,
525            "Should have adapted to Task B: {}",
526            final_param
527        );
528    }
529
530    /// Test: Experience window for rehearsal.
531    ///
532    /// What happens:
533    /// 1. Add experiences to the buffer
534    /// 2. Sample for rehearsal
535    /// 3. Buffer respects capacity limits
536    #[test]
537    fn test_experience_window() {
538        let mut window = ExperienceWindow::new(10);
539
540        // Add 15 experiences
541        for i in 0..15 {
542            window.add(vec![i as f32], i as f32, Some("task1".to_string()));
543        }
544
545        // Should be capped at capacity
546        assert_eq!(window.len(), 10);
547
548        // Sample should return requested count
549        let sample = window.sample(5);
550        assert_eq!(sample.len(), 5);
551
552        // Filter by task
553        let task1 = window.by_task("task1");
554        assert!(!task1.is_empty());
555    }
556
557    /// Test: Distribution drift detection.
558    ///
559    /// What happens:
560    /// 1. Establish baseline with normal distribution
561    /// 2. Shift to different distribution
562    /// 3. Detector identifies the drift
563    #[test]
564    fn test_drift_detection() {
565        let mut detector = DriftDetector::new(2).with_threshold(3.0);
566
567        // Baseline: centered around (0.5, 0.5)
568        for _ in 0..200 {
569            let x1 = rand::random::<f32>();
570            let x2 = rand::random::<f32>();
571            detector.update(&[x1, x2]);
572        }
573
574        // Reset shift scores to clear baseline
575        detector.shift_scores.clear();
576
577        // Now feed normal data - should not drift
578        for _ in 0..50 {
579            let x1 = rand::random::<f32>();
580            let x2 = rand::random::<f32>();
581            detector.update(&[x1, x2]);
582        }
583
584        // Average shift should be low for normal data
585        let baseline_shift = detector.average_shift();
586
587        // Shift: centered around (10, 10) - very different distribution
588        let mut _drift_detected = false;
589        for _ in 0..20 {
590            let x1 = rand::random::<f32>() + 9.5;
591            let x2 = rand::random::<f32>() + 9.5;
592            if detector.update(&[x1, x2]) {
593                _drift_detected = true;
594            }
595        }
596
597        // After drift, average should be higher
598        let drift_shift = detector.average_shift();
599        assert!(
600            drift_shift > baseline_shift,
601            "Drift shift {} should be > baseline {}",
602            drift_shift,
603            baseline_shift
604        );
605    }
606}