axonml_data/
lib.rs

1//! axonml-data - Data Loading Utilities
2//!
3//! Provides data loading infrastructure for training neural networks:
4//! - Dataset trait for defining data sources
5//! - `DataLoader` for batched iteration with parallel loading
6//! - Samplers for controlling data access patterns
7//! - Transforms for data augmentation
8//!
9//! # Example
10//!
11//! ```ignore
12//! use axonml_data::prelude::*;
13//!
14//! // Define a simple dataset
15//! struct MyDataset {
16//!     data: Vec<(Tensor<f32>, Tensor<f32>)>,
17//! }
18//!
19//! impl Dataset for MyDataset {
20//!     type Item = (Tensor<f32>, Tensor<f32>);
21//!
22//!     fn len(&self) -> usize {
23//!         self.data.len()
24//!     }
25//!
26//!     fn get(&self, index: usize) -> Option<Self::Item> {
27//!         self.data.get(index).cloned()
28//!     }
29//! }
30//!
31//! // Create a DataLoader
32//! let loader = DataLoader::new(dataset, 32)
33//!     .shuffle(true)
34//!     .num_workers(4);
35//!
36//! for batch in loader.iter() {
37//!     // Process batch
38//! }
39//! ```
40//!
41//! @version 0.1.0
42//! @author `AutomataNexus` Development Team
43
44#![warn(missing_docs)]
45#![warn(clippy::all)]
46#![warn(clippy::pedantic)]
47// ML/tensor-specific allowances
48#![allow(clippy::cast_possible_truncation)]
49#![allow(clippy::cast_sign_loss)]
50#![allow(clippy::cast_precision_loss)]
51#![allow(clippy::cast_possible_wrap)]
52#![allow(clippy::missing_errors_doc)]
53#![allow(clippy::missing_panics_doc)]
54#![allow(clippy::must_use_candidate)]
55#![allow(clippy::module_name_repetitions)]
56#![allow(clippy::similar_names)]
57#![allow(clippy::many_single_char_names)]
58#![allow(clippy::too_many_arguments)]
59#![allow(clippy::doc_markdown)]
60#![allow(clippy::cast_lossless)]
61#![allow(clippy::needless_pass_by_value)]
62#![allow(clippy::redundant_closure_for_method_calls)]
63#![allow(clippy::uninlined_format_args)]
64#![allow(clippy::ptr_arg)]
65#![allow(clippy::return_self_not_must_use)]
66#![allow(clippy::not_unsafe_ptr_arg_deref)]
67#![allow(clippy::items_after_statements)]
68#![allow(clippy::unreadable_literal)]
69#![allow(clippy::if_same_then_else)]
70#![allow(clippy::needless_range_loop)]
71#![allow(clippy::trivially_copy_pass_by_ref)]
72#![allow(clippy::unnecessary_wraps)]
73#![allow(clippy::match_same_arms)]
74#![allow(clippy::unused_self)]
75#![allow(clippy::too_many_lines)]
76#![allow(clippy::single_match_else)]
77#![allow(clippy::fn_params_excessive_bools)]
78#![allow(clippy::struct_excessive_bools)]
79#![allow(clippy::format_push_string)]
80#![allow(clippy::erasing_op)]
81#![allow(clippy::type_repetition_in_bounds)]
82#![allow(clippy::iter_without_into_iter)]
83#![allow(clippy::should_implement_trait)]
84#![allow(clippy::use_debug)]
85#![allow(clippy::case_sensitive_file_extension_comparisons)]
86#![allow(clippy::large_enum_variant)]
87#![allow(clippy::panic)]
88#![allow(clippy::struct_field_names)]
89#![allow(clippy::missing_fields_in_debug)]
90#![allow(clippy::upper_case_acronyms)]
91#![allow(clippy::assigning_clones)]
92#![allow(clippy::option_if_let_else)]
93#![allow(clippy::manual_let_else)]
94#![allow(clippy::explicit_iter_loop)]
95#![allow(clippy::default_trait_access)]
96#![allow(clippy::only_used_in_recursion)]
97#![allow(clippy::manual_clamp)]
98#![allow(clippy::ref_option)]
99#![allow(clippy::multiple_bound_locations)]
100#![allow(clippy::comparison_chain)]
101#![allow(clippy::manual_assert)]
102#![allow(clippy::unnecessary_debug_formatting)]
103
104// =============================================================================
105// Module Declarations
106// =============================================================================
107
108pub mod collate;
109pub mod dataloader;
110pub mod dataset;
111pub mod sampler;
112pub mod transforms;
113
114// =============================================================================
115// Re-exports
116// =============================================================================
117
118pub use collate::{Collate, DefaultCollate, StackCollate};
119pub use dataloader::{Batch, DataLoader, DataLoaderIter};
120pub use dataset::{
121    ConcatDataset, Dataset, InMemoryDataset, MapDataset, SubsetDataset, TensorDataset,
122};
123pub use sampler::{
124    BatchSampler, RandomSampler, Sampler, SequentialSampler, SubsetRandomSampler,
125    WeightedRandomSampler,
126};
127pub use transforms::{Compose, Normalize, RandomNoise, ToTensor, Transform};
128
129// =============================================================================
130// Prelude
131// =============================================================================
132
133/// Common imports for data loading.
134pub mod prelude {
135    pub use crate::{
136        Batch, BatchSampler, Collate, Compose, ConcatDataset, DataLoader, DataLoaderIter, Dataset,
137        DefaultCollate, InMemoryDataset, MapDataset, Normalize, RandomNoise, RandomSampler,
138        Sampler, SequentialSampler, StackCollate, SubsetDataset, SubsetRandomSampler,
139        TensorDataset, ToTensor, Transform, WeightedRandomSampler,
140    };
141    pub use axonml_tensor::Tensor;
142}
143
144// =============================================================================
145// Tests
146// =============================================================================
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use axonml_tensor::Tensor;
152
153    #[test]
154    fn test_tensor_dataset() {
155        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
156        let y = Tensor::from_vec(vec![0.0, 1.0, 0.0], &[3]).unwrap();
157        let dataset = TensorDataset::new(x, y);
158
159        assert_eq!(dataset.len(), 3);
160        let (x_item, y_item) = dataset.get(0).unwrap();
161        assert_eq!(x_item.to_vec(), vec![1.0, 2.0]);
162        assert_eq!(y_item.to_vec(), vec![0.0]);
163    }
164
165    #[test]
166    fn test_dataloader_basic() {
167        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6, 1]).unwrap();
168        let y = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0], &[6]).unwrap();
169        let dataset = TensorDataset::new(x, y);
170        let loader = DataLoader::new(dataset, 2);
171
172        let batches: Vec<_> = loader.iter().collect();
173        assert_eq!(batches.len(), 3); // 6 items / 2 batch_size = 3 batches
174    }
175
176    #[test]
177    fn test_dataloader_shuffle() {
178        let x = Tensor::from_vec((0..100).map(|i| i as f32).collect(), &[100, 1]).unwrap();
179        let y = Tensor::from_vec((0..100).map(|i| i as f32).collect(), &[100]).unwrap();
180        let dataset = TensorDataset::new(x, y);
181        let loader = DataLoader::new(dataset, 10).shuffle(true);
182
183        // Collect first batch from two iterations - they should differ if shuffled
184        let batch1: Vec<_> = loader.iter().take(1).collect();
185        let batch2: Vec<_> = loader.iter().take(1).collect();
186
187        // Due to randomness, we can't guarantee they're different,
188        // but at least verify the loader works
189        assert!(!batch1.is_empty());
190        assert!(!batch2.is_empty());
191    }
192
193    #[test]
194    fn test_transform_compose() {
195        let normalize = Normalize::new(0.0, 1.0);
196        let noise = RandomNoise::new(0.0);
197        let transform = Compose::new(vec![Box::new(normalize), Box::new(noise)]);
198
199        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
200        let output = transform.apply(&input);
201        assert_eq!(output.shape(), &[3]);
202    }
203
204    #[test]
205    fn test_samplers() {
206        let sequential = SequentialSampler::new(10);
207        let indices: Vec<_> = sequential.iter().collect();
208        assert_eq!(indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
209
210        let random = RandomSampler::new(10);
211        let indices: Vec<_> = random.iter().collect();
212        assert_eq!(indices.len(), 10);
213    }
214}