1#![warn(missing_docs)]
97#![warn(clippy::all)]
98#![warn(clippy::pedantic)]
99#![allow(clippy::cast_possible_truncation)]
101#![allow(clippy::cast_sign_loss)]
102#![allow(clippy::cast_precision_loss)]
103#![allow(clippy::cast_possible_wrap)]
104#![allow(clippy::missing_errors_doc)]
105#![allow(clippy::missing_panics_doc)]
106#![allow(clippy::must_use_candidate)]
107#![allow(clippy::module_name_repetitions)]
108#![allow(clippy::similar_names)]
109#![allow(clippy::many_single_char_names)]
110#![allow(clippy::too_many_arguments)]
111#![allow(clippy::doc_markdown)]
112#![allow(clippy::cast_lossless)]
113#![allow(clippy::needless_pass_by_value)]
114#![allow(clippy::redundant_closure_for_method_calls)]
115#![allow(clippy::uninlined_format_args)]
116#![allow(clippy::ptr_arg)]
117#![allow(clippy::return_self_not_must_use)]
118#![allow(clippy::not_unsafe_ptr_arg_deref)]
119#![allow(clippy::items_after_statements)]
120#![allow(clippy::unreadable_literal)]
121#![allow(clippy::if_same_then_else)]
122#![allow(clippy::needless_range_loop)]
123#![allow(clippy::trivially_copy_pass_by_ref)]
124#![allow(clippy::unnecessary_wraps)]
125#![allow(clippy::match_same_arms)]
126#![allow(clippy::unused_self)]
127#![allow(clippy::too_many_lines)]
128#![allow(clippy::single_match_else)]
129#![allow(clippy::fn_params_excessive_bools)]
130#![allow(clippy::struct_excessive_bools)]
131#![allow(clippy::format_push_string)]
132#![allow(clippy::erasing_op)]
133#![allow(clippy::type_repetition_in_bounds)]
134#![allow(clippy::iter_without_into_iter)]
135#![allow(clippy::should_implement_trait)]
136#![allow(clippy::use_debug)]
137#![allow(clippy::case_sensitive_file_extension_comparisons)]
138#![allow(clippy::large_enum_variant)]
139#![allow(clippy::panic)]
140#![allow(clippy::struct_field_names)]
141#![allow(clippy::missing_fields_in_debug)]
142#![allow(clippy::upper_case_acronyms)]
143#![allow(clippy::assigning_clones)]
144#![allow(clippy::option_if_let_else)]
145#![allow(clippy::manual_let_else)]
146#![allow(clippy::explicit_iter_loop)]
147#![allow(clippy::default_trait_access)]
148#![allow(clippy::only_used_in_recursion)]
149#![allow(clippy::manual_clamp)]
150#![allow(clippy::ref_option)]
151#![allow(clippy::multiple_bound_locations)]
152#![allow(clippy::comparison_chain)]
153#![allow(clippy::manual_assert)]
154#![allow(clippy::unnecessary_debug_formatting)]
155
156pub mod adam;
161pub mod grad_scaler;
162pub mod lamb;
163pub mod lr_scheduler;
164pub mod optimizer;
165pub mod rmsprop;
166pub mod sgd;
167
168pub use adam::{Adam, AdamW};
173pub use grad_scaler::{GradScaler, GradScalerState};
174pub use lamb::LAMB;
175pub use lr_scheduler::{
176 CosineAnnealingLR, ExponentialLR, LRScheduler, MultiStepLR, OneCycleLR, ReduceLROnPlateau,
177 StepLR, WarmupLR,
178};
179pub use optimizer::Optimizer;
180pub use rmsprop::RMSprop;
181pub use sgd::SGD;
182
183pub mod prelude {
189 pub use crate::{
190 Adam, AdamW, CosineAnnealingLR, ExponentialLR, GradScaler, LRScheduler, LAMB, MultiStepLR,
191 OneCycleLR, Optimizer, RMSprop, ReduceLROnPlateau, StepLR, WarmupLR, SGD,
192 };
193}
194
195#[cfg(test)]
200mod tests {
201 use super::*;
202 use axonml_autograd::Variable;
203 use axonml_nn::{Linear, MSELoss, Module, ReLU, Sequential};
204 use axonml_tensor::Tensor;
205
206 #[test]
207 fn test_sgd_optimization() {
208 let model = Sequential::new()
209 .add(Linear::new(2, 4))
210 .add(ReLU)
211 .add(Linear::new(4, 1));
212
213 let mut optimizer = SGD::new(model.parameters(), 0.01);
214 let loss_fn = MSELoss::new();
215
216 let input = Variable::new(
217 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
218 false,
219 );
220 let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap(), false);
221
222 let initial_loss = loss_fn.compute(&model.forward(&input), &target);
223 let initial_loss_val = initial_loss.data().to_vec()[0];
224
225 for _ in 0..10 {
227 optimizer.zero_grad();
228 let output = model.forward(&input);
229 let loss = loss_fn.compute(&output, &target);
230 loss.backward();
231 optimizer.step();
232 }
233
234 let final_loss = loss_fn.compute(&model.forward(&input), &target);
235 let final_loss_val = final_loss.data().to_vec()[0];
236
237 assert!(final_loss_val <= initial_loss_val);
239 }
240
241 #[test]
242 fn test_adam_optimization() {
243 let model = Sequential::new()
244 .add(Linear::new(2, 4))
245 .add(ReLU)
246 .add(Linear::new(4, 1));
247
248 let mut optimizer = Adam::new(model.parameters(), 0.01);
249 let loss_fn = MSELoss::new();
250
251 let input = Variable::new(
252 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
253 false,
254 );
255 let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2, 1]).unwrap(), false);
256
257 for _ in 0..20 {
259 optimizer.zero_grad();
260 let output = model.forward(&input);
261 let loss = loss_fn.compute(&output, &target);
262 loss.backward();
263 optimizer.step();
264 }
265
266 let final_output = model.forward(&input);
268 assert_eq!(final_output.shape(), vec![2, 1]);
269 }
270
271 #[test]
272 fn test_lr_scheduler() {
273 let model = Linear::new(10, 5);
274 let mut optimizer = SGD::new(model.parameters(), 0.1);
275 let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
276
277 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
278
279 for _ in 0..10 {
280 scheduler.step(&mut optimizer);
281 }
282
283 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
284 }
285}