rs_ml/classification/
mod.rs

1//! Commonly used classification models.
2
3use ndarray::Array2;
4
5pub mod naive_bayes;
6
7/// Single training record for classification task
8#[derive(Debug)]
9pub struct ClassificationRecord<Features, Label> {
10    /// feature for a single classification record
11    pub features: Features,
12    /// label for a single classification record
13    pub label: Label,
14}
15
16/// Dataset to feed into classification model for training task
17#[derive(Debug)]
18pub struct ClassificationDataSet<Features, Label> {
19    /// dataset of classification records on which to train
20    pub dataset: Vec<ClassificationRecord<Features, Label>>,
21}
22
23impl<Features, Label> From<(Features, Label)> for ClassificationRecord<Features, Label> {
24    fn from(value: (Features, Label)) -> Self {
25        ClassificationRecord {
26            features: value.0,
27            label: value.1,
28        }
29    }
30}
31
32impl<Features, Label> From<Vec<ClassificationRecord<Features, Label>>>
33    for ClassificationDataSet<Features, Label>
34{
35    fn from(value: Vec<ClassificationRecord<Features, Label>>) -> Self {
36        ClassificationDataSet { dataset: value }
37    }
38}
39
40impl<Features, Label> From<(Vec<Features>, Vec<Label>)> for ClassificationDataSet<Features, Label> {
41    fn from((train, test): (Vec<Features>, Vec<Label>)) -> Self {
42        ClassificationDataSet {
43            dataset: train.into_iter().zip(test).map(|r| r.into()).collect(),
44        }
45    }
46}
47
48impl<Features, Label> ClassificationDataSet<Features, Label> {
49    /// get labels for record
50    pub fn get_labels(&self) -> Vec<&Label> {
51        self.dataset.iter().map(|record| &record.label).collect()
52    }
53
54    /// get features
55    pub fn get_features(&self) -> Vec<&Features> {
56        self.dataset.iter().map(|record| &record.features).collect()
57    }
58
59    /// get records
60    pub fn get_records(&self) -> &Vec<ClassificationRecord<Features, Label>> {
61        &self.dataset
62    }
63
64    /// consume records of dataset
65    pub fn consume_records(self) -> Vec<ClassificationRecord<Features, Label>> {
66        self.dataset
67    }
68
69    /// Create dataset from iterator of structs
70    pub fn from_struct<'a, I, S: 'a>(
71        it: I,
72        feature_extraction: fn(&S) -> Features,
73        label_extraction: fn(&S) -> Label,
74    ) -> Self
75    where
76        I: Iterator<Item = &'a S>,
77    {
78        let dataset: Vec<ClassificationRecord<Features, Label>> = it
79            .map(|record| (feature_extraction(record), label_extraction(record)))
80            .map(|row| row.into())
81            .collect();
82
83        ClassificationDataSet { dataset }
84    }
85}
86
87/// Trait to interface with a fitted classification model
88pub trait Classifier<Features, Label>
89where
90    Label: Clone,
91{
92    /// Labels on which the model is fitted.
93    fn labels(&self) -> &[Label];
94
95    /// Estimates likelihood of each class per record. Rows correspond to each record, columns are
96    /// in the same order as label function.
97    fn predict_proba<I>(&self, arr: I) -> Option<Array2<f64>>
98    where
99        I: Iterator<Item = Features>;
100
101    /// Provided function which returns the most likely class per record based on the results of
102    /// `predict_proba()`.
103    fn predict<I>(&self, arr: I) -> Option<Vec<Label>>
104    where
105        I: Iterator<Item = Features>,
106    {
107        let l = self.labels();
108        let predictions = self.predict_proba(arr)?;
109
110        let a = predictions
111            .rows()
112            .into_iter()
113            .map(|a| {
114                a.iter().zip(l).fold((f64::MIN, l[0].clone()), |agg, curr| {
115                    match &agg.0 < curr.0 {
116                        true => (*curr.0, curr.1.clone()),
117                        false => agg,
118                    }
119                })
120            })
121            .map(|(_, l)| l);
122
123        Some(a.collect())
124    }
125}