rs_ml/
lib.rs

1//! rs-ml is a simple ML framework for the Rust language. it includes train test splitting,
2//! scalers, and a guassian naive bayes model. It also includes traits to add more transfomers and
3//! models to the framework.
4#![deny(missing_docs)]
5
6use core::f64;
7
8use ndarray::{Array, Axis, Dimension, RemoveAxis};
9use rand::{rng, Rng};
10
11pub mod classification;
12pub mod metrics;
13pub mod regression;
14pub mod transformer;
15
16/// Trait for fitting classification and regression models, and transformers.
17///
18/// The struct on which this trait is implemented holds and validates the hyperparameters necessary
19/// to fit the estimator to the desired output. For example, a classification model may take as
20/// input a tuple with features and labels:
21/// ```
22/// use ndarray::{Array1, Array2};
23/// use rs_ml::Estimator;
24///
25/// struct ModelParameters {
26///   // Hyperparameters required to fit the model
27///   learning_rate: f64
28/// }
29///
30/// struct Model {
31///     // Internal state of model required to predict features
32///     means: Array2<f64>
33/// };
34///
35/// impl Estimator<(Array2<f64>, Array1<String>)> for ModelParameters {
36///     type Estimator = Model;
37///
38///     fn fit(&self, input: &(Array2<f64>, Array1<String>)) -> Option<Self::Estimator> {
39///         let (features, labels) = input;
40///
41///         // logic to fit the model
42///         Some(Model {
43///             means: Array2::zeros((1, 1))
44///         })
45///     }
46/// }
47/// ```
48pub trait Estimator<Input> {
49    /// Output model or transformer fitted to input data.
50    type Estimator;
51
52    /// Fit model or transformer based on given inputs, or None if the estimator was not able to
53    /// fit to the input data as expected.
54    fn fit(&self, input: &Input) -> Option<Self::Estimator>;
55}
56
57/// Split data and features into training and testing set. `test_size` must be between 0 and 1.
58/// Panics if `test_size` is outside 0 and 1.
59pub fn train_test_split<
60    D: Dimension + RemoveAxis,
61    D2: Dimension + RemoveAxis,
62    Feature: Clone,
63    Label: Clone,
64>(
65    arr: &Array<Feature, D>,
66    y: &Array<Label, D2>,
67    test_size: f64,
68) -> (
69    Array<Feature, D>,
70    Array<Feature, D>,
71    Array<Label, D2>,
72    Array<Label, D2>,
73) {
74    let rows = arr.shape()[0];
75
76    let (test, train): (Vec<usize>, Vec<usize>) =
77        (0..rows).partition(|_| rng().random_bool(test_size));
78
79    (
80        arr.select(Axis(0), &train),
81        arr.select(Axis(0), &test),
82        y.select(Axis(0), &train),
83        y.select(Axis(0), &test),
84    )
85}