rs_ml/classification/
mod.rs

1//! Commonly used classification models.
2
3use ndarray::Array2;
4
5pub mod naive_bayes;
6
7/// Trait to interface with a fitted classification model
8pub trait Classifier<Features, Label>
9where
10    Label: Clone,
11{
12    /// Labels on which the model is fitted.
13    fn labels(&self) -> &[Label];
14
15    /// Estimates likelihood of each class per record. Rows correspond to each record, columns are
16    /// in the same order as label function.
17    fn predict_proba(&self, arr: &Features) -> Option<Array2<f64>>;
18
19    /// Provided function which returns the most likely class per record based on the results of
20    /// `predict_proba()`.
21    fn predict(&self, arr: &Features) -> Option<Vec<Label>> {
22        let l = self.labels();
23        let predictions = self.predict_proba(arr)?;
24
25        let a = predictions
26            .rows()
27            .into_iter()
28            .map(|a| {
29                a.iter().zip(l).fold((f64::MIN, l[0].clone()), |agg, curr| {
30                    match &agg.0 < curr.0 {
31                        true => (*curr.0, curr.1.clone()),
32                        false => agg,
33                    }
34                })
35            })
36            .map(|(_, l)| l);
37
38        Some(a.collect())
39    }
40}