1#![warn(missing_docs)]
27#![warn(clippy::all)]
28#![warn(clippy::pedantic)]
29#![allow(clippy::cast_possible_truncation)]
31#![allow(clippy::cast_sign_loss)]
32#![allow(clippy::cast_precision_loss)]
33#![allow(clippy::cast_possible_wrap)]
34#![allow(clippy::missing_errors_doc)]
35#![allow(clippy::missing_panics_doc)]
36#![allow(clippy::must_use_candidate)]
37#![allow(clippy::module_name_repetitions)]
38#![allow(clippy::similar_names)]
39#![allow(clippy::many_single_char_names)]
40#![allow(clippy::too_many_arguments)]
41#![allow(clippy::doc_markdown)]
42#![allow(clippy::cast_lossless)]
43#![allow(clippy::needless_pass_by_value)]
44#![allow(clippy::redundant_closure_for_method_calls)]
45#![allow(clippy::uninlined_format_args)]
46#![allow(clippy::ptr_arg)]
47#![allow(clippy::return_self_not_must_use)]
48#![allow(clippy::not_unsafe_ptr_arg_deref)]
49#![allow(clippy::items_after_statements)]
50#![allow(clippy::unreadable_literal)]
51#![allow(clippy::if_same_then_else)]
52#![allow(clippy::needless_range_loop)]
53#![allow(clippy::trivially_copy_pass_by_ref)]
54#![allow(clippy::unnecessary_wraps)]
55#![allow(clippy::match_same_arms)]
56#![allow(clippy::unused_self)]
57#![allow(clippy::too_many_lines)]
58#![allow(clippy::single_match_else)]
59#![allow(clippy::fn_params_excessive_bools)]
60#![allow(clippy::struct_excessive_bools)]
61#![allow(clippy::format_push_string)]
62#![allow(clippy::erasing_op)]
63#![allow(clippy::type_repetition_in_bounds)]
64#![allow(clippy::iter_without_into_iter)]
65#![allow(clippy::should_implement_trait)]
66#![allow(clippy::use_debug)]
67#![allow(clippy::case_sensitive_file_extension_comparisons)]
68#![allow(clippy::large_enum_variant)]
69#![allow(clippy::panic)]
70#![allow(clippy::struct_field_names)]
71#![allow(clippy::missing_fields_in_debug)]
72#![allow(clippy::upper_case_acronyms)]
73#![allow(clippy::assigning_clones)]
74#![allow(clippy::option_if_let_else)]
75#![allow(clippy::manual_let_else)]
76#![allow(clippy::explicit_iter_loop)]
77#![allow(clippy::default_trait_access)]
78#![allow(clippy::only_used_in_recursion)]
79#![allow(clippy::manual_clamp)]
80#![allow(clippy::ref_option)]
81#![allow(clippy::multiple_bound_locations)]
82#![allow(clippy::comparison_chain)]
83#![allow(clippy::manual_assert)]
84#![allow(clippy::unnecessary_debug_formatting)]
85
86pub mod collate;
91pub mod dataloader;
92pub mod dataset;
93pub mod sampler;
94pub mod transforms;
95
96pub use collate::{Collate, DefaultCollate, StackCollate};
101pub use dataloader::{Batch, DataLoader, DataLoaderIter, GpuPrefetchIter};
102pub use dataset::{
103 ConcatDataset, Dataset, InMemoryDataset, MapDataset, SubsetDataset, TensorDataset,
104};
105pub use sampler::{
106 BatchSampler, RandomSampler, Sampler, SequentialSampler, SubsetRandomSampler,
107 WeightedRandomSampler,
108};
109pub use transforms::{
110 Clamp, Compose, DropoutTransform, Flatten, Lambda, Normalize, RandomCrop, RandomFlip,
111 RandomNoise, Reshape, Scale, ToTensor, Transform,
112};
113
114pub mod prelude {
120 pub use crate::{
121 Batch, BatchSampler, Collate, Compose, ConcatDataset, DataLoader, DataLoaderIter, Dataset,
122 DefaultCollate, GpuPrefetchIter, InMemoryDataset, MapDataset, Normalize, RandomNoise,
123 RandomSampler, Sampler, SequentialSampler, StackCollate, SubsetDataset,
124 SubsetRandomSampler, TensorDataset, ToTensor, Transform, WeightedRandomSampler,
125 };
126 pub use axonml_tensor::Tensor;
127}
128
129#[cfg(test)]
134mod tests {
135 use super::*;
136 use axonml_tensor::Tensor;
137
138 #[test]
139 fn test_tensor_dataset() {
140 let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
141 let y = Tensor::from_vec(vec![0.0, 1.0, 0.0], &[3]).unwrap();
142 let dataset = TensorDataset::new(x, y);
143
144 assert_eq!(dataset.len(), 3);
145 let (x_item, y_item) = dataset.get(0).unwrap();
146 assert_eq!(x_item.to_vec(), vec![1.0, 2.0]);
147 assert_eq!(y_item.to_vec(), vec![0.0]);
148 }
149
150 #[test]
151 fn test_dataloader_basic() {
152 let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6, 1]).unwrap();
153 let y = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0], &[6]).unwrap();
154 let dataset = TensorDataset::new(x, y);
155 let loader = DataLoader::new(dataset, 2);
156
157 let batches: Vec<_> = loader.iter().collect();
158 assert_eq!(batches.len(), 3); }
160
161 #[test]
162 fn test_dataloader_shuffle() {
163 let x = Tensor::from_vec((0..100).map(|i| i as f32).collect(), &[100, 1]).unwrap();
164 let y = Tensor::from_vec((0..100).map(|i| i as f32).collect(), &[100]).unwrap();
165 let dataset = TensorDataset::new(x, y);
166 let loader = DataLoader::new(dataset, 10).shuffle(true);
167
168 let batch1: Vec<_> = loader.iter().take(1).collect();
170 let batch2: Vec<_> = loader.iter().take(1).collect();
171
172 assert!(!batch1.is_empty());
175 assert!(!batch2.is_empty());
176 }
177
178 #[test]
179 fn test_transform_compose() {
180 let normalize = Normalize::new(0.0, 1.0);
181 let noise = RandomNoise::new(0.0);
182 let transform = Compose::new(vec![Box::new(normalize), Box::new(noise)]);
183
184 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
185 let output = transform.apply(&input);
186 assert_eq!(output.shape(), &[3]);
187 }
188
189 #[test]
190 fn test_samplers() {
191 let sequential = SequentialSampler::new(10);
192 let indices: Vec<_> = sequential.iter().collect();
193 assert_eq!(indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
194
195 let random = RandomSampler::new(10);
196 let indices: Vec<_> = random.iter().collect();
197 assert_eq!(indices.len(), 10);
198 }
199}