rs_ml/classification/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
use ndarray::Array2;

pub mod naive_bayes;

pub trait Classifier<Features, Label: Eq + Clone>
where
    Self: Sized,
{
    fn fit(arr: &Features, y: &[Label]) -> Option<Self>;
    fn labels(&self) -> &[Label];
    fn predict_proba(&self, arr: &Features) -> Option<Array2<f64>>;
    fn predict(&self, arr: &Features) -> Option<Vec<Label>> {
        let l = self.labels();
        let predictions = self.predict_proba(arr)?;

        let a = predictions
            .rows()
            .into_iter()
            .map(|a| {
                a.iter().zip(l).fold((f64::MIN, l[0].clone()), |agg, curr| {
                    match &agg.0 < curr.0 {
                        true => (*curr.0, curr.1.clone()),
                        false => agg,
                    }
                })
            })
            .map(|(_, l)| l);

        Some(a.collect())
    }
}