rs_ml/classification/
mod.rs1use ndarray::Array2;
4
5pub mod naive_bayes;
6
7pub trait Classifier<Features, Label>
9where
10 Label: Clone,
11{
12 fn labels(&self) -> &[Label];
14
15 fn predict_proba(&self, arr: &Features) -> Option<Array2<f64>>;
18
19 fn predict(&self, arr: &Features) -> Option<Vec<Label>> {
21 let l = self.labels();
22 let predictions = self.predict_proba(arr)?;
23
24 let a = predictions
25 .rows()
26 .into_iter()
27 .map(|a| {
28 a.iter().zip(l).fold((f64::MIN, l[0].clone()), |agg, curr| {
29 match &agg.0 < curr.0 {
30 true => (*curr.0, curr.1.clone()),
31 false => agg,
32 }
33 })
34 })
35 .map(|(_, l)| l);
36
37 Some(a.collect())
38 }
39}