use ndarray::Array2;
pub mod naive_bayes;
#[derive(Debug)]
pub struct ClassificationRecord<Features, Label> {
pub features: Features,
pub label: Label,
}
#[derive(Debug)]
pub struct ClassificationDataSet<Features, Label> {
pub dataset: Vec<ClassificationRecord<Features, Label>>,
}
impl<Features, Label> From<(Features, Label)> for ClassificationRecord<Features, Label> {
fn from(value: (Features, Label)) -> Self {
ClassificationRecord {
features: value.0,
label: value.1,
}
}
}
impl<Itr, Record, Features, Label> From<Itr> for ClassificationDataSet<Features, Label>
where
Itr: IntoIterator<Item = Record>,
Record: Into<ClassificationRecord<Features, Label>>,
{
fn from(value: Itr) -> Self {
ClassificationDataSet {
dataset: value.into_iter().map(|record| record.into()).collect(),
}
}
}
impl<Features, Label> ClassificationDataSet<Features, Label> {
pub fn get_labels(&self) -> Vec<&Label> {
self.dataset.iter().map(|record| &record.label).collect()
}
pub fn get_features(&self) -> Vec<&Features> {
self.dataset.iter().map(|record| &record.features).collect()
}
pub fn get_records(&self) -> &Vec<ClassificationRecord<Features, Label>> {
&self.dataset
}
pub fn consume_records(self) -> Vec<ClassificationRecord<Features, Label>> {
self.dataset
}
pub fn from_struct<'a, I, S: 'a>(
it: I,
feature_extraction: fn(&S) -> Features,
label_extraction: fn(&S) -> Label,
) -> Self
where
I: Iterator<Item = &'a S>,
{
let dataset: Vec<ClassificationRecord<Features, Label>> = it
.map(|record| (feature_extraction(record), label_extraction(record)))
.map(|row| row.into())
.collect();
ClassificationDataSet { dataset }
}
}
pub trait Classifier<Features, Label>
where
Label: Clone,
{
fn labels(&self) -> &[Label];
fn predict_proba<I>(&self, arr: I) -> Option<Array2<f64>>
where
I: Iterator<Item = Features>;
fn predict<I>(&self, arr: I) -> Option<Vec<Label>>
where
I: Iterator<Item = Features>,
{
let l = self.labels();
let predictions = self.predict_proba(arr)?;
let a = predictions
.rows()
.into_iter()
.map(|a| {
a.iter().zip(l).fold((f64::MIN, l[0].clone()), |agg, curr| {
match &agg.0 < curr.0 {
true => (*curr.0, curr.1.clone()),
false => agg,
}
})
})
.map(|(_, l)| l);
Some(a.collect())
}
}