rs_ml/classification/
mod.rs1use ndarray::Array2;
4
5pub mod naive_bayes;
6
7pub trait Classifier<Features, Label: Eq + Clone>
9where
10 Self: Sized,
11{
12 fn fit<I>(arr: &Features, y: I) -> Option<Self>
14 where
15 for<'a> &'a I: IntoIterator<Item = &'a Label>;
16
17 fn labels(&self) -> &[Label];
19
20 fn predict_proba(&self, arr: &Features) -> Option<Array2<f64>>;
23
24 fn predict(&self, arr: &Features) -> Option<Vec<Label>> {
26 let l = self.labels();
27 let predictions = self.predict_proba(arr)?;
28
29 let a = predictions
30 .rows()
31 .into_iter()
32 .map(|a| {
33 a.iter().zip(l).fold((f64::MIN, l[0].clone()), |agg, curr| {
34 match &agg.0 < curr.0 {
35 true => (*curr.0, curr.1.clone()),
36 false => agg,
37 }
38 })
39 })
40 .map(|(_, l)| l);
41
42 Some(a.collect())
43 }
44}