ai_dataloader/indexable/dataset/
ndarray_dataset.rs1use super::{Dataset, GetSample};
2use crate::Len;
3use ndarray::{Array, Axis, Dimension, RemoveAxis};
4
5#[derive(Debug, PartialEq, Hash, Eq)]
7pub struct NdarrayDataset<A1, A2, D1, D2>
8where
9 A1: Clone,
10 A2: Clone,
11 D1: Dimension + RemoveAxis,
12 D2: Dimension + RemoveAxis,
13{
14 pub ndarrays: (Array<A1, D1>, Array<A2, D2>),
16}
17impl<A1, A2, D1, D2> Dataset for NdarrayDataset<A1, A2, D1, D2>
18where
19 A1: Clone,
20 A2: Clone,
21 D1: Dimension + RemoveAxis,
22 D2: Dimension + RemoveAxis,
23{
24}
25
26impl<A1, A2, D1, D2> Clone for NdarrayDataset<A1, A2, D1, D2>
27where
28 A1: Clone,
29 A2: Clone,
30 D1: Dimension + RemoveAxis,
31 D2: Dimension + RemoveAxis,
32{
33 fn clone(&self) -> Self {
34 Self {
35 ndarrays: self.ndarrays.clone(),
36 }
37 }
38}
39
40impl<A1, A2, D1, D2> Len for NdarrayDataset<A1, A2, D1, D2>
41where
42 A1: Clone,
43 A2: Clone,
44 D1: Dimension + RemoveAxis,
45 D2: Dimension + RemoveAxis,
46{
47 fn len(&self) -> usize {
48 self.ndarrays.0.shape()[0]
49 }
50}
51impl<A1, A2, D1, D2> GetSample for NdarrayDataset<A1, A2, D1, D2>
52where
53 A1: Clone,
54 A2: Clone,
55 D1: Dimension + RemoveAxis,
56 D2: Dimension + RemoveAxis,
57{
58 type Sample = (
59 Array<A1, <D1 as Dimension>::Smaller>,
60 Array<A2, <D2 as Dimension>::Smaller>,
61 );
62 fn get_sample(&self, index: usize) -> Self::Sample {
63 (
64 self.ndarrays.0.index_axis(Axis(0), index).into_owned(),
65 self.ndarrays.1.index_axis(Axis(0), index).into_owned(),
66 )
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use ndarray::{arr0, array};
74 #[test]
75 fn ndarray_dataset() {
76 let dataset = NdarrayDataset {
77 ndarrays: (array![1, 2], array![3, 4]),
78 };
79 assert_eq!(dataset.get_sample(0), (arr0(1), arr0(3)));
80 assert_eq!(dataset.get_sample(1), (arr0(2), arr0(4)));
81 }
82}