rs_ml/
lib.rs

1//! rs-ml is a simple ML framework for the Rust language. it includes train test splitting,
2//! scalers, and a guassian naive bayes model. It also includes traits to add more transfomers and
3//! models to the framework.
4#![deny(missing_docs)]
5
6use core::f64;
7
8use classification::Classifier;
9use ndarray::{Array, Axis, Dimension, RemoveAxis};
10use rand::{rng, Rng};
11
12pub mod classification;
13pub mod metrics;
14pub mod transformer;
15
16/// Split data and features into training and testing set. `test_size` must be between 0 and 1.
17/// panics if `test_size` is outside 0 and 1.
18pub fn train_test_split<
19    D: Dimension + RemoveAxis,
20    D2: Dimension + RemoveAxis,
21    Feature: Clone,
22    Label: Clone,
23>(
24    arr: &Array<Feature, D>,
25    y: &Array<Label, D2>,
26    test_size: f64,
27) -> (
28    Array<Feature, D>,
29    Array<Feature, D>,
30    Array<Label, D2>,
31    Array<Label, D2>,
32) {
33    let rows = arr.shape()[0];
34
35    let (test, train): (Vec<usize>, Vec<usize>) =
36        (0..rows).partition(|_| rng().random_bool(test_size));
37
38    (
39        arr.select(Axis(0), &train),
40        arr.select(Axis(0), &test),
41        y.select(Axis(0), &train),
42        y.select(Axis(0), &test),
43    )
44}