converge_knowledge/agentic/
online.rs1use chrono::{DateTime, Duration, Utc};
14use serde::{Deserialize, Serialize};
15use std::collections::VecDeque;
16use uuid::Uuid;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct OnlineLearner {
26 pub id: Uuid,
28
29 pub name: String,
31
32 pub parameters: Vec<f32>,
34
35 pub fisher_diagonal: Vec<f32>,
37
38 pub parameter_history: VecDeque<ParameterSnapshot>,
40
41 pub learning_rate: f32,
43
44 pub ewc_lambda: f32,
46
47 pub update_count: u64,
49
50 pub created_at: DateTime<Utc>,
52
53 pub updated_at: DateTime<Utc>,
55}
56
57impl OnlineLearner {
58 pub fn new(name: impl Into<String>, num_parameters: usize) -> Self {
60 let now = Utc::now();
61 Self {
62 id: Uuid::new_v4(),
63 name: name.into(),
64 parameters: vec![0.0; num_parameters],
65 fisher_diagonal: vec![1.0; num_parameters], parameter_history: VecDeque::with_capacity(10),
67 learning_rate: 0.01,
68 ewc_lambda: 0.5,
69 update_count: 0,
70 created_at: now,
71 updated_at: now,
72 }
73 }
74
75 pub fn with_learning_rate(mut self, lr: f32) -> Self {
77 self.learning_rate = lr;
78 self
79 }
80
81 pub fn with_ewc_lambda(mut self, lambda: f32) -> Self {
83 self.ewc_lambda = lambda;
84 self
85 }
86
87 pub fn update(&mut self, features: &[f32], target: f32) -> f32 {
91 if features.len() != self.parameters.len() {
92 return 0.0;
93 }
94
95 let prediction: f32 = features
97 .iter()
98 .zip(self.parameters.iter())
99 .map(|(f, p)| f * p)
100 .sum();
101
102 let error = prediction - target;
104 let loss = error * error;
105
106 for i in 0..self.parameters.len() {
108 let base_grad = 2.0 * error * features[i];
110
111 let mut ewc_grad = 0.0;
113 for snapshot in &self.parameter_history {
114 let delta = self.parameters[i] - snapshot.parameters[i];
115 let importance = snapshot.fisher[i];
116 ewc_grad += 2.0 * self.ewc_lambda * importance * delta;
117 }
118
119 let total_grad = base_grad + ewc_grad;
121 self.parameters[i] -= self.learning_rate * total_grad;
122 }
123
124 self.update_fisher(features, error);
126
127 self.update_count += 1;
128 self.updated_at = Utc::now();
129
130 loss
131 }
132
133 fn update_fisher(&mut self, features: &[f32], error: f32) {
135 let decay = 0.99;
137 for i in 0..self.fisher_diagonal.len() {
138 let grad_sq = (2.0 * error * features[i]).powi(2);
139 self.fisher_diagonal[i] = decay * self.fisher_diagonal[i] + (1.0 - decay) * grad_sq;
140 }
141 }
142
143 pub fn consolidate(&mut self) {
148 let snapshot = ParameterSnapshot {
149 parameters: self.parameters.clone(),
150 fisher: self.fisher_diagonal.clone(),
151 timestamp: Utc::now(),
152 update_count: self.update_count,
153 };
154
155 self.parameter_history.push_back(snapshot);
156
157 while self.parameter_history.len() > 10 {
159 self.parameter_history.pop_front();
160 }
161 }
162
163 pub fn predict(&self, features: &[f32]) -> f32 {
165 if features.len() != self.parameters.len() {
166 return 0.0;
167 }
168
169 features
170 .iter()
171 .zip(self.parameters.iter())
172 .map(|(f, p)| f * p)
173 .sum()
174 }
175
176 pub fn get_parameters(&self) -> &[f32] {
178 &self.parameters
179 }
180
181 pub fn get_importance(&self) -> &[f32] {
183 &self.fisher_diagonal
184 }
185
186 pub fn num_snapshots(&self) -> usize {
188 self.parameter_history.len()
189 }
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct ParameterSnapshot {
195 pub parameters: Vec<f32>,
197
198 pub fisher: Vec<f32>,
200
201 pub timestamp: DateTime<Utc>,
203
204 pub update_count: u64,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct ExperienceWindow {
214 experiences: VecDeque<Experience>,
216
217 capacity: usize,
219
220 max_age: Duration,
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct Experience {
227 pub features: Vec<f32>,
229
230 pub target: f32,
232
233 pub timestamp: DateTime<Utc>,
235
236 pub task_id: Option<String>,
238}
239
240impl ExperienceWindow {
241 pub fn new(capacity: usize) -> Self {
243 Self {
244 experiences: VecDeque::with_capacity(capacity),
245 capacity,
246 max_age: Duration::hours(24),
247 }
248 }
249
250 pub fn with_max_age(mut self, hours: i64) -> Self {
252 self.max_age = Duration::hours(hours);
253 self
254 }
255
256 pub fn add(&mut self, features: Vec<f32>, target: f32, task_id: Option<String>) {
258 let exp = Experience {
259 features,
260 target,
261 timestamp: Utc::now(),
262 task_id,
263 };
264
265 self.experiences.push_back(exp);
266
267 while self.experiences.len() > self.capacity {
269 self.experiences.pop_front();
270 }
271
272 self.prune_old();
274 }
275
276 pub fn sample(&self, count: usize) -> Vec<&Experience> {
280 if self.experiences.is_empty() || count == 0 {
281 return Vec::new();
282 }
283
284 use rand::Rng;
286 let mut rng = rand::thread_rng();
287 let mut result: Vec<&Experience> = Vec::with_capacity(count.min(self.experiences.len()));
288
289 for (i, exp) in self.experiences.iter().enumerate() {
290 if result.len() < count {
291 result.push(exp);
292 } else {
293 let j = rng.gen_range(0..=i);
294 if j < count {
295 result[j] = exp;
296 }
297 }
298 }
299
300 result
301 }
302
303 pub fn by_task(&self, task_id: &str) -> Vec<&Experience> {
305 self.experiences
306 .iter()
307 .filter(|e| e.task_id.as_deref() == Some(task_id))
308 .collect()
309 }
310
311 fn prune_old(&mut self) {
313 let cutoff = Utc::now() - self.max_age;
314 while let Some(front) = self.experiences.front() {
315 if front.timestamp < cutoff {
316 self.experiences.pop_front();
317 } else {
318 break;
319 }
320 }
321 }
322
323 pub fn len(&self) -> usize {
325 self.experiences.len()
326 }
327
328 pub fn is_empty(&self) -> bool {
330 self.experiences.is_empty()
331 }
332
333 pub fn capacity(&self) -> usize {
335 self.capacity
336 }
337}
338
339impl Default for ExperienceWindow {
340 fn default() -> Self {
341 Self::new(1000)
342 }
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct DriftDetector {
351 running_mean: Vec<f32>,
353
354 running_var: Vec<f32>,
356
357 count: u64,
359
360 shift_scores: VecDeque<f32>,
362
363 threshold: f32,
365}
366
367impl DriftDetector {
368 pub fn new(num_features: usize) -> Self {
370 Self {
371 running_mean: vec![0.0; num_features],
372 running_var: vec![1.0; num_features],
373 count: 0,
374 shift_scores: VecDeque::with_capacity(100),
375 threshold: 2.0, }
377 }
378
379 pub fn with_threshold(mut self, threshold: f32) -> Self {
381 self.threshold = threshold;
382 self
383 }
384
385 pub fn update(&mut self, features: &[f32]) -> bool {
389 if features.len() != self.running_mean.len() {
390 return false;
391 }
392
393 let shift_score: f32 = features
395 .iter()
396 .zip(self.running_mean.iter())
397 .zip(self.running_var.iter())
398 .map(|((f, m), v)| ((f - m).powi(2)) / v.max(1e-6))
399 .sum::<f32>()
400 .sqrt()
401 / (features.len() as f32).sqrt();
402
403 self.shift_scores.push_back(shift_score);
404 while self.shift_scores.len() > 100 {
405 self.shift_scores.pop_front();
406 }
407
408 self.count += 1;
410 let n = self.count as f32;
411
412 for i in 0..features.len() {
413 let delta = features[i] - self.running_mean[i];
414 self.running_mean[i] += delta / n;
415 let delta2 = features[i] - self.running_mean[i];
416 self.running_var[i] += (delta * delta2 - self.running_var[i]) / n;
417 }
418
419 shift_score > self.threshold
420 }
421
422 pub fn average_shift(&self) -> f32 {
424 if self.shift_scores.is_empty() {
425 return 0.0;
426 }
427 self.shift_scores.iter().sum::<f32>() / self.shift_scores.len() as f32
428 }
429
430 pub fn is_drifting(&self) -> bool {
432 self.average_shift() > self.threshold
433 }
434
435 pub fn reset(&mut self) {
437 self.running_mean.fill(0.0);
438 self.running_var.fill(1.0);
439 self.count = 0;
440 self.shift_scores.clear();
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[test]
455 fn test_online_learning() {
456 let mut learner = OnlineLearner::new("linear", 2).with_learning_rate(0.1);
457
458 for _ in 0..100 {
460 let x1 = rand::random::<f32>();
461 let x2 = rand::random::<f32>();
462 let y = 2.0 * x1 + 3.0 * x2;
463
464 learner.update(&[x1, x2], y);
465 }
466
467 let params = learner.get_parameters();
469 assert!(
470 (params[0] - 2.0).abs() < 0.3,
471 "Expected ~2.0, got {}",
472 params[0]
473 );
474 assert!(
475 (params[1] - 3.0).abs() < 0.3,
476 "Expected ~3.0, got {}",
477 params[1]
478 );
479 }
480
481 #[test]
489 fn test_ewc_consolidation() {
490 let mut learner = OnlineLearner::new("ewc_test", 1)
491 .with_learning_rate(0.1)
492 .with_ewc_lambda(1.0);
493
494 for _ in 0..50 {
496 let x = rand::random::<f32>();
497 let y = 2.0 * x;
498 learner.update(&[x], y);
499 }
500
501 let task_a_param = learner.parameters[0];
502
503 learner.consolidate();
505 assert_eq!(learner.num_snapshots(), 1);
506
507 for _ in 0..50 {
509 let x = rand::random::<f32>();
510 let y = -1.0 * x;
511 learner.update(&[x], y);
512 }
513
514 let final_param = learner.parameters[0];
515
516 assert!(
519 final_param > -0.5,
520 "EWC should prevent full forgetting: {}",
521 final_param
522 );
523 assert!(
524 final_param < task_a_param,
525 "Should have adapted to Task B: {}",
526 final_param
527 );
528 }
529
530 #[test]
537 fn test_experience_window() {
538 let mut window = ExperienceWindow::new(10);
539
540 for i in 0..15 {
542 window.add(vec![i as f32], i as f32, Some("task1".to_string()));
543 }
544
545 assert_eq!(window.len(), 10);
547
548 let sample = window.sample(5);
550 assert_eq!(sample.len(), 5);
551
552 let task1 = window.by_task("task1");
554 assert!(!task1.is_empty());
555 }
556
557 #[test]
564 fn test_drift_detection() {
565 let mut detector = DriftDetector::new(2).with_threshold(3.0);
566
567 for _ in 0..200 {
569 let x1 = rand::random::<f32>();
570 let x2 = rand::random::<f32>();
571 detector.update(&[x1, x2]);
572 }
573
574 detector.shift_scores.clear();
576
577 for _ in 0..50 {
579 let x1 = rand::random::<f32>();
580 let x2 = rand::random::<f32>();
581 detector.update(&[x1, x2]);
582 }
583
584 let baseline_shift = detector.average_shift();
586
587 let mut _drift_detected = false;
589 for _ in 0..20 {
590 let x1 = rand::random::<f32>() + 9.5;
591 let x2 = rand::random::<f32>() + 9.5;
592 if detector.update(&[x1, x2]) {
593 _drift_detected = true;
594 }
595 }
596
597 let drift_shift = detector.average_shift();
599 assert!(
600 drift_shift > baseline_shift,
601 "Drift shift {} should be > baseline {}",
602 drift_shift,
603 baseline_shift
604 );
605 }
606}