1use axonml_tensor::Tensor;
18
19pub trait Dataset: Send + Sync {
27 type Item: Send;
29
30 fn len(&self) -> usize;
32
33 fn is_empty(&self) -> bool {
35 self.len() == 0
36 }
37
38 fn get(&self, index: usize) -> Option<Self::Item>;
40}
41
42#[derive(Clone)]
51pub struct TensorDataset {
52 data_vec: Vec<f32>,
54 target_vec: Vec<f32>,
56 data_shape: Vec<usize>,
58 target_shape: Vec<usize>,
60 row_size: usize,
62 target_row_size: usize,
64 len: usize,
66}
67
68impl TensorDataset {
69 #[must_use]
73 pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
74 let data_shape = data.shape().to_vec();
75 let target_shape = targets.shape().to_vec();
76 let len = data_shape[0];
77 assert_eq!(
78 len, target_shape[0],
79 "Data and targets must have same first dimension"
80 );
81
82 let row_size: usize = data_shape[1..].iter().product();
83 let target_row_size: usize = if target_shape.len() > 1 {
84 target_shape[1..].iter().product()
85 } else {
86 1
87 };
88
89 Self {
90 data_vec: data.to_vec(),
91 target_vec: targets.to_vec(),
92 data_shape,
93 target_shape,
94 row_size,
95 target_row_size,
96 len,
97 }
98 }
99
100 #[must_use]
102 pub fn from_data(data: Tensor<f32>) -> Self {
103 let len = data.shape()[0];
104 let targets = Tensor::from_vec(vec![0.0; len], &[len]).unwrap();
105 Self::new(data, targets)
106 }
107}
108
109impl Dataset for TensorDataset {
110 type Item = (Tensor<f32>, Tensor<f32>);
111
112 fn len(&self) -> usize {
113 self.len
114 }
115
116 fn get(&self, index: usize) -> Option<Self::Item> {
117 if index >= self.len {
118 return None;
119 }
120
121 let start = index * self.row_size;
123 let end = start + self.row_size;
124 let item_data = self.data_vec[start..end].to_vec();
125 let item_shape: Vec<usize> = self.data_shape[1..].to_vec();
126 let x = Tensor::from_vec(item_data, &item_shape).unwrap();
127
128 let target_start = index * self.target_row_size;
129 let target_end = target_start + self.target_row_size;
130 let item_target = self.target_vec[target_start..target_end].to_vec();
131 let target_item_shape: Vec<usize> = if self.target_shape.len() > 1 {
132 self.target_shape[1..].to_vec()
133 } else {
134 vec![1]
135 };
136 let y = Tensor::from_vec(item_target, &target_item_shape).unwrap();
137
138 Some((x, y))
139 }
140}
141
142pub struct MapDataset<D, F>
148where
149 D: Dataset,
150 F: Fn(D::Item) -> D::Item + Send + Sync,
151{
152 dataset: D,
153 transform: F,
154}
155
156impl<D, F> MapDataset<D, F>
157where
158 D: Dataset,
159 F: Fn(D::Item) -> D::Item + Send + Sync,
160{
161 pub fn new(dataset: D, transform: F) -> Self {
163 Self { dataset, transform }
164 }
165}
166
167impl<D, F> Dataset for MapDataset<D, F>
168where
169 D: Dataset,
170 F: Fn(D::Item) -> D::Item + Send + Sync,
171{
172 type Item = D::Item;
173
174 fn len(&self) -> usize {
175 self.dataset.len()
176 }
177
178 fn get(&self, index: usize) -> Option<Self::Item> {
179 self.dataset.get(index).map(&self.transform)
180 }
181}
182
183pub struct ConcatDataset<D: Dataset> {
189 datasets: Vec<D>,
190 cumulative_sizes: Vec<usize>,
191}
192
193impl<D: Dataset> ConcatDataset<D> {
194 #[must_use]
196 pub fn new(datasets: Vec<D>) -> Self {
197 let mut cumulative_sizes = Vec::with_capacity(datasets.len());
198 let mut total = 0;
199 for d in &datasets {
200 total += d.len();
201 cumulative_sizes.push(total);
202 }
203 Self {
204 datasets,
205 cumulative_sizes,
206 }
207 }
208
209 fn find_dataset(&self, index: usize) -> Option<(usize, usize)> {
211 if index >= self.len() {
212 return None;
213 }
214
215 for (i, &cum_size) in self.cumulative_sizes.iter().enumerate() {
216 if index < cum_size {
217 let prev_size = if i == 0 {
218 0
219 } else {
220 self.cumulative_sizes[i - 1]
221 };
222 return Some((i, index - prev_size));
223 }
224 }
225 None
226 }
227}
228
229impl<D: Dataset> Dataset for ConcatDataset<D> {
230 type Item = D::Item;
231
232 fn len(&self) -> usize {
233 *self.cumulative_sizes.last().unwrap_or(&0)
234 }
235
236 fn get(&self, index: usize) -> Option<Self::Item> {
237 let (dataset_idx, local_idx) = self.find_dataset(index)?;
238 self.datasets[dataset_idx].get(local_idx)
239 }
240}
241
242pub struct SubsetDataset<D: Dataset> {
248 dataset: D,
249 indices: Vec<usize>,
250}
251
252impl<D: Dataset> SubsetDataset<D> {
253 pub fn new(dataset: D, indices: Vec<usize>) -> Self {
255 Self { dataset, indices }
256 }
257
258 pub fn random_split(dataset: D, lengths: &[usize]) -> Vec<Self>
260 where
261 D: Clone,
262 {
263 use rand::seq::SliceRandom;
264 use rand::thread_rng;
265
266 let total_len: usize = lengths.iter().sum();
267 assert_eq!(
268 total_len,
269 dataset.len(),
270 "Split lengths must sum to dataset length"
271 );
272
273 let mut indices: Vec<usize> = (0..dataset.len()).collect();
274 indices.shuffle(&mut thread_rng());
275
276 let mut subsets = Vec::with_capacity(lengths.len());
277 let mut offset = 0;
278 for &len in lengths {
279 let subset_indices = indices[offset..offset + len].to_vec();
280 subsets.push(Self::new(dataset.clone(), subset_indices));
281 offset += len;
282 }
283 subsets
284 }
285}
286
287impl<D: Dataset> Dataset for SubsetDataset<D> {
288 type Item = D::Item;
289
290 fn len(&self) -> usize {
291 self.indices.len()
292 }
293
294 fn get(&self, index: usize) -> Option<Self::Item> {
295 let real_index = *self.indices.get(index)?;
296 self.dataset.get(real_index)
297 }
298}
299
300pub struct InMemoryDataset<T: Clone + Send> {
306 items: Vec<T>,
307}
308
309impl<T: Clone + Send> InMemoryDataset<T> {
310 #[must_use]
312 pub fn new(items: Vec<T>) -> Self {
313 Self { items }
314 }
315}
316
317impl<T: Clone + Send + Sync> Dataset for InMemoryDataset<T> {
318 type Item = T;
319
320 fn len(&self) -> usize {
321 self.items.len()
322 }
323
324 fn get(&self, index: usize) -> Option<Self::Item> {
325 self.items.get(index).cloned()
326 }
327}
328
329#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_tensor_dataset() {
339 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
340 let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap();
341 let dataset = TensorDataset::new(data, targets);
342
343 assert_eq!(dataset.len(), 3);
344
345 let (x, y) = dataset.get(0).unwrap();
346 assert_eq!(x.to_vec(), vec![1.0, 2.0]);
347 assert_eq!(y.to_vec(), vec![0.0]);
348
349 let (x, y) = dataset.get(2).unwrap();
350 assert_eq!(x.to_vec(), vec![5.0, 6.0]);
351 assert_eq!(y.to_vec(), vec![2.0]);
352
353 assert!(dataset.get(3).is_none());
354 }
355
356 #[test]
357 fn test_map_dataset() {
358 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4, 1]).unwrap();
359 let targets = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4]).unwrap();
360 let base = TensorDataset::new(data, targets);
361
362 let mapped = MapDataset::new(base, |(x, y)| (x.mul_scalar(2.0), y));
363
364 assert_eq!(mapped.len(), 4);
365 let (x, _) = mapped.get(0).unwrap();
366 assert_eq!(x.to_vec(), vec![2.0]);
367 }
368
369 #[test]
370 fn test_concat_dataset() {
371 let data1 = Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap();
372 let targets1 = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
373 let ds1 = TensorDataset::new(data1, targets1);
374
375 let data2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3, 1]).unwrap();
376 let targets2 = Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap();
377 let ds2 = TensorDataset::new(data2, targets2);
378
379 let concat = ConcatDataset::new(vec![ds1, ds2]);
380
381 assert_eq!(concat.len(), 5);
382
383 let (x, y) = concat.get(0).unwrap();
384 assert_eq!(x.to_vec(), vec![1.0]);
385 assert_eq!(y.to_vec(), vec![0.0]);
386
387 let (x, y) = concat.get(3).unwrap();
388 assert_eq!(x.to_vec(), vec![4.0]);
389 assert_eq!(y.to_vec(), vec![3.0]);
390 }
391
392 #[test]
393 fn test_subset_dataset() {
394 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5, 1]).unwrap();
395 let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5]).unwrap();
396 let base = TensorDataset::new(data, targets);
397
398 let subset = SubsetDataset::new(base, vec![0, 2, 4]);
399
400 assert_eq!(subset.len(), 3);
401
402 let (x, _) = subset.get(0).unwrap();
403 assert_eq!(x.to_vec(), vec![1.0]);
404
405 let (x, _) = subset.get(1).unwrap();
406 assert_eq!(x.to_vec(), vec![3.0]);
407
408 let (x, _) = subset.get(2).unwrap();
409 assert_eq!(x.to_vec(), vec![5.0]);
410 }
411
412 #[test]
413 fn test_in_memory_dataset() {
414 let dataset = InMemoryDataset::new(vec![1, 2, 3, 4, 5]);
415
416 assert_eq!(dataset.len(), 5);
417 assert_eq!(dataset.get(0), Some(1));
418 assert_eq!(dataset.get(4), Some(5));
419 assert_eq!(dataset.get(5), None);
420 }
421}