Skip to main content

axonml_data/
dataset.rs

1//! Dataset Trait - Core Data Abstraction
2//!
3//! # File
4//! `crates/axonml-data/src/dataset.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18
19// =============================================================================
20// Dataset Trait
21// =============================================================================
22
23/// Core trait for all datasets.
24///
25/// A dataset provides indexed access to data items.
26pub trait Dataset: Send + Sync {
27    /// The type of items in the dataset.
28    type Item: Send;
29
30    /// Returns the number of items in the dataset.
31    fn len(&self) -> usize;
32
33    /// Returns true if the dataset is empty.
34    fn is_empty(&self) -> bool {
35        self.len() == 0
36    }
37
38    /// Gets an item by index.
39    fn get(&self, index: usize) -> Option<Self::Item>;
40}
41
42// =============================================================================
43// TensorDataset
44// =============================================================================
45
46/// A dataset wrapping tensors.
47///
48/// Each item is a tuple of (input, target) tensors.
49/// Data is cached as flat vectors to avoid per-access O(N) copies.
50#[derive(Clone)]
51pub struct TensorDataset {
52    /// Cached flat data (extracted once at construction).
53    data_vec: Vec<f32>,
54    /// Cached flat targets (extracted once at construction).
55    target_vec: Vec<f32>,
56    /// Shape of the data tensor (first dim is sample count).
57    data_shape: Vec<usize>,
58    /// Shape of the targets tensor (first dim is sample count).
59    target_shape: Vec<usize>,
60    /// Number of elements per data sample (product of shape[1..]).
61    row_size: usize,
62    /// Number of elements per target sample.
63    target_row_size: usize,
64    /// Number of samples.
65    len: usize,
66}
67
68impl TensorDataset {
69    /// Creates a new `TensorDataset` from input and target tensors.
70    ///
71    /// The first dimension of both tensors must match.
72    #[must_use]
73    pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
74        let data_shape = data.shape().to_vec();
75        let target_shape = targets.shape().to_vec();
76        let len = data_shape[0];
77        assert_eq!(
78            len, target_shape[0],
79            "Data and targets must have same first dimension"
80        );
81
82        let row_size: usize = data_shape[1..].iter().product();
83        let target_row_size: usize = if target_shape.len() > 1 {
84            target_shape[1..].iter().product()
85        } else {
86            1
87        };
88
89        Self {
90            data_vec: data.to_vec(),
91            target_vec: targets.to_vec(),
92            data_shape,
93            target_shape,
94            row_size,
95            target_row_size,
96            len,
97        }
98    }
99
100    /// Creates a `TensorDataset` from just input data (no targets).
101    #[must_use]
102    pub fn from_data(data: Tensor<f32>) -> Self {
103        let len = data.shape()[0];
104        let targets = Tensor::from_vec(vec![0.0; len], &[len]).unwrap();
105        Self::new(data, targets)
106    }
107}
108
109impl Dataset for TensorDataset {
110    type Item = (Tensor<f32>, Tensor<f32>);
111
112    fn len(&self) -> usize {
113        self.len
114    }
115
116    fn get(&self, index: usize) -> Option<Self::Item> {
117        if index >= self.len {
118            return None;
119        }
120
121        // Slice directly from cached vec — O(row_size) not O(dataset_size)
122        let start = index * self.row_size;
123        let end = start + self.row_size;
124        let item_data = self.data_vec[start..end].to_vec();
125        let item_shape: Vec<usize> = self.data_shape[1..].to_vec();
126        let x = Tensor::from_vec(item_data, &item_shape).unwrap();
127
128        let target_start = index * self.target_row_size;
129        let target_end = target_start + self.target_row_size;
130        let item_target = self.target_vec[target_start..target_end].to_vec();
131        let target_item_shape: Vec<usize> = if self.target_shape.len() > 1 {
132            self.target_shape[1..].to_vec()
133        } else {
134            vec![1]
135        };
136        let y = Tensor::from_vec(item_target, &target_item_shape).unwrap();
137
138        Some((x, y))
139    }
140}
141
142// =============================================================================
143// MapDataset
144// =============================================================================
145
146/// A dataset that applies a transform to another dataset.
147pub struct MapDataset<D, F>
148where
149    D: Dataset,
150    F: Fn(D::Item) -> D::Item + Send + Sync,
151{
152    dataset: D,
153    transform: F,
154}
155
156impl<D, F> MapDataset<D, F>
157where
158    D: Dataset,
159    F: Fn(D::Item) -> D::Item + Send + Sync,
160{
161    /// Creates a new `MapDataset`.
162    pub fn new(dataset: D, transform: F) -> Self {
163        Self { dataset, transform }
164    }
165}
166
167impl<D, F> Dataset for MapDataset<D, F>
168where
169    D: Dataset,
170    F: Fn(D::Item) -> D::Item + Send + Sync,
171{
172    type Item = D::Item;
173
174    fn len(&self) -> usize {
175        self.dataset.len()
176    }
177
178    fn get(&self, index: usize) -> Option<Self::Item> {
179        self.dataset.get(index).map(&self.transform)
180    }
181}
182
183// =============================================================================
184// ConcatDataset
185// =============================================================================
186
187/// A dataset that concatenates multiple datasets.
188pub struct ConcatDataset<D: Dataset> {
189    datasets: Vec<D>,
190    cumulative_sizes: Vec<usize>,
191}
192
193impl<D: Dataset> ConcatDataset<D> {
194    /// Creates a new `ConcatDataset` from multiple datasets.
195    #[must_use]
196    pub fn new(datasets: Vec<D>) -> Self {
197        let mut cumulative_sizes = Vec::with_capacity(datasets.len());
198        let mut total = 0;
199        for d in &datasets {
200            total += d.len();
201            cumulative_sizes.push(total);
202        }
203        Self {
204            datasets,
205            cumulative_sizes,
206        }
207    }
208
209    /// Finds which dataset contains the given index.
210    fn find_dataset(&self, index: usize) -> Option<(usize, usize)> {
211        if index >= self.len() {
212            return None;
213        }
214
215        for (i, &cum_size) in self.cumulative_sizes.iter().enumerate() {
216            if index < cum_size {
217                let prev_size = if i == 0 {
218                    0
219                } else {
220                    self.cumulative_sizes[i - 1]
221                };
222                return Some((i, index - prev_size));
223            }
224        }
225        None
226    }
227}
228
229impl<D: Dataset> Dataset for ConcatDataset<D> {
230    type Item = D::Item;
231
232    fn len(&self) -> usize {
233        *self.cumulative_sizes.last().unwrap_or(&0)
234    }
235
236    fn get(&self, index: usize) -> Option<Self::Item> {
237        let (dataset_idx, local_idx) = self.find_dataset(index)?;
238        self.datasets[dataset_idx].get(local_idx)
239    }
240}
241
242// =============================================================================
243// SubsetDataset
244// =============================================================================
245
246/// A dataset that provides a subset of another dataset.
247pub struct SubsetDataset<D: Dataset> {
248    dataset: D,
249    indices: Vec<usize>,
250}
251
252impl<D: Dataset> SubsetDataset<D> {
253    /// Creates a new `SubsetDataset` with specified indices.
254    pub fn new(dataset: D, indices: Vec<usize>) -> Self {
255        Self { dataset, indices }
256    }
257
258    /// Creates a random split of a dataset into two subsets.
259    pub fn random_split(dataset: D, lengths: &[usize]) -> Vec<Self>
260    where
261        D: Clone,
262    {
263        use rand::seq::SliceRandom;
264        use rand::thread_rng;
265
266        let total_len: usize = lengths.iter().sum();
267        assert_eq!(
268            total_len,
269            dataset.len(),
270            "Split lengths must sum to dataset length"
271        );
272
273        let mut indices: Vec<usize> = (0..dataset.len()).collect();
274        indices.shuffle(&mut thread_rng());
275
276        let mut subsets = Vec::with_capacity(lengths.len());
277        let mut offset = 0;
278        for &len in lengths {
279            let subset_indices = indices[offset..offset + len].to_vec();
280            subsets.push(Self::new(dataset.clone(), subset_indices));
281            offset += len;
282        }
283        subsets
284    }
285}
286
287impl<D: Dataset> Dataset for SubsetDataset<D> {
288    type Item = D::Item;
289
290    fn len(&self) -> usize {
291        self.indices.len()
292    }
293
294    fn get(&self, index: usize) -> Option<Self::Item> {
295        let real_index = *self.indices.get(index)?;
296        self.dataset.get(real_index)
297    }
298}
299
300// =============================================================================
301// InMemoryDataset
302// =============================================================================
303
304/// A simple in-memory dataset from a vector.
305pub struct InMemoryDataset<T: Clone + Send> {
306    items: Vec<T>,
307}
308
309impl<T: Clone + Send> InMemoryDataset<T> {
310    /// Creates a new `InMemoryDataset` from a vector.
311    #[must_use]
312    pub fn new(items: Vec<T>) -> Self {
313        Self { items }
314    }
315}
316
317impl<T: Clone + Send + Sync> Dataset for InMemoryDataset<T> {
318    type Item = T;
319
320    fn len(&self) -> usize {
321        self.items.len()
322    }
323
324    fn get(&self, index: usize) -> Option<Self::Item> {
325        self.items.get(index).cloned()
326    }
327}
328
329// =============================================================================
330// Tests
331// =============================================================================
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_tensor_dataset() {
339        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
340        let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap();
341        let dataset = TensorDataset::new(data, targets);
342
343        assert_eq!(dataset.len(), 3);
344
345        let (x, y) = dataset.get(0).unwrap();
346        assert_eq!(x.to_vec(), vec![1.0, 2.0]);
347        assert_eq!(y.to_vec(), vec![0.0]);
348
349        let (x, y) = dataset.get(2).unwrap();
350        assert_eq!(x.to_vec(), vec![5.0, 6.0]);
351        assert_eq!(y.to_vec(), vec![2.0]);
352
353        assert!(dataset.get(3).is_none());
354    }
355
356    #[test]
357    fn test_map_dataset() {
358        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4, 1]).unwrap();
359        let targets = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4]).unwrap();
360        let base = TensorDataset::new(data, targets);
361
362        let mapped = MapDataset::new(base, |(x, y)| (x.mul_scalar(2.0), y));
363
364        assert_eq!(mapped.len(), 4);
365        let (x, _) = mapped.get(0).unwrap();
366        assert_eq!(x.to_vec(), vec![2.0]);
367    }
368
369    #[test]
370    fn test_concat_dataset() {
371        let data1 = Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap();
372        let targets1 = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
373        let ds1 = TensorDataset::new(data1, targets1);
374
375        let data2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3, 1]).unwrap();
376        let targets2 = Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap();
377        let ds2 = TensorDataset::new(data2, targets2);
378
379        let concat = ConcatDataset::new(vec![ds1, ds2]);
380
381        assert_eq!(concat.len(), 5);
382
383        let (x, y) = concat.get(0).unwrap();
384        assert_eq!(x.to_vec(), vec![1.0]);
385        assert_eq!(y.to_vec(), vec![0.0]);
386
387        let (x, y) = concat.get(3).unwrap();
388        assert_eq!(x.to_vec(), vec![4.0]);
389        assert_eq!(y.to_vec(), vec![3.0]);
390    }
391
392    #[test]
393    fn test_subset_dataset() {
394        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5, 1]).unwrap();
395        let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5]).unwrap();
396        let base = TensorDataset::new(data, targets);
397
398        let subset = SubsetDataset::new(base, vec![0, 2, 4]);
399
400        assert_eq!(subset.len(), 3);
401
402        let (x, _) = subset.get(0).unwrap();
403        assert_eq!(x.to_vec(), vec![1.0]);
404
405        let (x, _) = subset.get(1).unwrap();
406        assert_eq!(x.to_vec(), vec![3.0]);
407
408        let (x, _) = subset.get(2).unwrap();
409        assert_eq!(x.to_vec(), vec![5.0]);
410    }
411
412    #[test]
413    fn test_in_memory_dataset() {
414        let dataset = InMemoryDataset::new(vec![1, 2, 3, 4, 5]);
415
416        assert_eq!(dataset.len(), 5);
417        assert_eq!(dataset.get(0), Some(1));
418        assert_eq!(dataset.get(4), Some(5));
419        assert_eq!(dataset.get(5), None);
420    }
421}