1#![deny(missing_docs)]
5
6use core::f64;
7
8use classification::Classifier;
9use ndarray::{Array, Axis, Dimension, RemoveAxis};
10use rand::{rng, Rng};
11
12pub mod classification;
13pub mod metrics;
14pub mod transformer;
15
16pub fn train_test_split<
19 D: Dimension + RemoveAxis,
20 D2: Dimension + RemoveAxis,
21 Feature: Clone,
22 Label: Clone,
23>(
24 arr: &Array<Feature, D>,
25 y: &Array<Label, D2>,
26 test_size: f64,
27) -> (
28 Array<Feature, D>,
29 Array<Feature, D>,
30 Array<Label, D2>,
31 Array<Label, D2>,
32) {
33 let rows = arr.shape()[0];
34
35 let (test, train): (Vec<usize>, Vec<usize>) =
36 (0..rows).partition(|_| rng().random_bool(test_size));
37
38 (
39 arr.select(Axis(0), &train),
40 arr.select(Axis(0), &test),
41 y.select(Axis(0), &train),
42 y.select(Axis(0), &test),
43 )
44}