neural_network_rs/dataset/
mod.rs1use ndarray::prelude::*;
2use ndarray::{Array, Array2};
3use ndarray_rand::{rand, RandomExt};
4
5pub mod example_datasets;
6
7pub enum DatasetType {
8 Static(fn() -> (Array2<f64>, Array2<f64>)),
9 Dynamic(fn(Array1<f64>) -> Array1<f64>, (usize, usize)),
10}
11
12pub struct Dataset<'a> {
13 pub dataset_type: DatasetType,
14 pub name: &'a str,
15}
16
17impl Dataset<'_> {
18 pub fn new(name: &str, dataset_type: DatasetType) -> Dataset {
19 Dataset { dataset_type, name }
20 }
21
22 pub fn get_full(&self) -> (Array2<f64>, Array2<f64>) {
23 match &self.dataset_type {
24 DatasetType::Static(f) => f(),
25 _ => panic!("Cannot get full dataset from dynamic dataset"),
26 }
27 }
28
29 pub fn get_2d_unit_square(resolution: usize) -> Array2<f64> {
30 let linspace = Array::linspace(0.0, 1.0, resolution);
31
32 let mut x = Array::zeros((resolution * resolution, 2).f());
33 for i in 0..resolution {
34 for j in 0..resolution {
35 x[[i * resolution + j, 0]] = linspace[j];
36 x[[i * resolution + j, 1]] = linspace[i];
37 }
38 }
39
40 x
41 }
42
43 pub fn get_batch(&self, batch_size: usize) -> (Array2<f64>, Array2<f64>) {
44 match &self.dataset_type {
45 DatasetType::Static(f) => {
46 let (data, labels) = f();
47
48 let indices = Array1::random(
49 batch_size,
50 rand::distributions::Uniform::new(0, data.shape()[0]),
51 )
52 .to_vec();
53
54 let data = data.select(Axis(0), &indices);
55 let labels = labels.select(Axis(0), &indices);
56
57 (data, labels)
58 }
59 DatasetType::Dynamic(f, (input_dim, output_dim)) => {
60 let x = Array::random(
61 (batch_size, *input_dim),
62 rand::distributions::Uniform::new(0.0, 1.0),
63 );
64
65 let mut y = Array2::zeros((batch_size, *output_dim));
66 for (i, xi) in x.outer_iter().enumerate() {
67 let yi = f(xi.to_owned());
68 y.row_mut(i).assign(&yi);
69 }
70
71 (x, y)
72 }
73 }
74 }
75}