1use axonml_tensor::Tensor;
9
10pub trait Dataset: Send + Sync {
18 type Item: Send;
20
21 fn len(&self) -> usize;
23
24 fn is_empty(&self) -> bool {
26 self.len() == 0
27 }
28
29 fn get(&self, index: usize) -> Option<Self::Item>;
31}
32
33pub struct TensorDataset {
41 data: Tensor<f32>,
43 targets: Tensor<f32>,
45 len: usize,
47}
48
49impl TensorDataset {
50 #[must_use]
54 pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
55 let len = data.shape()[0];
56 assert_eq!(
57 len,
58 targets.shape()[0],
59 "Data and targets must have same first dimension"
60 );
61 Self { data, targets, len }
62 }
63
64 #[must_use]
66 pub fn from_data(data: Tensor<f32>) -> Self {
67 let len = data.shape()[0];
68 let targets = Tensor::from_vec(vec![0.0; len], &[len]).unwrap();
69 Self { data, targets, len }
70 }
71}
72
73impl Dataset for TensorDataset {
74 type Item = (Tensor<f32>, Tensor<f32>);
75
76 fn len(&self) -> usize {
77 self.len
78 }
79
80 fn get(&self, index: usize) -> Option<Self::Item> {
81 if index >= self.len {
82 return None;
83 }
84
85 let data_shape = self.data.shape();
87 let row_size: usize = data_shape[1..].iter().product();
88 let data_vec = self.data.to_vec();
89 let start = index * row_size;
90 let end = start + row_size;
91 let item_data = data_vec[start..end].to_vec();
92 let item_shape: Vec<usize> = data_shape[1..].to_vec();
93 let x = Tensor::from_vec(item_data, &item_shape).unwrap();
94
95 let target_shape = self.targets.shape();
97 let target_row_size: usize = if target_shape.len() > 1 {
98 target_shape[1..].iter().product()
99 } else {
100 1
101 };
102 let target_vec = self.targets.to_vec();
103 let target_start = index * target_row_size;
104 let target_end = target_start + target_row_size;
105 let item_target = target_vec[target_start..target_end].to_vec();
106 let target_item_shape: Vec<usize> = if target_shape.len() > 1 {
107 target_shape[1..].to_vec()
108 } else {
109 vec![1]
110 };
111 let y = Tensor::from_vec(item_target, &target_item_shape).unwrap();
112
113 Some((x, y))
114 }
115}
116
117pub struct MapDataset<D, F>
123where
124 D: Dataset,
125 F: Fn(D::Item) -> D::Item + Send + Sync,
126{
127 dataset: D,
128 transform: F,
129}
130
131impl<D, F> MapDataset<D, F>
132where
133 D: Dataset,
134 F: Fn(D::Item) -> D::Item + Send + Sync,
135{
136 pub fn new(dataset: D, transform: F) -> Self {
138 Self { dataset, transform }
139 }
140}
141
142impl<D, F> Dataset for MapDataset<D, F>
143where
144 D: Dataset,
145 F: Fn(D::Item) -> D::Item + Send + Sync,
146{
147 type Item = D::Item;
148
149 fn len(&self) -> usize {
150 self.dataset.len()
151 }
152
153 fn get(&self, index: usize) -> Option<Self::Item> {
154 self.dataset.get(index).map(&self.transform)
155 }
156}
157
158pub struct ConcatDataset<D: Dataset> {
164 datasets: Vec<D>,
165 cumulative_sizes: Vec<usize>,
166}
167
168impl<D: Dataset> ConcatDataset<D> {
169 #[must_use]
171 pub fn new(datasets: Vec<D>) -> Self {
172 let mut cumulative_sizes = Vec::with_capacity(datasets.len());
173 let mut total = 0;
174 for d in &datasets {
175 total += d.len();
176 cumulative_sizes.push(total);
177 }
178 Self {
179 datasets,
180 cumulative_sizes,
181 }
182 }
183
184 fn find_dataset(&self, index: usize) -> Option<(usize, usize)> {
186 if index >= self.len() {
187 return None;
188 }
189
190 for (i, &cum_size) in self.cumulative_sizes.iter().enumerate() {
191 if index < cum_size {
192 let prev_size = if i == 0 {
193 0
194 } else {
195 self.cumulative_sizes[i - 1]
196 };
197 return Some((i, index - prev_size));
198 }
199 }
200 None
201 }
202}
203
204impl<D: Dataset> Dataset for ConcatDataset<D> {
205 type Item = D::Item;
206
207 fn len(&self) -> usize {
208 *self.cumulative_sizes.last().unwrap_or(&0)
209 }
210
211 fn get(&self, index: usize) -> Option<Self::Item> {
212 let (dataset_idx, local_idx) = self.find_dataset(index)?;
213 self.datasets[dataset_idx].get(local_idx)
214 }
215}
216
217pub struct SubsetDataset<D: Dataset> {
223 dataset: D,
224 indices: Vec<usize>,
225}
226
227impl<D: Dataset> SubsetDataset<D> {
228 pub fn new(dataset: D, indices: Vec<usize>) -> Self {
230 Self { dataset, indices }
231 }
232
233 pub fn random_split(dataset: D, lengths: &[usize]) -> Vec<Self>
235 where
236 D: Clone,
237 {
238 use rand::seq::SliceRandom;
239 use rand::thread_rng;
240
241 let total_len: usize = lengths.iter().sum();
242 assert_eq!(
243 total_len,
244 dataset.len(),
245 "Split lengths must sum to dataset length"
246 );
247
248 let mut indices: Vec<usize> = (0..dataset.len()).collect();
249 indices.shuffle(&mut thread_rng());
250
251 let mut subsets = Vec::with_capacity(lengths.len());
252 let mut offset = 0;
253 for &len in lengths {
254 let subset_indices = indices[offset..offset + len].to_vec();
255 subsets.push(Self::new(dataset.clone(), subset_indices));
256 offset += len;
257 }
258 subsets
259 }
260}
261
262impl<D: Dataset> Dataset for SubsetDataset<D> {
263 type Item = D::Item;
264
265 fn len(&self) -> usize {
266 self.indices.len()
267 }
268
269 fn get(&self, index: usize) -> Option<Self::Item> {
270 let real_index = *self.indices.get(index)?;
271 self.dataset.get(real_index)
272 }
273}
274
275pub struct InMemoryDataset<T: Clone + Send> {
281 items: Vec<T>,
282}
283
284impl<T: Clone + Send> InMemoryDataset<T> {
285 #[must_use]
287 pub fn new(items: Vec<T>) -> Self {
288 Self { items }
289 }
290}
291
292impl<T: Clone + Send + Sync> Dataset for InMemoryDataset<T> {
293 type Item = T;
294
295 fn len(&self) -> usize {
296 self.items.len()
297 }
298
299 fn get(&self, index: usize) -> Option<Self::Item> {
300 self.items.get(index).cloned()
301 }
302}
303
304#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_tensor_dataset() {
314 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
315 let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap();
316 let dataset = TensorDataset::new(data, targets);
317
318 assert_eq!(dataset.len(), 3);
319
320 let (x, y) = dataset.get(0).unwrap();
321 assert_eq!(x.to_vec(), vec![1.0, 2.0]);
322 assert_eq!(y.to_vec(), vec![0.0]);
323
324 let (x, y) = dataset.get(2).unwrap();
325 assert_eq!(x.to_vec(), vec![5.0, 6.0]);
326 assert_eq!(y.to_vec(), vec![2.0]);
327
328 assert!(dataset.get(3).is_none());
329 }
330
331 #[test]
332 fn test_map_dataset() {
333 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4, 1]).unwrap();
334 let targets = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4]).unwrap();
335 let base = TensorDataset::new(data, targets);
336
337 let mapped = MapDataset::new(base, |(x, y)| (x.mul_scalar(2.0), y));
338
339 assert_eq!(mapped.len(), 4);
340 let (x, _) = mapped.get(0).unwrap();
341 assert_eq!(x.to_vec(), vec![2.0]);
342 }
343
344 #[test]
345 fn test_concat_dataset() {
346 let data1 = Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap();
347 let targets1 = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
348 let ds1 = TensorDataset::new(data1, targets1);
349
350 let data2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3, 1]).unwrap();
351 let targets2 = Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap();
352 let ds2 = TensorDataset::new(data2, targets2);
353
354 let concat = ConcatDataset::new(vec![ds1, ds2]);
355
356 assert_eq!(concat.len(), 5);
357
358 let (x, y) = concat.get(0).unwrap();
359 assert_eq!(x.to_vec(), vec![1.0]);
360 assert_eq!(y.to_vec(), vec![0.0]);
361
362 let (x, y) = concat.get(3).unwrap();
363 assert_eq!(x.to_vec(), vec![4.0]);
364 assert_eq!(y.to_vec(), vec![3.0]);
365 }
366
367 #[test]
368 fn test_subset_dataset() {
369 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5, 1]).unwrap();
370 let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5]).unwrap();
371 let base = TensorDataset::new(data, targets);
372
373 let subset = SubsetDataset::new(base, vec![0, 2, 4]);
374
375 assert_eq!(subset.len(), 3);
376
377 let (x, _) = subset.get(0).unwrap();
378 assert_eq!(x.to_vec(), vec![1.0]);
379
380 let (x, _) = subset.get(1).unwrap();
381 assert_eq!(x.to_vec(), vec![3.0]);
382
383 let (x, _) = subset.get(2).unwrap();
384 assert_eq!(x.to_vec(), vec![5.0]);
385 }
386
387 #[test]
388 fn test_in_memory_dataset() {
389 let dataset = InMemoryDataset::new(vec![1, 2, 3, 4, 5]);
390
391 assert_eq!(dataset.len(), 5);
392 assert_eq!(dataset.get(0), Some(1));
393 assert_eq!(dataset.get(4), Some(5));
394 assert_eq!(dataset.get(5), None);
395 }
396}