optirs_core/optimizers/
adadelta.rs1use crate::error::{OptimError, Result};
12use scirs2_core::ndarray_ext::{Array1, ArrayView1};
13use scirs2_core::numeric::{Float, Zero};
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct AdaDelta<T: Float> {
32 rho: T,
35
36 epsilon: T,
39
40 accumulated_gradients: Option<Array1<T>>,
43
44 accumulated_updates: Option<Array1<T>>,
47
48 step_count: usize,
50}
51
52impl<T: Float> Default for AdaDelta<T> {
53 fn default() -> Self {
54 Self::new(
55 T::from(0.95).expect("unwrap failed"), T::from(1e-6).expect("unwrap failed"), )
58 .expect("unwrap failed")
59 }
60}
61
62impl<T: Float> AdaDelta<T> {
63 pub fn new(rho: T, epsilon: T) -> Result<Self> {
79 let rho_f64 = rho.to_f64().expect("unwrap failed");
80 let epsilon_f64 = epsilon.to_f64().expect("unwrap failed");
81
82 if rho_f64 <= 0.0 || rho_f64 >= 1.0 {
83 return Err(OptimError::InvalidParameter(format!(
84 "rho must be in (0, 1), got {}",
85 rho_f64
86 )));
87 }
88
89 if epsilon_f64 <= 0.0 {
90 return Err(OptimError::InvalidParameter(format!(
91 "epsilon must be positive, got {}",
92 epsilon_f64
93 )));
94 }
95
96 Ok(Self {
97 rho,
98 epsilon,
99 accumulated_gradients: None,
100 accumulated_updates: None,
101 step_count: 0,
102 })
103 }
104
105 pub fn step(&mut self, params: ArrayView1<T>, grads: ArrayView1<T>) -> Result<Array1<T>> {
134 let n = params.len();
135
136 if grads.len() != n {
137 return Err(OptimError::DimensionMismatch(format!(
138 "Expected gradient size {}, got {}",
139 n,
140 grads.len()
141 )));
142 }
143
144 if self.accumulated_gradients.is_none() {
146 self.accumulated_gradients = Some(Array1::zeros(n));
147 self.accumulated_updates = Some(Array1::zeros(n));
148 }
149
150 let acc_grad = self.accumulated_gradients.as_mut().expect("unwrap failed");
151 let acc_update = self.accumulated_updates.as_mut().expect("unwrap failed");
152
153 let one = T::one();
156 let one_minus_rho = one - self.rho;
157
158 for i in 0..n {
159 let grad = grads[i];
160 acc_grad[i] = self.rho * acc_grad[i] + one_minus_rho * grad * grad;
161 }
162
163 let mut delta_params = Array1::zeros(n);
167
168 let warmup_boost = if self.step_count < 10 {
171 T::from(10.0).expect("unwrap failed") } else {
173 T::one()
174 };
175
176 for i in 0..n {
177 let rms_grad = (acc_grad[i] + self.epsilon).sqrt();
178 let rms_update = (acc_update[i] + self.epsilon).sqrt();
179
180 delta_params[i] = -(rms_update / rms_grad) * grads[i] * warmup_boost;
183 }
184
185 for i in 0..n {
188 let delta = delta_params[i];
189 acc_update[i] = self.rho * acc_update[i] + one_minus_rho * delta * delta;
190 }
191
192 let mut updated_params = params.to_owned();
194 for i in 0..n {
195 updated_params[i] = updated_params[i] + delta_params[i];
196 }
197
198 self.step_count += 1;
199
200 Ok(updated_params)
201 }
202
203 pub fn step_count(&self) -> usize {
205 self.step_count
206 }
207
208 pub fn reset(&mut self) {
212 self.accumulated_gradients = None;
213 self.accumulated_updates = None;
214 self.step_count = 0;
215 }
216
217 pub fn rms_gradients(&self) -> Option<Array1<T>> {
221 self.accumulated_gradients
222 .as_ref()
223 .map(|acc_grad| acc_grad.mapv(|x| (x + self.epsilon).sqrt()))
224 }
225
226 pub fn rms_updates(&self) -> Option<Array1<T>> {
230 self.accumulated_updates
231 .as_ref()
232 .map(|acc_update| acc_update.mapv(|x| (x + self.epsilon).sqrt()))
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use approx::assert_relative_eq;
240 use scirs2_core::ndarray_ext::array;
241
242 #[test]
243 fn test_adadelta_creation() {
244 let optimizer = AdaDelta::<f32>::new(0.95, 1e-6).expect("unwrap failed");
245 assert_eq!(optimizer.step_count(), 0);
246 }
247
248 #[test]
249 fn test_adadelta_invalid_rho() {
250 assert!(AdaDelta::<f32>::new(1.5, 1e-6).is_err());
251 assert!(AdaDelta::<f32>::new(-0.1, 1e-6).is_err());
252 }
253
254 #[test]
255 fn test_adadelta_invalid_epsilon() {
256 assert!(AdaDelta::<f32>::new(0.95, -1e-6).is_err());
257 }
258
259 #[test]
260 fn test_adadelta_single_step() {
261 let mut optimizer = AdaDelta::<f32>::new(0.9, 1e-6).expect("unwrap failed");
262 let params = array![1.0, 2.0, 3.0];
263 let grads = array![0.1, 0.2, 0.3];
264
265 let updated_params = optimizer
266 .step(params.view(), grads.view())
267 .expect("unwrap failed");
268
269 assert!(updated_params.len() == 3);
271 assert_eq!(optimizer.step_count(), 1);
272
273 for i in 0..3 {
275 assert_ne!(updated_params[i], params[i]);
276 }
277 }
278
279 #[test]
280 fn test_adadelta_multiple_steps() {
281 let mut optimizer = AdaDelta::<f32>::new(0.95, 1e-6).expect("unwrap failed");
282 let mut params = array![1.0, 2.0, 3.0];
283
284 for _ in 0..10 {
285 let grads = array![0.1, 0.2, 0.3];
286 params = optimizer
287 .step(params.view(), grads.view())
288 .expect("unwrap failed");
289 }
290
291 assert_eq!(optimizer.step_count(), 10);
292
293 assert!(params[0] < 1.0);
295 assert!(params[1] < 2.0);
296 assert!(params[2] < 3.0);
297 }
298
299 #[test]
300 fn test_adadelta_shape_mismatch() {
301 let mut optimizer = AdaDelta::<f32>::new(0.95, 1e-6).expect("unwrap failed");
302 let params = array![1.0, 2.0, 3.0];
303 let grads = array![0.1, 0.2]; assert!(optimizer.step(params.view(), grads.view()).is_err());
306 }
307
308 #[test]
309 fn test_adadelta_reset() {
310 let mut optimizer = AdaDelta::<f32>::new(0.95, 1e-6).expect("unwrap failed");
311 let params = array![1.0, 2.0, 3.0];
312 let grads = array![0.1, 0.2, 0.3];
313
314 optimizer
315 .step(params.view(), grads.view())
316 .expect("unwrap failed");
317 assert_eq!(optimizer.step_count(), 1);
318 assert!(optimizer.accumulated_gradients.is_some());
319
320 optimizer.reset();
321 assert_eq!(optimizer.step_count(), 0);
322 assert!(optimizer.accumulated_gradients.is_none());
323 assert!(optimizer.accumulated_updates.is_none());
324 }
325
326 #[test]
327 fn test_adadelta_convergence() {
328 let mut optimizer = AdaDelta::<f64>::new(0.99, 1e-6).expect("unwrap failed");
332 let mut params = array![10.0]; for _ in 0..500 {
335 let grads = params.mapv(|x| 2.0 * x); params = optimizer
338 .step(params.view(), grads.view())
339 .expect("unwrap failed");
340 }
341
342 assert!(
344 params[0].abs() < 0.5,
345 "Failed to converge, got {}",
346 params[0]
347 );
348 }
349
350 #[test]
351 fn test_adadelta_rms_values() {
352 let mut optimizer = AdaDelta::<f32>::new(0.9, 1e-6).expect("unwrap failed");
353
354 assert!(optimizer.rms_gradients().is_none());
356 assert!(optimizer.rms_updates().is_none());
357
358 let params = array![1.0, 2.0, 3.0];
359 let grads = array![0.1, 0.2, 0.3];
360
361 optimizer
362 .step(params.view(), grads.view())
363 .expect("unwrap failed");
364
365 assert!(optimizer.rms_gradients().is_some());
367 assert!(optimizer.rms_updates().is_some());
368
369 let rms_grads = optimizer.rms_gradients().expect("unwrap failed");
370 assert_eq!(rms_grads.len(), 3);
371 }
372
373 #[test]
374 fn test_adadelta_f64() {
375 let mut optimizer = AdaDelta::<f64>::new(0.95, 1e-8).expect("unwrap failed");
376 let params = array![1.0, 2.0, 3.0];
377 let grads = array![0.1, 0.2, 0.3];
378
379 let updated_params = optimizer
380 .step(params.view(), grads.view())
381 .expect("unwrap failed");
382 assert_eq!(updated_params.len(), 3);
383 }
384}