Skip to main content

axonml_data/
dataset.rs

1//! Dataset Trait - Core Data Abstraction
2//!
3//! Defines the Dataset trait that all data sources implement.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_tensor::Tensor;
9
10// =============================================================================
11// Dataset Trait
12// =============================================================================
13
14/// Core trait for all datasets.
15///
16/// A dataset provides indexed access to data items.
17pub trait Dataset: Send + Sync {
18    /// The type of items in the dataset.
19    type Item: Send;
20
21    /// Returns the number of items in the dataset.
22    fn len(&self) -> usize;
23
24    /// Returns true if the dataset is empty.
25    fn is_empty(&self) -> bool {
26        self.len() == 0
27    }
28
29    /// Gets an item by index.
30    fn get(&self, index: usize) -> Option<Self::Item>;
31}
32
33// =============================================================================
34// TensorDataset
35// =============================================================================
36
37/// A dataset wrapping tensors.
38///
39/// Each item is a tuple of (input, target) tensors.
40pub struct TensorDataset {
41    /// Input data tensor.
42    data: Tensor<f32>,
43    /// Target tensor.
44    targets: Tensor<f32>,
45    /// Number of samples.
46    len: usize,
47}
48
49impl TensorDataset {
50    /// Creates a new `TensorDataset` from input and target tensors.
51    ///
52    /// The first dimension of both tensors must match.
53    #[must_use]
54    pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
55        let len = data.shape()[0];
56        assert_eq!(
57            len,
58            targets.shape()[0],
59            "Data and targets must have same first dimension"
60        );
61        Self { data, targets, len }
62    }
63
64    /// Creates a `TensorDataset` from just input data (no targets).
65    #[must_use]
66    pub fn from_data(data: Tensor<f32>) -> Self {
67        let len = data.shape()[0];
68        let targets = Tensor::from_vec(vec![0.0; len], &[len]).unwrap();
69        Self { data, targets, len }
70    }
71}
72
73impl Dataset for TensorDataset {
74    type Item = (Tensor<f32>, Tensor<f32>);
75
76    fn len(&self) -> usize {
77        self.len
78    }
79
80    fn get(&self, index: usize) -> Option<Self::Item> {
81        if index >= self.len {
82            return None;
83        }
84
85        // Extract row from data tensor
86        let data_shape = self.data.shape();
87        let row_size: usize = data_shape[1..].iter().product();
88        let data_vec = self.data.to_vec();
89        let start = index * row_size;
90        let end = start + row_size;
91        let item_data = data_vec[start..end].to_vec();
92        let item_shape: Vec<usize> = data_shape[1..].to_vec();
93        let x = Tensor::from_vec(item_data, &item_shape).unwrap();
94
95        // Extract target
96        let target_shape = self.targets.shape();
97        let target_row_size: usize = if target_shape.len() > 1 {
98            target_shape[1..].iter().product()
99        } else {
100            1
101        };
102        let target_vec = self.targets.to_vec();
103        let target_start = index * target_row_size;
104        let target_end = target_start + target_row_size;
105        let item_target = target_vec[target_start..target_end].to_vec();
106        let target_item_shape: Vec<usize> = if target_shape.len() > 1 {
107            target_shape[1..].to_vec()
108        } else {
109            vec![1]
110        };
111        let y = Tensor::from_vec(item_target, &target_item_shape).unwrap();
112
113        Some((x, y))
114    }
115}
116
117// =============================================================================
118// MapDataset
119// =============================================================================
120
121/// A dataset that applies a transform to another dataset.
122pub struct MapDataset<D, F>
123where
124    D: Dataset,
125    F: Fn(D::Item) -> D::Item + Send + Sync,
126{
127    dataset: D,
128    transform: F,
129}
130
131impl<D, F> MapDataset<D, F>
132where
133    D: Dataset,
134    F: Fn(D::Item) -> D::Item + Send + Sync,
135{
136    /// Creates a new `MapDataset`.
137    pub fn new(dataset: D, transform: F) -> Self {
138        Self { dataset, transform }
139    }
140}
141
142impl<D, F> Dataset for MapDataset<D, F>
143where
144    D: Dataset,
145    F: Fn(D::Item) -> D::Item + Send + Sync,
146{
147    type Item = D::Item;
148
149    fn len(&self) -> usize {
150        self.dataset.len()
151    }
152
153    fn get(&self, index: usize) -> Option<Self::Item> {
154        self.dataset.get(index).map(&self.transform)
155    }
156}
157
158// =============================================================================
159// ConcatDataset
160// =============================================================================
161
162/// A dataset that concatenates multiple datasets.
163pub struct ConcatDataset<D: Dataset> {
164    datasets: Vec<D>,
165    cumulative_sizes: Vec<usize>,
166}
167
168impl<D: Dataset> ConcatDataset<D> {
169    /// Creates a new `ConcatDataset` from multiple datasets.
170    #[must_use]
171    pub fn new(datasets: Vec<D>) -> Self {
172        let mut cumulative_sizes = Vec::with_capacity(datasets.len());
173        let mut total = 0;
174        for d in &datasets {
175            total += d.len();
176            cumulative_sizes.push(total);
177        }
178        Self {
179            datasets,
180            cumulative_sizes,
181        }
182    }
183
184    /// Finds which dataset contains the given index.
185    fn find_dataset(&self, index: usize) -> Option<(usize, usize)> {
186        if index >= self.len() {
187            return None;
188        }
189
190        for (i, &cum_size) in self.cumulative_sizes.iter().enumerate() {
191            if index < cum_size {
192                let prev_size = if i == 0 {
193                    0
194                } else {
195                    self.cumulative_sizes[i - 1]
196                };
197                return Some((i, index - prev_size));
198            }
199        }
200        None
201    }
202}
203
204impl<D: Dataset> Dataset for ConcatDataset<D> {
205    type Item = D::Item;
206
207    fn len(&self) -> usize {
208        *self.cumulative_sizes.last().unwrap_or(&0)
209    }
210
211    fn get(&self, index: usize) -> Option<Self::Item> {
212        let (dataset_idx, local_idx) = self.find_dataset(index)?;
213        self.datasets[dataset_idx].get(local_idx)
214    }
215}
216
217// =============================================================================
218// SubsetDataset
219// =============================================================================
220
221/// A dataset that provides a subset of another dataset.
222pub struct SubsetDataset<D: Dataset> {
223    dataset: D,
224    indices: Vec<usize>,
225}
226
227impl<D: Dataset> SubsetDataset<D> {
228    /// Creates a new `SubsetDataset` with specified indices.
229    pub fn new(dataset: D, indices: Vec<usize>) -> Self {
230        Self { dataset, indices }
231    }
232
233    /// Creates a random split of a dataset into two subsets.
234    pub fn random_split(dataset: D, lengths: &[usize]) -> Vec<Self>
235    where
236        D: Clone,
237    {
238        use rand::seq::SliceRandom;
239        use rand::thread_rng;
240
241        let total_len: usize = lengths.iter().sum();
242        assert_eq!(
243            total_len,
244            dataset.len(),
245            "Split lengths must sum to dataset length"
246        );
247
248        let mut indices: Vec<usize> = (0..dataset.len()).collect();
249        indices.shuffle(&mut thread_rng());
250
251        let mut subsets = Vec::with_capacity(lengths.len());
252        let mut offset = 0;
253        for &len in lengths {
254            let subset_indices = indices[offset..offset + len].to_vec();
255            subsets.push(Self::new(dataset.clone(), subset_indices));
256            offset += len;
257        }
258        subsets
259    }
260}
261
262impl<D: Dataset> Dataset for SubsetDataset<D> {
263    type Item = D::Item;
264
265    fn len(&self) -> usize {
266        self.indices.len()
267    }
268
269    fn get(&self, index: usize) -> Option<Self::Item> {
270        let real_index = *self.indices.get(index)?;
271        self.dataset.get(real_index)
272    }
273}
274
275// =============================================================================
276// InMemoryDataset
277// =============================================================================
278
279/// A simple in-memory dataset from a vector.
280pub struct InMemoryDataset<T: Clone + Send> {
281    items: Vec<T>,
282}
283
284impl<T: Clone + Send> InMemoryDataset<T> {
285    /// Creates a new `InMemoryDataset` from a vector.
286    #[must_use]
287    pub fn new(items: Vec<T>) -> Self {
288        Self { items }
289    }
290}
291
292impl<T: Clone + Send + Sync> Dataset for InMemoryDataset<T> {
293    type Item = T;
294
295    fn len(&self) -> usize {
296        self.items.len()
297    }
298
299    fn get(&self, index: usize) -> Option<Self::Item> {
300        self.items.get(index).cloned()
301    }
302}
303
304// =============================================================================
305// Tests
306// =============================================================================
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_tensor_dataset() {
314        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
315        let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap();
316        let dataset = TensorDataset::new(data, targets);
317
318        assert_eq!(dataset.len(), 3);
319
320        let (x, y) = dataset.get(0).unwrap();
321        assert_eq!(x.to_vec(), vec![1.0, 2.0]);
322        assert_eq!(y.to_vec(), vec![0.0]);
323
324        let (x, y) = dataset.get(2).unwrap();
325        assert_eq!(x.to_vec(), vec![5.0, 6.0]);
326        assert_eq!(y.to_vec(), vec![2.0]);
327
328        assert!(dataset.get(3).is_none());
329    }
330
331    #[test]
332    fn test_map_dataset() {
333        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4, 1]).unwrap();
334        let targets = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4]).unwrap();
335        let base = TensorDataset::new(data, targets);
336
337        let mapped = MapDataset::new(base, |(x, y)| (x.mul_scalar(2.0), y));
338
339        assert_eq!(mapped.len(), 4);
340        let (x, _) = mapped.get(0).unwrap();
341        assert_eq!(x.to_vec(), vec![2.0]);
342    }
343
344    #[test]
345    fn test_concat_dataset() {
346        let data1 = Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap();
347        let targets1 = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
348        let ds1 = TensorDataset::new(data1, targets1);
349
350        let data2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3, 1]).unwrap();
351        let targets2 = Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap();
352        let ds2 = TensorDataset::new(data2, targets2);
353
354        let concat = ConcatDataset::new(vec![ds1, ds2]);
355
356        assert_eq!(concat.len(), 5);
357
358        let (x, y) = concat.get(0).unwrap();
359        assert_eq!(x.to_vec(), vec![1.0]);
360        assert_eq!(y.to_vec(), vec![0.0]);
361
362        let (x, y) = concat.get(3).unwrap();
363        assert_eq!(x.to_vec(), vec![4.0]);
364        assert_eq!(y.to_vec(), vec![3.0]);
365    }
366
367    #[test]
368    fn test_subset_dataset() {
369        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5, 1]).unwrap();
370        let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5]).unwrap();
371        let base = TensorDataset::new(data, targets);
372
373        let subset = SubsetDataset::new(base, vec![0, 2, 4]);
374
375        assert_eq!(subset.len(), 3);
376
377        let (x, _) = subset.get(0).unwrap();
378        assert_eq!(x.to_vec(), vec![1.0]);
379
380        let (x, _) = subset.get(1).unwrap();
381        assert_eq!(x.to_vec(), vec![3.0]);
382
383        let (x, _) = subset.get(2).unwrap();
384        assert_eq!(x.to_vec(), vec![5.0]);
385    }
386
387    #[test]
388    fn test_in_memory_dataset() {
389        let dataset = InMemoryDataset::new(vec![1, 2, 3, 4, 5]);
390
391        assert_eq!(dataset.len(), 5);
392        assert_eq!(dataset.get(0), Some(1));
393        assert_eq!(dataset.get(4), Some(5));
394        assert_eq!(dataset.get(5), None);
395    }
396}