reddb_server/storage/ml/classifier/
mod.rs1pub mod features;
24pub mod logreg;
25pub mod naive_bayes;
26
27pub use features::{one_hot, tf_idf_vectorize, Vocabulary};
28pub use logreg::{LogisticRegression, LogisticRegressionConfig};
29pub use naive_bayes::{MultinomialNaiveBayes, NaiveBayesConfig};
30
31use crate::json::{Map, Value as JsonValue};
32
33#[derive(Debug, Clone)]
35pub struct TrainingExample {
36 pub features: Vec<f32>,
37 pub label: u32,
38}
39
40pub trait IncrementalClassifier {
42 fn fit(&mut self, examples: &[TrainingExample]);
44
45 fn partial_fit(&mut self, examples: &[TrainingExample]);
48
49 fn predict(&self, features: &[f32]) -> Option<u32>;
51
52 fn predict_proba(&self, features: &[f32]) -> Vec<f32>;
54
55 fn num_classes(&self) -> usize;
57
58 fn num_features(&self) -> usize;
61
62 fn samples_seen(&self) -> u64;
66}
67
68#[derive(Debug, Clone, PartialEq)]
71pub struct ClassifierMetrics {
72 pub accuracy: f32,
73 pub per_class_precision: Vec<f32>,
74 pub per_class_recall: Vec<f32>,
75 pub per_class_f1: Vec<f32>,
76 pub confusion_matrix: Vec<Vec<u32>>,
77 pub samples_evaluated: u64,
78}
79
80impl ClassifierMetrics {
81 pub fn macro_f1(&self) -> f32 {
82 if self.per_class_f1.is_empty() {
83 return 0.0;
84 }
85 self.per_class_f1.iter().sum::<f32>() / self.per_class_f1.len() as f32
86 }
87
88 pub fn to_json(&self) -> String {
89 let mut obj = Map::new();
90 obj.insert(
91 "accuracy".to_string(),
92 JsonValue::Number(self.accuracy as f64),
93 );
94 obj.insert(
95 "macro_f1".to_string(),
96 JsonValue::Number(self.macro_f1() as f64),
97 );
98 obj.insert(
99 "samples".to_string(),
100 JsonValue::Number(self.samples_evaluated as f64),
101 );
102 obj.insert(
103 "precision".to_string(),
104 JsonValue::Array(
105 self.per_class_precision
106 .iter()
107 .map(|f| JsonValue::Number(*f as f64))
108 .collect(),
109 ),
110 );
111 obj.insert(
112 "recall".to_string(),
113 JsonValue::Array(
114 self.per_class_recall
115 .iter()
116 .map(|f| JsonValue::Number(*f as f64))
117 .collect(),
118 ),
119 );
120 obj.insert(
121 "f1".to_string(),
122 JsonValue::Array(
123 self.per_class_f1
124 .iter()
125 .map(|f| JsonValue::Number(*f as f64))
126 .collect(),
127 ),
128 );
129 obj.insert(
130 "confusion_matrix".to_string(),
131 JsonValue::Array(
132 self.confusion_matrix
133 .iter()
134 .map(|row| {
135 JsonValue::Array(row.iter().map(|v| JsonValue::Number(*v as f64)).collect())
136 })
137 .collect(),
138 ),
139 );
140 JsonValue::Object(obj).to_string_compact()
141 }
142}
143
144pub fn evaluate<C: IncrementalClassifier>(
147 model: &C,
148 examples: &[TrainingExample],
149) -> ClassifierMetrics {
150 let k = model.num_classes().max(1);
151 let mut confusion = vec![vec![0u32; k]; k];
152 let mut correct = 0u32;
153 for ex in examples {
154 let predicted = model.predict(&ex.features).unwrap_or(0) as usize;
155 let actual = ex.label as usize;
156 if predicted < k && actual < k {
157 confusion[actual][predicted] += 1;
158 if predicted == actual {
159 correct += 1;
160 }
161 }
162 }
163 let total = examples.len() as u32;
164 let accuracy = if total == 0 {
165 0.0
166 } else {
167 correct as f32 / total as f32
168 };
169 let mut precision = vec![0.0f32; k];
170 let mut recall = vec![0.0f32; k];
171 let mut f1 = vec![0.0f32; k];
172 for c in 0..k {
173 let tp = confusion[c][c] as f32;
174 let pred_positive: u32 = (0..k).map(|r| confusion[r][c]).sum();
175 let actual_positive: u32 = confusion[c].iter().sum();
176 let p = if pred_positive == 0 {
177 0.0
178 } else {
179 tp / pred_positive as f32
180 };
181 let r = if actual_positive == 0 {
182 0.0
183 } else {
184 tp / actual_positive as f32
185 };
186 precision[c] = p;
187 recall[c] = r;
188 f1[c] = if p + r == 0.0 {
189 0.0
190 } else {
191 2.0 * p * r / (p + r)
192 };
193 }
194 ClassifierMetrics {
195 accuracy,
196 per_class_precision: precision,
197 per_class_recall: recall,
198 per_class_f1: f1,
199 confusion_matrix: confusion,
200 samples_evaluated: total as u64,
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 struct DummyClassifier {
209 classes: usize,
210 }
211
212 impl IncrementalClassifier for DummyClassifier {
213 fn fit(&mut self, _: &[TrainingExample]) {}
214 fn partial_fit(&mut self, _: &[TrainingExample]) {}
215 fn predict(&self, features: &[f32]) -> Option<u32> {
216 let raw = features.first().copied().unwrap_or(0.0);
219 Some(raw.round().max(0.0) as u32)
220 }
221 fn predict_proba(&self, _: &[f32]) -> Vec<f32> {
222 vec![1.0 / self.classes as f32; self.classes]
223 }
224 fn num_classes(&self) -> usize {
225 self.classes
226 }
227 fn num_features(&self) -> usize {
228 1
229 }
230 fn samples_seen(&self) -> u64 {
231 0
232 }
233 }
234
235 #[test]
236 fn evaluate_reports_perfect_scores_for_oracle_model() {
237 let dummy = DummyClassifier { classes: 2 };
238 let examples: Vec<_> = (0..10)
239 .map(|i| TrainingExample {
240 features: vec![(i % 2) as f32],
241 label: (i % 2) as u32,
242 })
243 .collect();
244 let m = evaluate(&dummy, &examples);
245 assert!((m.accuracy - 1.0).abs() < 1e-6);
246 assert!((m.macro_f1() - 1.0).abs() < 1e-6);
247 assert_eq!(m.samples_evaluated, 10);
248 }
249
250 #[test]
251 fn metrics_json_round_trips_every_field() {
252 let m = ClassifierMetrics {
253 accuracy: 0.8,
254 per_class_precision: vec![0.9, 0.7],
255 per_class_recall: vec![0.8, 0.8],
256 per_class_f1: vec![0.85, 0.74],
257 confusion_matrix: vec![vec![8, 2], vec![2, 8]],
258 samples_evaluated: 20,
259 };
260 let raw = m.to_json();
261 assert!(raw.contains("\"accuracy\""));
262 assert!(raw.contains("\"confusion_matrix\""));
263 assert!(raw.contains("\"macro_f1\""));
264 }
265}