rs_ml/
lib.rs

1//! rs-ml is a simple ML framework for the Rust language. it includes train test splitting,
2//! scalers, and a guassian naive bayes model. It also includes traits to add more transfomers and
3//! models to the framework.
4//!
5//! # Usage
6//!
7//! This library requires a compute backend to perform matrix operations. Compute backends are
8//! exposed with provided feature flags. Refer to the
9//! [ndarray_linalg](https://github.com/rust-ndarray/ndarray-linalg?tab=readme-ov-file#backend-features)
10//! docs for more information.
11#![deny(
12    missing_docs,
13    unsafe_code,
14    missing_debug_implementations,
15    missing_copy_implementations,
16    clippy::missing_panics_doc
17)]
18
19use classification::ClassificationDataSet;
20use ndarray::Axis;
21use num_traits::Float;
22
23pub mod classification;
24pub mod metrics;
25pub mod regression;
26pub mod transformer;
27
28/// Trait for fitting classification and regression models, and transformers.
29///
30/// The struct on which this trait is implemented holds and validates the hyperparameters necessary
31/// to fit the estimator to the desired output. For example, a classification model may take as
32/// input a tuple with features and labels:
33/// ```
34/// use ndarray::{Array1, Array2};
35/// use rs_ml::Estimator;
36///
37/// struct ModelParameters {
38///   // Hyperparameters required to fit the model
39///   learning_rate: f64
40/// }
41///
42/// struct Model {
43///     // Internal state of model required to predict features
44///     means: Array2<f64>
45/// };
46///
47/// impl Estimator<(Array2<f64>, Array1<String>)> for ModelParameters {
48///     type Estimator = Model;
49///
50///     fn fit(&self, input: &(Array2<f64>, Array1<String>)) -> Option<Self::Estimator> {
51///         let (features, labels) = input;
52///
53///         // logic to fit the model
54///         Some(Model {
55///             means: Array2::zeros((1, 1))
56///         })
57///     }
58/// }
59/// ```
60pub trait Estimator<Input> {
61    /// Output model or transformer fitted to input data.
62    type Estimator;
63
64    /// Fit model or transformer based on given inputs, or None if the estimator was not able to
65    /// fit to the input data as expected.
66    fn fit(&self, input: &Input) -> Option<Self::Estimator>;
67}
68
69/// Train test split result. returns in order training features, testing features, training labels,
70/// testing labels.
71#[derive(Debug, Clone)]
72pub struct SplitDataset<Feature, Label>(
73    pub Vec<Feature>,
74    pub Vec<Feature>,
75    pub Vec<Label>,
76    pub Vec<Label>,
77);
78
79/// Split data and features into training and testing set. `test_size` must be between 0 and 1.
80///
81/// # Panics
82///
83/// Panics if `test_size` is outside range 0..=1.
84///
85/// Example:
86/// ```
87/// use rs_ml::{train_test_split};
88/// use rs_ml::classification::ClassificationDataSet;
89/// use ndarray::{arr1, arr2};
90///
91/// let features = arr2(&[
92///   [1., 0.],
93///   [0., 1.],
94///   [0., 0.],
95///   [1., 1.]]);
96///
97/// let labels = vec![1, 1, 0, 0];
98///
99/// let dataset = ClassificationDataSet::from((features.rows().into_iter().collect(), labels));
100///
101/// let (train, test) = train_test_split(dataset, 0.25);
102/// ```
103pub fn train_test_split<Feature, Label>(
104    dataset: ClassificationDataSet<Feature, Label>,
105    test_size: f64,
106) -> (
107    ClassificationDataSet<Feature, Label>,
108    ClassificationDataSet<Feature, Label>,
109) {
110    let (train, test): (Vec<_>, Vec<_>) = dataset
111        .consume_records()
112        .into_iter()
113        .partition(|_| rand::random_bool(test_size));
114
115    (
116        ClassificationDataSet::from(train),
117        ClassificationDataSet::from(test),
118    )
119}
120
121fn iterative_mean<I, F>(it: I) -> Option<F>
122where
123    I: Iterator<Item = F>,
124    F: Float,
125{
126    it.into_iter().enumerate().fold(None, |acc, (i, curr)| {
127        let idx: F = F::from(i)?;
128        let idx_inc_1: F = F::from(i + 1)?;
129
130        Some((idx / idx_inc_1) * acc.unwrap_or(F::zero()) + curr / idx_inc_1)
131    })
132}