Skip to main content

axolotl_rs/
optimizer.rs

1//! Optimizer implementations (AdamW, SGD).
2
3use candle_core::Tensor;
4use candle_nn::{Optimizer, ParamsAdamW, VarMap};
5
6use crate::error::{AxolotlError, Result};
7
8/// Optimizer configuration.
9#[derive(Debug, Clone)]
10pub struct OptimizerConfig {
11    /// Learning rate
12    pub learning_rate: f64,
13    /// Beta1 for Adam
14    pub beta1: f64,
15    /// Beta2 for Adam
16    pub beta2: f64,
17    /// Weight decay
18    pub weight_decay: f64,
19    /// Epsilon for numerical stability
20    pub eps: f64,
21}
22
23impl Default for OptimizerConfig {
24    fn default() -> Self {
25        Self {
26            learning_rate: 5e-5,
27            beta1: 0.9,
28            beta2: 0.999,
29            weight_decay: 0.01,
30            eps: 1e-8,
31        }
32    }
33}
34
35impl OptimizerConfig {
36    /// Create AdamW optimizer with these parameters.
37    ///
38    /// # Errors
39    ///
40    /// Returns an error if the optimizer cannot be created.
41    pub fn build_adamw(&self, varmap: &VarMap) -> Result<AdamWOptimizer> {
42        let vars = varmap.all_vars();
43        let params = ParamsAdamW {
44            lr: self.learning_rate,
45            beta1: self.beta1,
46            beta2: self.beta2,
47            eps: self.eps,
48            weight_decay: self.weight_decay,
49        };
50
51        let opt = candle_nn::AdamW::new(vars, params)
52            .map_err(|e| AxolotlError::Training(format!("Failed to create AdamW: {}", e)))?;
53
54        Ok(AdamWOptimizer { inner: opt })
55    }
56}
57
58/// AdamW optimizer wrapper.
59pub struct AdamWOptimizer {
60    inner: candle_nn::AdamW,
61}
62
63impl AdamWOptimizer {
64    /// Perform a single optimization step.
65    ///
66    /// # Errors
67    ///
68    /// Returns an error if the step fails.
69    pub fn step(&mut self, loss: &Tensor) -> Result<()> {
70        self.inner
71            .backward_step(loss)
72            .map_err(|e| AxolotlError::Training(format!("Optimizer step failed: {}", e)))
73    }
74
75    /// Get current learning rate.
76    pub fn learning_rate(&self) -> f64 {
77        self.inner.learning_rate()
78    }
79
80    /// Set learning rate (used by schedulers).
81    pub fn set_learning_rate(&mut self, lr: f64) {
82        self.inner.set_learning_rate(lr);
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn test_optimizer_config_default() {
92        let config = OptimizerConfig::default();
93        assert_eq!(config.learning_rate, 5e-5);
94        assert_eq!(config.beta1, 0.9);
95        assert_eq!(config.beta2, 0.999);
96        assert_eq!(config.weight_decay, 0.01);
97    }
98
99    #[test]
100    fn test_build_adamw() -> Result<()> {
101        let config = OptimizerConfig::default();
102        let varmap = VarMap::new();
103
104        let optimizer = config.build_adamw(&varmap)?;
105        assert_eq!(optimizer.learning_rate(), 5e-5);
106
107        Ok(())
108    }
109}