axonml_data/
dataset.rs

1//! Dataset Trait - Core Data Abstraction
2//!
3//! Defines the Dataset trait that all data sources implement.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_tensor::Tensor;
9
10// =============================================================================
11// Dataset Trait
12// =============================================================================
13
14/// Core trait for all datasets.
15///
16/// A dataset provides indexed access to data items.
17pub trait Dataset: Send + Sync {
18    /// The type of items in the dataset.
19    type Item: Send;
20
21    /// Returns the number of items in the dataset.
22    fn len(&self) -> usize;
23
24    /// Returns true if the dataset is empty.
25    fn is_empty(&self) -> bool {
26        self.len() == 0
27    }
28
29    /// Gets an item by index.
30    fn get(&self, index: usize) -> Option<Self::Item>;
31}
32
33// =============================================================================
34// TensorDataset
35// =============================================================================
36
37/// A dataset wrapping tensors.
38///
39/// Each item is a tuple of (input, target) tensors.
40pub struct TensorDataset {
41    /// Input data tensor.
42    data: Tensor<f32>,
43    /// Target tensor.
44    targets: Tensor<f32>,
45    /// Number of samples.
46    len: usize,
47}
48
49impl TensorDataset {
50    /// Creates a new `TensorDataset` from input and target tensors.
51    ///
52    /// The first dimension of both tensors must match.
53    #[must_use] pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
54        let len = data.shape()[0];
55        assert_eq!(
56            len,
57            targets.shape()[0],
58            "Data and targets must have same first dimension"
59        );
60        Self { data, targets, len }
61    }
62
63    /// Creates a `TensorDataset` from just input data (no targets).
64    #[must_use] pub fn from_data(data: Tensor<f32>) -> Self {
65        let len = data.shape()[0];
66        let targets = Tensor::from_vec(vec![0.0; len], &[len]).unwrap();
67        Self { data, targets, len }
68    }
69}
70
71impl Dataset for TensorDataset {
72    type Item = (Tensor<f32>, Tensor<f32>);
73
74    fn len(&self) -> usize {
75        self.len
76    }
77
78    fn get(&self, index: usize) -> Option<Self::Item> {
79        if index >= self.len {
80            return None;
81        }
82
83        // Extract row from data tensor
84        let data_shape = self.data.shape();
85        let row_size: usize = data_shape[1..].iter().product();
86        let data_vec = self.data.to_vec();
87        let start = index * row_size;
88        let end = start + row_size;
89        let item_data = data_vec[start..end].to_vec();
90        let item_shape: Vec<usize> = data_shape[1..].to_vec();
91        let x = Tensor::from_vec(item_data, &item_shape).unwrap();
92
93        // Extract target
94        let target_shape = self.targets.shape();
95        let target_row_size: usize = if target_shape.len() > 1 {
96            target_shape[1..].iter().product()
97        } else {
98            1
99        };
100        let target_vec = self.targets.to_vec();
101        let target_start = index * target_row_size;
102        let target_end = target_start + target_row_size;
103        let item_target = target_vec[target_start..target_end].to_vec();
104        let target_item_shape: Vec<usize> = if target_shape.len() > 1 {
105            target_shape[1..].to_vec()
106        } else {
107            vec![1]
108        };
109        let y = Tensor::from_vec(item_target, &target_item_shape).unwrap();
110
111        Some((x, y))
112    }
113}
114
115// =============================================================================
116// MapDataset
117// =============================================================================
118
119/// A dataset that applies a transform to another dataset.
120pub struct MapDataset<D, F>
121where
122    D: Dataset,
123    F: Fn(D::Item) -> D::Item + Send + Sync,
124{
125    dataset: D,
126    transform: F,
127}
128
129impl<D, F> MapDataset<D, F>
130where
131    D: Dataset,
132    F: Fn(D::Item) -> D::Item + Send + Sync,
133{
134    /// Creates a new `MapDataset`.
135    pub fn new(dataset: D, transform: F) -> Self {
136        Self { dataset, transform }
137    }
138}
139
140impl<D, F> Dataset for MapDataset<D, F>
141where
142    D: Dataset,
143    F: Fn(D::Item) -> D::Item + Send + Sync,
144{
145    type Item = D::Item;
146
147    fn len(&self) -> usize {
148        self.dataset.len()
149    }
150
151    fn get(&self, index: usize) -> Option<Self::Item> {
152        self.dataset.get(index).map(&self.transform)
153    }
154}
155
156// =============================================================================
157// ConcatDataset
158// =============================================================================
159
160/// A dataset that concatenates multiple datasets.
161pub struct ConcatDataset<D: Dataset> {
162    datasets: Vec<D>,
163    cumulative_sizes: Vec<usize>,
164}
165
166impl<D: Dataset> ConcatDataset<D> {
167    /// Creates a new `ConcatDataset` from multiple datasets.
168    #[must_use] pub fn new(datasets: Vec<D>) -> Self {
169        let mut cumulative_sizes = Vec::with_capacity(datasets.len());
170        let mut total = 0;
171        for d in &datasets {
172            total += d.len();
173            cumulative_sizes.push(total);
174        }
175        Self {
176            datasets,
177            cumulative_sizes,
178        }
179    }
180
181    /// Finds which dataset contains the given index.
182    fn find_dataset(&self, index: usize) -> Option<(usize, usize)> {
183        if index >= self.len() {
184            return None;
185        }
186
187        for (i, &cum_size) in self.cumulative_sizes.iter().enumerate() {
188            if index < cum_size {
189                let prev_size = if i == 0 {
190                    0
191                } else {
192                    self.cumulative_sizes[i - 1]
193                };
194                return Some((i, index - prev_size));
195            }
196        }
197        None
198    }
199}
200
201impl<D: Dataset> Dataset for ConcatDataset<D> {
202    type Item = D::Item;
203
204    fn len(&self) -> usize {
205        *self.cumulative_sizes.last().unwrap_or(&0)
206    }
207
208    fn get(&self, index: usize) -> Option<Self::Item> {
209        let (dataset_idx, local_idx) = self.find_dataset(index)?;
210        self.datasets[dataset_idx].get(local_idx)
211    }
212}
213
214// =============================================================================
215// SubsetDataset
216// =============================================================================
217
218/// A dataset that provides a subset of another dataset.
219pub struct SubsetDataset<D: Dataset> {
220    dataset: D,
221    indices: Vec<usize>,
222}
223
224impl<D: Dataset> SubsetDataset<D> {
225    /// Creates a new `SubsetDataset` with specified indices.
226    pub fn new(dataset: D, indices: Vec<usize>) -> Self {
227        Self { dataset, indices }
228    }
229
230    /// Creates a random split of a dataset into two subsets.
231    pub fn random_split(dataset: D, lengths: &[usize]) -> Vec<Self>
232    where
233        D: Clone,
234    {
235        use rand::seq::SliceRandom;
236        use rand::thread_rng;
237
238        let total_len: usize = lengths.iter().sum();
239        assert_eq!(
240            total_len,
241            dataset.len(),
242            "Split lengths must sum to dataset length"
243        );
244
245        let mut indices: Vec<usize> = (0..dataset.len()).collect();
246        indices.shuffle(&mut thread_rng());
247
248        let mut subsets = Vec::with_capacity(lengths.len());
249        let mut offset = 0;
250        for &len in lengths {
251            let subset_indices = indices[offset..offset + len].to_vec();
252            subsets.push(Self::new(dataset.clone(), subset_indices));
253            offset += len;
254        }
255        subsets
256    }
257}
258
259impl<D: Dataset> Dataset for SubsetDataset<D> {
260    type Item = D::Item;
261
262    fn len(&self) -> usize {
263        self.indices.len()
264    }
265
266    fn get(&self, index: usize) -> Option<Self::Item> {
267        let real_index = *self.indices.get(index)?;
268        self.dataset.get(real_index)
269    }
270}
271
272// =============================================================================
273// InMemoryDataset
274// =============================================================================
275
276/// A simple in-memory dataset from a vector.
277pub struct InMemoryDataset<T: Clone + Send> {
278    items: Vec<T>,
279}
280
281impl<T: Clone + Send> InMemoryDataset<T> {
282    /// Creates a new `InMemoryDataset` from a vector.
283    #[must_use] pub fn new(items: Vec<T>) -> Self {
284        Self { items }
285    }
286}
287
288impl<T: Clone + Send + Sync> Dataset for InMemoryDataset<T> {
289    type Item = T;
290
291    fn len(&self) -> usize {
292        self.items.len()
293    }
294
295    fn get(&self, index: usize) -> Option<Self::Item> {
296        self.items.get(index).cloned()
297    }
298}
299
300// =============================================================================
301// Tests
302// =============================================================================
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_tensor_dataset() {
310        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
311        let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap();
312        let dataset = TensorDataset::new(data, targets);
313
314        assert_eq!(dataset.len(), 3);
315
316        let (x, y) = dataset.get(0).unwrap();
317        assert_eq!(x.to_vec(), vec![1.0, 2.0]);
318        assert_eq!(y.to_vec(), vec![0.0]);
319
320        let (x, y) = dataset.get(2).unwrap();
321        assert_eq!(x.to_vec(), vec![5.0, 6.0]);
322        assert_eq!(y.to_vec(), vec![2.0]);
323
324        assert!(dataset.get(3).is_none());
325    }
326
327    #[test]
328    fn test_map_dataset() {
329        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4, 1]).unwrap();
330        let targets = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4]).unwrap();
331        let base = TensorDataset::new(data, targets);
332
333        let mapped = MapDataset::new(base, |(x, y)| (x.mul_scalar(2.0), y));
334
335        assert_eq!(mapped.len(), 4);
336        let (x, _) = mapped.get(0).unwrap();
337        assert_eq!(x.to_vec(), vec![2.0]);
338    }
339
340    #[test]
341    fn test_concat_dataset() {
342        let data1 = Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap();
343        let targets1 = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
344        let ds1 = TensorDataset::new(data1, targets1);
345
346        let data2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3, 1]).unwrap();
347        let targets2 = Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap();
348        let ds2 = TensorDataset::new(data2, targets2);
349
350        let concat = ConcatDataset::new(vec![ds1, ds2]);
351
352        assert_eq!(concat.len(), 5);
353
354        let (x, y) = concat.get(0).unwrap();
355        assert_eq!(x.to_vec(), vec![1.0]);
356        assert_eq!(y.to_vec(), vec![0.0]);
357
358        let (x, y) = concat.get(3).unwrap();
359        assert_eq!(x.to_vec(), vec![4.0]);
360        assert_eq!(y.to_vec(), vec![3.0]);
361    }
362
363    #[test]
364    fn test_subset_dataset() {
365        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5, 1]).unwrap();
366        let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5]).unwrap();
367        let base = TensorDataset::new(data, targets);
368
369        let subset = SubsetDataset::new(base, vec![0, 2, 4]);
370
371        assert_eq!(subset.len(), 3);
372
373        let (x, _) = subset.get(0).unwrap();
374        assert_eq!(x.to_vec(), vec![1.0]);
375
376        let (x, _) = subset.get(1).unwrap();
377        assert_eq!(x.to_vec(), vec![3.0]);
378
379        let (x, _) = subset.get(2).unwrap();
380        assert_eq!(x.to_vec(), vec![5.0]);
381    }
382
383    #[test]
384    fn test_in_memory_dataset() {
385        let dataset = InMemoryDataset::new(vec![1, 2, 3, 4, 5]);
386
387        assert_eq!(dataset.len(), 5);
388        assert_eq!(dataset.get(0), Some(1));
389        assert_eq!(dataset.get(4), Some(5));
390        assert_eq!(dataset.get(5), None);
391    }
392}