optirs_core/optimizers/
adam.rs1use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
4use scirs2_core::numeric::Float;
5use std::fmt::Debug;
6
7use scirs2_optimize::stochastic::{minimize_adam, AdamOptions};
10
11use crate::error::Result;
12use crate::optimizers::Optimizer;
13
14#[derive(Debug, Clone)]
43pub struct Adam<A: Float + ScalarOperand + Debug> {
44 learning_rate: A,
46 beta1: A,
48 beta2: A,
50 epsilon: A,
52 weight_decay: A,
54 m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
56 v: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
58 t: usize,
60}
61
62impl<A: Float + ScalarOperand + Debug + Send + Sync> Adam<A> {
63 pub fn new(learning_rate: A) -> Self {
69 Self {
70 learning_rate,
71 beta1: A::from(0.9).unwrap(),
72 beta2: A::from(0.999).unwrap(),
73 epsilon: A::from(1e-8).unwrap(),
74 weight_decay: A::zero(),
75 m: None,
76 v: None,
77 t: 0,
78 }
79 }
80
81 pub fn new_with_config(
91 learning_rate: A,
92 beta1: A,
93 beta2: A,
94 epsilon: A,
95 weight_decay: A,
96 ) -> Self {
97 Self {
98 learning_rate,
99 beta1,
100 beta2,
101 epsilon,
102 weight_decay,
103 m: None,
104 v: None,
105 t: 0,
106 }
107 }
108
109 pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
111 self.beta1 = beta1;
112 self
113 }
114
115 pub fn with_beta1(mut self, beta1: A) -> Self {
117 self.beta1 = beta1;
118 self
119 }
120
121 pub fn get_beta1(&self) -> A {
123 self.beta1
124 }
125
126 pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
128 self.beta2 = beta2;
129 self
130 }
131
132 pub fn with_beta2(mut self, beta2: A) -> Self {
134 self.beta2 = beta2;
135 self
136 }
137
138 pub fn get_beta2(&self) -> A {
140 self.beta2
141 }
142
143 pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
145 self.epsilon = epsilon;
146 self
147 }
148
149 pub fn with_epsilon(mut self, epsilon: A) -> Self {
151 self.epsilon = epsilon;
152 self
153 }
154
155 pub fn get_epsilon(&self) -> A {
157 self.epsilon
158 }
159
160 pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
162 self.weight_decay = weight_decay;
163 self
164 }
165
166 pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
168 self.weight_decay = weight_decay;
169 self
170 }
171
172 pub fn get_weight_decay(&self) -> A {
174 self.weight_decay
175 }
176
177 pub fn learning_rate(&self) -> A {
179 self.learning_rate
180 }
181
182 pub fn set_lr(&mut self, lr: A) {
184 self.learning_rate = lr;
185 }
186
187 pub fn reset(&mut self) {
189 self.m = None;
190 self.v = None;
191 self.t = 0;
192 }
193}
194
195impl<A, D> Optimizer<A, D> for Adam<A>
196where
197 A: Float + ScalarOperand + Debug + Send + Sync,
198 D: Dimension,
199{
200 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
201 if params.shape() != gradients.shape() {
203 return Err(crate::error::OptimError::DimensionMismatch(format!(
204 "Incompatible shapes: parameters have shape {:?}, gradients have shape {:?}",
205 params.shape(),
206 gradients.shape()
207 )));
208 }
209
210 let params_dyn = params.to_owned().into_dyn();
212 let gradients_dyn = gradients.to_owned().into_dyn();
213
214 let adjusted_gradients = if self.weight_decay > A::zero() {
216 &gradients_dyn + &(¶ms_dyn * self.weight_decay)
217 } else {
218 gradients_dyn
219 };
220
221 if self.m.is_none() {
223 self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
224 self.v = Some(vec![Array::zeros(params_dyn.raw_dim())]);
225 self.t = 0;
226 }
227
228 let m = self.m.as_mut().unwrap();
229 let v = self.v.as_mut().unwrap();
230
231 if m.is_empty() {
233 m.push(Array::zeros(params_dyn.raw_dim()));
234 v.push(Array::zeros(params_dyn.raw_dim()));
235 } else if m[0].raw_dim() != params_dyn.raw_dim() {
236 m[0] = Array::zeros(params_dyn.raw_dim());
238 v[0] = Array::zeros(params_dyn.raw_dim());
239 }
240
241 self.t = self.t.checked_add(1).ok_or_else(|| {
243 crate::error::OptimError::InvalidConfig(
244 "Timestep counter overflow - too many optimization steps".to_string(),
245 )
246 })?;
247
248 m[0] = &m[0] * self.beta1 + &(&adjusted_gradients * (A::one() - self.beta1));
250
251 v[0] = &v[0] * self.beta2
253 + &(&adjusted_gradients * &adjusted_gradients * (A::one() - self.beta2));
254
255 let exp_beta1 = i32::try_from(self.t).map_err(|_| {
257 crate::error::OptimError::InvalidConfig(
258 "Timestep too large for bias correction calculation".to_string(),
259 )
260 })?;
261 let m_hat = &m[0] / (A::one() - self.beta1.powi(exp_beta1));
262
263 let exp_beta2 = i32::try_from(self.t).map_err(|_| {
265 crate::error::OptimError::InvalidConfig(
266 "Timestep too large for bias correction calculation".to_string(),
267 )
268 })?;
269 let v_hat = &v[0] / (A::one() - self.beta2.powi(exp_beta2));
270
271 let v_hat_sqrt = v_hat.mapv(|x| x.sqrt());
273
274 let step = &m_hat / &(&v_hat_sqrt + self.epsilon) * self.learning_rate;
276 let updated_params = ¶ms_dyn - step;
277
278 Ok(updated_params.into_dimensionality::<D>().unwrap())
280 }
281
282 fn get_learning_rate(&self) -> A {
283 self.learning_rate
284 }
285
286 fn set_learning_rate(&mut self, learning_rate: A) {
287 self.learning_rate = learning_rate;
288 }
289}