rs_ml/classification/
mod.rs

1//! Commonly used classification models.
2
3use ndarray::Array2;
4
5pub mod naive_bayes;
6
7/// Single training record for classification task
8#[derive(Debug)]
9pub struct ClassificationRecord<Features, Label> {
10    /// feature for a single classification record
11    pub features: Features,
12    /// label for a single classification record
13    pub label: Label,
14}
15
16/// Dataset to feed into classification model for training task
17#[derive(Debug)]
18pub struct ClassificationDataSet<Features, Label> {
19    /// dataset of classification records on which to train
20    pub dataset: Vec<ClassificationRecord<Features, Label>>,
21}
22
23impl<Features, Label> From<(Features, Label)> for ClassificationRecord<Features, Label> {
24    fn from(value: (Features, Label)) -> Self {
25        ClassificationRecord {
26            features: value.0,
27            label: value.1,
28        }
29    }
30}
31
32impl<Itr, Record, Features, Label> From<Itr> for ClassificationDataSet<Features, Label>
33where
34    Itr: IntoIterator<Item = Record>,
35    Record: Into<ClassificationRecord<Features, Label>>,
36{
37    fn from(value: Itr) -> Self {
38        ClassificationDataSet {
39            dataset: value.into_iter().map(|record| record.into()).collect(),
40        }
41    }
42}
43
44impl<Features, Label> ClassificationDataSet<Features, Label> {
45    /// get labels for record
46    pub fn get_labels(&self) -> Vec<&Label> {
47        self.dataset.iter().map(|record| &record.label).collect()
48    }
49
50    /// get features
51    pub fn get_features(&self) -> Vec<&Features> {
52        self.dataset.iter().map(|record| &record.features).collect()
53    }
54
55    /// get records
56    pub fn get_records(&self) -> &Vec<ClassificationRecord<Features, Label>> {
57        &self.dataset
58    }
59
60    /// consume records of dataset
61    pub fn consume_records(self) -> Vec<ClassificationRecord<Features, Label>> {
62        self.dataset
63    }
64
65    /// Create dataset from iterator of structs
66    pub fn from_struct<'a, I, S: 'a>(
67        it: I,
68        feature_extraction: fn(&S) -> Features,
69        label_extraction: fn(&S) -> Label,
70    ) -> Self
71    where
72        I: Iterator<Item = &'a S>,
73    {
74        let dataset: Vec<ClassificationRecord<Features, Label>> = it
75            .map(|record| (feature_extraction(record), label_extraction(record)))
76            .map(|row| row.into())
77            .collect();
78
79        ClassificationDataSet { dataset }
80    }
81}
82
83/// Trait to interface with a fitted classification model
84pub trait Classifier<Features, Label>
85where
86    Label: Clone,
87{
88    /// Labels on which the model is fitted.
89    fn labels(&self) -> &[Label];
90
91    /// Estimates likelihood of each class per record. Rows correspond to each record, columns are
92    /// in the same order as label function.
93    fn predict_proba<I>(&self, arr: I) -> Option<Array2<f64>>
94    where
95        I: Iterator<Item = Features>;
96
97    /// Provided function which returns the most likely class per record based on the results of
98    /// `predict_proba()`.
99    fn predict<I>(&self, arr: I) -> Option<Vec<Label>>
100    where
101        I: Iterator<Item = Features>,
102    {
103        let l = self.labels();
104        let predictions = self.predict_proba(arr)?;
105
106        let a = predictions
107            .rows()
108            .into_iter()
109            .map(|a| {
110                a.iter().zip(l).fold((f64::MIN, l[0].clone()), |agg, curr| {
111                    match &agg.0 < curr.0 {
112                        true => (*curr.0, curr.1.clone()),
113                        false => agg,
114                    }
115                })
116            })
117            .map(|(_, l)| l);
118
119        Some(a.collect())
120    }
121}