Skip to main content

reddb_server/storage/ml/classifier/
mod.rs

1//! Classifier subsystem (ML Feature 5).
2//!
3//! Two algorithms ship in this sprint, both trivially incremental
4//! (support `partial_fit` over a single example or a mini-batch
5//! without replaying the full training set):
6//!
7//! * [`LogisticRegression`] — binary SGD; update per-example,
8//!   optional L2. Used for text + numerical features.
9//! * [`MultinomialNaiveBayes`] — count-based; `partial_fit`
10//!   literally adds the new counts to the existing tables. Works
11//!   perfectly over a data stream.
12//!
13//! Incremental training is a first-class operation, not an
14//! afterthought: every classifier exposes both `fit` (fresh model)
15//! and `partial_fit` (update existing weights with new examples).
16//! The surrounding `Classifier` enum routes both calls uniformly
17//! so callers don't special-case algorithm choice.
18//!
19//! Serialisation: each model has a compact JSON representation
20//! so `ModelRegistry` can store versions side-by-side with other
21//! model types.
22
23pub mod features;
24pub mod logreg;
25pub mod naive_bayes;
26
27pub use features::{one_hot, tf_idf_vectorize, Vocabulary};
28pub use logreg::{LogisticRegression, LogisticRegressionConfig};
29pub use naive_bayes::{MultinomialNaiveBayes, NaiveBayesConfig};
30
31use crate::json::{Map, Value as JsonValue};
32
33/// Generic training example: feature vector + integer class label.
34#[derive(Debug, Clone)]
35pub struct TrainingExample {
36    pub features: Vec<f32>,
37    pub label: u32,
38}
39
40/// Common surface every classifier exposes.
41pub trait IncrementalClassifier {
42    /// Train from scratch. Any previous weights are discarded.
43    fn fit(&mut self, examples: &[TrainingExample]);
44
45    /// Incrementally update with a mini-batch. Previous weights are
46    /// preserved; this is the online-learning entrypoint.
47    fn partial_fit(&mut self, examples: &[TrainingExample]);
48
49    /// Predict the most likely class for one example.
50    fn predict(&self, features: &[f32]) -> Option<u32>;
51
52    /// Predict a probability per class (0..num_classes).
53    fn predict_proba(&self, features: &[f32]) -> Vec<f32>;
54
55    /// Number of distinct classes seen so far.
56    fn num_classes(&self) -> usize;
57
58    /// Number of features the model expects. 0 until `fit`/
59    /// `partial_fit` has been called with at least one example.
60    fn num_features(&self) -> usize;
61
62    /// Total number of training examples the model has seen over
63    /// its lifetime — incremented by both `fit` (reset to N) and
64    /// `partial_fit` (additive). Useful for lineage + rate-limiting.
65    fn samples_seen(&self) -> u64;
66}
67
68/// Evaluation metrics for a classifier. Always populated by
69/// `evaluate()`; serialised into the model version's `metrics_json`.
70#[derive(Debug, Clone, PartialEq)]
71pub struct ClassifierMetrics {
72    pub accuracy: f32,
73    pub per_class_precision: Vec<f32>,
74    pub per_class_recall: Vec<f32>,
75    pub per_class_f1: Vec<f32>,
76    pub confusion_matrix: Vec<Vec<u32>>,
77    pub samples_evaluated: u64,
78}
79
80impl ClassifierMetrics {
81    pub fn macro_f1(&self) -> f32 {
82        if self.per_class_f1.is_empty() {
83            return 0.0;
84        }
85        self.per_class_f1.iter().sum::<f32>() / self.per_class_f1.len() as f32
86    }
87
88    pub fn to_json(&self) -> String {
89        let mut obj = Map::new();
90        obj.insert(
91            "accuracy".to_string(),
92            JsonValue::Number(self.accuracy as f64),
93        );
94        obj.insert(
95            "macro_f1".to_string(),
96            JsonValue::Number(self.macro_f1() as f64),
97        );
98        obj.insert(
99            "samples".to_string(),
100            JsonValue::Number(self.samples_evaluated as f64),
101        );
102        obj.insert(
103            "precision".to_string(),
104            JsonValue::Array(
105                self.per_class_precision
106                    .iter()
107                    .map(|f| JsonValue::Number(*f as f64))
108                    .collect(),
109            ),
110        );
111        obj.insert(
112            "recall".to_string(),
113            JsonValue::Array(
114                self.per_class_recall
115                    .iter()
116                    .map(|f| JsonValue::Number(*f as f64))
117                    .collect(),
118            ),
119        );
120        obj.insert(
121            "f1".to_string(),
122            JsonValue::Array(
123                self.per_class_f1
124                    .iter()
125                    .map(|f| JsonValue::Number(*f as f64))
126                    .collect(),
127            ),
128        );
129        obj.insert(
130            "confusion_matrix".to_string(),
131            JsonValue::Array(
132                self.confusion_matrix
133                    .iter()
134                    .map(|row| {
135                        JsonValue::Array(row.iter().map(|v| JsonValue::Number(*v as f64)).collect())
136                    })
137                    .collect(),
138            ),
139        );
140        JsonValue::Object(obj).to_string_compact()
141    }
142}
143
144/// Compute accuracy + per-class precision/recall/F1 + confusion
145/// matrix against a held-out slice of examples.
146pub fn evaluate<C: IncrementalClassifier>(
147    model: &C,
148    examples: &[TrainingExample],
149) -> ClassifierMetrics {
150    let k = model.num_classes().max(1);
151    let mut confusion = vec![vec![0u32; k]; k];
152    let mut correct = 0u32;
153    for ex in examples {
154        let predicted = model.predict(&ex.features).unwrap_or(0) as usize;
155        let actual = ex.label as usize;
156        if predicted < k && actual < k {
157            confusion[actual][predicted] += 1;
158            if predicted == actual {
159                correct += 1;
160            }
161        }
162    }
163    let total = examples.len() as u32;
164    let accuracy = if total == 0 {
165        0.0
166    } else {
167        correct as f32 / total as f32
168    };
169    let mut precision = vec![0.0f32; k];
170    let mut recall = vec![0.0f32; k];
171    let mut f1 = vec![0.0f32; k];
172    for c in 0..k {
173        let tp = confusion[c][c] as f32;
174        let pred_positive: u32 = (0..k).map(|r| confusion[r][c]).sum();
175        let actual_positive: u32 = confusion[c].iter().sum();
176        let p = if pred_positive == 0 {
177            0.0
178        } else {
179            tp / pred_positive as f32
180        };
181        let r = if actual_positive == 0 {
182            0.0
183        } else {
184            tp / actual_positive as f32
185        };
186        precision[c] = p;
187        recall[c] = r;
188        f1[c] = if p + r == 0.0 {
189            0.0
190        } else {
191            2.0 * p * r / (p + r)
192        };
193    }
194    ClassifierMetrics {
195        accuracy,
196        per_class_precision: precision,
197        per_class_recall: recall,
198        per_class_f1: f1,
199        confusion_matrix: confusion,
200        samples_evaluated: total as u64,
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    struct DummyClassifier {
209        classes: usize,
210    }
211
212    impl IncrementalClassifier for DummyClassifier {
213        fn fit(&mut self, _: &[TrainingExample]) {}
214        fn partial_fit(&mut self, _: &[TrainingExample]) {}
215        fn predict(&self, features: &[f32]) -> Option<u32> {
216            // "predict the class whose index best matches the first
217            // feature" — enough to drive metrics tests.
218            let raw = features.first().copied().unwrap_or(0.0);
219            Some(raw.round().max(0.0) as u32)
220        }
221        fn predict_proba(&self, _: &[f32]) -> Vec<f32> {
222            vec![1.0 / self.classes as f32; self.classes]
223        }
224        fn num_classes(&self) -> usize {
225            self.classes
226        }
227        fn num_features(&self) -> usize {
228            1
229        }
230        fn samples_seen(&self) -> u64 {
231            0
232        }
233    }
234
235    #[test]
236    fn evaluate_reports_perfect_scores_for_oracle_model() {
237        let dummy = DummyClassifier { classes: 2 };
238        let examples: Vec<_> = (0..10)
239            .map(|i| TrainingExample {
240                features: vec![(i % 2) as f32],
241                label: (i % 2) as u32,
242            })
243            .collect();
244        let m = evaluate(&dummy, &examples);
245        assert!((m.accuracy - 1.0).abs() < 1e-6);
246        assert!((m.macro_f1() - 1.0).abs() < 1e-6);
247        assert_eq!(m.samples_evaluated, 10);
248    }
249
250    #[test]
251    fn metrics_json_round_trips_every_field() {
252        let m = ClassifierMetrics {
253            accuracy: 0.8,
254            per_class_precision: vec![0.9, 0.7],
255            per_class_recall: vec![0.8, 0.8],
256            per_class_f1: vec![0.85, 0.74],
257            confusion_matrix: vec![vec![8, 2], vec![2, 8]],
258            samples_evaluated: 20,
259        };
260        let raw = m.to_json();
261        assert!(raw.contains("\"accuracy\""));
262        assert!(raw.contains("\"confusion_matrix\""));
263        assert!(raw.contains("\"macro_f1\""));
264    }
265}