axonml_optim/
lib.rs

1//! axonml-optim - Optimization Algorithms
2//!
3//! Provides optimizers for training neural networks, including:
4//! - SGD with momentum and Nesterov acceleration
5//! - Adam and `AdamW`
6//! - `RMSprop`
7//! - Learning rate schedulers
8//!
9//! # Example
10//!
11//! ```ignore
12//! use axonml_optim::prelude::*;
13//! use axonml_nn::{Linear, Module, Sequential};
14//!
15//! // Create model
16//! let model = Sequential::new()
17//!     .add(Linear::new(784, 128))
18//!     .add(Linear::new(128, 10));
19//!
20//! // Create optimizer
21//! let mut optimizer = Adam::new(model.parameters(), 0.001);
22//!
23//! // Training loop
24//! for epoch in 0..100 {
25//!     let output = model.forward(&input);
26//!     let loss = compute_loss(&output, &target);
27//!
28//!     optimizer.zero_grad();
29//!     loss.backward();
30//!     optimizer.step();
31//! }
32//! ```
33//!
34//! @version 0.1.0
35//! @author `AutomataNexus` Development Team
36
37#![warn(missing_docs)]
38#![warn(clippy::all)]
39#![warn(clippy::pedantic)]
40// ML/tensor-specific allowances
41#![allow(clippy::cast_possible_truncation)]
42#![allow(clippy::cast_sign_loss)]
43#![allow(clippy::cast_precision_loss)]
44#![allow(clippy::cast_possible_wrap)]
45#![allow(clippy::missing_errors_doc)]
46#![allow(clippy::missing_panics_doc)]
47#![allow(clippy::must_use_candidate)]
48#![allow(clippy::module_name_repetitions)]
49#![allow(clippy::similar_names)]
50#![allow(clippy::many_single_char_names)]
51#![allow(clippy::too_many_arguments)]
52#![allow(clippy::doc_markdown)]
53#![allow(clippy::cast_lossless)]
54#![allow(clippy::needless_pass_by_value)]
55#![allow(clippy::redundant_closure_for_method_calls)]
56#![allow(clippy::uninlined_format_args)]
57#![allow(clippy::ptr_arg)]
58#![allow(clippy::return_self_not_must_use)]
59#![allow(clippy::not_unsafe_ptr_arg_deref)]
60#![allow(clippy::items_after_statements)]
61#![allow(clippy::unreadable_literal)]
62#![allow(clippy::if_same_then_else)]
63#![allow(clippy::needless_range_loop)]
64#![allow(clippy::trivially_copy_pass_by_ref)]
65#![allow(clippy::unnecessary_wraps)]
66#![allow(clippy::match_same_arms)]
67#![allow(clippy::unused_self)]
68#![allow(clippy::too_many_lines)]
69#![allow(clippy::single_match_else)]
70#![allow(clippy::fn_params_excessive_bools)]
71#![allow(clippy::struct_excessive_bools)]
72#![allow(clippy::format_push_string)]
73#![allow(clippy::erasing_op)]
74#![allow(clippy::type_repetition_in_bounds)]
75#![allow(clippy::iter_without_into_iter)]
76#![allow(clippy::should_implement_trait)]
77#![allow(clippy::use_debug)]
78#![allow(clippy::case_sensitive_file_extension_comparisons)]
79#![allow(clippy::large_enum_variant)]
80#![allow(clippy::panic)]
81#![allow(clippy::struct_field_names)]
82#![allow(clippy::missing_fields_in_debug)]
83#![allow(clippy::upper_case_acronyms)]
84#![allow(clippy::assigning_clones)]
85#![allow(clippy::option_if_let_else)]
86#![allow(clippy::manual_let_else)]
87#![allow(clippy::explicit_iter_loop)]
88#![allow(clippy::default_trait_access)]
89#![allow(clippy::only_used_in_recursion)]
90#![allow(clippy::manual_clamp)]
91#![allow(clippy::ref_option)]
92#![allow(clippy::multiple_bound_locations)]
93#![allow(clippy::comparison_chain)]
94#![allow(clippy::manual_assert)]
95#![allow(clippy::unnecessary_debug_formatting)]
96
97// =============================================================================
98// Module Declarations
99// =============================================================================
100
101pub mod adam;
102pub mod lr_scheduler;
103pub mod optimizer;
104pub mod rmsprop;
105pub mod sgd;
106
107// =============================================================================
108// Re-exports
109// =============================================================================
110
111pub use adam::{Adam, AdamW};
112pub use lr_scheduler::{
113    CosineAnnealingLR, ExponentialLR, LRScheduler, MultiStepLR, OneCycleLR, ReduceLROnPlateau,
114    StepLR, WarmupLR,
115};
116pub use optimizer::Optimizer;
117pub use rmsprop::RMSprop;
118pub use sgd::SGD;
119
120// =============================================================================
121// Prelude
122// =============================================================================
123
124/// Common imports for optimization.
125pub mod prelude {
126    pub use crate::{
127        Adam, AdamW, CosineAnnealingLR, ExponentialLR, LRScheduler, MultiStepLR, OneCycleLR,
128        Optimizer, RMSprop, ReduceLROnPlateau, StepLR, WarmupLR, SGD,
129    };
130}
131
132// =============================================================================
133// Tests
134// =============================================================================
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use axonml_autograd::Variable;
140    use axonml_nn::{Linear, MSELoss, Module, ReLU, Sequential};
141    use axonml_tensor::Tensor;
142
143    #[test]
144    fn test_sgd_optimization() {
145        let model = Sequential::new()
146            .add(Linear::new(2, 4))
147            .add(ReLU)
148            .add(Linear::new(4, 1));
149
150        let mut optimizer = SGD::new(model.parameters(), 0.01);
151        let loss_fn = MSELoss::new();
152
153        let input = Variable::new(
154            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
155            false,
156        );
157        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap(), false);
158
159        let initial_loss = loss_fn.compute(&model.forward(&input), &target);
160        let initial_loss_val = initial_loss.data().to_vec()[0];
161
162        // Run a few optimization steps
163        for _ in 0..10 {
164            optimizer.zero_grad();
165            let output = model.forward(&input);
166            let loss = loss_fn.compute(&output, &target);
167            loss.backward();
168            optimizer.step();
169        }
170
171        let final_loss = loss_fn.compute(&model.forward(&input), &target);
172        let final_loss_val = final_loss.data().to_vec()[0];
173
174        // Loss should decrease
175        assert!(final_loss_val <= initial_loss_val);
176    }
177
178    #[test]
179    fn test_adam_optimization() {
180        let model = Sequential::new()
181            .add(Linear::new(2, 4))
182            .add(ReLU)
183            .add(Linear::new(4, 1));
184
185        let mut optimizer = Adam::new(model.parameters(), 0.01);
186        let loss_fn = MSELoss::new();
187
188        let input = Variable::new(
189            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
190            false,
191        );
192        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap(), false);
193
194        // Run optimization
195        for _ in 0..20 {
196            optimizer.zero_grad();
197            let output = model.forward(&input);
198            let loss = loss_fn.compute(&output, &target);
199            loss.backward();
200            optimizer.step();
201        }
202
203        // Just verify it runs without error
204        let final_output = model.forward(&input);
205        assert_eq!(final_output.shape(), vec![2, 1]);
206    }
207
208    #[test]
209    fn test_lr_scheduler() {
210        let model = Linear::new(10, 5);
211        let mut optimizer = SGD::new(model.parameters(), 0.1);
212        let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
213
214        assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
215
216        for _ in 0..10 {
217            scheduler.step(&mut optimizer);
218        }
219
220        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
221    }
222}