1use axonml_tensor::Tensor;
18
19pub trait Dataset: Send + Sync {
27 type Item: Send;
29
30 fn len(&self) -> usize;
32
33 fn is_empty(&self) -> bool {
35 self.len() == 0
36 }
37
38 fn get(&self, index: usize) -> Option<Self::Item>;
40}
41
42pub struct TensorDataset {
50 data: Tensor<f32>,
52 targets: Tensor<f32>,
54 len: usize,
56}
57
58impl TensorDataset {
59 #[must_use]
63 pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
64 let len = data.shape()[0];
65 assert_eq!(
66 len,
67 targets.shape()[0],
68 "Data and targets must have same first dimension"
69 );
70 Self { data, targets, len }
71 }
72
73 #[must_use]
75 pub fn from_data(data: Tensor<f32>) -> Self {
76 let len = data.shape()[0];
77 let targets = Tensor::from_vec(vec![0.0; len], &[len]).unwrap();
78 Self { data, targets, len }
79 }
80}
81
82impl Dataset for TensorDataset {
83 type Item = (Tensor<f32>, Tensor<f32>);
84
85 fn len(&self) -> usize {
86 self.len
87 }
88
89 fn get(&self, index: usize) -> Option<Self::Item> {
90 if index >= self.len {
91 return None;
92 }
93
94 let data_shape = self.data.shape();
96 let row_size: usize = data_shape[1..].iter().product();
97 let data_vec = self.data.to_vec();
98 let start = index * row_size;
99 let end = start + row_size;
100 let item_data = data_vec[start..end].to_vec();
101 let item_shape: Vec<usize> = data_shape[1..].to_vec();
102 let x = Tensor::from_vec(item_data, &item_shape).unwrap();
103
104 let target_shape = self.targets.shape();
106 let target_row_size: usize = if target_shape.len() > 1 {
107 target_shape[1..].iter().product()
108 } else {
109 1
110 };
111 let target_vec = self.targets.to_vec();
112 let target_start = index * target_row_size;
113 let target_end = target_start + target_row_size;
114 let item_target = target_vec[target_start..target_end].to_vec();
115 let target_item_shape: Vec<usize> = if target_shape.len() > 1 {
116 target_shape[1..].to_vec()
117 } else {
118 vec![1]
119 };
120 let y = Tensor::from_vec(item_target, &target_item_shape).unwrap();
121
122 Some((x, y))
123 }
124}
125
126pub struct MapDataset<D, F>
132where
133 D: Dataset,
134 F: Fn(D::Item) -> D::Item + Send + Sync,
135{
136 dataset: D,
137 transform: F,
138}
139
140impl<D, F> MapDataset<D, F>
141where
142 D: Dataset,
143 F: Fn(D::Item) -> D::Item + Send + Sync,
144{
145 pub fn new(dataset: D, transform: F) -> Self {
147 Self { dataset, transform }
148 }
149}
150
151impl<D, F> Dataset for MapDataset<D, F>
152where
153 D: Dataset,
154 F: Fn(D::Item) -> D::Item + Send + Sync,
155{
156 type Item = D::Item;
157
158 fn len(&self) -> usize {
159 self.dataset.len()
160 }
161
162 fn get(&self, index: usize) -> Option<Self::Item> {
163 self.dataset.get(index).map(&self.transform)
164 }
165}
166
167pub struct ConcatDataset<D: Dataset> {
173 datasets: Vec<D>,
174 cumulative_sizes: Vec<usize>,
175}
176
177impl<D: Dataset> ConcatDataset<D> {
178 #[must_use]
180 pub fn new(datasets: Vec<D>) -> Self {
181 let mut cumulative_sizes = Vec::with_capacity(datasets.len());
182 let mut total = 0;
183 for d in &datasets {
184 total += d.len();
185 cumulative_sizes.push(total);
186 }
187 Self {
188 datasets,
189 cumulative_sizes,
190 }
191 }
192
193 fn find_dataset(&self, index: usize) -> Option<(usize, usize)> {
195 if index >= self.len() {
196 return None;
197 }
198
199 for (i, &cum_size) in self.cumulative_sizes.iter().enumerate() {
200 if index < cum_size {
201 let prev_size = if i == 0 {
202 0
203 } else {
204 self.cumulative_sizes[i - 1]
205 };
206 return Some((i, index - prev_size));
207 }
208 }
209 None
210 }
211}
212
213impl<D: Dataset> Dataset for ConcatDataset<D> {
214 type Item = D::Item;
215
216 fn len(&self) -> usize {
217 *self.cumulative_sizes.last().unwrap_or(&0)
218 }
219
220 fn get(&self, index: usize) -> Option<Self::Item> {
221 let (dataset_idx, local_idx) = self.find_dataset(index)?;
222 self.datasets[dataset_idx].get(local_idx)
223 }
224}
225
226pub struct SubsetDataset<D: Dataset> {
232 dataset: D,
233 indices: Vec<usize>,
234}
235
236impl<D: Dataset> SubsetDataset<D> {
237 pub fn new(dataset: D, indices: Vec<usize>) -> Self {
239 Self { dataset, indices }
240 }
241
242 pub fn random_split(dataset: D, lengths: &[usize]) -> Vec<Self>
244 where
245 D: Clone,
246 {
247 use rand::seq::SliceRandom;
248 use rand::thread_rng;
249
250 let total_len: usize = lengths.iter().sum();
251 assert_eq!(
252 total_len,
253 dataset.len(),
254 "Split lengths must sum to dataset length"
255 );
256
257 let mut indices: Vec<usize> = (0..dataset.len()).collect();
258 indices.shuffle(&mut thread_rng());
259
260 let mut subsets = Vec::with_capacity(lengths.len());
261 let mut offset = 0;
262 for &len in lengths {
263 let subset_indices = indices[offset..offset + len].to_vec();
264 subsets.push(Self::new(dataset.clone(), subset_indices));
265 offset += len;
266 }
267 subsets
268 }
269}
270
271impl<D: Dataset> Dataset for SubsetDataset<D> {
272 type Item = D::Item;
273
274 fn len(&self) -> usize {
275 self.indices.len()
276 }
277
278 fn get(&self, index: usize) -> Option<Self::Item> {
279 let real_index = *self.indices.get(index)?;
280 self.dataset.get(real_index)
281 }
282}
283
284pub struct InMemoryDataset<T: Clone + Send> {
290 items: Vec<T>,
291}
292
293impl<T: Clone + Send> InMemoryDataset<T> {
294 #[must_use]
296 pub fn new(items: Vec<T>) -> Self {
297 Self { items }
298 }
299}
300
301impl<T: Clone + Send + Sync> Dataset for InMemoryDataset<T> {
302 type Item = T;
303
304 fn len(&self) -> usize {
305 self.items.len()
306 }
307
308 fn get(&self, index: usize) -> Option<Self::Item> {
309 self.items.get(index).cloned()
310 }
311}
312
313#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_tensor_dataset() {
323 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
324 let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap();
325 let dataset = TensorDataset::new(data, targets);
326
327 assert_eq!(dataset.len(), 3);
328
329 let (x, y) = dataset.get(0).unwrap();
330 assert_eq!(x.to_vec(), vec![1.0, 2.0]);
331 assert_eq!(y.to_vec(), vec![0.0]);
332
333 let (x, y) = dataset.get(2).unwrap();
334 assert_eq!(x.to_vec(), vec![5.0, 6.0]);
335 assert_eq!(y.to_vec(), vec![2.0]);
336
337 assert!(dataset.get(3).is_none());
338 }
339
340 #[test]
341 fn test_map_dataset() {
342 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4, 1]).unwrap();
343 let targets = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4]).unwrap();
344 let base = TensorDataset::new(data, targets);
345
346 let mapped = MapDataset::new(base, |(x, y)| (x.mul_scalar(2.0), y));
347
348 assert_eq!(mapped.len(), 4);
349 let (x, _) = mapped.get(0).unwrap();
350 assert_eq!(x.to_vec(), vec![2.0]);
351 }
352
353 #[test]
354 fn test_concat_dataset() {
355 let data1 = Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap();
356 let targets1 = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
357 let ds1 = TensorDataset::new(data1, targets1);
358
359 let data2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3, 1]).unwrap();
360 let targets2 = Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap();
361 let ds2 = TensorDataset::new(data2, targets2);
362
363 let concat = ConcatDataset::new(vec![ds1, ds2]);
364
365 assert_eq!(concat.len(), 5);
366
367 let (x, y) = concat.get(0).unwrap();
368 assert_eq!(x.to_vec(), vec![1.0]);
369 assert_eq!(y.to_vec(), vec![0.0]);
370
371 let (x, y) = concat.get(3).unwrap();
372 assert_eq!(x.to_vec(), vec![4.0]);
373 assert_eq!(y.to_vec(), vec![3.0]);
374 }
375
376 #[test]
377 fn test_subset_dataset() {
378 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5, 1]).unwrap();
379 let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5]).unwrap();
380 let base = TensorDataset::new(data, targets);
381
382 let subset = SubsetDataset::new(base, vec![0, 2, 4]);
383
384 assert_eq!(subset.len(), 3);
385
386 let (x, _) = subset.get(0).unwrap();
387 assert_eq!(x.to_vec(), vec![1.0]);
388
389 let (x, _) = subset.get(1).unwrap();
390 assert_eq!(x.to_vec(), vec![3.0]);
391
392 let (x, _) = subset.get(2).unwrap();
393 assert_eq!(x.to_vec(), vec![5.0]);
394 }
395
396 #[test]
397 fn test_in_memory_dataset() {
398 let dataset = InMemoryDataset::new(vec![1, 2, 3, 4, 5]);
399
400 assert_eq!(dataset.len(), 5);
401 assert_eq!(dataset.get(0), Some(1));
402 assert_eq!(dataset.get(4), Some(5));
403 assert_eq!(dataset.get(5), None);
404 }
405}