1use axonml_tensor::Tensor;
19
20pub trait Dataset: Send + Sync {
28 type Item: Send;
30
31 fn len(&self) -> usize;
33
34 fn is_empty(&self) -> bool {
36 self.len() == 0
37 }
38
39 fn get(&self, index: usize) -> Option<Self::Item>;
41}
42
43#[derive(Clone)]
52pub struct TensorDataset {
53 data_vec: Vec<f32>,
55 target_vec: Vec<f32>,
57 data_shape: Vec<usize>,
59 target_shape: Vec<usize>,
61 row_size: usize,
63 target_row_size: usize,
65 len: usize,
67}
68
69impl TensorDataset {
70 #[must_use]
74 pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
75 let data_shape = data.shape().to_vec();
76 let target_shape = targets.shape().to_vec();
77 let len = data_shape[0];
78 assert_eq!(
79 len, target_shape[0],
80 "Data and targets must have same first dimension"
81 );
82
83 let row_size: usize = data_shape[1..].iter().product();
84 let target_row_size: usize = if target_shape.len() > 1 {
85 target_shape[1..].iter().product()
86 } else {
87 1
88 };
89
90 Self {
91 data_vec: data.to_vec(),
92 target_vec: targets.to_vec(),
93 data_shape,
94 target_shape,
95 row_size,
96 target_row_size,
97 len,
98 }
99 }
100
101 #[must_use]
103 pub fn from_data(data: Tensor<f32>) -> Self {
104 let len = data.shape()[0];
105 let targets = Tensor::from_vec(vec![0.0; len], &[len]).unwrap();
106 Self::new(data, targets)
107 }
108}
109
110impl Dataset for TensorDataset {
111 type Item = (Tensor<f32>, Tensor<f32>);
112
113 fn len(&self) -> usize {
114 self.len
115 }
116
117 fn get(&self, index: usize) -> Option<Self::Item> {
118 if index >= self.len {
119 return None;
120 }
121
122 let start = index * self.row_size;
124 let end = start + self.row_size;
125 let item_data = self.data_vec[start..end].to_vec();
126 let item_shape: Vec<usize> = self.data_shape[1..].to_vec();
127 let x = Tensor::from_vec(item_data, &item_shape).unwrap();
128
129 let target_start = index * self.target_row_size;
130 let target_end = target_start + self.target_row_size;
131 let item_target = self.target_vec[target_start..target_end].to_vec();
132 let target_item_shape: Vec<usize> = if self.target_shape.len() > 1 {
133 self.target_shape[1..].to_vec()
134 } else {
135 vec![1]
136 };
137 let y = Tensor::from_vec(item_target, &target_item_shape).unwrap();
138
139 Some((x, y))
140 }
141}
142
143pub struct MapDataset<D, F>
149where
150 D: Dataset,
151 F: Fn(D::Item) -> D::Item + Send + Sync,
152{
153 dataset: D,
154 transform: F,
155}
156
157impl<D, F> MapDataset<D, F>
158where
159 D: Dataset,
160 F: Fn(D::Item) -> D::Item + Send + Sync,
161{
162 pub fn new(dataset: D, transform: F) -> Self {
164 Self { dataset, transform }
165 }
166}
167
168impl<D, F> Dataset for MapDataset<D, F>
169where
170 D: Dataset,
171 F: Fn(D::Item) -> D::Item + Send + Sync,
172{
173 type Item = D::Item;
174
175 fn len(&self) -> usize {
176 self.dataset.len()
177 }
178
179 fn get(&self, index: usize) -> Option<Self::Item> {
180 self.dataset.get(index).map(&self.transform)
181 }
182}
183
184pub struct ConcatDataset<D: Dataset> {
190 datasets: Vec<D>,
191 cumulative_sizes: Vec<usize>,
192}
193
194impl<D: Dataset> ConcatDataset<D> {
195 #[must_use]
197 pub fn new(datasets: Vec<D>) -> Self {
198 let mut cumulative_sizes = Vec::with_capacity(datasets.len());
199 let mut total = 0;
200 for d in &datasets {
201 total += d.len();
202 cumulative_sizes.push(total);
203 }
204 Self {
205 datasets,
206 cumulative_sizes,
207 }
208 }
209
210 fn find_dataset(&self, index: usize) -> Option<(usize, usize)> {
212 if index >= self.len() {
213 return None;
214 }
215
216 for (i, &cum_size) in self.cumulative_sizes.iter().enumerate() {
217 if index < cum_size {
218 let prev_size = if i == 0 {
219 0
220 } else {
221 self.cumulative_sizes[i - 1]
222 };
223 return Some((i, index - prev_size));
224 }
225 }
226 None
227 }
228}
229
230impl<D: Dataset> Dataset for ConcatDataset<D> {
231 type Item = D::Item;
232
233 fn len(&self) -> usize {
234 *self.cumulative_sizes.last().unwrap_or(&0)
235 }
236
237 fn get(&self, index: usize) -> Option<Self::Item> {
238 let (dataset_idx, local_idx) = self.find_dataset(index)?;
239 self.datasets[dataset_idx].get(local_idx)
240 }
241}
242
243pub struct SubsetDataset<D: Dataset> {
249 dataset: D,
250 indices: Vec<usize>,
251}
252
253impl<D: Dataset> SubsetDataset<D> {
254 pub fn new(dataset: D, indices: Vec<usize>) -> Self {
256 Self { dataset, indices }
257 }
258
259 pub fn random_split(dataset: D, lengths: &[usize]) -> Vec<Self>
261 where
262 D: Clone,
263 {
264 use rand::seq::SliceRandom;
265 use rand::thread_rng;
266
267 let total_len: usize = lengths.iter().sum();
268 assert_eq!(
269 total_len,
270 dataset.len(),
271 "Split lengths must sum to dataset length"
272 );
273
274 let mut indices: Vec<usize> = (0..dataset.len()).collect();
275 indices.shuffle(&mut thread_rng());
276
277 let mut subsets = Vec::with_capacity(lengths.len());
278 let mut offset = 0;
279 for &len in lengths {
280 let subset_indices = indices[offset..offset + len].to_vec();
281 subsets.push(Self::new(dataset.clone(), subset_indices));
282 offset += len;
283 }
284 subsets
285 }
286}
287
288impl<D: Dataset> Dataset for SubsetDataset<D> {
289 type Item = D::Item;
290
291 fn len(&self) -> usize {
292 self.indices.len()
293 }
294
295 fn get(&self, index: usize) -> Option<Self::Item> {
296 let real_index = *self.indices.get(index)?;
297 self.dataset.get(real_index)
298 }
299}
300
301pub struct InMemoryDataset<T: Clone + Send> {
307 items: Vec<T>,
308}
309
310impl<T: Clone + Send> InMemoryDataset<T> {
311 #[must_use]
313 pub fn new(items: Vec<T>) -> Self {
314 Self { items }
315 }
316}
317
318impl<T: Clone + Send + Sync> Dataset for InMemoryDataset<T> {
319 type Item = T;
320
321 fn len(&self) -> usize {
322 self.items.len()
323 }
324
325 fn get(&self, index: usize) -> Option<Self::Item> {
326 self.items.get(index).cloned()
327 }
328}
329
330#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[test]
339 fn test_tensor_dataset() {
340 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
341 let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap();
342 let dataset = TensorDataset::new(data, targets);
343
344 assert_eq!(dataset.len(), 3);
345
346 let (x, y) = dataset.get(0).unwrap();
347 assert_eq!(x.to_vec(), vec![1.0, 2.0]);
348 assert_eq!(y.to_vec(), vec![0.0]);
349
350 let (x, y) = dataset.get(2).unwrap();
351 assert_eq!(x.to_vec(), vec![5.0, 6.0]);
352 assert_eq!(y.to_vec(), vec![2.0]);
353
354 assert!(dataset.get(3).is_none());
355 }
356
357 #[test]
358 fn test_map_dataset() {
359 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4, 1]).unwrap();
360 let targets = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4]).unwrap();
361 let base = TensorDataset::new(data, targets);
362
363 let mapped = MapDataset::new(base, |(x, y)| (x.mul_scalar(2.0), y));
364
365 assert_eq!(mapped.len(), 4);
366 let (x, _) = mapped.get(0).unwrap();
367 assert_eq!(x.to_vec(), vec![2.0]);
368 }
369
370 #[test]
371 fn test_concat_dataset() {
372 let data1 = Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap();
373 let targets1 = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
374 let ds1 = TensorDataset::new(data1, targets1);
375
376 let data2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3, 1]).unwrap();
377 let targets2 = Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap();
378 let ds2 = TensorDataset::new(data2, targets2);
379
380 let concat = ConcatDataset::new(vec![ds1, ds2]);
381
382 assert_eq!(concat.len(), 5);
383
384 let (x, y) = concat.get(0).unwrap();
385 assert_eq!(x.to_vec(), vec![1.0]);
386 assert_eq!(y.to_vec(), vec![0.0]);
387
388 let (x, y) = concat.get(3).unwrap();
389 assert_eq!(x.to_vec(), vec![4.0]);
390 assert_eq!(y.to_vec(), vec![3.0]);
391 }
392
393 #[test]
394 fn test_subset_dataset() {
395 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5, 1]).unwrap();
396 let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5]).unwrap();
397 let base = TensorDataset::new(data, targets);
398
399 let subset = SubsetDataset::new(base, vec![0, 2, 4]);
400
401 assert_eq!(subset.len(), 3);
402
403 let (x, _) = subset.get(0).unwrap();
404 assert_eq!(x.to_vec(), vec![1.0]);
405
406 let (x, _) = subset.get(1).unwrap();
407 assert_eq!(x.to_vec(), vec![3.0]);
408
409 let (x, _) = subset.get(2).unwrap();
410 assert_eq!(x.to_vec(), vec![5.0]);
411 }
412
413 #[test]
414 fn test_in_memory_dataset() {
415 let dataset = InMemoryDataset::new(vec![1, 2, 3, 4, 5]);
416
417 assert_eq!(dataset.len(), 5);
418 assert_eq!(dataset.get(0), Some(1));
419 assert_eq!(dataset.get(4), Some(5));
420 assert_eq!(dataset.get(5), None);
421 }
422}