rs_ml/classification/mod.rs
1//! Commonly used classification models.
2
3use ndarray::Array2;
4
5pub mod naive_bayes;
6
7/// Trait to interface with a fitted classification model
8pub trait Classifier<Features, Label>
9where
10 Label: Clone,
11{
12 /// Labels on which the model is fitted.
13 fn labels(&self) -> &[Label];
14
15 /// Estimates likelihood of each class per record. Rows correspond to each record, columns are
16 /// in the same order as label function.
17 fn predict_proba(&self, arr: &Features) -> Option<Array2<f64>>;
18
19 /// Provided function which returns the most likely class per record based on the results of
20 /// `predict_proba()`.
21 fn predict(&self, arr: &Features) -> Option<Vec<Label>> {
22 let l = self.labels();
23 let predictions = self.predict_proba(arr)?;
24
25 let a = predictions
26 .rows()
27 .into_iter()
28 .map(|a| {
29 a.iter().zip(l).fold((f64::MIN, l[0].clone()), |agg, curr| {
30 match &agg.0 < curr.0 {
31 true => (*curr.0, curr.1.clone()),
32 false => agg,
33 }
34 })
35 })
36 .map(|(_, l)| l);
37
38 Some(a.collect())
39 }
40}