Skip to main content

axonml_data/
dataset.rs

1//! Dataset Trait - Core Data Abstraction
2//!
3//! # File
4//! `crates/axonml-data/src/dataset.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use axonml_tensor::Tensor;
19
20// =============================================================================
21// Dataset Trait
22// =============================================================================
23
24/// Core trait for all datasets.
25///
26/// A dataset provides indexed access to data items.
27pub trait Dataset: Send + Sync {
28    /// The type of items in the dataset.
29    type Item: Send;
30
31    /// Returns the number of items in the dataset.
32    fn len(&self) -> usize;
33
34    /// Returns true if the dataset is empty.
35    fn is_empty(&self) -> bool {
36        self.len() == 0
37    }
38
39    /// Gets an item by index.
40    fn get(&self, index: usize) -> Option<Self::Item>;
41}
42
43// =============================================================================
44// TensorDataset
45// =============================================================================
46
47/// A dataset wrapping tensors.
48///
49/// Each item is a tuple of (input, target) tensors.
50/// Data is cached as flat vectors to avoid per-access O(N) copies.
51#[derive(Clone)]
52pub struct TensorDataset {
53    /// Cached flat data (extracted once at construction).
54    data_vec: Vec<f32>,
55    /// Cached flat targets (extracted once at construction).
56    target_vec: Vec<f32>,
57    /// Shape of the data tensor (first dim is sample count).
58    data_shape: Vec<usize>,
59    /// Shape of the targets tensor (first dim is sample count).
60    target_shape: Vec<usize>,
61    /// Number of elements per data sample (product of shape[1..]).
62    row_size: usize,
63    /// Number of elements per target sample.
64    target_row_size: usize,
65    /// Number of samples.
66    len: usize,
67}
68
69impl TensorDataset {
70    /// Creates a new `TensorDataset` from input and target tensors.
71    ///
72    /// The first dimension of both tensors must match.
73    #[must_use]
74    pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
75        let data_shape = data.shape().to_vec();
76        let target_shape = targets.shape().to_vec();
77        let len = data_shape[0];
78        assert_eq!(
79            len, target_shape[0],
80            "Data and targets must have same first dimension"
81        );
82
83        let row_size: usize = data_shape[1..].iter().product();
84        let target_row_size: usize = if target_shape.len() > 1 {
85            target_shape[1..].iter().product()
86        } else {
87            1
88        };
89
90        Self {
91            data_vec: data.to_vec(),
92            target_vec: targets.to_vec(),
93            data_shape,
94            target_shape,
95            row_size,
96            target_row_size,
97            len,
98        }
99    }
100
101    /// Creates a `TensorDataset` from just input data (no targets).
102    #[must_use]
103    pub fn from_data(data: Tensor<f32>) -> Self {
104        let len = data.shape()[0];
105        let targets = Tensor::from_vec(vec![0.0; len], &[len]).unwrap();
106        Self::new(data, targets)
107    }
108}
109
110impl Dataset for TensorDataset {
111    type Item = (Tensor<f32>, Tensor<f32>);
112
113    fn len(&self) -> usize {
114        self.len
115    }
116
117    fn get(&self, index: usize) -> Option<Self::Item> {
118        if index >= self.len {
119            return None;
120        }
121
122        // Slice directly from cached vec — O(row_size) not O(dataset_size)
123        let start = index * self.row_size;
124        let end = start + self.row_size;
125        let item_data = self.data_vec[start..end].to_vec();
126        let item_shape: Vec<usize> = self.data_shape[1..].to_vec();
127        let x = Tensor::from_vec(item_data, &item_shape).unwrap();
128
129        let target_start = index * self.target_row_size;
130        let target_end = target_start + self.target_row_size;
131        let item_target = self.target_vec[target_start..target_end].to_vec();
132        let target_item_shape: Vec<usize> = if self.target_shape.len() > 1 {
133            self.target_shape[1..].to_vec()
134        } else {
135            vec![1]
136        };
137        let y = Tensor::from_vec(item_target, &target_item_shape).unwrap();
138
139        Some((x, y))
140    }
141}
142
143// =============================================================================
144// MapDataset
145// =============================================================================
146
147/// A dataset that applies a transform to another dataset.
148pub struct MapDataset<D, F>
149where
150    D: Dataset,
151    F: Fn(D::Item) -> D::Item + Send + Sync,
152{
153    dataset: D,
154    transform: F,
155}
156
157impl<D, F> MapDataset<D, F>
158where
159    D: Dataset,
160    F: Fn(D::Item) -> D::Item + Send + Sync,
161{
162    /// Creates a new `MapDataset`.
163    pub fn new(dataset: D, transform: F) -> Self {
164        Self { dataset, transform }
165    }
166}
167
168impl<D, F> Dataset for MapDataset<D, F>
169where
170    D: Dataset,
171    F: Fn(D::Item) -> D::Item + Send + Sync,
172{
173    type Item = D::Item;
174
175    fn len(&self) -> usize {
176        self.dataset.len()
177    }
178
179    fn get(&self, index: usize) -> Option<Self::Item> {
180        self.dataset.get(index).map(&self.transform)
181    }
182}
183
184// =============================================================================
185// ConcatDataset
186// =============================================================================
187
188/// A dataset that concatenates multiple datasets.
189pub struct ConcatDataset<D: Dataset> {
190    datasets: Vec<D>,
191    cumulative_sizes: Vec<usize>,
192}
193
194impl<D: Dataset> ConcatDataset<D> {
195    /// Creates a new `ConcatDataset` from multiple datasets.
196    #[must_use]
197    pub fn new(datasets: Vec<D>) -> Self {
198        let mut cumulative_sizes = Vec::with_capacity(datasets.len());
199        let mut total = 0;
200        for d in &datasets {
201            total += d.len();
202            cumulative_sizes.push(total);
203        }
204        Self {
205            datasets,
206            cumulative_sizes,
207        }
208    }
209
210    /// Finds which dataset contains the given index.
211    fn find_dataset(&self, index: usize) -> Option<(usize, usize)> {
212        if index >= self.len() {
213            return None;
214        }
215
216        for (i, &cum_size) in self.cumulative_sizes.iter().enumerate() {
217            if index < cum_size {
218                let prev_size = if i == 0 {
219                    0
220                } else {
221                    self.cumulative_sizes[i - 1]
222                };
223                return Some((i, index - prev_size));
224            }
225        }
226        None
227    }
228}
229
230impl<D: Dataset> Dataset for ConcatDataset<D> {
231    type Item = D::Item;
232
233    fn len(&self) -> usize {
234        *self.cumulative_sizes.last().unwrap_or(&0)
235    }
236
237    fn get(&self, index: usize) -> Option<Self::Item> {
238        let (dataset_idx, local_idx) = self.find_dataset(index)?;
239        self.datasets[dataset_idx].get(local_idx)
240    }
241}
242
243// =============================================================================
244// SubsetDataset
245// =============================================================================
246
247/// A dataset that provides a subset of another dataset.
248pub struct SubsetDataset<D: Dataset> {
249    dataset: D,
250    indices: Vec<usize>,
251}
252
253impl<D: Dataset> SubsetDataset<D> {
254    /// Creates a new `SubsetDataset` with specified indices.
255    pub fn new(dataset: D, indices: Vec<usize>) -> Self {
256        Self { dataset, indices }
257    }
258
259    /// Creates a random split of a dataset into two subsets.
260    pub fn random_split(dataset: D, lengths: &[usize]) -> Vec<Self>
261    where
262        D: Clone,
263    {
264        use rand::seq::SliceRandom;
265        use rand::thread_rng;
266
267        let total_len: usize = lengths.iter().sum();
268        assert_eq!(
269            total_len,
270            dataset.len(),
271            "Split lengths must sum to dataset length"
272        );
273
274        let mut indices: Vec<usize> = (0..dataset.len()).collect();
275        indices.shuffle(&mut thread_rng());
276
277        let mut subsets = Vec::with_capacity(lengths.len());
278        let mut offset = 0;
279        for &len in lengths {
280            let subset_indices = indices[offset..offset + len].to_vec();
281            subsets.push(Self::new(dataset.clone(), subset_indices));
282            offset += len;
283        }
284        subsets
285    }
286}
287
288impl<D: Dataset> Dataset for SubsetDataset<D> {
289    type Item = D::Item;
290
291    fn len(&self) -> usize {
292        self.indices.len()
293    }
294
295    fn get(&self, index: usize) -> Option<Self::Item> {
296        let real_index = *self.indices.get(index)?;
297        self.dataset.get(real_index)
298    }
299}
300
301// =============================================================================
302// InMemoryDataset
303// =============================================================================
304
305/// A simple in-memory dataset from a vector.
306pub struct InMemoryDataset<T: Clone + Send> {
307    items: Vec<T>,
308}
309
310impl<T: Clone + Send> InMemoryDataset<T> {
311    /// Creates a new `InMemoryDataset` from a vector.
312    #[must_use]
313    pub fn new(items: Vec<T>) -> Self {
314        Self { items }
315    }
316}
317
318impl<T: Clone + Send + Sync> Dataset for InMemoryDataset<T> {
319    type Item = T;
320
321    fn len(&self) -> usize {
322        self.items.len()
323    }
324
325    fn get(&self, index: usize) -> Option<Self::Item> {
326        self.items.get(index).cloned()
327    }
328}
329
330// =============================================================================
331// Tests
332// =============================================================================
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn test_tensor_dataset() {
340        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
341        let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap();
342        let dataset = TensorDataset::new(data, targets);
343
344        assert_eq!(dataset.len(), 3);
345
346        let (x, y) = dataset.get(0).unwrap();
347        assert_eq!(x.to_vec(), vec![1.0, 2.0]);
348        assert_eq!(y.to_vec(), vec![0.0]);
349
350        let (x, y) = dataset.get(2).unwrap();
351        assert_eq!(x.to_vec(), vec![5.0, 6.0]);
352        assert_eq!(y.to_vec(), vec![2.0]);
353
354        assert!(dataset.get(3).is_none());
355    }
356
357    #[test]
358    fn test_map_dataset() {
359        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4, 1]).unwrap();
360        let targets = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4]).unwrap();
361        let base = TensorDataset::new(data, targets);
362
363        let mapped = MapDataset::new(base, |(x, y)| (x.mul_scalar(2.0), y));
364
365        assert_eq!(mapped.len(), 4);
366        let (x, _) = mapped.get(0).unwrap();
367        assert_eq!(x.to_vec(), vec![2.0]);
368    }
369
370    #[test]
371    fn test_concat_dataset() {
372        let data1 = Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap();
373        let targets1 = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
374        let ds1 = TensorDataset::new(data1, targets1);
375
376        let data2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3, 1]).unwrap();
377        let targets2 = Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap();
378        let ds2 = TensorDataset::new(data2, targets2);
379
380        let concat = ConcatDataset::new(vec![ds1, ds2]);
381
382        assert_eq!(concat.len(), 5);
383
384        let (x, y) = concat.get(0).unwrap();
385        assert_eq!(x.to_vec(), vec![1.0]);
386        assert_eq!(y.to_vec(), vec![0.0]);
387
388        let (x, y) = concat.get(3).unwrap();
389        assert_eq!(x.to_vec(), vec![4.0]);
390        assert_eq!(y.to_vec(), vec![3.0]);
391    }
392
393    #[test]
394    fn test_subset_dataset() {
395        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5, 1]).unwrap();
396        let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5]).unwrap();
397        let base = TensorDataset::new(data, targets);
398
399        let subset = SubsetDataset::new(base, vec![0, 2, 4]);
400
401        assert_eq!(subset.len(), 3);
402
403        let (x, _) = subset.get(0).unwrap();
404        assert_eq!(x.to_vec(), vec![1.0]);
405
406        let (x, _) = subset.get(1).unwrap();
407        assert_eq!(x.to_vec(), vec![3.0]);
408
409        let (x, _) = subset.get(2).unwrap();
410        assert_eq!(x.to_vec(), vec![5.0]);
411    }
412
413    #[test]
414    fn test_in_memory_dataset() {
415        let dataset = InMemoryDataset::new(vec![1, 2, 3, 4, 5]);
416
417        assert_eq!(dataset.len(), 5);
418        assert_eq!(dataset.get(0), Some(1));
419        assert_eq!(dataset.get(4), Some(5));
420        assert_eq!(dataset.get(5), None);
421    }
422}