rs_ml/lib.rs
1//! rs-ml is a simple ML framework for the Rust language. it includes train test splitting,
2//! scalers, and a guassian naive bayes model. It also includes traits to add more transfomers and
3//! models to the framework.
4#![deny(missing_docs)]
5
6use core::f64;
7
8use ndarray::{Array, Axis, Dimension, RemoveAxis};
9use rand::{rng, Rng};
10
11pub mod classification;
12pub mod metrics;
13pub mod regression;
14pub mod transformer;
15
16/// Trait for fitting classification and regression models, and transformers.
17///
18/// The struct on which this trait is implemented holds and validates the hyperparameters necessary
19/// to fit the estimator to the desired output. For example, a classification model may take as
20/// input a tuple with features and labels:
21/// ```
22/// use ndarray::{Array1, Array2};
23/// use rs_ml::Estimator;
24///
25/// struct ModelParameters {
26/// // Hyperparameters required to fit the model
27/// learning_rate: f64
28/// }
29///
30/// struct Model {
31/// // Internal state of model required to predict features
32/// means: Array2<f64>
33/// };
34///
35/// impl Estimator<(Array2<f64>, Array1<String>)> for ModelParameters {
36/// type Estimator = Model;
37///
38/// fn fit(&self, input: &(Array2<f64>, Array1<String>)) -> Option<Self::Estimator> {
39/// let (features, labels) = input;
40///
41/// // logic to fit the model
42/// Some(Model {
43/// means: Array2::zeros((1, 1))
44/// })
45/// }
46/// }
47/// ```
48pub trait Estimator<Input> {
49 /// Output model or transformer fitted to input data.
50 type Estimator;
51
52 /// Fit model or transformer based on given inputs, or None if the estimator was not able to
53 /// fit to the input data as expected.
54 fn fit(&self, input: &Input) -> Option<Self::Estimator>;
55}
56
57/// Split data and features into training and testing set. `test_size` must be between 0 and 1.
58/// Panics if `test_size` is outside 0 and 1.
59pub fn train_test_split<
60 D: Dimension + RemoveAxis,
61 D2: Dimension + RemoveAxis,
62 Feature: Clone,
63 Label: Clone,
64>(
65 arr: &Array<Feature, D>,
66 y: &Array<Label, D2>,
67 test_size: f64,
68) -> (
69 Array<Feature, D>,
70 Array<Feature, D>,
71 Array<Label, D2>,
72 Array<Label, D2>,
73) {
74 let rows = arr.shape()[0];
75
76 let (test, train): (Vec<usize>, Vec<usize>) =
77 (0..rows).partition(|_| rng().random_bool(test_size));
78
79 (
80 arr.select(Axis(0), &train),
81 arr.select(Axis(0), &test),
82 y.select(Axis(0), &train),
83 y.select(Axis(0), &test),
84 )
85}