1#![warn(missing_docs)]
31#![warn(clippy::all)]
32#![warn(clippy::pedantic)]
33#![allow(clippy::cast_possible_truncation)]
35#![allow(clippy::cast_sign_loss)]
36#![allow(clippy::cast_precision_loss)]
37#![allow(clippy::cast_possible_wrap)]
38#![allow(clippy::missing_errors_doc)]
39#![allow(clippy::missing_panics_doc)]
40#![allow(clippy::must_use_candidate)]
41#![allow(clippy::module_name_repetitions)]
42#![allow(clippy::similar_names)]
43#![allow(clippy::many_single_char_names)]
44#![allow(clippy::too_many_arguments)]
45#![allow(clippy::doc_markdown)]
46#![allow(clippy::cast_lossless)]
47#![allow(clippy::needless_pass_by_value)]
48#![allow(clippy::redundant_closure_for_method_calls)]
49#![allow(clippy::uninlined_format_args)]
50#![allow(clippy::ptr_arg)]
51#![allow(clippy::return_self_not_must_use)]
52#![allow(clippy::not_unsafe_ptr_arg_deref)]
53#![allow(clippy::items_after_statements)]
54#![allow(clippy::unreadable_literal)]
55#![allow(clippy::if_same_then_else)]
56#![allow(clippy::needless_range_loop)]
57#![allow(clippy::trivially_copy_pass_by_ref)]
58#![allow(clippy::unnecessary_wraps)]
59#![allow(clippy::match_same_arms)]
60#![allow(clippy::unused_self)]
61#![allow(clippy::too_many_lines)]
62#![allow(clippy::single_match_else)]
63#![allow(clippy::fn_params_excessive_bools)]
64#![allow(clippy::struct_excessive_bools)]
65#![allow(clippy::format_push_string)]
66#![allow(clippy::erasing_op)]
67#![allow(clippy::type_repetition_in_bounds)]
68#![allow(clippy::iter_without_into_iter)]
69#![allow(clippy::should_implement_trait)]
70#![allow(clippy::use_debug)]
71#![allow(clippy::case_sensitive_file_extension_comparisons)]
72#![allow(clippy::large_enum_variant)]
73#![allow(clippy::panic)]
74#![allow(clippy::struct_field_names)]
75#![allow(clippy::missing_fields_in_debug)]
76#![allow(clippy::upper_case_acronyms)]
77#![allow(clippy::assigning_clones)]
78#![allow(clippy::option_if_let_else)]
79#![allow(clippy::manual_let_else)]
80#![allow(clippy::explicit_iter_loop)]
81#![allow(clippy::default_trait_access)]
82#![allow(clippy::only_used_in_recursion)]
83#![allow(clippy::manual_clamp)]
84#![allow(clippy::ref_option)]
85#![allow(clippy::multiple_bound_locations)]
86#![allow(clippy::comparison_chain)]
87#![allow(clippy::manual_assert)]
88#![allow(clippy::unnecessary_debug_formatting)]
89
90pub mod datasets;
91pub mod hub;
92pub mod models;
93pub mod transforms;
94
95pub use transforms::{
100 CenterCrop, ColorJitter, Grayscale, ImageNormalize, Pad, RandomHorizontalFlip, RandomRotation,
101 RandomVerticalFlip, Resize, ToTensorImage,
102};
103
104pub use datasets::{FashionMNIST, SyntheticCIFAR, SyntheticMNIST, CIFAR10, CIFAR100, MNIST};
105
106pub use models::{LeNet, SimpleCNN, MLP};
107
108pub use hub::{
109 cache_dir, download_weights, is_cached, list_models, load_state_dict, model_info,
110 model_registry, HubError, HubResult, PretrainedModel, StateDict,
111};
112
113pub mod prelude {
119 pub use crate::{
120 CenterCrop,
121 ColorJitter,
122 FashionMNIST,
123 Grayscale,
124 ImageNormalize,
125 LeNet,
127 Pad,
128 RandomHorizontalFlip,
129 RandomRotation,
130 RandomVerticalFlip,
131 Resize,
133 SimpleCNN,
134 SyntheticCIFAR,
135 SyntheticMNIST,
136 ToTensorImage,
137 CIFAR10,
138 CIFAR100,
139 MLP,
140 MNIST,
142 };
143
144 pub use axonml_autograd::Variable;
146 pub use axonml_data::{Compose, DataLoader, Dataset, Transform};
147 pub use axonml_nn::Module;
148 pub use axonml_tensor::Tensor;
149}
150
151#[cfg(test)]
156mod tests {
157 use super::*;
158 use axonml_data::{Compose, Dataset, Transform};
159 use axonml_tensor::Tensor;
160
161 #[test]
162 fn test_synthetic_mnist_with_transforms() {
163 let dataset = SyntheticMNIST::small();
164 let normalize = ImageNormalize::mnist();
165
166 let (image, label) = dataset.get(0).unwrap();
167 let normalized = normalize.apply(&image);
168
169 assert_eq!(normalized.shape(), &[1, 28, 28]);
170 assert_eq!(label.shape(), &[10]);
171 }
172
173 #[test]
174 fn test_synthetic_cifar_with_transforms() {
175 let dataset = SyntheticCIFAR::small();
176 let normalize = ImageNormalize::cifar10();
177
178 let (image, label) = dataset.get(0).unwrap();
179 let normalized = normalize.apply(&image);
180
181 assert_eq!(normalized.shape(), &[3, 32, 32]);
182 assert_eq!(label.shape(), &[10]);
183 }
184
185 #[test]
186 fn test_transform_pipeline() {
187 let transform = Compose::empty()
188 .add(Resize::new(32, 32))
189 .add(RandomHorizontalFlip::with_probability(0.0)) .add(ImageNormalize::new(vec![0.5], vec![0.5]));
191
192 let input = Tensor::from_vec(vec![0.5; 28 * 28], &[1, 28, 28]).unwrap();
193 let output = transform.apply(&input);
194
195 assert_eq!(output.shape(), &[1, 32, 32]);
196 }
197
198 #[test]
199 fn test_lenet_with_synthetic_data() {
200 use axonml_autograd::Variable;
201 use axonml_nn::Module;
202
203 let dataset = SyntheticMNIST::small();
204 let model = LeNet::new();
205
206 let (image, _label) = dataset.get(0).unwrap();
208
209 let batched = Tensor::from_vec(image.to_vec(), &[1, 1, 28, 28]).unwrap();
211
212 let input = Variable::new(batched, false);
213 let output = model.forward(&input);
214
215 assert_eq!(output.data().shape(), &[1, 10]);
216 }
217
218 #[test]
219 fn test_mlp_with_synthetic_data() {
220 use axonml_autograd::Variable;
221 use axonml_nn::Module;
222
223 let dataset = SyntheticMNIST::small();
224 let model = MLP::for_mnist();
225
226 let (image, _) = dataset.get(0).unwrap();
227
228 let flattened = Tensor::from_vec(image.to_vec(), &[1, 784]).unwrap();
230
231 let input = Variable::new(flattened, false);
232 let output = model.forward(&input);
233
234 assert_eq!(output.data().shape(), &[1, 10]);
235 }
236
237 #[test]
238 fn test_resize_and_crop_pipeline() {
239 let transform = Compose::empty()
240 .add(Resize::new(64, 64))
241 .add(CenterCrop::new(32, 32));
242
243 let input = Tensor::from_vec(vec![0.5; 3 * 28 * 28], &[3, 28, 28]).unwrap();
244 let output = transform.apply(&input);
245
246 assert_eq!(output.shape(), &[3, 32, 32]);
247 }
248
249 #[test]
250 fn test_grayscale_transform() {
251 let transform = Grayscale::new();
252 let input = Tensor::from_vec(vec![0.5; 3 * 32 * 32], &[3, 32, 32]).unwrap();
253 let output = transform.apply(&input);
254
255 assert_eq!(output.shape(), &[1, 32, 32]);
256 }
257
258 #[test]
259 fn test_full_training_pipeline() {
260 use axonml_autograd::Variable;
261 use axonml_data::DataLoader;
262 use axonml_nn::Module;
263
264 let dataset = SyntheticMNIST::new(32);
266
267 let loader = DataLoader::new(dataset, 8);
269
270 let model = MLP::for_mnist();
272
273 let mut processed_batches = 0;
275 for batch in loader.iter().take(1) {
276 let batch_data = batch.data.to_vec();
278 let batch_size = batch.data.shape()[0];
279 let features: usize = batch.data.shape()[1..].iter().product();
280
281 let flattened = Tensor::from_vec(batch_data, &[batch_size, features]).unwrap();
282 let input = Variable::new(flattened, false);
283
284 let output = model.forward(&input);
285 assert_eq!(output.data().shape()[0], batch_size);
286 assert_eq!(output.data().shape()[1], 10);
287
288 processed_batches += 1;
289 }
290
291 assert_eq!(processed_batches, 1);
292 }
293}