Skip to main content

axonml_data/
lib.rs

1//! axonml-data - Data Loading Utilities
2//!
3//! # File
4//! `crates/axonml-data/src/lib.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17#![warn(missing_docs)]
18#![warn(clippy::all)]
19#![warn(clippy::pedantic)]
20// ML/tensor-specific allowances
21#![allow(clippy::cast_possible_truncation)]
22#![allow(clippy::cast_sign_loss)]
23#![allow(clippy::cast_precision_loss)]
24#![allow(clippy::cast_possible_wrap)]
25#![allow(clippy::missing_errors_doc)]
26#![allow(clippy::missing_panics_doc)]
27#![allow(clippy::must_use_candidate)]
28#![allow(clippy::module_name_repetitions)]
29#![allow(clippy::similar_names)]
30#![allow(clippy::many_single_char_names)]
31#![allow(clippy::too_many_arguments)]
32#![allow(clippy::doc_markdown)]
33#![allow(clippy::cast_lossless)]
34#![allow(clippy::needless_pass_by_value)]
35#![allow(clippy::redundant_closure_for_method_calls)]
36#![allow(clippy::uninlined_format_args)]
37#![allow(clippy::ptr_arg)]
38#![allow(clippy::return_self_not_must_use)]
39#![allow(clippy::not_unsafe_ptr_arg_deref)]
40#![allow(clippy::items_after_statements)]
41#![allow(clippy::unreadable_literal)]
42#![allow(clippy::if_same_then_else)]
43#![allow(clippy::needless_range_loop)]
44#![allow(clippy::trivially_copy_pass_by_ref)]
45#![allow(clippy::unnecessary_wraps)]
46#![allow(clippy::match_same_arms)]
47#![allow(clippy::unused_self)]
48#![allow(clippy::too_many_lines)]
49#![allow(clippy::single_match_else)]
50#![allow(clippy::fn_params_excessive_bools)]
51#![allow(clippy::struct_excessive_bools)]
52#![allow(clippy::format_push_string)]
53#![allow(clippy::erasing_op)]
54#![allow(clippy::type_repetition_in_bounds)]
55#![allow(clippy::iter_without_into_iter)]
56#![allow(clippy::should_implement_trait)]
57#![allow(clippy::use_debug)]
58#![allow(clippy::case_sensitive_file_extension_comparisons)]
59#![allow(clippy::large_enum_variant)]
60#![allow(clippy::panic)]
61#![allow(clippy::struct_field_names)]
62#![allow(clippy::missing_fields_in_debug)]
63#![allow(clippy::upper_case_acronyms)]
64#![allow(clippy::assigning_clones)]
65#![allow(clippy::option_if_let_else)]
66#![allow(clippy::manual_let_else)]
67#![allow(clippy::explicit_iter_loop)]
68#![allow(clippy::default_trait_access)]
69#![allow(clippy::only_used_in_recursion)]
70#![allow(clippy::manual_clamp)]
71#![allow(clippy::ref_option)]
72#![allow(clippy::multiple_bound_locations)]
73#![allow(clippy::comparison_chain)]
74#![allow(clippy::manual_assert)]
75#![allow(clippy::unnecessary_debug_formatting)]
76
77// =============================================================================
78// Module Declarations
79// =============================================================================
80
81pub mod collate;
82pub mod dataloader;
83pub mod dataset;
84pub mod sampler;
85pub mod transforms;
86
87// =============================================================================
88// Re-exports
89// =============================================================================
90
91pub use collate::{Collate, DefaultCollate, StackCollate};
92pub use dataloader::{Batch, DataLoader, DataLoaderIter, GpuPrefetchIter};
93pub use dataset::{
94    ConcatDataset, Dataset, InMemoryDataset, MapDataset, SubsetDataset, TensorDataset,
95};
96pub use sampler::{
97    BatchSampler, RandomSampler, Sampler, SequentialSampler, SubsetRandomSampler,
98    WeightedRandomSampler,
99};
100pub use transforms::{
101    Clamp, Compose, DropoutTransform, Flatten, Lambda, Normalize, RandomCrop, RandomFlip,
102    RandomNoise, Reshape, Scale, ToTensor, Transform,
103};
104
105// =============================================================================
106// Prelude
107// =============================================================================
108
109/// Common imports for data loading.
110pub mod prelude {
111    pub use crate::{
112        Batch, BatchSampler, Collate, Compose, ConcatDataset, DataLoader, DataLoaderIter, Dataset,
113        DefaultCollate, GpuPrefetchIter, InMemoryDataset, MapDataset, Normalize, RandomNoise,
114        RandomSampler, Sampler, SequentialSampler, StackCollate, SubsetDataset,
115        SubsetRandomSampler, TensorDataset, ToTensor, Transform, WeightedRandomSampler,
116    };
117    pub use axonml_tensor::Tensor;
118}
119
120// =============================================================================
121// Tests
122// =============================================================================
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use axonml_tensor::Tensor;
128
129    #[test]
130    fn test_tensor_dataset() {
131        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
132        let y = Tensor::from_vec(vec![0.0, 1.0, 0.0], &[3]).unwrap();
133        let dataset = TensorDataset::new(x, y);
134
135        assert_eq!(dataset.len(), 3);
136        let (x_item, y_item) = dataset.get(0).unwrap();
137        assert_eq!(x_item.to_vec(), vec![1.0, 2.0]);
138        assert_eq!(y_item.to_vec(), vec![0.0]);
139    }
140
141    #[test]
142    fn test_dataloader_basic() {
143        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6, 1]).unwrap();
144        let y = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0], &[6]).unwrap();
145        let dataset = TensorDataset::new(x, y);
146        let loader = DataLoader::new(dataset, 2);
147
148        let batches: Vec<_> = loader.iter().collect();
149        assert_eq!(batches.len(), 3); // 6 items / 2 batch_size = 3 batches
150    }
151
152    #[test]
153    fn test_dataloader_shuffle() {
154        let x = Tensor::from_vec((0..100).map(|i| i as f32).collect(), &[100, 1]).unwrap();
155        let y = Tensor::from_vec((0..100).map(|i| i as f32).collect(), &[100]).unwrap();
156        let dataset = TensorDataset::new(x, y);
157        let loader = DataLoader::new(dataset, 10).shuffle(true);
158
159        // Collect first batch from two iterations - they should differ if shuffled
160        let batch1: Vec<_> = loader.iter().take(1).collect();
161        let batch2: Vec<_> = loader.iter().take(1).collect();
162
163        // Due to randomness, we can't guarantee they're different,
164        // but at least verify the loader works
165        assert!(!batch1.is_empty());
166        assert!(!batch2.is_empty());
167    }
168
169    #[test]
170    fn test_transform_compose() {
171        let normalize = Normalize::new(0.0, 1.0);
172        let noise = RandomNoise::new(0.0);
173        let transform = Compose::new(vec![Box::new(normalize), Box::new(noise)]);
174
175        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
176        let output = transform.apply(&input);
177        assert_eq!(output.shape(), &[3]);
178    }
179
180    #[test]
181    fn test_samplers() {
182        let sequential = SequentialSampler::new(10);
183        let indices: Vec<_> = sequential.iter().collect();
184        assert_eq!(indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
185
186        let random = RandomSampler::new(10);
187        let indices: Vec<_> = random.iter().collect();
188        assert_eq!(indices.len(), 10);
189    }
190}