1use scirs2_core::ndarray::{Array, Ix1, ScalarOperand};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12use crate::error::{OptimError, Result};
13use crate::optimizers::Optimizer;
14
15pub struct SparseGradient<A: Float + ScalarOperand + Debug> {
20 pub indices: Vec<usize>,
22 pub values: Vec<A>,
24 pub dim: usize,
26}
27
28impl<A: Float + ScalarOperand + Debug + Send + Sync> SparseGradient<A> {
29 pub fn new(indices: Vec<usize>, values: Vec<A>, dim: usize) -> Self {
31 assert_eq!(
32 indices.len(),
33 values.len(),
34 "Indices and values must have the same length"
35 );
36 if let Some(&max_idx) = indices.iter().max() {
38 assert!(
39 max_idx < dim,
40 "Index {} is out of bounds for dimension {}",
41 max_idx,
42 dim
43 );
44 }
45 Self {
46 indices,
47 values,
48 dim,
49 }
50 }
51
52 pub fn from_array(array: &Array<A, Ix1>) -> Self {
54 let mut indices = Vec::new();
55 let mut values = Vec::new();
56
57 for (idx, &val) in array.iter().enumerate() {
58 if !val.is_zero() {
59 indices.push(idx);
60 values.push(val);
61 }
62 }
63
64 Self {
65 indices,
66 values,
67 dim: array.len(),
68 }
69 }
70
71 pub fn to_array(&self) -> Array<A, Ix1> {
73 let mut array = Array::zeros(self.dim);
74 for (&idx, &val) in self.indices.iter().zip(&self.values) {
75 array[idx] = val;
76 }
77 array
78 }
79
80 pub fn is_empty(&self) -> bool {
82 self.indices.is_empty()
83 }
84}
85
86#[derive(Debug, Clone)]
125pub struct SparseAdam<A: Float + ScalarOperand + Debug> {
126 learning_rate: A,
128 beta1: A,
130 beta2: A,
132 epsilon: A,
134 weight_decay: A,
136 m: HashMap<usize, A>,
138 v: HashMap<usize, A>,
140 t: usize,
142}
143
144impl<A: Float + ScalarOperand + Debug + Send + Sync> SparseAdam<A> {
145 pub fn new(learning_rate: A) -> Self {
151 Self {
152 learning_rate,
153 beta1: A::from(0.9).expect("unwrap failed"),
154 beta2: A::from(0.999).expect("unwrap failed"),
155 epsilon: A::from(1e-8).expect("unwrap failed"),
156 weight_decay: A::zero(),
157 m: HashMap::new(),
158 v: HashMap::new(),
159 t: 0,
160 }
161 }
162
163 pub fn new_with_config(
173 learning_rate: A,
174 beta1: A,
175 beta2: A,
176 epsilon: A,
177 weight_decay: A,
178 ) -> Self {
179 Self {
180 learning_rate,
181 beta1,
182 beta2,
183 epsilon,
184 weight_decay,
185 m: HashMap::new(),
186 v: HashMap::new(),
187 t: 0,
188 }
189 }
190
191 pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
193 self.beta1 = beta1;
194 self
195 }
196
197 pub fn with_beta1(mut self, beta1: A) -> Self {
199 self.beta1 = beta1;
200 self
201 }
202
203 pub fn get_beta1(&self) -> A {
205 self.beta1
206 }
207
208 pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
210 self.beta2 = beta2;
211 self
212 }
213
214 pub fn with_beta2(mut self, beta2: A) -> Self {
216 self.beta2 = beta2;
217 self
218 }
219
220 pub fn get_beta2(&self) -> A {
222 self.beta2
223 }
224
225 pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
227 self.epsilon = epsilon;
228 self
229 }
230
231 pub fn with_epsilon(mut self, epsilon: A) -> Self {
233 self.epsilon = epsilon;
234 self
235 }
236
237 pub fn get_epsilon(&self) -> A {
239 self.epsilon
240 }
241
242 pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
244 self.weight_decay = weight_decay;
245 self
246 }
247
248 pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
250 self.weight_decay = weight_decay;
251 self
252 }
253
254 pub fn get_weight_decay(&self) -> A {
256 self.weight_decay
257 }
258
259 pub fn step_sparse(
273 &mut self,
274 params: &Array<A, Ix1>,
275 gradient: &SparseGradient<A>,
276 ) -> Result<Array<A, Ix1>> {
277 if params.len() != gradient.dim {
279 return Err(OptimError::InvalidConfig(format!(
280 "Parameter dimension ({}) doesn't match gradient dimension ({})",
281 params.len(),
282 gradient.dim
283 )));
284 }
285
286 if gradient.is_empty() {
288 return Ok(params.clone());
289 }
290
291 self.t += 1;
293
294 let bias_correction1 = A::one() - self.beta1.powi(self.t as i32);
296 let bias_correction2 = A::one() - self.beta2.powi(self.t as i32);
297
298 let mut updated_params = params.clone();
300
301 for (&idx, &grad_val) in gradient.indices.iter().zip(&gradient.values) {
303 let adjusted_grad = if self.weight_decay > A::zero() {
305 grad_val + params[idx] * self.weight_decay
306 } else {
307 grad_val
308 };
309
310 let m_prev = *self.m.get(&idx).unwrap_or(&A::zero());
312 let m_t = self.beta1 * m_prev + (A::one() - self.beta1) * adjusted_grad;
313 self.m.insert(idx, m_t);
314
315 let v_prev = *self.v.get(&idx).unwrap_or(&A::zero());
317 let v_t = self.beta2 * v_prev + (A::one() - self.beta2) * adjusted_grad * adjusted_grad;
318 self.v.insert(idx, v_t);
319
320 let m_hat = m_t / bias_correction1;
322 let v_hat = v_t / bias_correction2;
323
324 let step = self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
326 updated_params[idx] = params[idx] - step;
327 }
328
329 Ok(updated_params)
330 }
331
332 pub fn reset(&mut self) {
334 self.m.clear();
335 self.v.clear();
336 self.t = 0;
337 }
338}
339
340impl<A> Optimizer<A, Ix1> for SparseAdam<A>
341where
342 A: Float + ScalarOperand + Debug + Send + Sync,
343{
344 fn step(&mut self, params: &Array<A, Ix1>, gradients: &Array<A, Ix1>) -> Result<Array<A, Ix1>> {
345 let sparse_gradient = SparseGradient::from_array(gradients);
347
348 self.step_sparse(params, &sparse_gradient)
350 }
351
352 fn get_learning_rate(&self) -> A {
353 self.learning_rate
354 }
355
356 fn set_learning_rate(&mut self, learning_rate: A) {
357 self.learning_rate = learning_rate;
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use approx::assert_abs_diff_eq;
365 use scirs2_core::ndarray::Array1;
366
367 #[test]
368 fn test_sparse_gradient_creation() {
369 let indices = vec![0, 2, 4];
370 let values = vec![1.0, 2.0, 3.0];
371 let dim = 5;
372
373 let sparse_grad = SparseGradient::new(indices, values, dim);
374
375 assert_eq!(sparse_grad.indices, vec![0, 2, 4]);
376 assert_eq!(sparse_grad.values, vec![1.0, 2.0, 3.0]);
377 assert_eq!(sparse_grad.dim, 5);
378 }
379
380 #[test]
381 fn test_sparse_gradient_from_array() {
382 let dense = Array1::from_vec(vec![1.0, 0.0, 2.0, 0.0, 3.0]);
383 let sparse_grad = SparseGradient::from_array(&dense);
384
385 assert_eq!(sparse_grad.indices, vec![0, 2, 4]);
386 assert_eq!(sparse_grad.values, vec![1.0, 2.0, 3.0]);
387 assert_eq!(sparse_grad.dim, 5);
388 }
389
390 #[test]
391 fn test_sparse_gradient_to_array() {
392 let indices = vec![0, 2, 4];
393 let values = vec![1.0, 2.0, 3.0];
394 let dim = 5;
395
396 let sparse_grad = SparseGradient::new(indices, values, dim);
397 let dense = sparse_grad.to_array();
398
399 let expected = Array1::from_vec(vec![1.0, 0.0, 2.0, 0.0, 3.0]);
400 assert_eq!(dense, expected);
401 }
402
403 #[test]
404 fn test_sparse_adam_creation() {
405 let optimizer = SparseAdam::<f64>::new(0.001);
406
407 assert_eq!(optimizer.get_learning_rate(), 0.001);
408 assert_eq!(optimizer.get_beta1(), 0.9);
409 assert_eq!(optimizer.get_beta2(), 0.999);
410 assert_eq!(optimizer.get_epsilon(), 1e-8);
411 assert_eq!(optimizer.get_weight_decay(), 0.0);
412 }
413
414 #[test]
415 fn test_sparse_adam_step() {
416 let mut optimizer = SparseAdam::<f64>::new(0.1);
417
418 let params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
420
421 let sparse_grad = SparseGradient::new(
423 vec![1, 3], vec![0.2, 0.5], 5, );
427
428 let updated_params = optimizer
430 .step_sparse(¶ms, &sparse_grad)
431 .expect("unwrap failed");
432
433 assert_abs_diff_eq!(updated_params[0], 0.0);
435 assert!(updated_params[1] < 0.0); assert_abs_diff_eq!(updated_params[2], 0.0);
437 assert!(updated_params[3] < 0.0); assert_abs_diff_eq!(updated_params[4], 0.0);
439
440 assert!(updated_params[3].abs() > updated_params[1].abs());
442 }
443
444 #[test]
445 fn test_sparse_adam_vs_dense_adam() {
446 let mut sparse_optimizer = SparseAdam::<f64>::new(0.1);
447 let mut dense_optimizer = crate::optimizers::adam::Adam::<f64>::new(0.1);
448
449 let params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
451
452 let dense_grad = Array1::from_vec(vec![0.0, 0.2, 0.0, 0.5, 0.0]);
454
455 let sparse_grad = SparseGradient::from_array(&dense_grad);
457
458 let sparse_result = sparse_optimizer
460 .step_sparse(¶ms, &sparse_grad)
461 .expect("unwrap failed");
462 let dense_result = dense_optimizer
463 .step(¶ms, &dense_grad)
464 .expect("unwrap failed");
465
466 assert_abs_diff_eq!(sparse_result[0], dense_result[0]);
468 assert_abs_diff_eq!(sparse_result[1], dense_result[1], epsilon = 1e-10);
469 assert_abs_diff_eq!(sparse_result[2], dense_result[2]);
470 assert_abs_diff_eq!(sparse_result[3], dense_result[3], epsilon = 1e-10);
471 assert_abs_diff_eq!(sparse_result[4], dense_result[4]);
472 }
473
474 #[test]
475 fn test_sparse_adam_multiple_steps() {
476 let mut optimizer = SparseAdam::<f64>::new(0.1);
477 let mut params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
478
479 let sparse_grad1 = SparseGradient::new(
481 vec![1, 3], vec![0.2, 0.5], 5, );
485
486 params = optimizer
487 .step_sparse(¶ms, &sparse_grad1)
488 .expect("unwrap failed");
489
490 let sparse_grad2 = SparseGradient::new(
492 vec![0, 2], vec![0.3, 0.4], 5, );
496
497 params = optimizer
498 .step_sparse(¶ms, &sparse_grad2)
499 .expect("unwrap failed");
500
501 assert!(params[0] < 0.0);
503 assert!(params[1] < 0.0);
504 assert!(params[2] < 0.0);
505 assert!(params[3] < 0.0);
506 assert_abs_diff_eq!(params[4], 0.0);
507
508 params = optimizer
510 .step_sparse(¶ms, &sparse_grad2)
511 .expect("unwrap failed");
512
513 let prev_param0 = params[0];
515 let prev_param2 = params[2];
516
517 params = optimizer
518 .step_sparse(¶ms, &sparse_grad2)
519 .expect("unwrap failed");
520
521 assert!(params[0].abs() > prev_param0.abs());
522 assert!(params[2].abs() > prev_param2.abs());
523 }
524
525 #[test]
526 fn test_sparse_adam_with_weight_decay() {
527 let mut optimizer = SparseAdam::<f64>::new(0.1).with_weight_decay(0.01);
528
529 let params = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
531
532 let sparse_grad = SparseGradient::new(
534 vec![1, 3], vec![0.2, 0.5], 5, );
538
539 let mut optimizer_no_decay = SparseAdam::<f64>::new(0.1);
541
542 let with_decay = optimizer
543 .step_sparse(¶ms, &sparse_grad)
544 .expect("unwrap failed");
545 let without_decay = optimizer_no_decay
546 .step_sparse(¶ms, &sparse_grad)
547 .expect("unwrap failed");
548
549 assert!(with_decay[1] != without_decay[1]);
551 assert!(with_decay[3] != without_decay[3]);
552
553 assert_abs_diff_eq!(with_decay[0], params[0]);
555 assert_abs_diff_eq!(with_decay[2], params[2]);
556 assert_abs_diff_eq!(with_decay[4], params[4]);
557 }
558
559 #[test]
560 fn test_sparse_adam_empty_gradient() {
561 let mut optimizer = SparseAdam::<f64>::new(0.1);
562
563 let params = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
565
566 let sparse_grad = SparseGradient::new(
568 vec![], vec![], 5, );
572
573 let result = optimizer
575 .step_sparse(¶ms, &sparse_grad)
576 .expect("unwrap failed");
577 assert_eq!(result, params);
578 }
579
580 #[test]
581 fn test_sparse_adam_reset() {
582 let mut optimizer = SparseAdam::<f64>::new(0.1);
583
584 let params = Array1::from_vec(vec![0.0; 5]);
586
587 let sparse_grad = SparseGradient::new(
589 vec![1, 3], vec![0.2, 0.5], 5, );
593
594 for _ in 0..10 {
596 optimizer
597 .step_sparse(¶ms, &sparse_grad)
598 .expect("unwrap failed");
599 }
600
601 optimizer.reset();
603
604 let mut new_optimizer = SparseAdam::<f64>::new(0.1);
606 let reset_result = optimizer
607 .step_sparse(¶ms, &sparse_grad)
608 .expect("unwrap failed");
609 let new_result = new_optimizer
610 .step_sparse(¶ms, &sparse_grad)
611 .expect("unwrap failed");
612
613 assert_abs_diff_eq!(reset_result[1], new_result[1], epsilon = 1e-10);
614 assert_abs_diff_eq!(reset_result[3], new_result[3], epsilon = 1e-10);
615 }
616}