rs_ml/classification/
mod.rs

1//! Commonly used classification models and trait to implement more.
2
3use ndarray::Array2;
4
5pub mod naive_bayes;
6
7/// Trait to make a classification model.
8pub trait Classifier<Features, Label>
9where
10    Label: Clone,
11{
12    /// Labels on which the model is fitted.
13    fn labels(&self) -> &[Label];
14
15    /// Predicts likelihood of each class per record. rows correspond to each record, columns are
16    /// in the same order as label function.
17    fn predict_proba(&self, arr: &Features) -> Option<Array2<f64>>;
18
19    /// Provided function which returns the most likely class per record.
20    fn predict(&self, arr: &Features) -> Option<Vec<Label>> {
21        let l = self.labels();
22        let predictions = self.predict_proba(arr)?;
23
24        let a = predictions
25            .rows()
26            .into_iter()
27            .map(|a| {
28                a.iter().zip(l).fold((f64::MIN, l[0].clone()), |agg, curr| {
29                    match &agg.0 < curr.0 {
30                        true => (*curr.0, curr.1.clone()),
31                        false => agg,
32                    }
33                })
34            })
35            .map(|(_, l)| l);
36
37        Some(a.collect())
38    }
39}