Skip to main content

axonml_vision/
lib.rs

1//! Axonml Vision - Computer Vision Utilities
2//!
3//! This crate provides computer vision functionality for the Axonml ML framework:
4//!
5//! - **Transforms**: Image-specific data augmentation and preprocessing
6//! - **Datasets**: Loaders for common vision datasets (MNIST, CIFAR)
7//! - **Models**: Pre-defined neural network architectures (`LeNet`, MLP)
8//!
9//! # Example
10//!
11//! ```ignore
12//! use axonml_vision::prelude::*;
13//!
14//! // Load synthetic MNIST data
15//! let train_data = SyntheticMNIST::train();
16//! let test_data = SyntheticMNIST::test();
17//!
18//! // Create a LeNet model
19//! let model = LeNet::new();
20//!
21//! // Apply image transforms
22//! let transform = Compose::empty()
23//!     .add(ImageNormalize::mnist())
24//!     .add(RandomHorizontalFlip::new());
25//! ```
26//!
27//! @version 0.1.0
28//! @author `AutomataNexus` Development Team
29
30#![warn(missing_docs)]
31#![warn(clippy::all)]
32#![warn(clippy::pedantic)]
33// ML/tensor-specific allowances
34#![allow(clippy::cast_possible_truncation)]
35#![allow(clippy::cast_sign_loss)]
36#![allow(clippy::cast_precision_loss)]
37#![allow(clippy::cast_possible_wrap)]
38#![allow(clippy::missing_errors_doc)]
39#![allow(clippy::missing_panics_doc)]
40#![allow(clippy::must_use_candidate)]
41#![allow(clippy::module_name_repetitions)]
42#![allow(clippy::similar_names)]
43#![allow(clippy::many_single_char_names)]
44#![allow(clippy::too_many_arguments)]
45#![allow(clippy::doc_markdown)]
46#![allow(clippy::cast_lossless)]
47#![allow(clippy::needless_pass_by_value)]
48#![allow(clippy::redundant_closure_for_method_calls)]
49#![allow(clippy::uninlined_format_args)]
50#![allow(clippy::ptr_arg)]
51#![allow(clippy::return_self_not_must_use)]
52#![allow(clippy::not_unsafe_ptr_arg_deref)]
53#![allow(clippy::items_after_statements)]
54#![allow(clippy::unreadable_literal)]
55#![allow(clippy::if_same_then_else)]
56#![allow(clippy::needless_range_loop)]
57#![allow(clippy::trivially_copy_pass_by_ref)]
58#![allow(clippy::unnecessary_wraps)]
59#![allow(clippy::match_same_arms)]
60#![allow(clippy::unused_self)]
61#![allow(clippy::too_many_lines)]
62#![allow(clippy::single_match_else)]
63#![allow(clippy::fn_params_excessive_bools)]
64#![allow(clippy::struct_excessive_bools)]
65#![allow(clippy::format_push_string)]
66#![allow(clippy::erasing_op)]
67#![allow(clippy::type_repetition_in_bounds)]
68#![allow(clippy::iter_without_into_iter)]
69#![allow(clippy::should_implement_trait)]
70#![allow(clippy::use_debug)]
71#![allow(clippy::case_sensitive_file_extension_comparisons)]
72#![allow(clippy::large_enum_variant)]
73#![allow(clippy::panic)]
74#![allow(clippy::struct_field_names)]
75#![allow(clippy::missing_fields_in_debug)]
76#![allow(clippy::upper_case_acronyms)]
77#![allow(clippy::assigning_clones)]
78#![allow(clippy::option_if_let_else)]
79#![allow(clippy::manual_let_else)]
80#![allow(clippy::explicit_iter_loop)]
81#![allow(clippy::default_trait_access)]
82#![allow(clippy::only_used_in_recursion)]
83#![allow(clippy::manual_clamp)]
84#![allow(clippy::ref_option)]
85#![allow(clippy::multiple_bound_locations)]
86#![allow(clippy::comparison_chain)]
87#![allow(clippy::manual_assert)]
88#![allow(clippy::unnecessary_debug_formatting)]
89
90pub mod datasets;
91pub mod hub;
92pub mod models;
93pub mod transforms;
94
95// =============================================================================
96// Re-exports
97// =============================================================================
98
99pub use transforms::{
100    CenterCrop, ColorJitter, Grayscale, ImageNormalize, Pad, RandomHorizontalFlip, RandomRotation,
101    RandomVerticalFlip, Resize, ToTensorImage,
102};
103
104pub use datasets::{FashionMNIST, SyntheticCIFAR, SyntheticMNIST, CIFAR10, CIFAR100, MNIST};
105
106pub use models::{LeNet, SimpleCNN, MLP};
107
108pub use hub::{
109    cache_dir, download_weights, is_cached, list_models, load_state_dict, model_info,
110    model_registry, HubError, HubResult, PretrainedModel, StateDict,
111};
112
113// =============================================================================
114// Prelude
115// =============================================================================
116
117/// Common imports for computer vision tasks.
118pub mod prelude {
119    pub use crate::{
120        CenterCrop,
121        ColorJitter,
122        FashionMNIST,
123        Grayscale,
124        ImageNormalize,
125        // Models
126        LeNet,
127        Pad,
128        RandomHorizontalFlip,
129        RandomRotation,
130        RandomVerticalFlip,
131        // Transforms
132        Resize,
133        SimpleCNN,
134        SyntheticCIFAR,
135        SyntheticMNIST,
136        ToTensorImage,
137        CIFAR10,
138        CIFAR100,
139        MLP,
140        // Datasets
141        MNIST,
142    };
143
144    // Re-export useful items from dependencies
145    pub use axonml_autograd::Variable;
146    pub use axonml_data::{Compose, DataLoader, Dataset, Transform};
147    pub use axonml_nn::Module;
148    pub use axonml_tensor::Tensor;
149}
150
151// =============================================================================
152// Tests
153// =============================================================================
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use axonml_data::{Compose, Dataset, Transform};
159    use axonml_tensor::Tensor;
160
161    #[test]
162    fn test_synthetic_mnist_with_transforms() {
163        let dataset = SyntheticMNIST::small();
164        let normalize = ImageNormalize::mnist();
165
166        let (image, label) = dataset.get(0).unwrap();
167        let normalized = normalize.apply(&image);
168
169        assert_eq!(normalized.shape(), &[1, 28, 28]);
170        assert_eq!(label.shape(), &[10]);
171    }
172
173    #[test]
174    fn test_synthetic_cifar_with_transforms() {
175        let dataset = SyntheticCIFAR::small();
176        let normalize = ImageNormalize::cifar10();
177
178        let (image, label) = dataset.get(0).unwrap();
179        let normalized = normalize.apply(&image);
180
181        assert_eq!(normalized.shape(), &[3, 32, 32]);
182        assert_eq!(label.shape(), &[10]);
183    }
184
185    #[test]
186    fn test_transform_pipeline() {
187        let transform = Compose::empty()
188            .add(Resize::new(32, 32))
189            .add(RandomHorizontalFlip::with_probability(0.0)) // No flip for determinism
190            .add(ImageNormalize::new(vec![0.5], vec![0.5]));
191
192        let input = Tensor::from_vec(vec![0.5; 28 * 28], &[1, 28, 28]).unwrap();
193        let output = transform.apply(&input);
194
195        assert_eq!(output.shape(), &[1, 32, 32]);
196    }
197
198    #[test]
199    fn test_lenet_with_synthetic_data() {
200        use axonml_autograd::Variable;
201        use axonml_nn::Module;
202
203        let dataset = SyntheticMNIST::small();
204        let model = LeNet::new();
205
206        // Get a sample and run forward pass
207        let (image, _label) = dataset.get(0).unwrap();
208
209        // Add batch dimension
210        let batched = Tensor::from_vec(image.to_vec(), &[1, 1, 28, 28]).unwrap();
211
212        let input = Variable::new(batched, false);
213        let output = model.forward(&input);
214
215        assert_eq!(output.data().shape(), &[1, 10]);
216    }
217
218    #[test]
219    fn test_mlp_with_synthetic_data() {
220        use axonml_autograd::Variable;
221        use axonml_nn::Module;
222
223        let dataset = SyntheticMNIST::small();
224        let model = MLP::for_mnist();
225
226        let (image, _) = dataset.get(0).unwrap();
227
228        // MLP expects flattened input with batch dimension
229        let flattened = Tensor::from_vec(image.to_vec(), &[1, 784]).unwrap();
230
231        let input = Variable::new(flattened, false);
232        let output = model.forward(&input);
233
234        assert_eq!(output.data().shape(), &[1, 10]);
235    }
236
237    #[test]
238    fn test_resize_and_crop_pipeline() {
239        let transform = Compose::empty()
240            .add(Resize::new(64, 64))
241            .add(CenterCrop::new(32, 32));
242
243        let input = Tensor::from_vec(vec![0.5; 3 * 28 * 28], &[3, 28, 28]).unwrap();
244        let output = transform.apply(&input);
245
246        assert_eq!(output.shape(), &[3, 32, 32]);
247    }
248
249    #[test]
250    fn test_grayscale_transform() {
251        let transform = Grayscale::new();
252        let input = Tensor::from_vec(vec![0.5; 3 * 32 * 32], &[3, 32, 32]).unwrap();
253        let output = transform.apply(&input);
254
255        assert_eq!(output.shape(), &[1, 32, 32]);
256    }
257
258    #[test]
259    fn test_full_training_pipeline() {
260        use axonml_autograd::Variable;
261        use axonml_data::DataLoader;
262        use axonml_nn::Module;
263
264        // Create dataset
265        let dataset = SyntheticMNIST::new(32);
266
267        // Create dataloader
268        let loader = DataLoader::new(dataset, 8);
269
270        // Create model
271        let model = MLP::for_mnist();
272
273        // Process one batch
274        let mut processed_batches = 0;
275        for batch in loader.iter().take(1) {
276            // Flatten images for MLP
277            let batch_data = batch.data.to_vec();
278            let batch_size = batch.data.shape()[0];
279            let features: usize = batch.data.shape()[1..].iter().product();
280
281            let flattened = Tensor::from_vec(batch_data, &[batch_size, features]).unwrap();
282            let input = Variable::new(flattened, false);
283
284            let output = model.forward(&input);
285            assert_eq!(output.data().shape()[0], batch_size);
286            assert_eq!(output.data().shape()[1], 10);
287
288            processed_batches += 1;
289        }
290
291        assert_eq!(processed_batches, 1);
292    }
293}