rs_ml/
lib.rs

1//! rs-ml is a simple ML framework for the Rust language. it includes train test splitting,
2//! scalers, and a guassian naive bayes model. It also includes traits to add more transfomers and
3//! models to the framework.
4#![deny(
5    missing_docs,
6    unsafe_code,
7    missing_debug_implementations,
8    missing_copy_implementations,
9    clippy::missing_panics_doc
10)]
11
12use ndarray::{Array, Axis, Dimension, RemoveAxis};
13use rand::{rng, Rng};
14
15pub mod classification;
16pub mod metrics;
17pub mod regression;
18pub mod transformer;
19
20/// Trait for fitting classification and regression models, and transformers.
21///
22/// The struct on which this trait is implemented holds and validates the hyperparameters necessary
23/// to fit the estimator to the desired output. For example, a classification model may take as
24/// input a tuple with features and labels:
25/// ```
26/// use ndarray::{Array1, Array2};
27/// use rs_ml::Estimator;
28///
29/// struct ModelParameters {
30///   // Hyperparameters required to fit the model
31///   learning_rate: f64
32/// }
33///
34/// struct Model {
35///     // Internal state of model required to predict features
36///     means: Array2<f64>
37/// };
38///
39/// impl Estimator<(Array2<f64>, Array1<String>)> for ModelParameters {
40///     type Estimator = Model;
41///
42///     fn fit(&self, input: &(Array2<f64>, Array1<String>)) -> Option<Self::Estimator> {
43///         let (features, labels) = input;
44///
45///         // logic to fit the model
46///         Some(Model {
47///             means: Array2::zeros((1, 1))
48///         })
49///     }
50/// }
51/// ```
52pub trait Estimator<Input> {
53    /// Output model or transformer fitted to input data.
54    type Estimator;
55
56    /// Fit model or transformer based on given inputs, or None if the estimator was not able to
57    /// fit to the input data as expected.
58    fn fit(&self, input: &Input) -> Option<Self::Estimator>;
59}
60
61/// Train test split result. returns in order training features, testing features, training labels,
62/// testing labels.
63#[derive(Debug, Clone)]
64pub struct TrainTestSplitResult<Feature, Label, D: Dimension, D2: Dimension>(
65    pub Array<Feature, D>,
66    pub Array<Feature, D>,
67    pub Array<Label, D2>,
68    pub Array<Label, D2>,
69);
70
71/// Split data and features into training and testing set. `test_size` must be between 0 and 1.
72///
73/// # Panics
74///
75/// Panics if `test_size` is outside range 0..=1.
76///
77/// Example:
78/// ```
79/// use rs_ml::{train_test_split, TrainTestSplitResult};
80/// use ndarray::{arr1, arr2};
81///
82/// let features = arr2(&[
83///   [1., 0.],
84///   [0., 1.],
85///   [0., 0.],
86///   [1., 1.]]);
87///
88/// let labels = arr1(&[1, 1, 0, 0]);
89///
90/// let TrainTestSplitResult(train_features, test_features, train_labels, test_labels) = train_test_split(&features,
91/// &labels, 0.25);
92/// ```
93pub fn train_test_split<
94    D: Dimension + RemoveAxis,
95    D2: Dimension + RemoveAxis,
96    Feature: Clone,
97    Label: Clone,
98>(
99    arr: &Array<Feature, D>,
100    y: &Array<Label, D2>,
101    test_size: f64,
102) -> TrainTestSplitResult<Feature, Label, D, D2> {
103    let rows = arr.shape()[0];
104
105    let (test, train): (Vec<usize>, Vec<usize>) =
106        (0..rows).partition(|_| rng().random_bool(test_size));
107
108    TrainTestSplitResult(
109        arr.select(Axis(0), &train),
110        arr.select(Axis(0), &test),
111        y.select(Axis(0), &train),
112        y.select(Axis(0), &test),
113    )
114}