Skip to main content

axonml_optim/
lib.rs

1//! Optimization algorithms for AxonML neural network training.
2//!
3//! `Optimizer` trait with `step`, `zero_grad`, `get_lr`, `set_lr`.
4//! Implementations: `SGD` (momentum, Nesterov), `Adam`, `AdamW` (decoupled
5//! weight decay), `RMSprop`, `LAMB` (layer-wise adaptive moments for large
6//! batch). `GradScaler` for AMP loss scaling. Seven LR schedulers (StepLR,
7//! MultiStepLR, ExponentialLR, CosineAnnealingLR, OneCycleLR, WarmupLR,
8//! ReduceLROnPlateau). Training Health Monitor (`health` module) for real-time
9//! NaN/gradient-explosion/vanishing detection, loss trend analysis, dead neuron
10//! tracking, convergence scoring, and automatic learning rate suggestions.
11//!
12//! # File
13//! `crates/axonml-optim/src/lib.rs`
14//!
15//! # Author
16//! Andrew Jewell Sr. — AutomataNexus LLC
17//! ORCID: 0009-0005-2158-7060
18//!
19//! # Updated
20//! April 14, 2026 11:15 PM EST
21//!
22//! # Disclaimer
23//! Use at own risk. This software is provided "as is", without warranty of any
24//! kind, express or implied. The author and AutomataNexus shall not be held
25//! liable for any damages arising from the use of this software.
26
27#![warn(missing_docs)]
28#![warn(clippy::all)]
29#![warn(clippy::pedantic)]
30// ML/tensor-specific allowances
31#![allow(clippy::cast_possible_truncation)]
32#![allow(clippy::cast_sign_loss)]
33#![allow(clippy::cast_precision_loss)]
34#![allow(clippy::cast_possible_wrap)]
35#![allow(clippy::missing_errors_doc)]
36#![allow(clippy::missing_panics_doc)]
37#![allow(clippy::must_use_candidate)]
38#![allow(clippy::module_name_repetitions)]
39#![allow(clippy::similar_names)]
40#![allow(clippy::many_single_char_names)]
41#![allow(clippy::too_many_arguments)]
42#![allow(clippy::doc_markdown)]
43#![allow(clippy::cast_lossless)]
44#![allow(clippy::needless_pass_by_value)]
45#![allow(clippy::redundant_closure_for_method_calls)]
46#![allow(clippy::uninlined_format_args)]
47#![allow(clippy::ptr_arg)]
48#![allow(clippy::return_self_not_must_use)]
49#![allow(clippy::not_unsafe_ptr_arg_deref)]
50#![allow(clippy::items_after_statements)]
51#![allow(clippy::unreadable_literal)]
52#![allow(clippy::if_same_then_else)]
53#![allow(clippy::needless_range_loop)]
54#![allow(clippy::trivially_copy_pass_by_ref)]
55#![allow(clippy::unnecessary_wraps)]
56#![allow(clippy::match_same_arms)]
57#![allow(clippy::unused_self)]
58#![allow(clippy::too_many_lines)]
59#![allow(clippy::single_match_else)]
60#![allow(clippy::fn_params_excessive_bools)]
61#![allow(clippy::struct_excessive_bools)]
62#![allow(clippy::format_push_string)]
63#![allow(clippy::erasing_op)]
64#![allow(clippy::type_repetition_in_bounds)]
65#![allow(clippy::iter_without_into_iter)]
66#![allow(clippy::should_implement_trait)]
67#![allow(clippy::use_debug)]
68#![allow(clippy::case_sensitive_file_extension_comparisons)]
69#![allow(clippy::large_enum_variant)]
70#![allow(clippy::panic)]
71#![allow(clippy::struct_field_names)]
72#![allow(clippy::missing_fields_in_debug)]
73#![allow(clippy::upper_case_acronyms)]
74#![allow(clippy::assigning_clones)]
75#![allow(clippy::option_if_let_else)]
76#![allow(clippy::manual_let_else)]
77#![allow(clippy::explicit_iter_loop)]
78#![allow(clippy::default_trait_access)]
79#![allow(clippy::only_used_in_recursion)]
80#![allow(clippy::manual_clamp)]
81#![allow(clippy::ref_option)]
82#![allow(clippy::multiple_bound_locations)]
83#![allow(clippy::comparison_chain)]
84#![allow(clippy::manual_assert)]
85#![allow(clippy::unnecessary_debug_formatting)]
86
87// =============================================================================
88// Module Declarations
89// =============================================================================
90
91pub mod adam;
92pub mod grad_scaler;
93pub mod health;
94pub mod lamb;
95pub mod lr_scheduler;
96pub mod optimizer;
97pub mod rmsprop;
98pub mod sgd;
99
100// =============================================================================
101// Re-exports
102// =============================================================================
103
104pub use adam::{Adam, AdamW};
105pub use grad_scaler::{GradScaler, GradScalerState};
106pub use health::{
107    AlertKind, AlertSeverity, HealthReport, LossTrend, MonitorConfig, TrainingAlert,
108    TrainingMonitor,
109};
110pub use lamb::LAMB;
111pub use lr_scheduler::{
112    CosineAnnealingLR, ExponentialLR, LRScheduler, MultiStepLR, OneCycleLR, ReduceLROnPlateau,
113    StepLR, WarmupLR,
114};
115pub use optimizer::Optimizer;
116pub use rmsprop::RMSprop;
117pub use sgd::SGD;
118
119// =============================================================================
120// Prelude
121// =============================================================================
122
123/// Common imports for optimization.
124pub mod prelude {
125    pub use crate::{
126        Adam, AdamW, CosineAnnealingLR, ExponentialLR, GradScaler, LAMB, LRScheduler, MultiStepLR,
127        OneCycleLR, Optimizer, RMSprop, ReduceLROnPlateau, SGD, StepLR, WarmupLR,
128    };
129}
130
131// =============================================================================
132// Tests
133// =============================================================================
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use axonml_autograd::Variable;
139    use axonml_nn::{Linear, MSELoss, Module, ReLU, Sequential};
140    use axonml_tensor::Tensor;
141
142    #[test]
143    fn test_sgd_optimization() {
144        let model = Sequential::new()
145            .add(Linear::new(2, 4))
146            .add(ReLU)
147            .add(Linear::new(4, 1));
148
149        let mut optimizer = SGD::new(model.parameters(), 0.01);
150        let loss_fn = MSELoss::new();
151
152        let input = Variable::new(
153            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
154            false,
155        );
156        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap(), false);
157
158        let initial_loss = loss_fn.compute(&model.forward(&input), &target);
159        let initial_loss_val = initial_loss.data().to_vec()[0];
160
161        // Run a few optimization steps
162        for _ in 0..10 {
163            optimizer.zero_grad();
164            let output = model.forward(&input);
165            let loss = loss_fn.compute(&output, &target);
166            loss.backward();
167            optimizer.step();
168        }
169
170        let final_loss = loss_fn.compute(&model.forward(&input), &target);
171        let final_loss_val = final_loss.data().to_vec()[0];
172
173        // Loss should decrease
174        assert!(final_loss_val <= initial_loss_val);
175    }
176
177    #[test]
178    fn test_adam_optimization() {
179        let model = Sequential::new()
180            .add(Linear::new(2, 4))
181            .add(ReLU)
182            .add(Linear::new(4, 1));
183
184        let mut optimizer = Adam::new(model.parameters(), 0.01);
185        let loss_fn = MSELoss::new();
186
187        let input = Variable::new(
188            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
189            false,
190        );
191        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap(), false);
192
193        // Run optimization
194        for _ in 0..20 {
195            optimizer.zero_grad();
196            let output = model.forward(&input);
197            let loss = loss_fn.compute(&output, &target);
198            loss.backward();
199            optimizer.step();
200        }
201
202        // Just verify it runs without error
203        let final_output = model.forward(&input);
204        assert_eq!(final_output.shape(), vec![2, 1]);
205    }
206
207    #[test]
208    fn test_lr_scheduler() {
209        let model = Linear::new(10, 5);
210        let mut optimizer = SGD::new(model.parameters(), 0.1);
211        let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
212
213        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
214
215        for _ in 0..10 {
216            scheduler.step(&mut optimizer);
217        }
218
219        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
220    }
221}