rs_ml/classification/
mod.rs1use ndarray::Array2;
4
5pub mod naive_bayes;
6
7#[derive(Debug)]
9pub struct ClassificationRecord<Features, Label> {
10 pub features: Features,
12 pub label: Label,
14}
15
16#[derive(Debug)]
18pub struct ClassificationDataSet<Features, Label> {
19 pub dataset: Vec<ClassificationRecord<Features, Label>>,
21}
22
23impl<Features, Label> From<(Features, Label)> for ClassificationRecord<Features, Label> {
24 fn from(value: (Features, Label)) -> Self {
25 ClassificationRecord {
26 features: value.0,
27 label: value.1,
28 }
29 }
30}
31
32impl<Features, Label> From<Vec<ClassificationRecord<Features, Label>>>
33 for ClassificationDataSet<Features, Label>
34{
35 fn from(value: Vec<ClassificationRecord<Features, Label>>) -> Self {
36 ClassificationDataSet { dataset: value }
37 }
38}
39
40impl<Features, Label> From<(Vec<Features>, Vec<Label>)> for ClassificationDataSet<Features, Label> {
41 fn from((train, test): (Vec<Features>, Vec<Label>)) -> Self {
42 ClassificationDataSet {
43 dataset: train.into_iter().zip(test).map(|r| r.into()).collect(),
44 }
45 }
46}
47
48impl<Features, Label> ClassificationDataSet<Features, Label> {
49 pub fn get_labels(&self) -> Vec<&Label> {
51 self.dataset.iter().map(|record| &record.label).collect()
52 }
53
54 pub fn get_features(&self) -> Vec<&Features> {
56 self.dataset.iter().map(|record| &record.features).collect()
57 }
58
59 pub fn get_records(&self) -> &Vec<ClassificationRecord<Features, Label>> {
61 &self.dataset
62 }
63
64 pub fn consume_records(self) -> Vec<ClassificationRecord<Features, Label>> {
66 self.dataset
67 }
68
69 pub fn from_struct<'a, I, S: 'a>(
71 it: I,
72 feature_extraction: fn(&S) -> Features,
73 label_extraction: fn(&S) -> Label,
74 ) -> Self
75 where
76 I: Iterator<Item = &'a S>,
77 {
78 let dataset: Vec<ClassificationRecord<Features, Label>> = it
79 .map(|record| (feature_extraction(record), label_extraction(record)))
80 .map(|row| row.into())
81 .collect();
82
83 ClassificationDataSet { dataset }
84 }
85}
86
87pub trait Classifier<Features, Label>
89where
90 Label: Clone,
91{
92 fn labels(&self) -> &[Label];
94
95 fn predict_proba<I>(&self, arr: I) -> Option<Array2<f64>>
98 where
99 I: Iterator<Item = Features>;
100
101 fn predict<I>(&self, arr: I) -> Option<Vec<Label>>
104 where
105 I: Iterator<Item = Features>,
106 {
107 let l = self.labels();
108 let predictions = self.predict_proba(arr)?;
109
110 let a = predictions
111 .rows()
112 .into_iter()
113 .map(|a| {
114 a.iter().zip(l).fold((f64::MIN, l[0].clone()), |agg, curr| {
115 match &agg.0 < curr.0 {
116 true => (*curr.0, curr.1.clone()),
117 false => agg,
118 }
119 })
120 })
121 .map(|(_, l)| l);
122
123 Some(a.collect())
124 }
125}