Skip to main content

trueno/tuner/evolution/
online.rs

1//! Online learning with SGD updates (MLT-12)
2//!
3//! Momentum SGD with replay buffer for catastrophic forgetting prevention.
4
5use super::super::pretrained;
6
7/// Online learning state for SGD updates (MLT-12)
8#[derive(Debug, Clone, Default)]
9pub struct OnlineLearner {
10    /// Current weights
11    pub(super) weights: Vec<f32>,
12    /// Learning rate
13    learning_rate: f32,
14    /// Momentum term
15    momentum: f32,
16    /// Velocity for momentum SGD
17    velocity: Vec<f32>,
18    /// Number of updates
19    num_updates: usize,
20    /// Exponential moving average of loss
21    ema_loss: f32,
22    /// Replay buffer for catastrophic forgetting prevention
23    replay_buffer: Vec<(Vec<f32>, f32)>,
24    /// Max replay buffer size
25    replay_buffer_size: usize,
26}
27
28impl OnlineLearner {
29    /// Create new online learner with pretrained weights
30    pub fn new() -> Self {
31        let weights = pretrained::THROUGHPUT_WEIGHTS.to_vec();
32        let velocity = vec![0.0; weights.len()];
33        Self {
34            weights,
35            learning_rate: 0.001,
36            momentum: 0.9,
37            velocity,
38            num_updates: 0,
39            ema_loss: 0.0,
40            replay_buffer: Vec::new(),
41            replay_buffer_size: 100,
42        }
43    }
44
45    /// Create learner with custom learning rate
46    pub fn with_learning_rate(mut self, lr: f32) -> Self {
47        self.learning_rate = lr;
48        self
49    }
50
51    /// Observe a new sample and update weights (SGD step)
52    pub fn observe(&mut self, features: &[f32], actual_throughput: f32) {
53        if features.len() + 1 != self.weights.len() {
54            return; // Dimension mismatch
55        }
56
57        // Forward pass: predict
58        let predicted = self.predict(features);
59        let error = predicted - actual_throughput;
60
61        // Update EMA loss
62        let alpha = 0.1;
63        self.ema_loss = alpha * error.abs() + (1.0 - alpha) * self.ema_loss;
64
65        // Backward pass: compute gradients
66        // For linear model: dL/dw_i = 2 * error * x_i
67        let mut gradients = vec![0.0; self.weights.len()];
68        gradients[0] = 2.0 * error; // bias gradient
69        for (i, &x) in features.iter().enumerate() {
70            gradients[i + 1] = 2.0 * error * x;
71        }
72
73        // Momentum SGD update
74        for i in 0..self.weights.len() {
75            self.velocity[i] = self.momentum * self.velocity[i] - self.learning_rate * gradients[i];
76            self.weights[i] += self.velocity[i];
77        }
78
79        // Add to replay buffer
80        if self.replay_buffer.len() >= self.replay_buffer_size {
81            // Remove oldest
82            self.replay_buffer.remove(0);
83        }
84        self.replay_buffer.push((features.to_vec(), actual_throughput));
85
86        self.num_updates += 1;
87
88        // Periodic replay to prevent catastrophic forgetting
89        if self.num_updates % 10 == 0 && !self.replay_buffer.is_empty() {
90            self.replay_step();
91        }
92    }
93
94    /// Replay a random sample from buffer
95    fn replay_step(&mut self) {
96        if self.replay_buffer.is_empty() {
97            return;
98        }
99
100        // Simple: replay oldest sample
101        let (features, target) = self.replay_buffer[0].clone();
102
103        let predicted = self.predict(&features);
104        let error = predicted - target;
105
106        // Smaller learning rate for replay
107        let replay_lr = self.learning_rate * 0.1;
108        self.weights[0] -= replay_lr * 2.0 * error;
109        for (i, &x) in features.iter().enumerate() {
110            self.weights[i + 1] -= replay_lr * 2.0 * error * x;
111        }
112    }
113
114    /// Predict throughput
115    pub fn predict(&self, features: &[f32]) -> f32 {
116        if features.is_empty() {
117            return self.weights[0]; // bias-only prediction
118        }
119        contract_pre_predict!(features);
120        let mut result = self.weights[0]; // bias
121        for (i, &x) in features.iter().enumerate() {
122            if i + 1 < self.weights.len() {
123                result += self.weights[i + 1] * x;
124            }
125        }
126        result.max(0.0) // Throughput must be non-negative
127    }
128
129    /// Get current weights
130    pub fn weights(&self) -> &[f32] {
131        &self.weights
132    }
133
134    /// Get number of updates
135    pub fn num_updates(&self) -> usize {
136        self.num_updates
137    }
138
139    /// Get current EMA loss
140    pub fn ema_loss(&self) -> f32 {
141        self.ema_loss
142    }
143
144    /// Check if model is converging (loss decreasing)
145    pub fn is_converging(&self) -> bool {
146        self.ema_loss < 0.15 // 15% MAPE threshold
147    }
148}