use super::super::pretrained;
#[derive(Debug, Clone, Default)]
pub struct OnlineLearner {
pub(super) weights: Vec<f32>,
learning_rate: f32,
momentum: f32,
velocity: Vec<f32>,
num_updates: usize,
ema_loss: f32,
replay_buffer: Vec<(Vec<f32>, f32)>,
replay_buffer_size: usize,
}
impl OnlineLearner {
pub fn new() -> Self {
let weights = pretrained::THROUGHPUT_WEIGHTS.to_vec();
let velocity = vec![0.0; weights.len()];
Self {
weights,
learning_rate: 0.001,
momentum: 0.9,
velocity,
num_updates: 0,
ema_loss: 0.0,
replay_buffer: Vec::new(),
replay_buffer_size: 100,
}
}
pub fn with_learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn observe(&mut self, features: &[f32], actual_throughput: f32) {
if features.len() + 1 != self.weights.len() {
return; }
let predicted = self.predict(features);
let error = predicted - actual_throughput;
let alpha = 0.1;
self.ema_loss = alpha * error.abs() + (1.0 - alpha) * self.ema_loss;
let mut gradients = vec![0.0; self.weights.len()];
gradients[0] = 2.0 * error; for (i, &x) in features.iter().enumerate() {
gradients[i + 1] = 2.0 * error * x;
}
for i in 0..self.weights.len() {
self.velocity[i] = self.momentum * self.velocity[i] - self.learning_rate * gradients[i];
self.weights[i] += self.velocity[i];
}
if self.replay_buffer.len() >= self.replay_buffer_size {
self.replay_buffer.remove(0);
}
self.replay_buffer.push((features.to_vec(), actual_throughput));
self.num_updates += 1;
if self.num_updates % 10 == 0 && !self.replay_buffer.is_empty() {
self.replay_step();
}
}
fn replay_step(&mut self) {
if self.replay_buffer.is_empty() {
return;
}
let (features, target) = self.replay_buffer[0].clone();
let predicted = self.predict(&features);
let error = predicted - target;
let replay_lr = self.learning_rate * 0.1;
self.weights[0] -= replay_lr * 2.0 * error;
for (i, &x) in features.iter().enumerate() {
self.weights[i + 1] -= replay_lr * 2.0 * error * x;
}
}
pub fn predict(&self, features: &[f32]) -> f32 {
if features.is_empty() {
return self.weights[0]; }
contract_pre_predict!(features);
let mut result = self.weights[0]; for (i, &x) in features.iter().enumerate() {
if i + 1 < self.weights.len() {
result += self.weights[i + 1] * x;
}
}
result.max(0.0) }
pub fn weights(&self) -> &[f32] {
&self.weights
}
pub fn num_updates(&self) -> usize {
self.num_updates
}
pub fn ema_loss(&self) -> f32 {
self.ema_loss
}
pub fn is_converging(&self) -> bool {
self.ema_loss < 0.15 }
}