reddb_server/storage/ml/classifier/
logreg.rs1use crate::json::{Map, Value as JsonValue};
15
16use super::{IncrementalClassifier, TrainingExample};
17
18#[derive(Debug, Clone)]
20pub struct LogisticRegressionConfig {
21 pub learning_rate: f32,
22 pub l2_penalty: f32,
23 pub epochs: usize,
26 pub shuffle_seed: u64,
29}
30
31impl Default for LogisticRegressionConfig {
32 fn default() -> Self {
33 Self {
34 learning_rate: 0.05,
35 l2_penalty: 0.0,
36 epochs: 10,
37 shuffle_seed: 0,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
43pub struct LogisticRegression {
44 config: LogisticRegressionConfig,
45 weights: Vec<Vec<f32>>,
47 biases: Vec<f32>,
48 num_features: usize,
49 num_classes: usize,
50 samples_seen: u64,
51}
52
53impl LogisticRegression {
54 pub fn new(config: LogisticRegressionConfig) -> Self {
55 Self {
56 config,
57 weights: Vec::new(),
58 biases: Vec::new(),
59 num_features: 0,
60 num_classes: 0,
61 samples_seen: 0,
62 }
63 }
64
65 fn ensure_shape(&mut self, num_features: usize, num_classes: usize) {
66 if self.num_features == 0 {
67 self.num_features = num_features;
68 }
69 if num_classes > self.num_classes {
70 self.weights
74 .resize(num_classes, vec![0.0; self.num_features]);
75 self.biases.resize(num_classes, 0.0);
76 self.num_classes = num_classes;
77 }
78 }
79
80 fn sgd_step(&mut self, ex: &TrainingExample) {
81 if ex.features.len() != self.num_features {
82 return;
83 }
84 let lr = self.config.learning_rate;
85 let l2 = self.config.l2_penalty;
86 for c in 0..self.num_classes {
87 let target = if ex.label as usize == c { 1.0 } else { 0.0 };
88 let mut z = self.biases[c];
90 for (w, x) in self.weights[c].iter().zip(ex.features.iter()) {
91 z += w * x;
92 }
93 let p = sigmoid(z);
94 let error = p - target;
95 for i in 0..self.num_features {
97 let grad = error * ex.features[i] + l2 * self.weights[c][i];
98 self.weights[c][i] -= lr * grad;
99 }
100 self.biases[c] -= lr * error;
101 }
102 }
103
104 fn infer_shape(examples: &[TrainingExample]) -> Option<(usize, usize)> {
105 let num_features = examples.first()?.features.len();
106 let num_classes = examples.iter().map(|e| e.label as usize).max()? + 1;
107 Some((num_features, num_classes))
108 }
109
110 pub fn to_json(&self) -> String {
112 let mut obj = Map::new();
113 obj.insert(
114 "lr".to_string(),
115 JsonValue::Number(self.config.learning_rate as f64),
116 );
117 obj.insert(
118 "l2".to_string(),
119 JsonValue::Number(self.config.l2_penalty as f64),
120 );
121 obj.insert(
122 "epochs".to_string(),
123 JsonValue::Number(self.config.epochs as f64),
124 );
125 obj.insert(
126 "shuffle_seed".to_string(),
127 JsonValue::Number(self.config.shuffle_seed as f64),
128 );
129 obj.insert(
130 "num_features".to_string(),
131 JsonValue::Number(self.num_features as f64),
132 );
133 obj.insert(
134 "num_classes".to_string(),
135 JsonValue::Number(self.num_classes as f64),
136 );
137 obj.insert(
138 "samples_seen".to_string(),
139 JsonValue::Number(self.samples_seen as f64),
140 );
141 obj.insert(
142 "weights".to_string(),
143 JsonValue::Array(
144 self.weights
145 .iter()
146 .map(|row| {
147 JsonValue::Array(row.iter().map(|f| JsonValue::Number(*f as f64)).collect())
148 })
149 .collect(),
150 ),
151 );
152 obj.insert(
153 "biases".to_string(),
154 JsonValue::Array(
155 self.biases
156 .iter()
157 .map(|f| JsonValue::Number(*f as f64))
158 .collect(),
159 ),
160 );
161 JsonValue::Object(obj).to_string_compact()
162 }
163
164 pub fn from_json(raw: &str) -> Option<Self> {
165 let parsed = crate::json::parse_json(raw).ok()?;
166 let value = JsonValue::from(parsed);
167 let obj = value.as_object()?;
168 let lr = obj.get("lr")?.as_f64()? as f32;
169 let l2 = obj.get("l2")?.as_f64()? as f32;
170 let epochs = obj.get("epochs")?.as_i64()? as usize;
171 let shuffle_seed = obj.get("shuffle_seed")?.as_i64()? as u64;
172 let num_features = obj.get("num_features")?.as_i64()? as usize;
173 let num_classes = obj.get("num_classes")?.as_i64()? as usize;
174 let samples_seen = obj.get("samples_seen")?.as_i64()? as u64;
175 let weights: Vec<Vec<f32>> = obj
176 .get("weights")?
177 .as_array()?
178 .iter()
179 .filter_map(|row| {
180 row.as_array().map(|inner| {
181 inner
182 .iter()
183 .filter_map(|v| v.as_f64().map(|f| f as f32))
184 .collect()
185 })
186 })
187 .collect();
188 let biases: Vec<f32> = obj
189 .get("biases")?
190 .as_array()?
191 .iter()
192 .filter_map(|v| v.as_f64().map(|f| f as f32))
193 .collect();
194 Some(Self {
195 config: LogisticRegressionConfig {
196 learning_rate: lr,
197 l2_penalty: l2,
198 epochs,
199 shuffle_seed,
200 },
201 weights,
202 biases,
203 num_features,
204 num_classes,
205 samples_seen,
206 })
207 }
208}
209
210impl IncrementalClassifier for LogisticRegression {
211 fn fit(&mut self, examples: &[TrainingExample]) {
212 if examples.is_empty() {
213 return;
214 }
215 let Some((num_features, num_classes)) = Self::infer_shape(examples) else {
216 return;
217 };
218 self.weights = vec![vec![0.0; num_features]; num_classes];
220 self.biases = vec![0.0; num_classes];
221 self.num_features = num_features;
222 self.num_classes = num_classes;
223 self.samples_seen = 0;
224 for _ in 0..self.config.epochs {
225 let mut indices: Vec<usize> = (0..examples.len()).collect();
226 if self.config.shuffle_seed != 0 {
227 deterministic_shuffle(&mut indices, self.config.shuffle_seed);
228 }
229 for i in indices {
230 self.sgd_step(&examples[i]);
231 }
232 }
233 self.samples_seen = examples.len() as u64;
234 }
235
236 fn partial_fit(&mut self, examples: &[TrainingExample]) {
237 if examples.is_empty() {
238 return;
239 }
240 let (batch_features, batch_classes) = match Self::infer_shape(examples) {
241 Some(pair) => pair,
242 None => return,
243 };
244 self.ensure_shape(batch_features, batch_classes);
245 for ex in examples {
246 self.sgd_step(ex);
247 }
248 self.samples_seen = self.samples_seen.saturating_add(examples.len() as u64);
249 }
250
251 fn predict(&self, features: &[f32]) -> Option<u32> {
252 let probs = self.predict_proba(features);
253 if probs.is_empty() {
254 return None;
255 }
256 let mut best = 0usize;
257 let mut best_p = probs[0];
258 for (i, &p) in probs.iter().enumerate().skip(1) {
259 if p > best_p {
260 best_p = p;
261 best = i;
262 }
263 }
264 Some(best as u32)
265 }
266
267 fn predict_proba(&self, features: &[f32]) -> Vec<f32> {
268 if features.len() != self.num_features || self.num_classes == 0 {
269 return Vec::new();
270 }
271 let mut out = Vec::with_capacity(self.num_classes);
272 for c in 0..self.num_classes {
273 let mut z = self.biases[c];
274 for (w, x) in self.weights[c].iter().zip(features.iter()) {
275 z += w * x;
276 }
277 out.push(sigmoid(z));
278 }
279 let sum: f32 = out.iter().sum();
281 if sum > 0.0 {
282 for p in out.iter_mut() {
283 *p /= sum;
284 }
285 }
286 out
287 }
288
289 fn num_classes(&self) -> usize {
290 self.num_classes
291 }
292
293 fn num_features(&self) -> usize {
294 self.num_features
295 }
296
297 fn samples_seen(&self) -> u64 {
298 self.samples_seen
299 }
300}
301
302fn sigmoid(z: f32) -> f32 {
303 1.0 / (1.0 + (-z).exp())
304}
305
306fn deterministic_shuffle<T>(items: &mut [T], seed: u64) {
309 if items.len() < 2 {
310 return;
311 }
312 let mut state = seed | 1;
313 for i in (1..items.len()).rev() {
314 state ^= state << 13;
315 state ^= state >> 7;
316 state ^= state << 17;
317 let j = (state as usize) % (i + 1);
318 items.swap(i, j);
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 fn linearly_separable(n: usize) -> Vec<TrainingExample> {
327 let mut out = Vec::with_capacity(n * 2);
329 for i in 0..n {
330 let jitter = (i as f32) * 0.01;
331 out.push(TrainingExample {
332 features: vec![-1.0 + jitter, jitter],
333 label: 0,
334 });
335 out.push(TrainingExample {
336 features: vec![1.0 - jitter, jitter],
337 label: 1,
338 });
339 }
340 out
341 }
342
343 #[test]
344 fn fit_learns_linearly_separable_classes() {
345 let data = linearly_separable(50);
346 let mut model = LogisticRegression::new(LogisticRegressionConfig {
347 epochs: 50,
348 ..Default::default()
349 });
350 model.fit(&data);
351 let correct: u32 = data
352 .iter()
353 .map(|ex| {
354 if model.predict(&ex.features) == Some(ex.label) {
355 1
356 } else {
357 0
358 }
359 })
360 .sum();
361 let acc = correct as f32 / data.len() as f32;
362 assert!(acc > 0.95, "accuracy too low: {acc}");
363 }
364
365 #[test]
366 fn partial_fit_moves_loss_in_the_right_direction() {
367 let mut data = Vec::new();
370 for i in 0..200 {
371 let f = i as f32 * 0.01;
372 data.push(TrainingExample {
373 features: vec![-0.3 + f.sin() * 0.5, 0.2 * (f * 1.3).cos()],
374 label: 0,
375 });
376 data.push(TrainingExample {
377 features: vec![0.3 + f.cos() * 0.5, 0.2 * (f * 1.7).sin()],
378 label: 1,
379 });
380 }
381 let mut model = LogisticRegression::new(LogisticRegressionConfig {
382 learning_rate: 0.01,
383 epochs: 1,
384 ..Default::default()
385 });
386 fn mean_abs_weight(m: &LogisticRegression) -> f32 {
387 let mut sum = 0.0f32;
388 let mut n = 0usize;
389 for row in &m.weights {
390 for w in row {
391 sum += w.abs();
392 n += 1;
393 }
394 }
395 if n == 0 {
396 0.0
397 } else {
398 sum / n as f32
399 }
400 }
401 model.partial_fit(&data[..40]);
402 let w_early = mean_abs_weight(&model);
403 for chunk in data[40..].chunks(40) {
404 model.partial_fit(chunk);
405 }
406 let w_late = mean_abs_weight(&model);
407 assert!(
408 w_late > w_early,
409 "partial_fit should keep updating weights: early={w_early} late={w_late}"
410 );
411 assert_eq!(model.samples_seen(), data.len() as u64);
413 }
414
415 #[test]
416 fn partial_fit_preserves_weights_across_calls() {
417 let mut model = LogisticRegression::new(LogisticRegressionConfig {
418 epochs: 1,
419 ..Default::default()
420 });
421 let batch = linearly_separable(30);
422 model.partial_fit(&batch);
423 let weights_after_first = model.weights.clone();
424 model.partial_fit(&batch);
425 let mut all_zero = true;
428 for row in &weights_after_first {
429 for w in row {
430 if w.abs() > 1e-6 {
431 all_zero = false;
432 }
433 }
434 }
435 assert!(!all_zero, "weights should be non-zero after partial_fit");
436 assert_ne!(model.weights, weights_after_first);
438 }
439
440 #[test]
441 fn partial_fit_extends_class_count_on_the_fly() {
442 let mut model = LogisticRegression::new(LogisticRegressionConfig::default());
443 model.partial_fit(&[TrainingExample {
444 features: vec![0.0, 1.0],
445 label: 0,
446 }]);
447 assert_eq!(model.num_classes, 1);
448 model.partial_fit(&[TrainingExample {
449 features: vec![1.0, 0.0],
450 label: 3,
451 }]);
452 assert_eq!(model.num_classes, 4);
453 assert_eq!(model.weights.len(), 4);
454 for row in &model.weights {
455 assert_eq!(row.len(), 2);
456 }
457 }
458
459 #[test]
460 fn samples_seen_tracks_lifetime_examples() {
461 let mut model = LogisticRegression::new(LogisticRegressionConfig::default());
462 let batch = linearly_separable(5);
463 model.partial_fit(&batch);
464 assert_eq!(model.samples_seen(), batch.len() as u64);
465 model.partial_fit(&batch);
466 assert_eq!(model.samples_seen(), 2 * batch.len() as u64);
467 model.fit(&batch);
469 assert_eq!(model.samples_seen(), batch.len() as u64);
470 }
471
472 #[test]
473 fn json_round_trips_preserves_predictions() {
474 let data = linearly_separable(40);
475 let mut m = LogisticRegression::new(LogisticRegressionConfig {
476 epochs: 20,
477 ..Default::default()
478 });
479 m.fit(&data);
480 let restored = LogisticRegression::from_json(&m.to_json()).unwrap();
481 for ex in &data {
482 assert_eq!(m.predict(&ex.features), restored.predict(&ex.features));
483 }
484 }
485
486 #[test]
487 fn predict_proba_is_normalised() {
488 let data = linearly_separable(30);
489 let mut m = LogisticRegression::new(LogisticRegressionConfig::default());
490 m.fit(&data);
491 let probs = m.predict_proba(&data[0].features);
492 let sum: f32 = probs.iter().sum();
493 assert!((sum - 1.0).abs() < 1e-4, "probs must sum to 1: {probs:?}");
494 }
495}