neural_network_rs/dataset/
mod.rs

1use ndarray::prelude::*;
2use ndarray::{Array, Array2};
3use ndarray_rand::{rand, RandomExt};
4
5pub mod example_datasets;
6
7pub enum DatasetType {
8    Static(fn() -> (Array2<f64>, Array2<f64>)),
9    Dynamic(fn(Array1<f64>) -> Array1<f64>, (usize, usize)),
10}
11
12pub struct Dataset<'a> {
13    pub dataset_type: DatasetType,
14    pub name: &'a str,
15}
16
17impl Dataset<'_> {
18    pub fn new(name: &str, dataset_type: DatasetType) -> Dataset {
19        Dataset { dataset_type, name }
20    }
21
22    pub fn get_full(&self) -> (Array2<f64>, Array2<f64>) {
23        match &self.dataset_type {
24            DatasetType::Static(f) => f(),
25            _ => panic!("Cannot get full dataset from dynamic dataset"),
26        }
27    }
28
29    pub fn get_2d_unit_square(resolution: usize) -> Array2<f64> {
30        let linspace = Array::linspace(0.0, 1.0, resolution);
31
32        let mut x = Array::zeros((resolution * resolution, 2).f());
33        for i in 0..resolution {
34            for j in 0..resolution {
35                x[[i * resolution + j, 0]] = linspace[j];
36                x[[i * resolution + j, 1]] = linspace[i];
37            }
38        }
39
40        x
41    }
42
43    pub fn get_batch(&self, batch_size: usize) -> (Array2<f64>, Array2<f64>) {
44        match &self.dataset_type {
45            DatasetType::Static(f) => {
46                let (data, labels) = f();
47
48                let indices = Array1::random(
49                    batch_size,
50                    rand::distributions::Uniform::new(0, data.shape()[0]),
51                )
52                .to_vec();
53
54                let data = data.select(Axis(0), &indices);
55                let labels = labels.select(Axis(0), &indices);
56
57                (data, labels)
58            }
59            DatasetType::Dynamic(f, (input_dim, output_dim)) => {
60                let x = Array::random(
61                    (batch_size, *input_dim),
62                    rand::distributions::Uniform::new(0.0, 1.0),
63                );
64
65                let mut y = Array2::zeros((batch_size, *output_dim));
66                for (i, xi) in x.outer_iter().enumerate() {
67                    let yi = f(xi.to_owned());
68                    y.row_mut(i).assign(&yi);
69                }
70
71                (x, y)
72            }
73        }
74    }
75}