ghostflow_data/
dataset.rs1use ghostflow_core::Tensor;
4
5pub trait Dataset: Send + Sync {
7 fn get(&self, index: usize) -> (Tensor, Tensor);
9
10 fn len(&self) -> usize;
12
13 fn is_empty(&self) -> bool {
15 self.len() == 0
16 }
17}
18
19pub struct TensorDataset {
21 data: Tensor,
22 targets: Tensor,
23}
24
25impl TensorDataset {
26 pub fn new(data: Tensor, targets: Tensor) -> Self {
27 assert_eq!(data.dims()[0], targets.dims()[0],
28 "Data and targets must have same number of samples");
29 TensorDataset { data, targets }
30 }
31}
32
33impl Dataset for TensorDataset {
34 fn get(&self, index: usize) -> (Tensor, Tensor) {
35 let data_dims = self.data.dims();
37 let target_dims = self.targets.dims();
38
39 let data_slice_size: usize = data_dims[1..].iter().product();
40 let target_slice_size: usize = if target_dims.len() > 1 {
41 target_dims[1..].iter().product()
42 } else {
43 1
44 };
45
46 let data_vec = self.data.data_f32();
47 let target_vec = self.targets.data_f32();
48
49 let data_start = index * data_slice_size;
50 let data_end = data_start + data_slice_size;
51 let sample_data = &data_vec[data_start..data_end];
52
53 let target_start = index * target_slice_size;
54 let target_end = target_start + target_slice_size;
55 let sample_target = &target_vec[target_start..target_end];
56
57 let data_shape: Vec<usize> = data_dims[1..].to_vec();
58 let target_shape: Vec<usize> = if target_dims.len() > 1 {
59 target_dims[1..].to_vec()
60 } else {
61 vec![1]
62 };
63
64 (
65 Tensor::from_slice(sample_data, &data_shape).unwrap(),
66 Tensor::from_slice(sample_target, &target_shape).unwrap(),
67 )
68 }
69
70 fn len(&self) -> usize {
71 self.data.dims()[0]
72 }
73}
74
75pub struct Subset<D: Dataset> {
77 dataset: D,
78 indices: Vec<usize>,
79}
80
81impl<D: Dataset> Subset<D> {
82 pub fn new(dataset: D, indices: Vec<usize>) -> Self {
83 Subset { dataset, indices }
84 }
85}
86
87impl<D: Dataset> Dataset for Subset<D> {
88 fn get(&self, index: usize) -> (Tensor, Tensor) {
89 self.dataset.get(self.indices[index])
90 }
91
92 fn len(&self) -> usize {
93 self.indices.len()
94 }
95}
96
97pub struct ConcatDataset<D: Dataset> {
99 datasets: Vec<D>,
100 cumulative_sizes: Vec<usize>,
101}
102
103impl<D: Dataset> ConcatDataset<D> {
104 pub fn new(datasets: Vec<D>) -> Self {
105 let mut cumulative_sizes = Vec::with_capacity(datasets.len());
106 let mut total = 0;
107
108 for ds in &datasets {
109 total += ds.len();
110 cumulative_sizes.push(total);
111 }
112
113 ConcatDataset {
114 datasets,
115 cumulative_sizes,
116 }
117 }
118}
119
120impl<D: Dataset> Dataset for ConcatDataset<D> {
121 fn get(&self, index: usize) -> (Tensor, Tensor) {
122 let mut dataset_idx = 0;
123 let mut sample_idx = index;
124
125 for (i, &size) in self.cumulative_sizes.iter().enumerate() {
126 if index < size {
127 dataset_idx = i;
128 if i > 0 {
129 sample_idx = index - self.cumulative_sizes[i - 1];
130 }
131 break;
132 }
133 }
134
135 self.datasets[dataset_idx].get(sample_idx)
136 }
137
138 fn len(&self) -> usize {
139 *self.cumulative_sizes.last().unwrap_or(&0)
140 }
141}