rs_ml/classification/
mod.rs

1//! Commonly used classification models and trait to implement more.
2
3use ndarray::Array2;
4
5pub mod naive_bayes;
6
7/// Trait to make a classification model.
8pub trait Classifier<Features, Label: Eq + Clone>
9where
10    Self: Sized,
11{
12    /// Fit data based on given input features and labels.
13    fn fit<I>(arr: &Features, y: I) -> Option<Self>
14    where
15        for<'a> &'a I: IntoIterator<Item = &'a Label>;
16
17    /// Labels on which the model is fitted.
18    fn labels(&self) -> &[Label];
19
20    /// Predicts likelihood of each class per record. rows correspond to each record, columns are
21    /// in the same order as label function.
22    fn predict_proba(&self, arr: &Features) -> Option<Array2<f64>>;
23
24    /// Provided function which returns the most likely class per record.
25    fn predict(&self, arr: &Features) -> Option<Vec<Label>> {
26        let l = self.labels();
27        let predictions = self.predict_proba(arr)?;
28
29        let a = predictions
30            .rows()
31            .into_iter()
32            .map(|a| {
33                a.iter().zip(l).fold((f64::MIN, l[0].clone()), |agg, curr| {
34                    match &agg.0 < curr.0 {
35                        true => (*curr.0, curr.1.clone()),
36                        false => agg,
37                    }
38                })
39            })
40            .map(|(_, l)| l);
41
42        Some(a.collect())
43    }
44}