1use axonml_tensor::Tensor;
9
10pub trait Dataset: Send + Sync {
18 type Item: Send;
20
21 fn len(&self) -> usize;
23
24 fn is_empty(&self) -> bool {
26 self.len() == 0
27 }
28
29 fn get(&self, index: usize) -> Option<Self::Item>;
31}
32
33pub struct TensorDataset {
41 data: Tensor<f32>,
43 targets: Tensor<f32>,
45 len: usize,
47}
48
49impl TensorDataset {
50 #[must_use] pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
54 let len = data.shape()[0];
55 assert_eq!(
56 len,
57 targets.shape()[0],
58 "Data and targets must have same first dimension"
59 );
60 Self { data, targets, len }
61 }
62
63 #[must_use] pub fn from_data(data: Tensor<f32>) -> Self {
65 let len = data.shape()[0];
66 let targets = Tensor::from_vec(vec![0.0; len], &[len]).unwrap();
67 Self { data, targets, len }
68 }
69}
70
71impl Dataset for TensorDataset {
72 type Item = (Tensor<f32>, Tensor<f32>);
73
74 fn len(&self) -> usize {
75 self.len
76 }
77
78 fn get(&self, index: usize) -> Option<Self::Item> {
79 if index >= self.len {
80 return None;
81 }
82
83 let data_shape = self.data.shape();
85 let row_size: usize = data_shape[1..].iter().product();
86 let data_vec = self.data.to_vec();
87 let start = index * row_size;
88 let end = start + row_size;
89 let item_data = data_vec[start..end].to_vec();
90 let item_shape: Vec<usize> = data_shape[1..].to_vec();
91 let x = Tensor::from_vec(item_data, &item_shape).unwrap();
92
93 let target_shape = self.targets.shape();
95 let target_row_size: usize = if target_shape.len() > 1 {
96 target_shape[1..].iter().product()
97 } else {
98 1
99 };
100 let target_vec = self.targets.to_vec();
101 let target_start = index * target_row_size;
102 let target_end = target_start + target_row_size;
103 let item_target = target_vec[target_start..target_end].to_vec();
104 let target_item_shape: Vec<usize> = if target_shape.len() > 1 {
105 target_shape[1..].to_vec()
106 } else {
107 vec![1]
108 };
109 let y = Tensor::from_vec(item_target, &target_item_shape).unwrap();
110
111 Some((x, y))
112 }
113}
114
115pub struct MapDataset<D, F>
121where
122 D: Dataset,
123 F: Fn(D::Item) -> D::Item + Send + Sync,
124{
125 dataset: D,
126 transform: F,
127}
128
129impl<D, F> MapDataset<D, F>
130where
131 D: Dataset,
132 F: Fn(D::Item) -> D::Item + Send + Sync,
133{
134 pub fn new(dataset: D, transform: F) -> Self {
136 Self { dataset, transform }
137 }
138}
139
140impl<D, F> Dataset for MapDataset<D, F>
141where
142 D: Dataset,
143 F: Fn(D::Item) -> D::Item + Send + Sync,
144{
145 type Item = D::Item;
146
147 fn len(&self) -> usize {
148 self.dataset.len()
149 }
150
151 fn get(&self, index: usize) -> Option<Self::Item> {
152 self.dataset.get(index).map(&self.transform)
153 }
154}
155
156pub struct ConcatDataset<D: Dataset> {
162 datasets: Vec<D>,
163 cumulative_sizes: Vec<usize>,
164}
165
166impl<D: Dataset> ConcatDataset<D> {
167 #[must_use] pub fn new(datasets: Vec<D>) -> Self {
169 let mut cumulative_sizes = Vec::with_capacity(datasets.len());
170 let mut total = 0;
171 for d in &datasets {
172 total += d.len();
173 cumulative_sizes.push(total);
174 }
175 Self {
176 datasets,
177 cumulative_sizes,
178 }
179 }
180
181 fn find_dataset(&self, index: usize) -> Option<(usize, usize)> {
183 if index >= self.len() {
184 return None;
185 }
186
187 for (i, &cum_size) in self.cumulative_sizes.iter().enumerate() {
188 if index < cum_size {
189 let prev_size = if i == 0 {
190 0
191 } else {
192 self.cumulative_sizes[i - 1]
193 };
194 return Some((i, index - prev_size));
195 }
196 }
197 None
198 }
199}
200
201impl<D: Dataset> Dataset for ConcatDataset<D> {
202 type Item = D::Item;
203
204 fn len(&self) -> usize {
205 *self.cumulative_sizes.last().unwrap_or(&0)
206 }
207
208 fn get(&self, index: usize) -> Option<Self::Item> {
209 let (dataset_idx, local_idx) = self.find_dataset(index)?;
210 self.datasets[dataset_idx].get(local_idx)
211 }
212}
213
214pub struct SubsetDataset<D: Dataset> {
220 dataset: D,
221 indices: Vec<usize>,
222}
223
224impl<D: Dataset> SubsetDataset<D> {
225 pub fn new(dataset: D, indices: Vec<usize>) -> Self {
227 Self { dataset, indices }
228 }
229
230 pub fn random_split(dataset: D, lengths: &[usize]) -> Vec<Self>
232 where
233 D: Clone,
234 {
235 use rand::seq::SliceRandom;
236 use rand::thread_rng;
237
238 let total_len: usize = lengths.iter().sum();
239 assert_eq!(
240 total_len,
241 dataset.len(),
242 "Split lengths must sum to dataset length"
243 );
244
245 let mut indices: Vec<usize> = (0..dataset.len()).collect();
246 indices.shuffle(&mut thread_rng());
247
248 let mut subsets = Vec::with_capacity(lengths.len());
249 let mut offset = 0;
250 for &len in lengths {
251 let subset_indices = indices[offset..offset + len].to_vec();
252 subsets.push(Self::new(dataset.clone(), subset_indices));
253 offset += len;
254 }
255 subsets
256 }
257}
258
259impl<D: Dataset> Dataset for SubsetDataset<D> {
260 type Item = D::Item;
261
262 fn len(&self) -> usize {
263 self.indices.len()
264 }
265
266 fn get(&self, index: usize) -> Option<Self::Item> {
267 let real_index = *self.indices.get(index)?;
268 self.dataset.get(real_index)
269 }
270}
271
272pub struct InMemoryDataset<T: Clone + Send> {
278 items: Vec<T>,
279}
280
281impl<T: Clone + Send> InMemoryDataset<T> {
282 #[must_use] pub fn new(items: Vec<T>) -> Self {
284 Self { items }
285 }
286}
287
288impl<T: Clone + Send + Sync> Dataset for InMemoryDataset<T> {
289 type Item = T;
290
291 fn len(&self) -> usize {
292 self.items.len()
293 }
294
295 fn get(&self, index: usize) -> Option<Self::Item> {
296 self.items.get(index).cloned()
297 }
298}
299
300#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_tensor_dataset() {
310 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
311 let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap();
312 let dataset = TensorDataset::new(data, targets);
313
314 assert_eq!(dataset.len(), 3);
315
316 let (x, y) = dataset.get(0).unwrap();
317 assert_eq!(x.to_vec(), vec![1.0, 2.0]);
318 assert_eq!(y.to_vec(), vec![0.0]);
319
320 let (x, y) = dataset.get(2).unwrap();
321 assert_eq!(x.to_vec(), vec![5.0, 6.0]);
322 assert_eq!(y.to_vec(), vec![2.0]);
323
324 assert!(dataset.get(3).is_none());
325 }
326
327 #[test]
328 fn test_map_dataset() {
329 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4, 1]).unwrap();
330 let targets = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4]).unwrap();
331 let base = TensorDataset::new(data, targets);
332
333 let mapped = MapDataset::new(base, |(x, y)| (x.mul_scalar(2.0), y));
334
335 assert_eq!(mapped.len(), 4);
336 let (x, _) = mapped.get(0).unwrap();
337 assert_eq!(x.to_vec(), vec![2.0]);
338 }
339
340 #[test]
341 fn test_concat_dataset() {
342 let data1 = Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap();
343 let targets1 = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
344 let ds1 = TensorDataset::new(data1, targets1);
345
346 let data2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3, 1]).unwrap();
347 let targets2 = Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap();
348 let ds2 = TensorDataset::new(data2, targets2);
349
350 let concat = ConcatDataset::new(vec![ds1, ds2]);
351
352 assert_eq!(concat.len(), 5);
353
354 let (x, y) = concat.get(0).unwrap();
355 assert_eq!(x.to_vec(), vec![1.0]);
356 assert_eq!(y.to_vec(), vec![0.0]);
357
358 let (x, y) = concat.get(3).unwrap();
359 assert_eq!(x.to_vec(), vec![4.0]);
360 assert_eq!(y.to_vec(), vec![3.0]);
361 }
362
363 #[test]
364 fn test_subset_dataset() {
365 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5, 1]).unwrap();
366 let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5]).unwrap();
367 let base = TensorDataset::new(data, targets);
368
369 let subset = SubsetDataset::new(base, vec![0, 2, 4]);
370
371 assert_eq!(subset.len(), 3);
372
373 let (x, _) = subset.get(0).unwrap();
374 assert_eq!(x.to_vec(), vec![1.0]);
375
376 let (x, _) = subset.get(1).unwrap();
377 assert_eq!(x.to_vec(), vec![3.0]);
378
379 let (x, _) = subset.get(2).unwrap();
380 assert_eq!(x.to_vec(), vec![5.0]);
381 }
382
383 #[test]
384 fn test_in_memory_dataset() {
385 let dataset = InMemoryDataset::new(vec![1, 2, 3, 4, 5]);
386
387 assert_eq!(dataset.len(), 5);
388 assert_eq!(dataset.get(0), Some(1));
389 assert_eq!(dataset.get(4), Some(5));
390 assert_eq!(dataset.get(5), None);
391 }
392}