ai_dataloader/indexable/dataset/
ndarray_dataset.rs

1use super::{Dataset, GetSample};
2use crate::Len;
3use ndarray::{Array, Axis, Dimension, RemoveAxis};
4
5/// Basic dataset than can contains two `ndarray` of any dimension.
6#[derive(Debug, PartialEq, Hash, Eq)]
7pub struct NdarrayDataset<A1, A2, D1, D2>
8where
9    A1: Clone,
10    A2: Clone,
11    D1: Dimension + RemoveAxis,
12    D2: Dimension + RemoveAxis,
13{
14    /// The content of the dataset.
15    pub ndarrays: (Array<A1, D1>, Array<A2, D2>),
16}
17impl<A1, A2, D1, D2> Dataset for NdarrayDataset<A1, A2, D1, D2>
18where
19    A1: Clone,
20    A2: Clone,
21    D1: Dimension + RemoveAxis,
22    D2: Dimension + RemoveAxis,
23{
24}
25
26impl<A1, A2, D1, D2> Clone for NdarrayDataset<A1, A2, D1, D2>
27where
28    A1: Clone,
29    A2: Clone,
30    D1: Dimension + RemoveAxis,
31    D2: Dimension + RemoveAxis,
32{
33    fn clone(&self) -> Self {
34        Self {
35            ndarrays: self.ndarrays.clone(),
36        }
37    }
38}
39
40impl<A1, A2, D1, D2> Len for NdarrayDataset<A1, A2, D1, D2>
41where
42    A1: Clone,
43    A2: Clone,
44    D1: Dimension + RemoveAxis,
45    D2: Dimension + RemoveAxis,
46{
47    fn len(&self) -> usize {
48        self.ndarrays.0.shape()[0]
49    }
50}
51impl<A1, A2, D1, D2> GetSample for NdarrayDataset<A1, A2, D1, D2>
52where
53    A1: Clone,
54    A2: Clone,
55    D1: Dimension + RemoveAxis,
56    D2: Dimension + RemoveAxis,
57{
58    type Sample = (
59        Array<A1, <D1 as Dimension>::Smaller>,
60        Array<A2, <D2 as Dimension>::Smaller>,
61    );
62    fn get_sample(&self, index: usize) -> Self::Sample {
63        (
64            self.ndarrays.0.index_axis(Axis(0), index).into_owned(),
65            self.ndarrays.1.index_axis(Axis(0), index).into_owned(),
66        )
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use ndarray::{arr0, array};
74    #[test]
75    fn ndarray_dataset() {
76        let dataset = NdarrayDataset {
77            ndarrays: (array![1, 2], array![3, 4]),
78        };
79        assert_eq!(dataset.get_sample(0), (arr0(1), arr0(3)));
80        assert_eq!(dataset.get_sample(1), (arr0(2), arr0(4)));
81    }
82}