rs_ml/lib.rs
1//! rs-ml is a simple ML framework for the Rust language. it includes train test splitting,
2//! scalers, and a guassian naive bayes model. It also includes traits to add more transfomers and
3//! models to the framework.
4#![deny(
5 missing_docs,
6 unsafe_code,
7 missing_debug_implementations,
8 missing_copy_implementations,
9 clippy::missing_panics_doc
10)]
11
12use ndarray::{Array, Axis, Dimension, RemoveAxis};
13use rand::{rng, Rng};
14
15pub mod classification;
16pub mod metrics;
17pub mod regression;
18pub mod transformer;
19
20/// Trait for fitting classification and regression models, and transformers.
21///
22/// The struct on which this trait is implemented holds and validates the hyperparameters necessary
23/// to fit the estimator to the desired output. For example, a classification model may take as
24/// input a tuple with features and labels:
25/// ```
26/// use ndarray::{Array1, Array2};
27/// use rs_ml::Estimator;
28///
29/// struct ModelParameters {
30/// // Hyperparameters required to fit the model
31/// learning_rate: f64
32/// }
33///
34/// struct Model {
35/// // Internal state of model required to predict features
36/// means: Array2<f64>
37/// };
38///
39/// impl Estimator<(Array2<f64>, Array1<String>)> for ModelParameters {
40/// type Estimator = Model;
41///
42/// fn fit(&self, input: &(Array2<f64>, Array1<String>)) -> Option<Self::Estimator> {
43/// let (features, labels) = input;
44///
45/// // logic to fit the model
46/// Some(Model {
47/// means: Array2::zeros((1, 1))
48/// })
49/// }
50/// }
51/// ```
52pub trait Estimator<Input> {
53 /// Output model or transformer fitted to input data.
54 type Estimator;
55
56 /// Fit model or transformer based on given inputs, or None if the estimator was not able to
57 /// fit to the input data as expected.
58 fn fit(&self, input: &Input) -> Option<Self::Estimator>;
59}
60
61/// Train test split result. returns in order training features, testing features, training labels,
62/// testing labels.
63#[derive(Debug, Clone)]
64pub struct TrainTestSplitResult<Feature, Label, D: Dimension, D2: Dimension>(
65 pub Array<Feature, D>,
66 pub Array<Feature, D>,
67 pub Array<Label, D2>,
68 pub Array<Label, D2>,
69);
70
71/// Split data and features into training and testing set. `test_size` must be between 0 and 1.
72///
73/// # Panics
74///
75/// Panics if `test_size` is outside range 0..=1.
76///
77/// Example:
78/// ```
79/// use rs_ml::{train_test_split, TrainTestSplitResult};
80/// use ndarray::{arr1, arr2};
81///
82/// let features = arr2(&[
83/// [1., 0.],
84/// [0., 1.],
85/// [0., 0.],
86/// [1., 1.]]);
87///
88/// let labels = arr1(&[1, 1, 0, 0]);
89///
90/// let TrainTestSplitResult(train_features, test_features, train_labels, test_labels) = train_test_split(&features,
91/// &labels, 0.25);
92/// ```
93pub fn train_test_split<
94 D: Dimension + RemoveAxis,
95 D2: Dimension + RemoveAxis,
96 Feature: Clone,
97 Label: Clone,
98>(
99 arr: &Array<Feature, D>,
100 y: &Array<Label, D2>,
101 test_size: f64,
102) -> TrainTestSplitResult<Feature, Label, D, D2> {
103 let rows = arr.shape()[0];
104
105 let (test, train): (Vec<usize>, Vec<usize>) =
106 (0..rows).partition(|_| rng().random_bool(test_size));
107
108 TrainTestSplitResult(
109 arr.select(Axis(0), &train),
110 arr.select(Axis(0), &test),
111 y.select(Axis(0), &train),
112 y.select(Axis(0), &test),
113 )
114}