Skip to main content

reddb_server/storage/ml/classifier/
logreg.rs

1//! Binary / multi-class logistic regression trained with SGD.
2//!
3//! One-vs-rest for the multi-class case: `K` independent binary
4//! classifiers, one per class. Each classifier stores `num_features`
5//! weights + 1 bias. Training passes one example at a time so
6//! [`Self::partial_fit`] reuses the same inner loop as `fit` — the
7//! only difference is whether weights are reset first.
8//!
9//! Features support L2 regularisation and a constant learning rate
10//! schedule. The implementation is deliberately minimal; production
11//! tuning (Adam, momentum, schedules, warm restarts) is follow-on
12//! work and lives behind the same `partial_fit` surface.
13
14use crate::json::{Map, Value as JsonValue};
15
16use super::{IncrementalClassifier, TrainingExample};
17
18/// Hyperparameters for [`LogisticRegression`].
19#[derive(Debug, Clone)]
20pub struct LogisticRegressionConfig {
21    pub learning_rate: f32,
22    pub l2_penalty: f32,
23    /// Training epochs per `fit` call. `partial_fit` always runs
24    /// exactly one epoch over the incoming mini-batch.
25    pub epochs: usize,
26    /// Random seed for shuffling. `0` disables shuffle (tests rely on
27    /// deterministic order).
28    pub shuffle_seed: u64,
29}
30
31impl Default for LogisticRegressionConfig {
32    fn default() -> Self {
33        Self {
34            learning_rate: 0.05,
35            l2_penalty: 0.0,
36            epochs: 10,
37            shuffle_seed: 0,
38        }
39    }
40}
41
42#[derive(Debug, Clone)]
43pub struct LogisticRegression {
44    config: LogisticRegressionConfig,
45    /// `weights[class][feature]`.
46    weights: Vec<Vec<f32>>,
47    biases: Vec<f32>,
48    num_features: usize,
49    num_classes: usize,
50    samples_seen: u64,
51}
52
53impl LogisticRegression {
54    pub fn new(config: LogisticRegressionConfig) -> Self {
55        Self {
56            config,
57            weights: Vec::new(),
58            biases: Vec::new(),
59            num_features: 0,
60            num_classes: 0,
61            samples_seen: 0,
62        }
63    }
64
65    fn ensure_shape(&mut self, num_features: usize, num_classes: usize) {
66        if self.num_features == 0 {
67            self.num_features = num_features;
68        }
69        if num_classes > self.num_classes {
70            // Extend class count without discarding existing weights —
71            // vital for partial_fit on a stream that sees new classes
72            // over time.
73            self.weights
74                .resize(num_classes, vec![0.0; self.num_features]);
75            self.biases.resize(num_classes, 0.0);
76            self.num_classes = num_classes;
77        }
78    }
79
80    fn sgd_step(&mut self, ex: &TrainingExample) {
81        if ex.features.len() != self.num_features {
82            return;
83        }
84        let lr = self.config.learning_rate;
85        let l2 = self.config.l2_penalty;
86        for c in 0..self.num_classes {
87            let target = if ex.label as usize == c { 1.0 } else { 0.0 };
88            // dot(weights, features) + bias
89            let mut z = self.biases[c];
90            for (w, x) in self.weights[c].iter().zip(ex.features.iter()) {
91                z += w * x;
92            }
93            let p = sigmoid(z);
94            let error = p - target;
95            // Gradient descent on binary cross-entropy.
96            for i in 0..self.num_features {
97                let grad = error * ex.features[i] + l2 * self.weights[c][i];
98                self.weights[c][i] -= lr * grad;
99            }
100            self.biases[c] -= lr * error;
101        }
102    }
103
104    fn infer_shape(examples: &[TrainingExample]) -> Option<(usize, usize)> {
105        let num_features = examples.first()?.features.len();
106        let num_classes = examples.iter().map(|e| e.label as usize).max()? + 1;
107        Some((num_features, num_classes))
108    }
109
110    /// Serialise to JSON for `ModelRegistry` storage.
111    pub fn to_json(&self) -> String {
112        let mut obj = Map::new();
113        obj.insert(
114            "lr".to_string(),
115            JsonValue::Number(self.config.learning_rate as f64),
116        );
117        obj.insert(
118            "l2".to_string(),
119            JsonValue::Number(self.config.l2_penalty as f64),
120        );
121        obj.insert(
122            "epochs".to_string(),
123            JsonValue::Number(self.config.epochs as f64),
124        );
125        obj.insert(
126            "shuffle_seed".to_string(),
127            JsonValue::Number(self.config.shuffle_seed as f64),
128        );
129        obj.insert(
130            "num_features".to_string(),
131            JsonValue::Number(self.num_features as f64),
132        );
133        obj.insert(
134            "num_classes".to_string(),
135            JsonValue::Number(self.num_classes as f64),
136        );
137        obj.insert(
138            "samples_seen".to_string(),
139            JsonValue::Number(self.samples_seen as f64),
140        );
141        obj.insert(
142            "weights".to_string(),
143            JsonValue::Array(
144                self.weights
145                    .iter()
146                    .map(|row| {
147                        JsonValue::Array(row.iter().map(|f| JsonValue::Number(*f as f64)).collect())
148                    })
149                    .collect(),
150            ),
151        );
152        obj.insert(
153            "biases".to_string(),
154            JsonValue::Array(
155                self.biases
156                    .iter()
157                    .map(|f| JsonValue::Number(*f as f64))
158                    .collect(),
159            ),
160        );
161        JsonValue::Object(obj).to_string_compact()
162    }
163
164    pub fn from_json(raw: &str) -> Option<Self> {
165        let parsed = crate::json::parse_json(raw).ok()?;
166        let value = JsonValue::from(parsed);
167        let obj = value.as_object()?;
168        let lr = obj.get("lr")?.as_f64()? as f32;
169        let l2 = obj.get("l2")?.as_f64()? as f32;
170        let epochs = obj.get("epochs")?.as_i64()? as usize;
171        let shuffle_seed = obj.get("shuffle_seed")?.as_i64()? as u64;
172        let num_features = obj.get("num_features")?.as_i64()? as usize;
173        let num_classes = obj.get("num_classes")?.as_i64()? as usize;
174        let samples_seen = obj.get("samples_seen")?.as_i64()? as u64;
175        let weights: Vec<Vec<f32>> = obj
176            .get("weights")?
177            .as_array()?
178            .iter()
179            .filter_map(|row| {
180                row.as_array().map(|inner| {
181                    inner
182                        .iter()
183                        .filter_map(|v| v.as_f64().map(|f| f as f32))
184                        .collect()
185                })
186            })
187            .collect();
188        let biases: Vec<f32> = obj
189            .get("biases")?
190            .as_array()?
191            .iter()
192            .filter_map(|v| v.as_f64().map(|f| f as f32))
193            .collect();
194        Some(Self {
195            config: LogisticRegressionConfig {
196                learning_rate: lr,
197                l2_penalty: l2,
198                epochs,
199                shuffle_seed,
200            },
201            weights,
202            biases,
203            num_features,
204            num_classes,
205            samples_seen,
206        })
207    }
208}
209
210impl IncrementalClassifier for LogisticRegression {
211    fn fit(&mut self, examples: &[TrainingExample]) {
212        if examples.is_empty() {
213            return;
214        }
215        let Some((num_features, num_classes)) = Self::infer_shape(examples) else {
216            return;
217        };
218        // fresh model
219        self.weights = vec![vec![0.0; num_features]; num_classes];
220        self.biases = vec![0.0; num_classes];
221        self.num_features = num_features;
222        self.num_classes = num_classes;
223        self.samples_seen = 0;
224        for _ in 0..self.config.epochs {
225            let mut indices: Vec<usize> = (0..examples.len()).collect();
226            if self.config.shuffle_seed != 0 {
227                deterministic_shuffle(&mut indices, self.config.shuffle_seed);
228            }
229            for i in indices {
230                self.sgd_step(&examples[i]);
231            }
232        }
233        self.samples_seen = examples.len() as u64;
234    }
235
236    fn partial_fit(&mut self, examples: &[TrainingExample]) {
237        if examples.is_empty() {
238            return;
239        }
240        let (batch_features, batch_classes) = match Self::infer_shape(examples) {
241            Some(pair) => pair,
242            None => return,
243        };
244        self.ensure_shape(batch_features, batch_classes);
245        for ex in examples {
246            self.sgd_step(ex);
247        }
248        self.samples_seen = self.samples_seen.saturating_add(examples.len() as u64);
249    }
250
251    fn predict(&self, features: &[f32]) -> Option<u32> {
252        let probs = self.predict_proba(features);
253        if probs.is_empty() {
254            return None;
255        }
256        let mut best = 0usize;
257        let mut best_p = probs[0];
258        for (i, &p) in probs.iter().enumerate().skip(1) {
259            if p > best_p {
260                best_p = p;
261                best = i;
262            }
263        }
264        Some(best as u32)
265    }
266
267    fn predict_proba(&self, features: &[f32]) -> Vec<f32> {
268        if features.len() != self.num_features || self.num_classes == 0 {
269            return Vec::new();
270        }
271        let mut out = Vec::with_capacity(self.num_classes);
272        for c in 0..self.num_classes {
273            let mut z = self.biases[c];
274            for (w, x) in self.weights[c].iter().zip(features.iter()) {
275                z += w * x;
276            }
277            out.push(sigmoid(z));
278        }
279        // Normalise so probs sum to 1 (one-vs-rest outputs often don't).
280        let sum: f32 = out.iter().sum();
281        if sum > 0.0 {
282            for p in out.iter_mut() {
283                *p /= sum;
284            }
285        }
286        out
287    }
288
289    fn num_classes(&self) -> usize {
290        self.num_classes
291    }
292
293    fn num_features(&self) -> usize {
294        self.num_features
295    }
296
297    fn samples_seen(&self) -> u64 {
298        self.samples_seen
299    }
300}
301
302fn sigmoid(z: f32) -> f32 {
303    1.0 / (1.0 + (-z).exp())
304}
305
306/// xorshift64*, deterministic, tiny. We only need reproducible
307/// shuffles in tests — no cryptographic properties required.
308fn deterministic_shuffle<T>(items: &mut [T], seed: u64) {
309    if items.len() < 2 {
310        return;
311    }
312    let mut state = seed | 1;
313    for i in (1..items.len()).rev() {
314        state ^= state << 13;
315        state ^= state >> 7;
316        state ^= state << 17;
317        let j = (state as usize) % (i + 1);
318        items.swap(i, j);
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    fn linearly_separable(n: usize) -> Vec<TrainingExample> {
327        // Two clusters: class 0 around (-1,0), class 1 around (1,0).
328        let mut out = Vec::with_capacity(n * 2);
329        for i in 0..n {
330            let jitter = (i as f32) * 0.01;
331            out.push(TrainingExample {
332                features: vec![-1.0 + jitter, jitter],
333                label: 0,
334            });
335            out.push(TrainingExample {
336                features: vec![1.0 - jitter, jitter],
337                label: 1,
338            });
339        }
340        out
341    }
342
343    #[test]
344    fn fit_learns_linearly_separable_classes() {
345        let data = linearly_separable(50);
346        let mut model = LogisticRegression::new(LogisticRegressionConfig {
347            epochs: 50,
348            ..Default::default()
349        });
350        model.fit(&data);
351        let correct: u32 = data
352            .iter()
353            .map(|ex| {
354                if model.predict(&ex.features) == Some(ex.label) {
355                    1
356                } else {
357                    0
358                }
359            })
360            .sum();
361        let acc = correct as f32 / data.len() as f32;
362        assert!(acc > 0.95, "accuracy too low: {acc}");
363    }
364
365    #[test]
366    fn partial_fit_moves_loss_in_the_right_direction() {
367        // Noisy overlapping clusters so one pass can't converge —
368        // weights must grow across partial_fit calls.
369        let mut data = Vec::new();
370        for i in 0..200 {
371            let f = i as f32 * 0.01;
372            data.push(TrainingExample {
373                features: vec![-0.3 + f.sin() * 0.5, 0.2 * (f * 1.3).cos()],
374                label: 0,
375            });
376            data.push(TrainingExample {
377                features: vec![0.3 + f.cos() * 0.5, 0.2 * (f * 1.7).sin()],
378                label: 1,
379            });
380        }
381        let mut model = LogisticRegression::new(LogisticRegressionConfig {
382            learning_rate: 0.01,
383            epochs: 1,
384            ..Default::default()
385        });
386        fn mean_abs_weight(m: &LogisticRegression) -> f32 {
387            let mut sum = 0.0f32;
388            let mut n = 0usize;
389            for row in &m.weights {
390                for w in row {
391                    sum += w.abs();
392                    n += 1;
393                }
394            }
395            if n == 0 {
396                0.0
397            } else {
398                sum / n as f32
399            }
400        }
401        model.partial_fit(&data[..40]);
402        let w_early = mean_abs_weight(&model);
403        for chunk in data[40..].chunks(40) {
404            model.partial_fit(chunk);
405        }
406        let w_late = mean_abs_weight(&model);
407        assert!(
408            w_late > w_early,
409            "partial_fit should keep updating weights: early={w_early} late={w_late}"
410        );
411        // Sanity: samples_seen reflects additive calls.
412        assert_eq!(model.samples_seen(), data.len() as u64);
413    }
414
415    #[test]
416    fn partial_fit_preserves_weights_across_calls() {
417        let mut model = LogisticRegression::new(LogisticRegressionConfig {
418            epochs: 1,
419            ..Default::default()
420        });
421        let batch = linearly_separable(30);
422        model.partial_fit(&batch);
423        let weights_after_first = model.weights.clone();
424        model.partial_fit(&batch);
425        // Weights moved further; they should not have been reset to 0
426        // (that's the `fit` contract, not `partial_fit`).
427        let mut all_zero = true;
428        for row in &weights_after_first {
429            for w in row {
430                if w.abs() > 1e-6 {
431                    all_zero = false;
432                }
433            }
434        }
435        assert!(!all_zero, "weights should be non-zero after partial_fit");
436        // And second call must have moved them again.
437        assert_ne!(model.weights, weights_after_first);
438    }
439
440    #[test]
441    fn partial_fit_extends_class_count_on_the_fly() {
442        let mut model = LogisticRegression::new(LogisticRegressionConfig::default());
443        model.partial_fit(&[TrainingExample {
444            features: vec![0.0, 1.0],
445            label: 0,
446        }]);
447        assert_eq!(model.num_classes, 1);
448        model.partial_fit(&[TrainingExample {
449            features: vec![1.0, 0.0],
450            label: 3,
451        }]);
452        assert_eq!(model.num_classes, 4);
453        assert_eq!(model.weights.len(), 4);
454        for row in &model.weights {
455            assert_eq!(row.len(), 2);
456        }
457    }
458
459    #[test]
460    fn samples_seen_tracks_lifetime_examples() {
461        let mut model = LogisticRegression::new(LogisticRegressionConfig::default());
462        let batch = linearly_separable(5);
463        model.partial_fit(&batch);
464        assert_eq!(model.samples_seen(), batch.len() as u64);
465        model.partial_fit(&batch);
466        assert_eq!(model.samples_seen(), 2 * batch.len() as u64);
467        // fit resets to the freshly-fitted count.
468        model.fit(&batch);
469        assert_eq!(model.samples_seen(), batch.len() as u64);
470    }
471
472    #[test]
473    fn json_round_trips_preserves_predictions() {
474        let data = linearly_separable(40);
475        let mut m = LogisticRegression::new(LogisticRegressionConfig {
476            epochs: 20,
477            ..Default::default()
478        });
479        m.fit(&data);
480        let restored = LogisticRegression::from_json(&m.to_json()).unwrap();
481        for ex in &data {
482            assert_eq!(m.predict(&ex.features), restored.predict(&ex.features));
483        }
484    }
485
486    #[test]
487    fn predict_proba_is_normalised() {
488        let data = linearly_separable(30);
489        let mut m = LogisticRegression::new(LogisticRegressionConfig::default());
490        m.fit(&data);
491        let probs = m.predict_proba(&data[0].features);
492        let sum: f32 = probs.iter().sum();
493        assert!((sum - 1.0).abs() < 1e-4, "probs must sum to 1: {probs:?}");
494    }
495}