Skip to main content

axonml_data/
dataset.rs

1//! Dataset Trait - Core Data Abstraction
2//!
3//! # File
4//! `crates/axonml-data/src/dataset.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18
19// =============================================================================
20// Dataset Trait
21// =============================================================================
22
23/// Core trait for all datasets.
24///
25/// A dataset provides indexed access to data items.
26pub trait Dataset: Send + Sync {
27    /// The type of items in the dataset.
28    type Item: Send;
29
30    /// Returns the number of items in the dataset.
31    fn len(&self) -> usize;
32
33    /// Returns true if the dataset is empty.
34    fn is_empty(&self) -> bool {
35        self.len() == 0
36    }
37
38    /// Gets an item by index.
39    fn get(&self, index: usize) -> Option<Self::Item>;
40}
41
42// =============================================================================
43// TensorDataset
44// =============================================================================
45
46/// A dataset wrapping tensors.
47///
48/// Each item is a tuple of (input, target) tensors.
49pub struct TensorDataset {
50    /// Input data tensor.
51    data: Tensor<f32>,
52    /// Target tensor.
53    targets: Tensor<f32>,
54    /// Number of samples.
55    len: usize,
56}
57
58impl TensorDataset {
59    /// Creates a new `TensorDataset` from input and target tensors.
60    ///
61    /// The first dimension of both tensors must match.
62    #[must_use]
63    pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
64        let len = data.shape()[0];
65        assert_eq!(
66            len,
67            targets.shape()[0],
68            "Data and targets must have same first dimension"
69        );
70        Self { data, targets, len }
71    }
72
73    /// Creates a `TensorDataset` from just input data (no targets).
74    #[must_use]
75    pub fn from_data(data: Tensor<f32>) -> Self {
76        let len = data.shape()[0];
77        let targets = Tensor::from_vec(vec![0.0; len], &[len]).unwrap();
78        Self { data, targets, len }
79    }
80}
81
82impl Dataset for TensorDataset {
83    type Item = (Tensor<f32>, Tensor<f32>);
84
85    fn len(&self) -> usize {
86        self.len
87    }
88
89    fn get(&self, index: usize) -> Option<Self::Item> {
90        if index >= self.len {
91            return None;
92        }
93
94        // Extract row from data tensor
95        let data_shape = self.data.shape();
96        let row_size: usize = data_shape[1..].iter().product();
97        let data_vec = self.data.to_vec();
98        let start = index * row_size;
99        let end = start + row_size;
100        let item_data = data_vec[start..end].to_vec();
101        let item_shape: Vec<usize> = data_shape[1..].to_vec();
102        let x = Tensor::from_vec(item_data, &item_shape).unwrap();
103
104        // Extract target
105        let target_shape = self.targets.shape();
106        let target_row_size: usize = if target_shape.len() > 1 {
107            target_shape[1..].iter().product()
108        } else {
109            1
110        };
111        let target_vec = self.targets.to_vec();
112        let target_start = index * target_row_size;
113        let target_end = target_start + target_row_size;
114        let item_target = target_vec[target_start..target_end].to_vec();
115        let target_item_shape: Vec<usize> = if target_shape.len() > 1 {
116            target_shape[1..].to_vec()
117        } else {
118            vec![1]
119        };
120        let y = Tensor::from_vec(item_target, &target_item_shape).unwrap();
121
122        Some((x, y))
123    }
124}
125
126// =============================================================================
127// MapDataset
128// =============================================================================
129
130/// A dataset that applies a transform to another dataset.
131pub struct MapDataset<D, F>
132where
133    D: Dataset,
134    F: Fn(D::Item) -> D::Item + Send + Sync,
135{
136    dataset: D,
137    transform: F,
138}
139
140impl<D, F> MapDataset<D, F>
141where
142    D: Dataset,
143    F: Fn(D::Item) -> D::Item + Send + Sync,
144{
145    /// Creates a new `MapDataset`.
146    pub fn new(dataset: D, transform: F) -> Self {
147        Self { dataset, transform }
148    }
149}
150
151impl<D, F> Dataset for MapDataset<D, F>
152where
153    D: Dataset,
154    F: Fn(D::Item) -> D::Item + Send + Sync,
155{
156    type Item = D::Item;
157
158    fn len(&self) -> usize {
159        self.dataset.len()
160    }
161
162    fn get(&self, index: usize) -> Option<Self::Item> {
163        self.dataset.get(index).map(&self.transform)
164    }
165}
166
167// =============================================================================
168// ConcatDataset
169// =============================================================================
170
171/// A dataset that concatenates multiple datasets.
172pub struct ConcatDataset<D: Dataset> {
173    datasets: Vec<D>,
174    cumulative_sizes: Vec<usize>,
175}
176
177impl<D: Dataset> ConcatDataset<D> {
178    /// Creates a new `ConcatDataset` from multiple datasets.
179    #[must_use]
180    pub fn new(datasets: Vec<D>) -> Self {
181        let mut cumulative_sizes = Vec::with_capacity(datasets.len());
182        let mut total = 0;
183        for d in &datasets {
184            total += d.len();
185            cumulative_sizes.push(total);
186        }
187        Self {
188            datasets,
189            cumulative_sizes,
190        }
191    }
192
193    /// Finds which dataset contains the given index.
194    fn find_dataset(&self, index: usize) -> Option<(usize, usize)> {
195        if index >= self.len() {
196            return None;
197        }
198
199        for (i, &cum_size) in self.cumulative_sizes.iter().enumerate() {
200            if index < cum_size {
201                let prev_size = if i == 0 {
202                    0
203                } else {
204                    self.cumulative_sizes[i - 1]
205                };
206                return Some((i, index - prev_size));
207            }
208        }
209        None
210    }
211}
212
213impl<D: Dataset> Dataset for ConcatDataset<D> {
214    type Item = D::Item;
215
216    fn len(&self) -> usize {
217        *self.cumulative_sizes.last().unwrap_or(&0)
218    }
219
220    fn get(&self, index: usize) -> Option<Self::Item> {
221        let (dataset_idx, local_idx) = self.find_dataset(index)?;
222        self.datasets[dataset_idx].get(local_idx)
223    }
224}
225
226// =============================================================================
227// SubsetDataset
228// =============================================================================
229
230/// A dataset that provides a subset of another dataset.
231pub struct SubsetDataset<D: Dataset> {
232    dataset: D,
233    indices: Vec<usize>,
234}
235
236impl<D: Dataset> SubsetDataset<D> {
237    /// Creates a new `SubsetDataset` with specified indices.
238    pub fn new(dataset: D, indices: Vec<usize>) -> Self {
239        Self { dataset, indices }
240    }
241
242    /// Creates a random split of a dataset into two subsets.
243    pub fn random_split(dataset: D, lengths: &[usize]) -> Vec<Self>
244    where
245        D: Clone,
246    {
247        use rand::seq::SliceRandom;
248        use rand::thread_rng;
249
250        let total_len: usize = lengths.iter().sum();
251        assert_eq!(
252            total_len,
253            dataset.len(),
254            "Split lengths must sum to dataset length"
255        );
256
257        let mut indices: Vec<usize> = (0..dataset.len()).collect();
258        indices.shuffle(&mut thread_rng());
259
260        let mut subsets = Vec::with_capacity(lengths.len());
261        let mut offset = 0;
262        for &len in lengths {
263            let subset_indices = indices[offset..offset + len].to_vec();
264            subsets.push(Self::new(dataset.clone(), subset_indices));
265            offset += len;
266        }
267        subsets
268    }
269}
270
271impl<D: Dataset> Dataset for SubsetDataset<D> {
272    type Item = D::Item;
273
274    fn len(&self) -> usize {
275        self.indices.len()
276    }
277
278    fn get(&self, index: usize) -> Option<Self::Item> {
279        let real_index = *self.indices.get(index)?;
280        self.dataset.get(real_index)
281    }
282}
283
284// =============================================================================
285// InMemoryDataset
286// =============================================================================
287
288/// A simple in-memory dataset from a vector.
289pub struct InMemoryDataset<T: Clone + Send> {
290    items: Vec<T>,
291}
292
293impl<T: Clone + Send> InMemoryDataset<T> {
294    /// Creates a new `InMemoryDataset` from a vector.
295    #[must_use]
296    pub fn new(items: Vec<T>) -> Self {
297        Self { items }
298    }
299}
300
301impl<T: Clone + Send + Sync> Dataset for InMemoryDataset<T> {
302    type Item = T;
303
304    fn len(&self) -> usize {
305        self.items.len()
306    }
307
308    fn get(&self, index: usize) -> Option<Self::Item> {
309        self.items.get(index).cloned()
310    }
311}
312
313// =============================================================================
314// Tests
315// =============================================================================
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_tensor_dataset() {
323        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
324        let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap();
325        let dataset = TensorDataset::new(data, targets);
326
327        assert_eq!(dataset.len(), 3);
328
329        let (x, y) = dataset.get(0).unwrap();
330        assert_eq!(x.to_vec(), vec![1.0, 2.0]);
331        assert_eq!(y.to_vec(), vec![0.0]);
332
333        let (x, y) = dataset.get(2).unwrap();
334        assert_eq!(x.to_vec(), vec![5.0, 6.0]);
335        assert_eq!(y.to_vec(), vec![2.0]);
336
337        assert!(dataset.get(3).is_none());
338    }
339
340    #[test]
341    fn test_map_dataset() {
342        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4, 1]).unwrap();
343        let targets = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4]).unwrap();
344        let base = TensorDataset::new(data, targets);
345
346        let mapped = MapDataset::new(base, |(x, y)| (x.mul_scalar(2.0), y));
347
348        assert_eq!(mapped.len(), 4);
349        let (x, _) = mapped.get(0).unwrap();
350        assert_eq!(x.to_vec(), vec![2.0]);
351    }
352
353    #[test]
354    fn test_concat_dataset() {
355        let data1 = Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap();
356        let targets1 = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
357        let ds1 = TensorDataset::new(data1, targets1);
358
359        let data2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3, 1]).unwrap();
360        let targets2 = Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap();
361        let ds2 = TensorDataset::new(data2, targets2);
362
363        let concat = ConcatDataset::new(vec![ds1, ds2]);
364
365        assert_eq!(concat.len(), 5);
366
367        let (x, y) = concat.get(0).unwrap();
368        assert_eq!(x.to_vec(), vec![1.0]);
369        assert_eq!(y.to_vec(), vec![0.0]);
370
371        let (x, y) = concat.get(3).unwrap();
372        assert_eq!(x.to_vec(), vec![4.0]);
373        assert_eq!(y.to_vec(), vec![3.0]);
374    }
375
376    #[test]
377    fn test_subset_dataset() {
378        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5, 1]).unwrap();
379        let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5]).unwrap();
380        let base = TensorDataset::new(data, targets);
381
382        let subset = SubsetDataset::new(base, vec![0, 2, 4]);
383
384        assert_eq!(subset.len(), 3);
385
386        let (x, _) = subset.get(0).unwrap();
387        assert_eq!(x.to_vec(), vec![1.0]);
388
389        let (x, _) = subset.get(1).unwrap();
390        assert_eq!(x.to_vec(), vec![3.0]);
391
392        let (x, _) = subset.get(2).unwrap();
393        assert_eq!(x.to_vec(), vec![5.0]);
394    }
395
396    #[test]
397    fn test_in_memory_dataset() {
398        let dataset = InMemoryDataset::new(vec![1, 2, 3, 4, 5]);
399
400        assert_eq!(dataset.len(), 5);
401        assert_eq!(dataset.get(0), Some(1));
402        assert_eq!(dataset.get(4), Some(5));
403        assert_eq!(dataset.get(5), None);
404    }
405}