trueno/tuner/evolution/
online.rs1use super::super::pretrained;
6
7#[derive(Debug, Clone, Default)]
9pub struct OnlineLearner {
10 pub(super) weights: Vec<f32>,
12 learning_rate: f32,
14 momentum: f32,
16 velocity: Vec<f32>,
18 num_updates: usize,
20 ema_loss: f32,
22 replay_buffer: Vec<(Vec<f32>, f32)>,
24 replay_buffer_size: usize,
26}
27
28impl OnlineLearner {
29 pub fn new() -> Self {
31 let weights = pretrained::THROUGHPUT_WEIGHTS.to_vec();
32 let velocity = vec![0.0; weights.len()];
33 Self {
34 weights,
35 learning_rate: 0.001,
36 momentum: 0.9,
37 velocity,
38 num_updates: 0,
39 ema_loss: 0.0,
40 replay_buffer: Vec::new(),
41 replay_buffer_size: 100,
42 }
43 }
44
45 pub fn with_learning_rate(mut self, lr: f32) -> Self {
47 self.learning_rate = lr;
48 self
49 }
50
51 pub fn observe(&mut self, features: &[f32], actual_throughput: f32) {
53 if features.len() + 1 != self.weights.len() {
54 return; }
56
57 let predicted = self.predict(features);
59 let error = predicted - actual_throughput;
60
61 let alpha = 0.1;
63 self.ema_loss = alpha * error.abs() + (1.0 - alpha) * self.ema_loss;
64
65 let mut gradients = vec![0.0; self.weights.len()];
68 gradients[0] = 2.0 * error; for (i, &x) in features.iter().enumerate() {
70 gradients[i + 1] = 2.0 * error * x;
71 }
72
73 for i in 0..self.weights.len() {
75 self.velocity[i] = self.momentum * self.velocity[i] - self.learning_rate * gradients[i];
76 self.weights[i] += self.velocity[i];
77 }
78
79 if self.replay_buffer.len() >= self.replay_buffer_size {
81 self.replay_buffer.remove(0);
83 }
84 self.replay_buffer.push((features.to_vec(), actual_throughput));
85
86 self.num_updates += 1;
87
88 if self.num_updates % 10 == 0 && !self.replay_buffer.is_empty() {
90 self.replay_step();
91 }
92 }
93
94 fn replay_step(&mut self) {
96 if self.replay_buffer.is_empty() {
97 return;
98 }
99
100 let (features, target) = self.replay_buffer[0].clone();
102
103 let predicted = self.predict(&features);
104 let error = predicted - target;
105
106 let replay_lr = self.learning_rate * 0.1;
108 self.weights[0] -= replay_lr * 2.0 * error;
109 for (i, &x) in features.iter().enumerate() {
110 self.weights[i + 1] -= replay_lr * 2.0 * error * x;
111 }
112 }
113
114 pub fn predict(&self, features: &[f32]) -> f32 {
116 if features.is_empty() {
117 return self.weights[0]; }
119 contract_pre_predict!(features);
120 let mut result = self.weights[0]; for (i, &x) in features.iter().enumerate() {
122 if i + 1 < self.weights.len() {
123 result += self.weights[i + 1] * x;
124 }
125 }
126 result.max(0.0) }
128
129 pub fn weights(&self) -> &[f32] {
131 &self.weights
132 }
133
134 pub fn num_updates(&self) -> usize {
136 self.num_updates
137 }
138
139 pub fn ema_loss(&self) -> f32 {
141 self.ema_loss
142 }
143
144 pub fn is_converging(&self) -> bool {
146 self.ema_loss < 0.15 }
148}