Skip to main content

reddb_server/storage/ml/classifier/
naive_bayes.rs

1//! Multinomial Naive Bayes with Laplace smoothing.
2//!
3//! Ideal counterpart to [`super::LogisticRegression`] for the
4//! incremental story: `partial_fit` is a pure additive update
5//! (class counts + per-feature counts accumulate), so you get
6//! identical results whether you train on the whole set in one
7//! pass or drip-feed examples through many calls. No epochs, no
8//! learning rate, no randomisation — the algorithm is
9//! deterministic given its counts.
10//!
11//! Features are treated as non-negative counts (TF-IDF values
12//! also work as long as they're ≥ 0). Smoothing parameter is
13//! configurable; default `alpha = 1.0` (classic Laplace).
14
15use crate::json::{Map, Value as JsonValue};
16
17use super::{IncrementalClassifier, TrainingExample};
18
19#[derive(Debug, Clone)]
20pub struct NaiveBayesConfig {
21    pub alpha: f32,
22}
23
24impl Default for NaiveBayesConfig {
25    fn default() -> Self {
26        Self { alpha: 1.0 }
27    }
28}
29
30#[derive(Debug, Clone)]
31pub struct MultinomialNaiveBayes {
32    config: NaiveBayesConfig,
33    /// `class_counts[c]` = number of examples seen with label c.
34    class_counts: Vec<u64>,
35    /// `feature_counts[c][f]` = sum of `features[f]` across examples
36    /// with label c.
37    feature_counts: Vec<Vec<f64>>,
38    /// `feature_totals[c]` = total mass in class c (for smoothing).
39    feature_totals: Vec<f64>,
40    num_features: usize,
41    num_classes: usize,
42    samples_seen: u64,
43}
44
45impl MultinomialNaiveBayes {
46    pub fn new(config: NaiveBayesConfig) -> Self {
47        Self {
48            config,
49            class_counts: Vec::new(),
50            feature_counts: Vec::new(),
51            feature_totals: Vec::new(),
52            num_features: 0,
53            num_classes: 0,
54            samples_seen: 0,
55        }
56    }
57
58    fn ensure_shape(&mut self, num_features: usize, num_classes: usize) {
59        if self.num_features == 0 {
60            self.num_features = num_features;
61        }
62        if num_classes > self.num_classes {
63            self.class_counts.resize(num_classes, 0);
64            self.feature_counts
65                .resize(num_classes, vec![0.0; self.num_features]);
66            self.feature_totals.resize(num_classes, 0.0);
67            self.num_classes = num_classes;
68        }
69    }
70
71    fn accumulate(&mut self, ex: &TrainingExample) {
72        if ex.features.len() != self.num_features {
73            return;
74        }
75        let c = ex.label as usize;
76        self.class_counts[c] += 1;
77        let mut total = 0.0;
78        for (i, &v) in ex.features.iter().enumerate() {
79            if v < 0.0 {
80                continue; // counts stay non-negative
81            }
82            self.feature_counts[c][i] += v as f64;
83            total += v as f64;
84        }
85        self.feature_totals[c] += total;
86    }
87
88    pub fn to_json(&self) -> String {
89        let mut obj = Map::new();
90        obj.insert(
91            "alpha".to_string(),
92            JsonValue::Number(self.config.alpha as f64),
93        );
94        obj.insert(
95            "num_features".to_string(),
96            JsonValue::Number(self.num_features as f64),
97        );
98        obj.insert(
99            "num_classes".to_string(),
100            JsonValue::Number(self.num_classes as f64),
101        );
102        obj.insert(
103            "samples_seen".to_string(),
104            JsonValue::Number(self.samples_seen as f64),
105        );
106        obj.insert(
107            "class_counts".to_string(),
108            JsonValue::Array(
109                self.class_counts
110                    .iter()
111                    .map(|v| JsonValue::Number(*v as f64))
112                    .collect(),
113            ),
114        );
115        obj.insert(
116            "feature_counts".to_string(),
117            JsonValue::Array(
118                self.feature_counts
119                    .iter()
120                    .map(|row| {
121                        JsonValue::Array(row.iter().map(|v| JsonValue::Number(*v)).collect())
122                    })
123                    .collect(),
124            ),
125        );
126        obj.insert(
127            "feature_totals".to_string(),
128            JsonValue::Array(
129                self.feature_totals
130                    .iter()
131                    .map(|v| JsonValue::Number(*v))
132                    .collect(),
133            ),
134        );
135        JsonValue::Object(obj).to_string_compact()
136    }
137
138    pub fn from_json(raw: &str) -> Option<Self> {
139        let parsed = crate::json::parse_json(raw).ok()?;
140        let value = JsonValue::from(parsed);
141        let obj = value.as_object()?;
142        let alpha = obj.get("alpha")?.as_f64()? as f32;
143        let num_features = obj.get("num_features")?.as_i64()? as usize;
144        let num_classes = obj.get("num_classes")?.as_i64()? as usize;
145        let samples_seen = obj.get("samples_seen")?.as_i64()? as u64;
146        let class_counts: Vec<u64> = obj
147            .get("class_counts")?
148            .as_array()?
149            .iter()
150            .filter_map(|v| v.as_i64().map(|i| i as u64))
151            .collect();
152        let feature_counts: Vec<Vec<f64>> = obj
153            .get("feature_counts")?
154            .as_array()?
155            .iter()
156            .filter_map(|row| {
157                row.as_array().map(|inner| {
158                    inner
159                        .iter()
160                        .filter_map(|v| v.as_f64())
161                        .collect::<Vec<f64>>()
162                })
163            })
164            .collect();
165        let feature_totals: Vec<f64> = obj
166            .get("feature_totals")?
167            .as_array()?
168            .iter()
169            .filter_map(|v| v.as_f64())
170            .collect();
171        Some(Self {
172            config: NaiveBayesConfig { alpha },
173            class_counts,
174            feature_counts,
175            feature_totals,
176            num_features,
177            num_classes,
178            samples_seen,
179        })
180    }
181}
182
183impl IncrementalClassifier for MultinomialNaiveBayes {
184    fn fit(&mut self, examples: &[TrainingExample]) {
185        if examples.is_empty() {
186            return;
187        }
188        let num_features = examples[0].features.len();
189        let num_classes = examples.iter().map(|e| e.label as usize).max().unwrap() + 1;
190        self.class_counts = vec![0; num_classes];
191        self.feature_counts = vec![vec![0.0; num_features]; num_classes];
192        self.feature_totals = vec![0.0; num_classes];
193        self.num_features = num_features;
194        self.num_classes = num_classes;
195        self.samples_seen = 0;
196        for ex in examples {
197            self.accumulate(ex);
198        }
199        self.samples_seen = examples.len() as u64;
200    }
201
202    fn partial_fit(&mut self, examples: &[TrainingExample]) {
203        if examples.is_empty() {
204            return;
205        }
206        let num_features = examples[0].features.len();
207        let num_classes = examples.iter().map(|e| e.label as usize).max().unwrap() + 1;
208        self.ensure_shape(num_features, num_classes);
209        for ex in examples {
210            self.accumulate(ex);
211        }
212        self.samples_seen = self.samples_seen.saturating_add(examples.len() as u64);
213    }
214
215    fn predict(&self, features: &[f32]) -> Option<u32> {
216        let probs = self.predict_proba(features);
217        if probs.is_empty() {
218            return None;
219        }
220        let mut best = 0usize;
221        let mut best_p = probs[0];
222        for (i, &p) in probs.iter().enumerate().skip(1) {
223            if p > best_p {
224                best_p = p;
225                best = i;
226            }
227        }
228        Some(best as u32)
229    }
230
231    fn predict_proba(&self, features: &[f32]) -> Vec<f32> {
232        if features.len() != self.num_features || self.num_classes == 0 {
233            return Vec::new();
234        }
235        let total_samples: u64 = self.class_counts.iter().sum();
236        if total_samples == 0 {
237            return vec![1.0 / self.num_classes as f32; self.num_classes];
238        }
239        let alpha = self.config.alpha as f64;
240        let mut log_scores = vec![0f64; self.num_classes];
241        for (c, log_score) in log_scores.iter_mut().enumerate().take(self.num_classes) {
242            let prior = (self.class_counts[c] as f64).max(f64::MIN_POSITIVE) / total_samples as f64;
243            let mut lp = prior.ln();
244            let denom = self.feature_totals[c] + alpha * self.num_features as f64;
245            for (i, &x) in features.iter().enumerate() {
246                if x <= 0.0 {
247                    continue;
248                }
249                let numer = self.feature_counts[c][i] + alpha;
250                lp += (x as f64) * (numer / denom).ln();
251            }
252            *log_score = lp;
253        }
254        // Softmax over log-scores → normalised probabilities.
255        let max = log_scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
256        let mut probs = Vec::with_capacity(self.num_classes);
257        let mut sum = 0.0f64;
258        for lp in &log_scores {
259            let v = (lp - max).exp();
260            probs.push(v);
261            sum += v;
262        }
263        if sum > 0.0 {
264            for p in probs.iter_mut() {
265                *p /= sum;
266            }
267        }
268        probs.into_iter().map(|p| p as f32).collect()
269    }
270
271    fn num_classes(&self) -> usize {
272        self.num_classes
273    }
274
275    fn num_features(&self) -> usize {
276        self.num_features
277    }
278
279    fn samples_seen(&self) -> u64 {
280        self.samples_seen
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    /// Two-class bag-of-words dataset: class 0 docs mention "cat",
289    /// class 1 docs mention "dog". Three-dim vectors: [cat, dog, the].
290    fn bow_dataset() -> Vec<TrainingExample> {
291        vec![
292            TrainingExample {
293                features: vec![3.0, 0.0, 1.0],
294                label: 0,
295            },
296            TrainingExample {
297                features: vec![2.0, 0.0, 2.0],
298                label: 0,
299            },
300            TrainingExample {
301                features: vec![4.0, 0.0, 0.0],
302                label: 0,
303            },
304            TrainingExample {
305                features: vec![0.0, 3.0, 1.0],
306                label: 1,
307            },
308            TrainingExample {
309                features: vec![0.0, 4.0, 2.0],
310                label: 1,
311            },
312            TrainingExample {
313                features: vec![0.0, 2.0, 1.0],
314                label: 1,
315            },
316        ]
317    }
318
319    #[test]
320    fn fit_learns_bow_dataset() {
321        let data = bow_dataset();
322        let mut m = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
323        m.fit(&data);
324        for ex in &data {
325            assert_eq!(m.predict(&ex.features), Some(ex.label));
326        }
327    }
328
329    #[test]
330    fn partial_fit_equivalent_to_fit_on_full_set() {
331        let data = bow_dataset();
332        let mut full = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
333        full.fit(&data);
334        let mut incremental = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
335        for ex in &data {
336            incremental.partial_fit(std::slice::from_ref(ex));
337        }
338        // Predictions on the training set must agree — NB with the
339        // same counts produces identical probabilities.
340        for ex in &data {
341            assert_eq!(
342                full.predict(&ex.features),
343                incremental.predict(&ex.features)
344            );
345        }
346        assert_eq!(full.class_counts, incremental.class_counts);
347        assert_eq!(full.feature_counts, incremental.feature_counts);
348        assert_eq!(full.feature_totals, incremental.feature_totals);
349    }
350
351    #[test]
352    fn partial_fit_is_associative() {
353        let data = bow_dataset();
354        let mut one_shot = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
355        one_shot.partial_fit(&data);
356        let mut split = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
357        split.partial_fit(&data[..3]);
358        split.partial_fit(&data[3..]);
359        assert_eq!(one_shot.class_counts, split.class_counts);
360        assert_eq!(one_shot.feature_counts, split.feature_counts);
361    }
362
363    #[test]
364    fn partial_fit_extends_class_count() {
365        let mut m = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
366        m.partial_fit(&[TrainingExample {
367            features: vec![1.0, 0.0],
368            label: 0,
369        }]);
370        m.partial_fit(&[TrainingExample {
371            features: vec![0.0, 1.0],
372            label: 2,
373        }]);
374        assert_eq!(m.num_classes(), 3);
375        // Class 1 was never seen — counts stay zero.
376        assert_eq!(m.class_counts[1], 0);
377    }
378
379    #[test]
380    fn predict_proba_sums_to_one_and_has_correct_length() {
381        let data = bow_dataset();
382        let mut m = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
383        m.fit(&data);
384        let p = m.predict_proba(&vec![1.0, 0.0, 1.0]);
385        assert_eq!(p.len(), 2);
386        let sum: f32 = p.iter().sum();
387        assert!((sum - 1.0).abs() < 1e-4, "{p:?}");
388        assert!(p[0] > p[1], "cat-heavy doc should prefer class 0: {p:?}");
389    }
390
391    #[test]
392    fn json_round_trips() {
393        let data = bow_dataset();
394        let mut m = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
395        m.fit(&data);
396        let back = MultinomialNaiveBayes::from_json(&m.to_json()).unwrap();
397        for ex in &data {
398            assert_eq!(m.predict(&ex.features), back.predict(&ex.features));
399        }
400    }
401}