reddb_server/storage/ml/classifier/
naive_bayes.rs1use crate::json::{Map, Value as JsonValue};
16
17use super::{IncrementalClassifier, TrainingExample};
18
19#[derive(Debug, Clone)]
20pub struct NaiveBayesConfig {
21 pub alpha: f32,
22}
23
24impl Default for NaiveBayesConfig {
25 fn default() -> Self {
26 Self { alpha: 1.0 }
27 }
28}
29
30#[derive(Debug, Clone)]
31pub struct MultinomialNaiveBayes {
32 config: NaiveBayesConfig,
33 class_counts: Vec<u64>,
35 feature_counts: Vec<Vec<f64>>,
38 feature_totals: Vec<f64>,
40 num_features: usize,
41 num_classes: usize,
42 samples_seen: u64,
43}
44
45impl MultinomialNaiveBayes {
46 pub fn new(config: NaiveBayesConfig) -> Self {
47 Self {
48 config,
49 class_counts: Vec::new(),
50 feature_counts: Vec::new(),
51 feature_totals: Vec::new(),
52 num_features: 0,
53 num_classes: 0,
54 samples_seen: 0,
55 }
56 }
57
58 fn ensure_shape(&mut self, num_features: usize, num_classes: usize) {
59 if self.num_features == 0 {
60 self.num_features = num_features;
61 }
62 if num_classes > self.num_classes {
63 self.class_counts.resize(num_classes, 0);
64 self.feature_counts
65 .resize(num_classes, vec![0.0; self.num_features]);
66 self.feature_totals.resize(num_classes, 0.0);
67 self.num_classes = num_classes;
68 }
69 }
70
71 fn accumulate(&mut self, ex: &TrainingExample) {
72 if ex.features.len() != self.num_features {
73 return;
74 }
75 let c = ex.label as usize;
76 self.class_counts[c] += 1;
77 let mut total = 0.0;
78 for (i, &v) in ex.features.iter().enumerate() {
79 if v < 0.0 {
80 continue; }
82 self.feature_counts[c][i] += v as f64;
83 total += v as f64;
84 }
85 self.feature_totals[c] += total;
86 }
87
88 pub fn to_json(&self) -> String {
89 let mut obj = Map::new();
90 obj.insert(
91 "alpha".to_string(),
92 JsonValue::Number(self.config.alpha as f64),
93 );
94 obj.insert(
95 "num_features".to_string(),
96 JsonValue::Number(self.num_features as f64),
97 );
98 obj.insert(
99 "num_classes".to_string(),
100 JsonValue::Number(self.num_classes as f64),
101 );
102 obj.insert(
103 "samples_seen".to_string(),
104 JsonValue::Number(self.samples_seen as f64),
105 );
106 obj.insert(
107 "class_counts".to_string(),
108 JsonValue::Array(
109 self.class_counts
110 .iter()
111 .map(|v| JsonValue::Number(*v as f64))
112 .collect(),
113 ),
114 );
115 obj.insert(
116 "feature_counts".to_string(),
117 JsonValue::Array(
118 self.feature_counts
119 .iter()
120 .map(|row| {
121 JsonValue::Array(row.iter().map(|v| JsonValue::Number(*v)).collect())
122 })
123 .collect(),
124 ),
125 );
126 obj.insert(
127 "feature_totals".to_string(),
128 JsonValue::Array(
129 self.feature_totals
130 .iter()
131 .map(|v| JsonValue::Number(*v))
132 .collect(),
133 ),
134 );
135 JsonValue::Object(obj).to_string_compact()
136 }
137
138 pub fn from_json(raw: &str) -> Option<Self> {
139 let parsed = crate::json::parse_json(raw).ok()?;
140 let value = JsonValue::from(parsed);
141 let obj = value.as_object()?;
142 let alpha = obj.get("alpha")?.as_f64()? as f32;
143 let num_features = obj.get("num_features")?.as_i64()? as usize;
144 let num_classes = obj.get("num_classes")?.as_i64()? as usize;
145 let samples_seen = obj.get("samples_seen")?.as_i64()? as u64;
146 let class_counts: Vec<u64> = obj
147 .get("class_counts")?
148 .as_array()?
149 .iter()
150 .filter_map(|v| v.as_i64().map(|i| i as u64))
151 .collect();
152 let feature_counts: Vec<Vec<f64>> = obj
153 .get("feature_counts")?
154 .as_array()?
155 .iter()
156 .filter_map(|row| {
157 row.as_array().map(|inner| {
158 inner
159 .iter()
160 .filter_map(|v| v.as_f64())
161 .collect::<Vec<f64>>()
162 })
163 })
164 .collect();
165 let feature_totals: Vec<f64> = obj
166 .get("feature_totals")?
167 .as_array()?
168 .iter()
169 .filter_map(|v| v.as_f64())
170 .collect();
171 Some(Self {
172 config: NaiveBayesConfig { alpha },
173 class_counts,
174 feature_counts,
175 feature_totals,
176 num_features,
177 num_classes,
178 samples_seen,
179 })
180 }
181}
182
183impl IncrementalClassifier for MultinomialNaiveBayes {
184 fn fit(&mut self, examples: &[TrainingExample]) {
185 if examples.is_empty() {
186 return;
187 }
188 let num_features = examples[0].features.len();
189 let num_classes = examples.iter().map(|e| e.label as usize).max().unwrap() + 1;
190 self.class_counts = vec![0; num_classes];
191 self.feature_counts = vec![vec![0.0; num_features]; num_classes];
192 self.feature_totals = vec![0.0; num_classes];
193 self.num_features = num_features;
194 self.num_classes = num_classes;
195 self.samples_seen = 0;
196 for ex in examples {
197 self.accumulate(ex);
198 }
199 self.samples_seen = examples.len() as u64;
200 }
201
202 fn partial_fit(&mut self, examples: &[TrainingExample]) {
203 if examples.is_empty() {
204 return;
205 }
206 let num_features = examples[0].features.len();
207 let num_classes = examples.iter().map(|e| e.label as usize).max().unwrap() + 1;
208 self.ensure_shape(num_features, num_classes);
209 for ex in examples {
210 self.accumulate(ex);
211 }
212 self.samples_seen = self.samples_seen.saturating_add(examples.len() as u64);
213 }
214
215 fn predict(&self, features: &[f32]) -> Option<u32> {
216 let probs = self.predict_proba(features);
217 if probs.is_empty() {
218 return None;
219 }
220 let mut best = 0usize;
221 let mut best_p = probs[0];
222 for (i, &p) in probs.iter().enumerate().skip(1) {
223 if p > best_p {
224 best_p = p;
225 best = i;
226 }
227 }
228 Some(best as u32)
229 }
230
231 fn predict_proba(&self, features: &[f32]) -> Vec<f32> {
232 if features.len() != self.num_features || self.num_classes == 0 {
233 return Vec::new();
234 }
235 let total_samples: u64 = self.class_counts.iter().sum();
236 if total_samples == 0 {
237 return vec![1.0 / self.num_classes as f32; self.num_classes];
238 }
239 let alpha = self.config.alpha as f64;
240 let mut log_scores = vec![0f64; self.num_classes];
241 for (c, log_score) in log_scores.iter_mut().enumerate().take(self.num_classes) {
242 let prior = (self.class_counts[c] as f64).max(f64::MIN_POSITIVE) / total_samples as f64;
243 let mut lp = prior.ln();
244 let denom = self.feature_totals[c] + alpha * self.num_features as f64;
245 for (i, &x) in features.iter().enumerate() {
246 if x <= 0.0 {
247 continue;
248 }
249 let numer = self.feature_counts[c][i] + alpha;
250 lp += (x as f64) * (numer / denom).ln();
251 }
252 *log_score = lp;
253 }
254 let max = log_scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
256 let mut probs = Vec::with_capacity(self.num_classes);
257 let mut sum = 0.0f64;
258 for lp in &log_scores {
259 let v = (lp - max).exp();
260 probs.push(v);
261 sum += v;
262 }
263 if sum > 0.0 {
264 for p in probs.iter_mut() {
265 *p /= sum;
266 }
267 }
268 probs.into_iter().map(|p| p as f32).collect()
269 }
270
271 fn num_classes(&self) -> usize {
272 self.num_classes
273 }
274
275 fn num_features(&self) -> usize {
276 self.num_features
277 }
278
279 fn samples_seen(&self) -> u64 {
280 self.samples_seen
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 fn bow_dataset() -> Vec<TrainingExample> {
291 vec![
292 TrainingExample {
293 features: vec![3.0, 0.0, 1.0],
294 label: 0,
295 },
296 TrainingExample {
297 features: vec![2.0, 0.0, 2.0],
298 label: 0,
299 },
300 TrainingExample {
301 features: vec![4.0, 0.0, 0.0],
302 label: 0,
303 },
304 TrainingExample {
305 features: vec![0.0, 3.0, 1.0],
306 label: 1,
307 },
308 TrainingExample {
309 features: vec![0.0, 4.0, 2.0],
310 label: 1,
311 },
312 TrainingExample {
313 features: vec![0.0, 2.0, 1.0],
314 label: 1,
315 },
316 ]
317 }
318
319 #[test]
320 fn fit_learns_bow_dataset() {
321 let data = bow_dataset();
322 let mut m = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
323 m.fit(&data);
324 for ex in &data {
325 assert_eq!(m.predict(&ex.features), Some(ex.label));
326 }
327 }
328
329 #[test]
330 fn partial_fit_equivalent_to_fit_on_full_set() {
331 let data = bow_dataset();
332 let mut full = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
333 full.fit(&data);
334 let mut incremental = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
335 for ex in &data {
336 incremental.partial_fit(std::slice::from_ref(ex));
337 }
338 for ex in &data {
341 assert_eq!(
342 full.predict(&ex.features),
343 incremental.predict(&ex.features)
344 );
345 }
346 assert_eq!(full.class_counts, incremental.class_counts);
347 assert_eq!(full.feature_counts, incremental.feature_counts);
348 assert_eq!(full.feature_totals, incremental.feature_totals);
349 }
350
351 #[test]
352 fn partial_fit_is_associative() {
353 let data = bow_dataset();
354 let mut one_shot = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
355 one_shot.partial_fit(&data);
356 let mut split = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
357 split.partial_fit(&data[..3]);
358 split.partial_fit(&data[3..]);
359 assert_eq!(one_shot.class_counts, split.class_counts);
360 assert_eq!(one_shot.feature_counts, split.feature_counts);
361 }
362
363 #[test]
364 fn partial_fit_extends_class_count() {
365 let mut m = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
366 m.partial_fit(&[TrainingExample {
367 features: vec![1.0, 0.0],
368 label: 0,
369 }]);
370 m.partial_fit(&[TrainingExample {
371 features: vec![0.0, 1.0],
372 label: 2,
373 }]);
374 assert_eq!(m.num_classes(), 3);
375 assert_eq!(m.class_counts[1], 0);
377 }
378
379 #[test]
380 fn predict_proba_sums_to_one_and_has_correct_length() {
381 let data = bow_dataset();
382 let mut m = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
383 m.fit(&data);
384 let p = m.predict_proba(&vec![1.0, 0.0, 1.0]);
385 assert_eq!(p.len(), 2);
386 let sum: f32 = p.iter().sum();
387 assert!((sum - 1.0).abs() < 1e-4, "{p:?}");
388 assert!(p[0] > p[1], "cat-heavy doc should prefer class 0: {p:?}");
389 }
390
391 #[test]
392 fn json_round_trips() {
393 let data = bow_dataset();
394 let mut m = MultinomialNaiveBayes::new(NaiveBayesConfig::default());
395 m.fit(&data);
396 let back = MultinomialNaiveBayes::from_json(&m.to_json()).unwrap();
397 for ex in &data {
398 assert_eq!(m.predict(&ex.features), back.predict(&ex.features));
399 }
400 }
401}