1use crate::error::{OptimError, Result};
4use crate::optimizers::Optimizer;
5use crate::parameter_groups::{
6 GroupManager, GroupedOptimizer, ParameterGroup, ParameterGroupConfig,
7};
8use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11
12#[derive(Debug)]
44pub struct GroupedAdam<A: Float + Send + Sync, D: Dimension> {
45 defaultlr: A,
47 default_beta1: A,
49 default_beta2: A,
51 default_weight_decay: A,
53 epsilon: A,
55 amsgrad: bool,
57 group_manager: GroupManager<A, D>,
59 step: usize,
61}
62
63impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> GroupedAdam<A, D> {
64 pub fn new(defaultlr: A) -> Self {
66 Self {
67 defaultlr,
68 default_beta1: A::from(0.9).unwrap(),
69 default_beta2: A::from(0.999).unwrap(),
70 default_weight_decay: A::zero(),
71 epsilon: A::from(1e-8).unwrap(),
72 amsgrad: false,
73 group_manager: GroupManager::new(),
74 step: 0,
75 }
76 }
77
78 pub fn with_beta1(mut self, beta1: A) -> Self {
80 self.default_beta1 = beta1;
81 self
82 }
83
84 pub fn with_beta2(mut self, beta2: A) -> Self {
86 self.default_beta2 = beta2;
87 self
88 }
89
90 pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
92 self.default_weight_decay = weight_decay;
93 self
94 }
95
96 pub fn with_amsgrad(mut self) -> Self {
98 self.amsgrad = true;
99 self
100 }
101
102 fn init_group_state(&mut self, groupid: usize) -> Result<()> {
104 let group = self.group_manager.get_group_mut(groupid)?;
105
106 if group.state.is_empty() {
107 let mut m_t = Vec::new();
108 let mut v_t = Vec::new();
109 let mut v_hat_max = Vec::new();
110
111 for param in &group.params {
112 m_t.push(Array::zeros(param.raw_dim()));
113 v_t.push(Array::zeros(param.raw_dim()));
114 if self.amsgrad {
115 v_hat_max.push(Array::zeros(param.raw_dim()));
116 }
117 }
118
119 group.state.insert("m_t".to_string(), m_t);
120 group.state.insert("v_t".to_string(), v_t);
121 if self.amsgrad {
122 group.state.insert("v_hat_max".to_string(), v_hat_max);
123 }
124 }
125
126 Ok(())
127 }
128
129 fn step_group_internal(
131 &mut self,
132 groupid: usize,
133 gradients: &[Array<A, D>],
134 ) -> Result<Vec<Array<A, D>>> {
135 let t = A::from(self.step + 1).unwrap();
136
137 self.init_group_state(groupid)?;
139
140 let group = self.group_manager.get_group_mut(groupid)?;
141
142 if gradients.len() != group.params.len() {
143 return Err(OptimError::InvalidConfig(format!(
144 "Number of gradients ({}) doesn't match number of parameters ({})",
145 gradients.len(),
146 group.params.len()
147 )));
148 }
149
150 let lr = group.learning_rate(self.defaultlr);
152 let beta1 = group.get_custom_param("beta1", self.default_beta1);
153 let beta2 = group.get_custom_param("beta2", self.default_beta2);
154 let weightdecay = group.weight_decay(self.default_weight_decay);
155
156 let mut updated_params = Vec::new();
157
158 for i in 0..group.params.len() {
160 let param = &group.params[i];
161 let grad = &gradients[i];
162
163 let grad_with_decay = if weightdecay > A::zero() {
165 grad + &(param * weightdecay)
166 } else {
167 grad.clone()
168 };
169
170 let updated = {
172 let m_t = group.state.get_mut("m_t").unwrap();
174 m_t[i] = &m_t[i] * beta1 + &grad_with_decay * (A::one() - beta1);
175 let m_hat = &m_t[i] / (A::one() - beta1.powi(t.to_i32().unwrap()));
176
177 let v_t = group.state.get_mut("v_t").unwrap();
179 v_t[i] = &v_t[i] * beta2 + &grad_with_decay * &grad_with_decay * (A::one() - beta2);
180 let v_hat = &v_t[i] / (A::one() - beta2.powi(t.to_i32().unwrap()));
181
182 if self.amsgrad {
184 let v_hat_max = group.state.get_mut("v_hat_max").unwrap();
185 v_hat_max[i].zip_mut_with(&v_hat, |a, &b| *a = a.max(b));
186 param - &(&m_hat * lr / (&v_hat_max[i].mapv(|x| x.sqrt()) + self.epsilon))
187 } else {
188 param - &(&m_hat * lr / (&v_hat.mapv(|x| x.sqrt()) + self.epsilon))
189 }
190 };
191
192 updated_params.push(updated);
193 }
194
195 group.params = updated_params.clone();
197
198 Ok(updated_params)
199 }
200}
201
202impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
203 GroupedOptimizer<A, D> for GroupedAdam<A, D>
204{
205 fn add_group(
206 &mut self,
207 params: Vec<Array<A, D>>,
208 config: ParameterGroupConfig<A>,
209 ) -> Result<usize> {
210 Ok(self.group_manager.add_group(params, config))
211 }
212
213 fn get_group(&self, groupid: usize) -> Result<&ParameterGroup<A, D>> {
214 self.group_manager.get_group(groupid)
215 }
216
217 fn get_group_mut(&mut self, groupid: usize) -> Result<&mut ParameterGroup<A, D>> {
218 self.group_manager.get_group_mut(groupid)
219 }
220
221 fn groups(&self) -> &[ParameterGroup<A, D>] {
222 self.group_manager.groups()
223 }
224
225 fn groups_mut(&mut self) -> &mut [ParameterGroup<A, D>] {
226 self.group_manager.groups_mut()
227 }
228
229 fn step_group(
230 &mut self,
231 groupid: usize,
232 gradients: &[Array<A, D>],
233 ) -> Result<Vec<Array<A, D>>> {
234 self.step += 1;
235 self.step_group_internal(groupid, gradients)
236 }
237
238 fn set_group_learning_rate(&mut self, groupid: usize, lr: A) -> Result<()> {
239 let group = self.group_manager.get_group_mut(groupid)?;
240 group.config.learning_rate = Some(lr);
241 Ok(())
242 }
243
244 fn set_group_weight_decay(&mut self, groupid: usize, wd: A) -> Result<()> {
245 let group = self.group_manager.get_group_mut(groupid)?;
246 group.config.weight_decay = Some(wd);
247 Ok(())
248 }
249}
250
251impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> Optimizer<A, D>
253 for GroupedAdam<A, D>
254{
255 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
256 let params_vec = vec![params.clone()];
258 let gradients_vec = vec![gradients.clone()];
259 let config = ParameterGroupConfig::new();
260
261 let groupid = self.add_group(params_vec, config)?;
262 let result = self.step_group(groupid, &gradients_vec)?;
263
264 Ok(result.into_iter().next().unwrap())
265 }
266
267 fn get_learning_rate(&self) -> A {
268 self.defaultlr
269 }
270
271 fn set_learning_rate(&mut self, learning_rate: A) {
272 self.defaultlr = learning_rate;
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use scirs2_core::ndarray::Array1;
280
281 #[test]
282 fn test_grouped_adam_creation() {
283 let optimizer: GroupedAdam<f64, scirs2_core::ndarray::Ix1> = GroupedAdam::new(0.001);
284 assert_eq!(optimizer.defaultlr, 0.001);
285 assert_eq!(optimizer.default_beta1, 0.9);
286 assert_eq!(optimizer.default_beta2, 0.999);
287 }
288
289 #[test]
290 fn test_grouped_adam_multiple_groups() {
291 let mut optimizer = GroupedAdam::new(0.001);
292
293 let params1 = vec![Array1::from_vec(vec![1.0, 2.0])];
295 let config1 = ParameterGroupConfig::new().with_learning_rate(0.01);
296 let group1 = optimizer.add_group(params1, config1).unwrap();
297
298 let params2 = vec![Array1::from_vec(vec![3.0, 4.0, 5.0])];
300 let config2 = ParameterGroupConfig::new().with_learning_rate(0.0001);
301 let group2 = optimizer.add_group(params2, config2).unwrap();
302
303 let grads1 = vec![Array1::from_vec(vec![0.1, 0.2])];
305 let updated1 = optimizer.step_group(group1, &grads1).unwrap();
306
307 let grads2 = vec![Array1::from_vec(vec![0.3, 0.4, 0.5])];
309 let updated2 = optimizer.step_group(group2, &grads2).unwrap();
310
311 assert!(updated1[0][0] < 1.0); assert!(updated2[0][0] > 2.9); }
315
316 #[test]
317 fn test_grouped_adam_custom_betas() {
318 let mut optimizer = GroupedAdam::new(0.001);
319
320 let params = vec![Array1::from_vec(vec![1.0, 2.0])];
322 let config = ParameterGroupConfig::new()
323 .with_custom_param("beta1".to_string(), 0.8)
324 .with_custom_param("beta2".to_string(), 0.99);
325 let group = optimizer.add_group(params, config).unwrap();
326
327 let group_ref = optimizer.get_group(group).unwrap();
329 assert_eq!(group_ref.get_custom_param("beta1", 0.0), 0.8);
330 assert_eq!(group_ref.get_custom_param("beta2", 0.0), 0.99);
331 }
332
333 #[test]
334 fn test_grouped_adam_clear() {
335 let mut optimizer = GroupedAdam::new(0.001);
336
337 let params1 = vec![Array1::zeros(2)];
339 let config1 = ParameterGroupConfig::new();
340 optimizer.add_group(params1, config1).unwrap();
341
342 assert_eq!(optimizer.groups().len(), 1);
343
344 optimizer.group_manager = GroupManager::new();
346 optimizer.step = 0;
347
348 assert_eq!(optimizer.groups().len(), 0);
349 assert_eq!(optimizer.step, 0);
350 }
351}