rs_ml/classification/
mod.rs1use ndarray::Array2;
4
5pub mod naive_bayes;
6
7#[derive(Debug)]
9pub struct ClassificationRecord<Features, Label> {
10 pub features: Features,
12 pub label: Label,
14}
15
16#[derive(Debug)]
18pub struct ClassificationDataSet<Features, Label> {
19 pub dataset: Vec<ClassificationRecord<Features, Label>>,
21}
22
23impl<Features, Label> From<(Features, Label)> for ClassificationRecord<Features, Label> {
24 fn from(value: (Features, Label)) -> Self {
25 ClassificationRecord {
26 features: value.0,
27 label: value.1,
28 }
29 }
30}
31
32impl<Itr, Record, Features, Label> From<Itr> for ClassificationDataSet<Features, Label>
33where
34 Itr: IntoIterator<Item = Record>,
35 Record: Into<ClassificationRecord<Features, Label>>,
36{
37 fn from(value: Itr) -> Self {
38 ClassificationDataSet {
39 dataset: value.into_iter().map(|record| record.into()).collect(),
40 }
41 }
42}
43
44impl<Features, Label> ClassificationDataSet<Features, Label> {
45 pub fn get_labels(&self) -> Vec<&Label> {
47 self.dataset.iter().map(|record| &record.label).collect()
48 }
49
50 pub fn get_features(&self) -> Vec<&Features> {
52 self.dataset.iter().map(|record| &record.features).collect()
53 }
54
55 pub fn get_records(&self) -> &Vec<ClassificationRecord<Features, Label>> {
57 &self.dataset
58 }
59
60 pub fn consume_records(self) -> Vec<ClassificationRecord<Features, Label>> {
62 self.dataset
63 }
64
65 pub fn from_struct<'a, I, S: 'a>(
67 it: I,
68 feature_extraction: fn(&S) -> Features,
69 label_extraction: fn(&S) -> Label,
70 ) -> Self
71 where
72 I: Iterator<Item = &'a S>,
73 {
74 let dataset: Vec<ClassificationRecord<Features, Label>> = it
75 .map(|record| (feature_extraction(record), label_extraction(record)))
76 .map(|row| row.into())
77 .collect();
78
79 ClassificationDataSet { dataset }
80 }
81}
82
83pub trait Classifier<Features, Label>
85where
86 Label: Clone,
87{
88 fn labels(&self) -> &[Label];
90
91 fn predict_proba<I>(&self, arr: I) -> Option<Array2<f64>>
94 where
95 I: Iterator<Item = Features>;
96
97 fn predict<I>(&self, arr: I) -> Option<Vec<Label>>
100 where
101 I: Iterator<Item = Features>,
102 {
103 let l = self.labels();
104 let predictions = self.predict_proba(arr)?;
105
106 let a = predictions
107 .rows()
108 .into_iter()
109 .map(|a| {
110 a.iter().zip(l).fold((f64::MIN, l[0].clone()), |agg, curr| {
111 match &agg.0 < curr.0 {
112 true => (*curr.0, curr.1.clone()),
113 false => agg,
114 }
115 })
116 })
117 .map(|(_, l)| l);
118
119 Some(a.collect())
120 }
121}