neural_network_rs/dataset/
example_datasets.rs

1use ndarray::array;
2
3use super::{Dataset, DatasetType};
4
5// The XOR dataset: [0, 0] -> 0, [0, 1] -> 1, [1, 0] -> 1, [1, 1] -> 0
6pub static XOR: Dataset = Dataset {
7    name: "XOR",
8    dataset_type: DatasetType::Static(|| {
9        let x = array![[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]];
10        let y = array![[0.0], [1.0], [1.0], [0.0]];
11        (x, y)
12    }),
13};
14
15// The Circle dataset: [x, y] -> 1 if (x-0.5)^2 + (y-0.5)^2 < 0.25, 0 otherwise
16pub static CIRCLE: Dataset = Dataset {
17    name: "Circle",
18    dataset_type: DatasetType::Dynamic(
19        |x| {
20            let dist_from_center = ((x[0] - 0.5).powi(2) + (x[1] - 0.5).powi(2)).sqrt();
21            let y = if dist_from_center < 0.25 { 1.0 } else { 0.0 };
22            array![y]
23        },
24        (2, 1),
25    ),
26};
27
28// The RGB_Test dataset: [x, y] -> [r=x, g=y, b=1-x]
29pub static RGB_TEST: Dataset = Dataset {
30    name: "RGB_TEST",
31    dataset_type: DatasetType::Dynamic(
32        |x| {
33            let r = x[0];
34            let g = x[1];
35            let b = 1.0 - r;
36            array![r, g, b]
37        },
38        (2, 3),
39    ),
40};
41
42// The RGB_DONUT dataset: represents a colorful donut-shape in RGB unit-square
43pub static RGB_DONUT: Dataset = Dataset {
44    name: "RGB_DONUT",
45    dataset_type: DatasetType::Dynamic(
46        |x| {
47            let dist_from_center = ((x[0] - 0.5).powi(2) + (x[1] - 0.5).powi(2)).sqrt();
48
49            let r = x[0];
50            let g = x[1];
51            let b = 1.0 - r;
52
53            if dist_from_center > 0.25 && dist_from_center < 0.45 {
54                array![r, g, b]
55            } else {
56                array![0.0, 0.0, 0.0]
57            }
58        },
59        (2, 3),
60    ),
61};