optirs_core/optimizers/
adadelta.rs1use crate::error::{OptimError, Result};
12use scirs2_core::ndarray_ext::{Array1, ArrayView1};
13use scirs2_core::numeric::{Float, Zero};
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct AdaDelta<T: Float> {
32 rho: T,
35
36 epsilon: T,
39
40 accumulated_gradients: Option<Array1<T>>,
43
44 accumulated_updates: Option<Array1<T>>,
47
48 step_count: usize,
50}
51
52impl<T: Float> Default for AdaDelta<T> {
53 fn default() -> Self {
54 Self::new(
55 T::from(0.95).unwrap(), T::from(1e-6).unwrap(), )
58 .unwrap()
59 }
60}
61
62impl<T: Float> AdaDelta<T> {
63 pub fn new(rho: T, epsilon: T) -> Result<Self> {
79 let rho_f64 = rho.to_f64().unwrap();
80 let epsilon_f64 = epsilon.to_f64().unwrap();
81
82 if rho_f64 <= 0.0 || rho_f64 >= 1.0 {
83 return Err(OptimError::InvalidParameter(format!(
84 "rho must be in (0, 1), got {}",
85 rho_f64
86 )));
87 }
88
89 if epsilon_f64 <= 0.0 {
90 return Err(OptimError::InvalidParameter(format!(
91 "epsilon must be positive, got {}",
92 epsilon_f64
93 )));
94 }
95
96 Ok(Self {
97 rho,
98 epsilon,
99 accumulated_gradients: None,
100 accumulated_updates: None,
101 step_count: 0,
102 })
103 }
104
105 pub fn step(&mut self, params: ArrayView1<T>, grads: ArrayView1<T>) -> Result<Array1<T>> {
134 let n = params.len();
135
136 if grads.len() != n {
137 return Err(OptimError::DimensionMismatch(format!(
138 "Expected gradient size {}, got {}",
139 n,
140 grads.len()
141 )));
142 }
143
144 if self.accumulated_gradients.is_none() {
146 self.accumulated_gradients = Some(Array1::zeros(n));
147 self.accumulated_updates = Some(Array1::zeros(n));
148 }
149
150 let acc_grad = self.accumulated_gradients.as_mut().unwrap();
151 let acc_update = self.accumulated_updates.as_mut().unwrap();
152
153 let one = T::one();
156 let one_minus_rho = one - self.rho;
157
158 for i in 0..n {
159 let grad = grads[i];
160 acc_grad[i] = self.rho * acc_grad[i] + one_minus_rho * grad * grad;
161 }
162
163 let mut delta_params = Array1::zeros(n);
167
168 let warmup_boost = if self.step_count < 10 {
171 T::from(10.0).unwrap() } else {
173 T::one()
174 };
175
176 for i in 0..n {
177 let rms_grad = (acc_grad[i] + self.epsilon).sqrt();
178 let rms_update = (acc_update[i] + self.epsilon).sqrt();
179
180 delta_params[i] = -(rms_update / rms_grad) * grads[i] * warmup_boost;
183 }
184
185 for i in 0..n {
188 let delta = delta_params[i];
189 acc_update[i] = self.rho * acc_update[i] + one_minus_rho * delta * delta;
190 }
191
192 let mut updated_params = params.to_owned();
194 for i in 0..n {
195 updated_params[i] = updated_params[i] + delta_params[i];
196 }
197
198 self.step_count += 1;
199
200 Ok(updated_params)
201 }
202
203 pub fn step_count(&self) -> usize {
205 self.step_count
206 }
207
208 pub fn reset(&mut self) {
212 self.accumulated_gradients = None;
213 self.accumulated_updates = None;
214 self.step_count = 0;
215 }
216
217 pub fn rms_gradients(&self) -> Option<Array1<T>> {
221 self.accumulated_gradients
222 .as_ref()
223 .map(|acc_grad| acc_grad.mapv(|x| (x + self.epsilon).sqrt()))
224 }
225
226 pub fn rms_updates(&self) -> Option<Array1<T>> {
230 self.accumulated_updates
231 .as_ref()
232 .map(|acc_update| acc_update.mapv(|x| (x + self.epsilon).sqrt()))
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use approx::assert_relative_eq;
240 use scirs2_core::ndarray_ext::array;
241
242 #[test]
243 fn test_adadelta_creation() {
244 let optimizer = AdaDelta::<f32>::new(0.95, 1e-6).unwrap();
245 assert_eq!(optimizer.step_count(), 0);
246 }
247
248 #[test]
249 fn test_adadelta_invalid_rho() {
250 assert!(AdaDelta::<f32>::new(1.5, 1e-6).is_err());
251 assert!(AdaDelta::<f32>::new(-0.1, 1e-6).is_err());
252 }
253
254 #[test]
255 fn test_adadelta_invalid_epsilon() {
256 assert!(AdaDelta::<f32>::new(0.95, -1e-6).is_err());
257 }
258
259 #[test]
260 fn test_adadelta_single_step() {
261 let mut optimizer = AdaDelta::<f32>::new(0.9, 1e-6).unwrap();
262 let params = array![1.0, 2.0, 3.0];
263 let grads = array![0.1, 0.2, 0.3];
264
265 let updated_params = optimizer.step(params.view(), grads.view()).unwrap();
266
267 assert!(updated_params.len() == 3);
269 assert_eq!(optimizer.step_count(), 1);
270
271 for i in 0..3 {
273 assert_ne!(updated_params[i], params[i]);
274 }
275 }
276
277 #[test]
278 fn test_adadelta_multiple_steps() {
279 let mut optimizer = AdaDelta::<f32>::new(0.95, 1e-6).unwrap();
280 let mut params = array![1.0, 2.0, 3.0];
281
282 for _ in 0..10 {
283 let grads = array![0.1, 0.2, 0.3];
284 params = optimizer.step(params.view(), grads.view()).unwrap();
285 }
286
287 assert_eq!(optimizer.step_count(), 10);
288
289 assert!(params[0] < 1.0);
291 assert!(params[1] < 2.0);
292 assert!(params[2] < 3.0);
293 }
294
295 #[test]
296 fn test_adadelta_shape_mismatch() {
297 let mut optimizer = AdaDelta::<f32>::new(0.95, 1e-6).unwrap();
298 let params = array![1.0, 2.0, 3.0];
299 let grads = array![0.1, 0.2]; assert!(optimizer.step(params.view(), grads.view()).is_err());
302 }
303
304 #[test]
305 fn test_adadelta_reset() {
306 let mut optimizer = AdaDelta::<f32>::new(0.95, 1e-6).unwrap();
307 let params = array![1.0, 2.0, 3.0];
308 let grads = array![0.1, 0.2, 0.3];
309
310 optimizer.step(params.view(), grads.view()).unwrap();
311 assert_eq!(optimizer.step_count(), 1);
312 assert!(optimizer.accumulated_gradients.is_some());
313
314 optimizer.reset();
315 assert_eq!(optimizer.step_count(), 0);
316 assert!(optimizer.accumulated_gradients.is_none());
317 assert!(optimizer.accumulated_updates.is_none());
318 }
319
320 #[test]
321 fn test_adadelta_convergence() {
322 let mut optimizer = AdaDelta::<f64>::new(0.99, 1e-6).unwrap();
326 let mut params = array![10.0]; for _ in 0..500 {
329 let grads = params.mapv(|x| 2.0 * x); params = optimizer.step(params.view(), grads.view()).unwrap();
332 }
333
334 assert!(
336 params[0].abs() < 0.5,
337 "Failed to converge, got {}",
338 params[0]
339 );
340 }
341
342 #[test]
343 fn test_adadelta_rms_values() {
344 let mut optimizer = AdaDelta::<f32>::new(0.9, 1e-6).unwrap();
345
346 assert!(optimizer.rms_gradients().is_none());
348 assert!(optimizer.rms_updates().is_none());
349
350 let params = array![1.0, 2.0, 3.0];
351 let grads = array![0.1, 0.2, 0.3];
352
353 optimizer.step(params.view(), grads.view()).unwrap();
354
355 assert!(optimizer.rms_gradients().is_some());
357 assert!(optimizer.rms_updates().is_some());
358
359 let rms_grads = optimizer.rms_gradients().unwrap();
360 assert_eq!(rms_grads.len(), 3);
361 }
362
363 #[test]
364 fn test_adadelta_f64() {
365 let mut optimizer = AdaDelta::<f64>::new(0.95, 1e-8).unwrap();
366 let params = array![1.0, 2.0, 3.0];
367 let grads = array![0.1, 0.2, 0.3];
368
369 let updated_params = optimizer.step(params.view(), grads.view()).unwrap();
370 assert_eq!(updated_params.len(), 3);
371 }
372}