optirs_core/optimizers/
grouped_adam.rs

1// Adam optimizer with parameter group support
2
3use crate::error::{OptimError, Result};
4use crate::optimizers::Optimizer;
5use crate::parameter_groups::{
6    GroupManager, GroupedOptimizer, ParameterGroup, ParameterGroupConfig,
7};
8use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11
12/// Adam optimizer with parameter group support
13///
14/// This optimizer allows different parameter groups to have different
15/// hyperparameters (learning rate, weight decay, betas).
16///
17/// # Example
18///
19/// ```no_run
20/// use scirs2_core::ndarray::Array1;
21/// use optirs_core::optimizers::{GroupedAdam, Optimizer};
22/// use optirs_core::parameter_groups::{GroupedOptimizer, ParameterGroupConfig};
23///
24/// // Create grouped optimizer
25/// let mut optimizer = GroupedAdam::new(0.001);
26///
27/// // Add parameter groups with different learning rates
28/// let params_fast = vec![Array1::zeros(5)];
29/// let config_fast = ParameterGroupConfig::new().with_learning_rate(0.01);
30/// let group_fast = optimizer.add_group(params_fast, config_fast).unwrap();
31///
32/// let params_slow = vec![Array1::zeros(3)];
33/// let config_slow = ParameterGroupConfig::new().with_learning_rate(0.0001);
34/// let group_slow = optimizer.add_group(params_slow, config_slow).unwrap();
35///
36/// // Optimize each group separately
37/// let grads_fast = vec![Array1::ones(5)];
38/// let updated_fast = optimizer.step_group(group_fast, &grads_fast).unwrap();
39///
40/// let grads_slow = vec![Array1::ones(3)];
41/// let updated_slow = optimizer.step_group(group_slow, &grads_slow).unwrap();
42/// ```
43#[derive(Debug)]
44pub struct GroupedAdam<A: Float + Send + Sync, D: Dimension> {
45    /// Default learning rate
46    defaultlr: A,
47    /// Default beta1
48    default_beta1: A,
49    /// Default beta2
50    default_beta2: A,
51    /// Default weight decay
52    default_weight_decay: A,
53    /// Epsilon to prevent division by zero
54    epsilon: A,
55    /// AMSGrad flag
56    amsgrad: bool,
57    /// Parameter groups
58    group_manager: GroupManager<A, D>,
59    /// Global step counter
60    step: usize,
61}
62
63impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> GroupedAdam<A, D> {
64    /// Create a new grouped Adam optimizer
65    pub fn new(defaultlr: A) -> Self {
66        Self {
67            defaultlr,
68            default_beta1: A::from(0.9).unwrap(),
69            default_beta2: A::from(0.999).unwrap(),
70            default_weight_decay: A::zero(),
71            epsilon: A::from(1e-8).unwrap(),
72            amsgrad: false,
73            group_manager: GroupManager::new(),
74            step: 0,
75        }
76    }
77
78    /// Set default beta1
79    pub fn with_beta1(mut self, beta1: A) -> Self {
80        self.default_beta1 = beta1;
81        self
82    }
83
84    /// Set default beta2
85    pub fn with_beta2(mut self, beta2: A) -> Self {
86        self.default_beta2 = beta2;
87        self
88    }
89
90    /// Set default weight decay
91    pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
92        self.default_weight_decay = weight_decay;
93        self
94    }
95
96    /// Enable AMSGrad
97    pub fn with_amsgrad(mut self) -> Self {
98        self.amsgrad = true;
99        self
100    }
101
102    /// Initialize state for a group
103    fn init_group_state(&mut self, groupid: usize) -> Result<()> {
104        let group = self.group_manager.get_group_mut(groupid)?;
105
106        if group.state.is_empty() {
107            let mut m_t = Vec::new();
108            let mut v_t = Vec::new();
109            let mut v_hat_max = Vec::new();
110
111            for param in &group.params {
112                m_t.push(Array::zeros(param.raw_dim()));
113                v_t.push(Array::zeros(param.raw_dim()));
114                if self.amsgrad {
115                    v_hat_max.push(Array::zeros(param.raw_dim()));
116                }
117            }
118
119            group.state.insert("m_t".to_string(), m_t);
120            group.state.insert("v_t".to_string(), v_t);
121            if self.amsgrad {
122                group.state.insert("v_hat_max".to_string(), v_hat_max);
123            }
124        }
125
126        Ok(())
127    }
128
129    /// Step for a specific group
130    fn step_group_internal(
131        &mut self,
132        groupid: usize,
133        gradients: &[Array<A, D>],
134    ) -> Result<Vec<Array<A, D>>> {
135        let t = A::from(self.step + 1).unwrap();
136
137        // Initialize state if needed
138        self.init_group_state(groupid)?;
139
140        let group = self.group_manager.get_group_mut(groupid)?;
141
142        if gradients.len() != group.params.len() {
143            return Err(OptimError::InvalidConfig(format!(
144                "Number of gradients ({}) doesn't match number of parameters ({})",
145                gradients.len(),
146                group.params.len()
147            )));
148        }
149
150        // Get hyperparameters for this group
151        let lr = group.learning_rate(self.defaultlr);
152        let beta1 = group.get_custom_param("beta1", self.default_beta1);
153        let beta2 = group.get_custom_param("beta2", self.default_beta2);
154        let weightdecay = group.weight_decay(self.default_weight_decay);
155
156        let mut updated_params = Vec::new();
157
158        // Process each parameter
159        for i in 0..group.params.len() {
160            let param = &group.params[i];
161            let grad = &gradients[i];
162
163            // Apply weight decay
164            let grad_with_decay = if weightdecay > A::zero() {
165                grad + &(param * weightdecay)
166            } else {
167                grad.clone()
168            };
169
170            // Update states and compute new parameters
171            let updated = {
172                // Update first moment
173                let m_t = group.state.get_mut("m_t").unwrap();
174                m_t[i] = &m_t[i] * beta1 + &grad_with_decay * (A::one() - beta1);
175                let m_hat = &m_t[i] / (A::one() - beta1.powi(t.to_i32().unwrap()));
176
177                // Update second moment
178                let v_t = group.state.get_mut("v_t").unwrap();
179                v_t[i] = &v_t[i] * beta2 + &grad_with_decay * &grad_with_decay * (A::one() - beta2);
180                let v_hat = &v_t[i] / (A::one() - beta2.powi(t.to_i32().unwrap()));
181
182                // Update parameters
183                if self.amsgrad {
184                    let v_hat_max = group.state.get_mut("v_hat_max").unwrap();
185                    v_hat_max[i].zip_mut_with(&v_hat, |a, &b| *a = a.max(b));
186                    param - &(&m_hat * lr / (&v_hat_max[i].mapv(|x| x.sqrt()) + self.epsilon))
187                } else {
188                    param - &(&m_hat * lr / (&v_hat.mapv(|x| x.sqrt()) + self.epsilon))
189                }
190            };
191
192            updated_params.push(updated);
193        }
194
195        // Update group parameters
196        group.params = updated_params.clone();
197
198        Ok(updated_params)
199    }
200}
201
202impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
203    GroupedOptimizer<A, D> for GroupedAdam<A, D>
204{
205    fn add_group(
206        &mut self,
207        params: Vec<Array<A, D>>,
208        config: ParameterGroupConfig<A>,
209    ) -> Result<usize> {
210        Ok(self.group_manager.add_group(params, config))
211    }
212
213    fn get_group(&self, groupid: usize) -> Result<&ParameterGroup<A, D>> {
214        self.group_manager.get_group(groupid)
215    }
216
217    fn get_group_mut(&mut self, groupid: usize) -> Result<&mut ParameterGroup<A, D>> {
218        self.group_manager.get_group_mut(groupid)
219    }
220
221    fn groups(&self) -> &[ParameterGroup<A, D>] {
222        self.group_manager.groups()
223    }
224
225    fn groups_mut(&mut self) -> &mut [ParameterGroup<A, D>] {
226        self.group_manager.groups_mut()
227    }
228
229    fn step_group(
230        &mut self,
231        groupid: usize,
232        gradients: &[Array<A, D>],
233    ) -> Result<Vec<Array<A, D>>> {
234        self.step += 1;
235        self.step_group_internal(groupid, gradients)
236    }
237
238    fn set_group_learning_rate(&mut self, groupid: usize, lr: A) -> Result<()> {
239        let group = self.group_manager.get_group_mut(groupid)?;
240        group.config.learning_rate = Some(lr);
241        Ok(())
242    }
243
244    fn set_group_weight_decay(&mut self, groupid: usize, wd: A) -> Result<()> {
245        let group = self.group_manager.get_group_mut(groupid)?;
246        group.config.weight_decay = Some(wd);
247        Ok(())
248    }
249}
250
251// Standard optimizer implementation for default behavior
252impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> Optimizer<A, D>
253    for GroupedAdam<A, D>
254{
255    fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
256        // For single parameter optimization, create a temporary group
257        let params_vec = vec![params.clone()];
258        let gradients_vec = vec![gradients.clone()];
259        let config = ParameterGroupConfig::new();
260
261        let groupid = self.add_group(params_vec, config)?;
262        let result = self.step_group(groupid, &gradients_vec)?;
263
264        Ok(result.into_iter().next().unwrap())
265    }
266
267    fn get_learning_rate(&self) -> A {
268        self.defaultlr
269    }
270
271    fn set_learning_rate(&mut self, learning_rate: A) {
272        self.defaultlr = learning_rate;
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use scirs2_core::ndarray::Array1;
280
281    #[test]
282    fn test_grouped_adam_creation() {
283        let optimizer: GroupedAdam<f64, scirs2_core::ndarray::Ix1> = GroupedAdam::new(0.001);
284        assert_eq!(optimizer.defaultlr, 0.001);
285        assert_eq!(optimizer.default_beta1, 0.9);
286        assert_eq!(optimizer.default_beta2, 0.999);
287    }
288
289    #[test]
290    fn test_grouped_adam_multiple_groups() {
291        let mut optimizer = GroupedAdam::new(0.001);
292
293        // Add first group with high learning rate
294        let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
295        let config1 = ParameterGroupConfig::new().with_learning_rate(0.01);
296        let group1 = optimizer.add_group(params1, config1).unwrap();
297
298        // Add second group with low learning rate
299        let params2 = vec![Array1::from_vec(vec![3.0, 4.0, 5.0])];
300        let config2 = ParameterGroupConfig::new().with_learning_rate(0.0001);
301        let group2 = optimizer.add_group(params2, config2).unwrap();
302
303        // Update first group
304        let grads1 = vec![Array1::from_vec(vec![0.1, 0.2])];
305        let updated1 = optimizer.step_group(group1, &grads1).unwrap();
306
307        // Update second group
308        let grads2 = vec![Array1::from_vec(vec![0.3, 0.4, 0.5])];
309        let updated2 = optimizer.step_group(group2, &grads2).unwrap();
310
311        // Verify different updates due to different learning rates
312        assert!(updated1[0][0] < 1.0); // Should decrease more
313        assert!(updated2[0][0] > 2.9); // Should decrease less
314    }
315
316    #[test]
317    fn test_grouped_adam_custom_betas() {
318        let mut optimizer = GroupedAdam::new(0.001);
319
320        // Add group with custom betas
321        let params = vec![Array1::from_vec(vec![1.0, 2.0])];
322        let config = ParameterGroupConfig::new()
323            .with_custom_param("beta1".to_string(), 0.8)
324            .with_custom_param("beta2".to_string(), 0.99);
325        let group = optimizer.add_group(params, config).unwrap();
326
327        // Verify custom parameters are used
328        let group_ref = optimizer.get_group(group).unwrap();
329        assert_eq!(group_ref.get_custom_param("beta1", 0.0), 0.8);
330        assert_eq!(group_ref.get_custom_param("beta2", 0.0), 0.99);
331    }
332
333    #[test]
334    fn test_grouped_adam_clear() {
335        let mut optimizer = GroupedAdam::new(0.001);
336
337        // Add groups
338        let params1 = vec![Array1::zeros(2)];
339        let config1 = ParameterGroupConfig::new();
340        optimizer.add_group(params1, config1).unwrap();
341
342        assert_eq!(optimizer.groups().len(), 1);
343
344        // Clear groups
345        optimizer.group_manager = GroupManager::new();
346        optimizer.step = 0;
347
348        assert_eq!(optimizer.groups().len(), 0);
349        assert_eq!(optimizer.step, 0);
350    }
351}