rs-ml 0.4.0

Simple ML crate including Gaussian Naive Bayesian classifier
Documentation
use ndarray::arr1;
use rs_ml::{
    classification::{naive_bayes::GaussianNBEstimator, ClassificationDataSet, Classifier},
    metrics::accuracy,
    train_test_split, Estimatable, Estimator,
};
use serde::Deserialize;

#[derive(Deserialize, Clone)]
struct Iris {
    sepal_length: f64,
    sepal_width: f64,
    petal_width: f64,
    petal_length: f64,
}

#[derive(Deserialize)]
struct DataPoint {
    #[serde(flatten)]
    iris: Iris,
    species: String,
}

impl Estimatable for Iris {
    fn prepare_for_estimation<F: num_traits::Float>(&self) -> ndarray::Array1<F> {
        arr1(&[
            F::from(self.sepal_length).unwrap(),
            F::from(self.sepal_width).unwrap(),
            F::from(self.petal_width).unwrap(),
            F::from(self.petal_length).unwrap(),
        ])
    }
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let mut csv = csv::Reader::from_path("iris.csv")?;

    let features: Vec<DataPoint> = csv
        .deserialize::<DataPoint>()
        .filter_map(|r| r.ok())
        .collect();

    let dataset: ClassificationDataSet<_, _> = ClassificationDataSet::from_struct(
        features.iter(),
        |f| f.iris.clone(),
        |f| f.species.clone(),
    );
    let (train_dataset, test_dataset) = train_test_split(dataset, 0.25);

    let model = GaussianNBEstimator
        .fit(&train_dataset)
        .ok_or("Training failed")?;
    let inference = model
        .predict(test_dataset.get_features().into_iter().cloned())
        .ok_or("Inference failed")?;

    let labels: Vec<String> = test_dataset.get_labels().into_iter().cloned().collect();

    let accuracy = accuracy(labels, inference).ok_or("Accuracy metric failed")?;

    println!("Test accuracy: {accuracy:.4}");

    Ok(())
}