1use candle_core::Tensor;
4use candle_nn::{Optimizer, ParamsAdamW, VarMap};
5
6use crate::error::{AxolotlError, Result};
7
8#[derive(Debug, Clone)]
10pub struct OptimizerConfig {
11 pub learning_rate: f64,
13 pub beta1: f64,
15 pub beta2: f64,
17 pub weight_decay: f64,
19 pub eps: f64,
21}
22
23impl Default for OptimizerConfig {
24 fn default() -> Self {
25 Self {
26 learning_rate: 5e-5,
27 beta1: 0.9,
28 beta2: 0.999,
29 weight_decay: 0.01,
30 eps: 1e-8,
31 }
32 }
33}
34
35impl OptimizerConfig {
36 pub fn build_adamw(&self, varmap: &VarMap) -> Result<AdamWOptimizer> {
42 let vars = varmap.all_vars();
43 let params = ParamsAdamW {
44 lr: self.learning_rate,
45 beta1: self.beta1,
46 beta2: self.beta2,
47 eps: self.eps,
48 weight_decay: self.weight_decay,
49 };
50
51 let opt = candle_nn::AdamW::new(vars, params)
52 .map_err(|e| AxolotlError::Training(format!("Failed to create AdamW: {}", e)))?;
53
54 Ok(AdamWOptimizer { inner: opt })
55 }
56}
57
58pub struct AdamWOptimizer {
60 inner: candle_nn::AdamW,
61}
62
63impl AdamWOptimizer {
64 pub fn step(&mut self, loss: &Tensor) -> Result<()> {
70 self.inner
71 .backward_step(loss)
72 .map_err(|e| AxolotlError::Training(format!("Optimizer step failed: {}", e)))
73 }
74
75 pub fn learning_rate(&self) -> f64 {
77 self.inner.learning_rate()
78 }
79
80 pub fn set_learning_rate(&mut self, lr: f64) {
82 self.inner.set_learning_rate(lr);
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn test_optimizer_config_default() {
92 let config = OptimizerConfig::default();
93 assert_eq!(config.learning_rate, 5e-5);
94 assert_eq!(config.beta1, 0.9);
95 assert_eq!(config.beta2, 0.999);
96 assert_eq!(config.weight_decay, 0.01);
97 }
98
99 #[test]
100 fn test_build_adamw() -> Result<()> {
101 let config = OptimizerConfig::default();
102 let varmap = VarMap::new();
103
104 let optimizer = config.build_adamw(&varmap)?;
105 assert_eq!(optimizer.learning_rate(), 5e-5);
106
107 Ok(())
108 }
109}