Skip to main content

axonml_optim/
lib.rs

1//! axonml-optim - Optimization Algorithms
2//!
3//! # File
4//! `crates/axonml-optim/src/lib.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17#![warn(missing_docs)]
18#![warn(clippy::all)]
19#![warn(clippy::pedantic)]
20// ML/tensor-specific allowances
21#![allow(clippy::cast_possible_truncation)]
22#![allow(clippy::cast_sign_loss)]
23#![allow(clippy::cast_precision_loss)]
24#![allow(clippy::cast_possible_wrap)]
25#![allow(clippy::missing_errors_doc)]
26#![allow(clippy::missing_panics_doc)]
27#![allow(clippy::must_use_candidate)]
28#![allow(clippy::module_name_repetitions)]
29#![allow(clippy::similar_names)]
30#![allow(clippy::many_single_char_names)]
31#![allow(clippy::too_many_arguments)]
32#![allow(clippy::doc_markdown)]
33#![allow(clippy::cast_lossless)]
34#![allow(clippy::needless_pass_by_value)]
35#![allow(clippy::redundant_closure_for_method_calls)]
36#![allow(clippy::uninlined_format_args)]
37#![allow(clippy::ptr_arg)]
38#![allow(clippy::return_self_not_must_use)]
39#![allow(clippy::not_unsafe_ptr_arg_deref)]
40#![allow(clippy::items_after_statements)]
41#![allow(clippy::unreadable_literal)]
42#![allow(clippy::if_same_then_else)]
43#![allow(clippy::needless_range_loop)]
44#![allow(clippy::trivially_copy_pass_by_ref)]
45#![allow(clippy::unnecessary_wraps)]
46#![allow(clippy::match_same_arms)]
47#![allow(clippy::unused_self)]
48#![allow(clippy::too_many_lines)]
49#![allow(clippy::single_match_else)]
50#![allow(clippy::fn_params_excessive_bools)]
51#![allow(clippy::struct_excessive_bools)]
52#![allow(clippy::format_push_string)]
53#![allow(clippy::erasing_op)]
54#![allow(clippy::type_repetition_in_bounds)]
55#![allow(clippy::iter_without_into_iter)]
56#![allow(clippy::should_implement_trait)]
57#![allow(clippy::use_debug)]
58#![allow(clippy::case_sensitive_file_extension_comparisons)]
59#![allow(clippy::large_enum_variant)]
60#![allow(clippy::panic)]
61#![allow(clippy::struct_field_names)]
62#![allow(clippy::missing_fields_in_debug)]
63#![allow(clippy::upper_case_acronyms)]
64#![allow(clippy::assigning_clones)]
65#![allow(clippy::option_if_let_else)]
66#![allow(clippy::manual_let_else)]
67#![allow(clippy::explicit_iter_loop)]
68#![allow(clippy::default_trait_access)]
69#![allow(clippy::only_used_in_recursion)]
70#![allow(clippy::manual_clamp)]
71#![allow(clippy::ref_option)]
72#![allow(clippy::multiple_bound_locations)]
73#![allow(clippy::comparison_chain)]
74#![allow(clippy::manual_assert)]
75#![allow(clippy::unnecessary_debug_formatting)]
76
77// =============================================================================
78// Module Declarations
79// =============================================================================
80
81pub mod adam;
82pub mod grad_scaler;
83pub mod health;
84pub mod lamb;
85pub mod lr_scheduler;
86pub mod optimizer;
87pub mod rmsprop;
88pub mod sgd;
89
90// =============================================================================
91// Re-exports
92// =============================================================================
93
94pub use adam::{Adam, AdamW};
95pub use grad_scaler::{GradScaler, GradScalerState};
96pub use health::{
97    AlertKind, AlertSeverity, HealthReport, LossTrend, MonitorConfig, TrainingAlert,
98    TrainingMonitor,
99};
100pub use lamb::LAMB;
101pub use lr_scheduler::{
102    CosineAnnealingLR, ExponentialLR, LRScheduler, MultiStepLR, OneCycleLR, ReduceLROnPlateau,
103    StepLR, WarmupLR,
104};
105pub use optimizer::Optimizer;
106pub use rmsprop::RMSprop;
107pub use sgd::SGD;
108
109// =============================================================================
110// Prelude
111// =============================================================================
112
113/// Common imports for optimization.
114pub mod prelude {
115    pub use crate::{
116        Adam, AdamW, CosineAnnealingLR, ExponentialLR, GradScaler, LAMB, LRScheduler, MultiStepLR,
117        OneCycleLR, Optimizer, RMSprop, ReduceLROnPlateau, SGD, StepLR, WarmupLR,
118    };
119}
120
121// =============================================================================
122// Tests
123// =============================================================================
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use axonml_autograd::Variable;
129    use axonml_nn::{Linear, MSELoss, Module, ReLU, Sequential};
130    use axonml_tensor::Tensor;
131
132    #[test]
133    fn test_sgd_optimization() {
134        let model = Sequential::new()
135            .add(Linear::new(2, 4))
136            .add(ReLU)
137            .add(Linear::new(4, 1));
138
139        let mut optimizer = SGD::new(model.parameters(), 0.01);
140        let loss_fn = MSELoss::new();
141
142        let input = Variable::new(
143            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
144            false,
145        );
146        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap(), false);
147
148        let initial_loss = loss_fn.compute(&model.forward(&input), &target);
149        let initial_loss_val = initial_loss.data().to_vec()[0];
150
151        // Run a few optimization steps
152        for _ in 0..10 {
153            optimizer.zero_grad();
154            let output = model.forward(&input);
155            let loss = loss_fn.compute(&output, &target);
156            loss.backward();
157            optimizer.step();
158        }
159
160        let final_loss = loss_fn.compute(&model.forward(&input), &target);
161        let final_loss_val = final_loss.data().to_vec()[0];
162
163        // Loss should decrease
164        assert!(final_loss_val <= initial_loss_val);
165    }
166
167    #[test]
168    fn test_adam_optimization() {
169        let model = Sequential::new()
170            .add(Linear::new(2, 4))
171            .add(ReLU)
172            .add(Linear::new(4, 1));
173
174        let mut optimizer = Adam::new(model.parameters(), 0.01);
175        let loss_fn = MSELoss::new();
176
177        let input = Variable::new(
178            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
179            false,
180        );
181        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap(), false);
182
183        // Run optimization
184        for _ in 0..20 {
185            optimizer.zero_grad();
186            let output = model.forward(&input);
187            let loss = loss_fn.compute(&output, &target);
188            loss.backward();
189            optimizer.step();
190        }
191
192        // Just verify it runs without error
193        let final_output = model.forward(&input);
194        assert_eq!(final_output.shape(), vec![2, 1]);
195    }
196
197    #[test]
198    fn test_lr_scheduler() {
199        let model = Linear::new(10, 5);
200        let mut optimizer = SGD::new(model.parameters(), 0.1);
201        let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
202
203        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
204
205        for _ in 0..10 {
206            scheduler.step(&mut optimizer);
207        }
208
209        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
210    }
211}