Skip to main content

axonml_optim/
lib.rs

1//! axonml-optim - Optimization Algorithms
2//!
3//! Provides optimizers for training neural networks with comprehensive support
4//! for modern training techniques.
5//!
6//! # Optimizers
7//!
8//! - **SGD** - Stochastic Gradient Descent with momentum and Nesterov acceleration
9//! - **Adam** - Adaptive Moment Estimation
10//! - **AdamW** - Adam with decoupled weight decay
11//! - **RMSprop** - Root Mean Square Propagation
12//! - **LAMB** - Layer-wise Adaptive Moments for large batch training (BERT-scale)
13//!
14//! # Learning Rate Schedulers
15//!
16//! - **StepLR** - Step decay at fixed intervals
17//! - **MultiStepLR** - Decay at specified milestones
18//! - **ExponentialLR** - Exponential decay
19//! - **CosineAnnealingLR** - Cosine annealing
20//! - **OneCycleLR** - 1cycle policy (super-convergence)
21//! - **WarmupLR** - Linear warmup
22//! - **ReduceLROnPlateau** - Reduce on metric plateau
23//!
24//! # Mixed Precision Support
25//!
26//! - **GradScaler** - Gradient scaling for F16 training to prevent underflow
27//!
28//! # Basic Example
29//!
30//! ```ignore
31//! use axonml_optim::prelude::*;
32//! use axonml_nn::{Linear, Module, Sequential};
33//!
34//! // Create model
35//! let model = Sequential::new()
36//!     .add(Linear::new(784, 128))
37//!     .add(Linear::new(128, 10));
38//!
39//! // Create optimizer
40//! let mut optimizer = Adam::new(model.parameters(), 0.001);
41//!
42//! // Training loop
43//! for epoch in 0..100 {
44//!     let output = model.forward(&input);
45//!     let loss = compute_loss(&output, &target);
46//!
47//!     optimizer.zero_grad();
48//!     loss.backward();
49//!     optimizer.step();
50//! }
51//! ```
52//!
53//! # Mixed Precision Training with GradScaler
54//!
55//! ```ignore
56//! use axonml_optim::{Adam, GradScaler};
57//!
58//! let mut optimizer = Adam::new(params, 0.001);
59//! let mut scaler = GradScaler::new();
60//!
61//! for batch in dataloader {
62//!     // Forward pass (with autocast in F16)
63//!     let loss = model.forward(&batch);
64//!
65//!     // Scale loss for backward
66//!     let scaled_loss = scaler.scale_loss(loss);
67//!
68//!     // Backward
69//!     optimizer.zero_grad();
70//!     scaled_loss.backward();
71//!
72//!     // Unscale gradients and check for inf/nan
73//!     if scaler.unscale_grads(&mut grads) {
74//!         optimizer.step();
75//!     }
76//!
77//!     // Update scale factor
78//!     scaler.update();
79//! }
80//! ```
81//!
82//! # LAMB for Large Batch Training
83//!
84//! ```ignore
85//! use axonml_optim::LAMB;
86//!
87//! // LAMB enables training with very large batches (32K+)
88//! let optimizer = LAMB::new(params, 0.001)
89//!     .betas(0.9, 0.999)
90//!     .weight_decay(0.01);
91//! ```
92//!
93//! @version 0.2.6
94//! @author `AutomataNexus` Development Team
95
96#![warn(missing_docs)]
97#![warn(clippy::all)]
98#![warn(clippy::pedantic)]
99// ML/tensor-specific allowances
100#![allow(clippy::cast_possible_truncation)]
101#![allow(clippy::cast_sign_loss)]
102#![allow(clippy::cast_precision_loss)]
103#![allow(clippy::cast_possible_wrap)]
104#![allow(clippy::missing_errors_doc)]
105#![allow(clippy::missing_panics_doc)]
106#![allow(clippy::must_use_candidate)]
107#![allow(clippy::module_name_repetitions)]
108#![allow(clippy::similar_names)]
109#![allow(clippy::many_single_char_names)]
110#![allow(clippy::too_many_arguments)]
111#![allow(clippy::doc_markdown)]
112#![allow(clippy::cast_lossless)]
113#![allow(clippy::needless_pass_by_value)]
114#![allow(clippy::redundant_closure_for_method_calls)]
115#![allow(clippy::uninlined_format_args)]
116#![allow(clippy::ptr_arg)]
117#![allow(clippy::return_self_not_must_use)]
118#![allow(clippy::not_unsafe_ptr_arg_deref)]
119#![allow(clippy::items_after_statements)]
120#![allow(clippy::unreadable_literal)]
121#![allow(clippy::if_same_then_else)]
122#![allow(clippy::needless_range_loop)]
123#![allow(clippy::trivially_copy_pass_by_ref)]
124#![allow(clippy::unnecessary_wraps)]
125#![allow(clippy::match_same_arms)]
126#![allow(clippy::unused_self)]
127#![allow(clippy::too_many_lines)]
128#![allow(clippy::single_match_else)]
129#![allow(clippy::fn_params_excessive_bools)]
130#![allow(clippy::struct_excessive_bools)]
131#![allow(clippy::format_push_string)]
132#![allow(clippy::erasing_op)]
133#![allow(clippy::type_repetition_in_bounds)]
134#![allow(clippy::iter_without_into_iter)]
135#![allow(clippy::should_implement_trait)]
136#![allow(clippy::use_debug)]
137#![allow(clippy::case_sensitive_file_extension_comparisons)]
138#![allow(clippy::large_enum_variant)]
139#![allow(clippy::panic)]
140#![allow(clippy::struct_field_names)]
141#![allow(clippy::missing_fields_in_debug)]
142#![allow(clippy::upper_case_acronyms)]
143#![allow(clippy::assigning_clones)]
144#![allow(clippy::option_if_let_else)]
145#![allow(clippy::manual_let_else)]
146#![allow(clippy::explicit_iter_loop)]
147#![allow(clippy::default_trait_access)]
148#![allow(clippy::only_used_in_recursion)]
149#![allow(clippy::manual_clamp)]
150#![allow(clippy::ref_option)]
151#![allow(clippy::multiple_bound_locations)]
152#![allow(clippy::comparison_chain)]
153#![allow(clippy::manual_assert)]
154#![allow(clippy::unnecessary_debug_formatting)]
155
156// =============================================================================
157// Module Declarations
158// =============================================================================
159
160pub mod adam;
161pub mod grad_scaler;
162pub mod health;
163pub mod lamb;
164pub mod lr_scheduler;
165pub mod optimizer;
166pub mod rmsprop;
167pub mod sgd;
168
169// =============================================================================
170// Re-exports
171// =============================================================================
172
173pub use adam::{Adam, AdamW};
174pub use grad_scaler::{GradScaler, GradScalerState};
175pub use health::{
176    AlertKind, AlertSeverity, HealthReport, LossTrend, MonitorConfig, TrainingAlert,
177    TrainingMonitor,
178};
179pub use lamb::LAMB;
180pub use lr_scheduler::{
181    CosineAnnealingLR, ExponentialLR, LRScheduler, MultiStepLR, OneCycleLR, ReduceLROnPlateau,
182    StepLR, WarmupLR,
183};
184pub use optimizer::Optimizer;
185pub use rmsprop::RMSprop;
186pub use sgd::SGD;
187
188// =============================================================================
189// Prelude
190// =============================================================================
191
192/// Common imports for optimization.
193pub mod prelude {
194    pub use crate::{
195        Adam, AdamW, CosineAnnealingLR, ExponentialLR, GradScaler, LRScheduler, MultiStepLR,
196        OneCycleLR, Optimizer, RMSprop, ReduceLROnPlateau, StepLR, WarmupLR, LAMB, SGD,
197    };
198}
199
200// =============================================================================
201// Tests
202// =============================================================================
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use axonml_autograd::Variable;
208    use axonml_nn::{Linear, MSELoss, Module, ReLU, Sequential};
209    use axonml_tensor::Tensor;
210
211    #[test]
212    fn test_sgd_optimization() {
213        let model = Sequential::new()
214            .add(Linear::new(2, 4))
215            .add(ReLU)
216            .add(Linear::new(4, 1));
217
218        let mut optimizer = SGD::new(model.parameters(), 0.01);
219        let loss_fn = MSELoss::new();
220
221        let input = Variable::new(
222            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
223            false,
224        );
225        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap(), false);
226
227        let initial_loss = loss_fn.compute(&model.forward(&input), &target);
228        let initial_loss_val = initial_loss.data().to_vec()[0];
229
230        // Run a few optimization steps
231        for _ in 0..10 {
232            optimizer.zero_grad();
233            let output = model.forward(&input);
234            let loss = loss_fn.compute(&output, &target);
235            loss.backward();
236            optimizer.step();
237        }
238
239        let final_loss = loss_fn.compute(&model.forward(&input), &target);
240        let final_loss_val = final_loss.data().to_vec()[0];
241
242        // Loss should decrease
243        assert!(final_loss_val <= initial_loss_val);
244    }
245
246    #[test]
247    fn test_adam_optimization() {
248        let model = Sequential::new()
249            .add(Linear::new(2, 4))
250            .add(ReLU)
251            .add(Linear::new(4, 1));
252
253        let mut optimizer = Adam::new(model.parameters(), 0.01);
254        let loss_fn = MSELoss::new();
255
256        let input = Variable::new(
257            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
258            false,
259        );
260        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap(), false);
261
262        // Run optimization
263        for _ in 0..20 {
264            optimizer.zero_grad();
265            let output = model.forward(&input);
266            let loss = loss_fn.compute(&output, &target);
267            loss.backward();
268            optimizer.step();
269        }
270
271        // Just verify it runs without error
272        let final_output = model.forward(&input);
273        assert_eq!(final_output.shape(), vec![2, 1]);
274    }
275
276    #[test]
277    fn test_lr_scheduler() {
278        let model = Linear::new(10, 5);
279        let mut optimizer = SGD::new(model.parameters(), 0.1);
280        let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
281
282        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
283
284        for _ in 0..10 {
285            scheduler.step(&mut optimizer);
286        }
287
288        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
289    }
290}