1use scirs2_core::ndarray::{Array, Ix1, ScalarOperand};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12use crate::error::{OptimError, Result};
13use crate::optimizers::Optimizer;
14
15pub struct SparseGradient<A: Float + ScalarOperand + Debug> {
20 pub indices: Vec<usize>,
22 pub values: Vec<A>,
24 pub dim: usize,
26}
27
28impl<A: Float + ScalarOperand + Debug + Send + Sync> SparseGradient<A> {
29 pub fn new(indices: Vec<usize>, values: Vec<A>, dim: usize) -> Self {
31 assert_eq!(
32 indices.len(),
33 values.len(),
34 "Indices and values must have the same length"
35 );
36 if let Some(&max_idx) = indices.iter().max() {
38 assert!(
39 max_idx < dim,
40 "Index {} is out of bounds for dimension {}",
41 max_idx,
42 dim
43 );
44 }
45 Self {
46 indices,
47 values,
48 dim,
49 }
50 }
51
52 pub fn from_array(array: &Array<A, Ix1>) -> Self {
54 let mut indices = Vec::new();
55 let mut values = Vec::new();
56
57 for (idx, &val) in array.iter().enumerate() {
58 if !val.is_zero() {
59 indices.push(idx);
60 values.push(val);
61 }
62 }
63
64 Self {
65 indices,
66 values,
67 dim: array.len(),
68 }
69 }
70
71 pub fn to_array(&self) -> Array<A, Ix1> {
73 let mut array = Array::zeros(self.dim);
74 for (&idx, &val) in self.indices.iter().zip(&self.values) {
75 array[idx] = val;
76 }
77 array
78 }
79
80 pub fn is_empty(&self) -> bool {
82 self.indices.is_empty()
83 }
84}
85
86#[derive(Debug, Clone)]
125pub struct SparseAdam<A: Float + ScalarOperand + Debug> {
126 learning_rate: A,
128 beta1: A,
130 beta2: A,
132 epsilon: A,
134 weight_decay: A,
136 m: HashMap<usize, A>,
138 v: HashMap<usize, A>,
140 t: usize,
142}
143
144impl<A: Float + ScalarOperand + Debug + Send + Sync> SparseAdam<A> {
145 pub fn new(learning_rate: A) -> Self {
151 Self {
152 learning_rate,
153 beta1: A::from(0.9).unwrap(),
154 beta2: A::from(0.999).unwrap(),
155 epsilon: A::from(1e-8).unwrap(),
156 weight_decay: A::zero(),
157 m: HashMap::new(),
158 v: HashMap::new(),
159 t: 0,
160 }
161 }
162
163 pub fn new_with_config(
173 learning_rate: A,
174 beta1: A,
175 beta2: A,
176 epsilon: A,
177 weight_decay: A,
178 ) -> Self {
179 Self {
180 learning_rate,
181 beta1,
182 beta2,
183 epsilon,
184 weight_decay,
185 m: HashMap::new(),
186 v: HashMap::new(),
187 t: 0,
188 }
189 }
190
191 pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
193 self.beta1 = beta1;
194 self
195 }
196
197 pub fn with_beta1(mut self, beta1: A) -> Self {
199 self.beta1 = beta1;
200 self
201 }
202
203 pub fn get_beta1(&self) -> A {
205 self.beta1
206 }
207
208 pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
210 self.beta2 = beta2;
211 self
212 }
213
214 pub fn with_beta2(mut self, beta2: A) -> Self {
216 self.beta2 = beta2;
217 self
218 }
219
220 pub fn get_beta2(&self) -> A {
222 self.beta2
223 }
224
225 pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
227 self.epsilon = epsilon;
228 self
229 }
230
231 pub fn with_epsilon(mut self, epsilon: A) -> Self {
233 self.epsilon = epsilon;
234 self
235 }
236
237 pub fn get_epsilon(&self) -> A {
239 self.epsilon
240 }
241
242 pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
244 self.weight_decay = weight_decay;
245 self
246 }
247
248 pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
250 self.weight_decay = weight_decay;
251 self
252 }
253
254 pub fn get_weight_decay(&self) -> A {
256 self.weight_decay
257 }
258
259 pub fn step_sparse(
273 &mut self,
274 params: &Array<A, Ix1>,
275 gradient: &SparseGradient<A>,
276 ) -> Result<Array<A, Ix1>> {
277 if params.len() != gradient.dim {
279 return Err(OptimError::InvalidConfig(format!(
280 "Parameter dimension ({}) doesn't match gradient dimension ({})",
281 params.len(),
282 gradient.dim
283 )));
284 }
285
286 if gradient.is_empty() {
288 return Ok(params.clone());
289 }
290
291 self.t += 1;
293
294 let bias_correction1 = A::one() - self.beta1.powi(self.t as i32);
296 let bias_correction2 = A::one() - self.beta2.powi(self.t as i32);
297
298 let mut updated_params = params.clone();
300
301 for (&idx, &grad_val) in gradient.indices.iter().zip(&gradient.values) {
303 let adjusted_grad = if self.weight_decay > A::zero() {
305 grad_val + params[idx] * self.weight_decay
306 } else {
307 grad_val
308 };
309
310 let m_prev = *self.m.get(&idx).unwrap_or(&A::zero());
312 let m_t = self.beta1 * m_prev + (A::one() - self.beta1) * adjusted_grad;
313 self.m.insert(idx, m_t);
314
315 let v_prev = *self.v.get(&idx).unwrap_or(&A::zero());
317 let v_t = self.beta2 * v_prev + (A::one() - self.beta2) * adjusted_grad * adjusted_grad;
318 self.v.insert(idx, v_t);
319
320 let m_hat = m_t / bias_correction1;
322 let v_hat = v_t / bias_correction2;
323
324 let step = self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
326 updated_params[idx] = params[idx] - step;
327 }
328
329 Ok(updated_params)
330 }
331
332 pub fn reset(&mut self) {
334 self.m.clear();
335 self.v.clear();
336 self.t = 0;
337 }
338}
339
340impl<A> Optimizer<A, Ix1> for SparseAdam<A>
341where
342 A: Float + ScalarOperand + Debug + Send + Sync,
343{
344 fn step(&mut self, params: &Array<A, Ix1>, gradients: &Array<A, Ix1>) -> Result<Array<A, Ix1>> {
345 let sparse_gradient = SparseGradient::from_array(gradients);
347
348 self.step_sparse(params, &sparse_gradient)
350 }
351
352 fn get_learning_rate(&self) -> A {
353 self.learning_rate
354 }
355
356 fn set_learning_rate(&mut self, learning_rate: A) {
357 self.learning_rate = learning_rate;
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use approx::assert_abs_diff_eq;
365 use scirs2_core::ndarray::Array1;
366
367 #[test]
368 fn test_sparse_gradient_creation() {
369 let indices = vec![0, 2, 4];
370 let values = vec![1.0, 2.0, 3.0];
371 let dim = 5;
372
373 let sparse_grad = SparseGradient::new(indices, values, dim);
374
375 assert_eq!(sparse_grad.indices, vec![0, 2, 4]);
376 assert_eq!(sparse_grad.values, vec![1.0, 2.0, 3.0]);
377 assert_eq!(sparse_grad.dim, 5);
378 }
379
380 #[test]
381 fn test_sparse_gradient_from_array() {
382 let dense = Array1::from_vec(vec![1.0, 0.0, 2.0, 0.0, 3.0]);
383 let sparse_grad = SparseGradient::from_array(&dense);
384
385 assert_eq!(sparse_grad.indices, vec![0, 2, 4]);
386 assert_eq!(sparse_grad.values, vec![1.0, 2.0, 3.0]);
387 assert_eq!(sparse_grad.dim, 5);
388 }
389
390 #[test]
391 fn test_sparse_gradient_to_array() {
392 let indices = vec![0, 2, 4];
393 let values = vec![1.0, 2.0, 3.0];
394 let dim = 5;
395
396 let sparse_grad = SparseGradient::new(indices, values, dim);
397 let dense = sparse_grad.to_array();
398
399 let expected = Array1::from_vec(vec![1.0, 0.0, 2.0, 0.0, 3.0]);
400 assert_eq!(dense, expected);
401 }
402
403 #[test]
404 fn test_sparse_adam_creation() {
405 let optimizer = SparseAdam::<f64>::new(0.001);
406
407 assert_eq!(optimizer.get_learning_rate(), 0.001);
408 assert_eq!(optimizer.get_beta1(), 0.9);
409 assert_eq!(optimizer.get_beta2(), 0.999);
410 assert_eq!(optimizer.get_epsilon(), 1e-8);
411 assert_eq!(optimizer.get_weight_decay(), 0.0);
412 }
413
414 #[test]
415 fn test_sparse_adam_step() {
416 let mut optimizer = SparseAdam::<f64>::new(0.1);
417
418 let params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
420
421 let sparse_grad = SparseGradient::new(
423 vec![1, 3], vec![0.2, 0.5], 5, );
427
428 let updated_params = optimizer.step_sparse(¶ms, &sparse_grad).unwrap();
430
431 assert_abs_diff_eq!(updated_params[0], 0.0);
433 assert!(updated_params[1] < 0.0); assert_abs_diff_eq!(updated_params[2], 0.0);
435 assert!(updated_params[3] < 0.0); assert_abs_diff_eq!(updated_params[4], 0.0);
437
438 assert!(updated_params[3].abs() > updated_params[1].abs());
440 }
441
442 #[test]
443 fn test_sparse_adam_vs_dense_adam() {
444 let mut sparse_optimizer = SparseAdam::<f64>::new(0.1);
445 let mut dense_optimizer = crate::optimizers::adam::Adam::<f64>::new(0.1);
446
447 let params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
449
450 let dense_grad = Array1::from_vec(vec![0.0, 0.2, 0.0, 0.5, 0.0]);
452
453 let sparse_grad = SparseGradient::from_array(&dense_grad);
455
456 let sparse_result = sparse_optimizer.step_sparse(¶ms, &sparse_grad).unwrap();
458 let dense_result = dense_optimizer.step(¶ms, &dense_grad).unwrap();
459
460 assert_abs_diff_eq!(sparse_result[0], dense_result[0]);
462 assert_abs_diff_eq!(sparse_result[1], dense_result[1], epsilon = 1e-10);
463 assert_abs_diff_eq!(sparse_result[2], dense_result[2]);
464 assert_abs_diff_eq!(sparse_result[3], dense_result[3], epsilon = 1e-10);
465 assert_abs_diff_eq!(sparse_result[4], dense_result[4]);
466 }
467
468 #[test]
469 fn test_sparse_adam_multiple_steps() {
470 let mut optimizer = SparseAdam::<f64>::new(0.1);
471 let mut params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
472
473 let sparse_grad1 = SparseGradient::new(
475 vec![1, 3], vec![0.2, 0.5], 5, );
479
480 params = optimizer.step_sparse(¶ms, &sparse_grad1).unwrap();
481
482 let sparse_grad2 = SparseGradient::new(
484 vec![0, 2], vec![0.3, 0.4], 5, );
488
489 params = optimizer.step_sparse(¶ms, &sparse_grad2).unwrap();
490
491 assert!(params[0] < 0.0);
493 assert!(params[1] < 0.0);
494 assert!(params[2] < 0.0);
495 assert!(params[3] < 0.0);
496 assert_abs_diff_eq!(params[4], 0.0);
497
498 params = optimizer.step_sparse(¶ms, &sparse_grad2).unwrap();
500
501 let prev_param0 = params[0];
503 let prev_param2 = params[2];
504
505 params = optimizer.step_sparse(¶ms, &sparse_grad2).unwrap();
506
507 assert!(params[0].abs() > prev_param0.abs());
508 assert!(params[2].abs() > prev_param2.abs());
509 }
510
511 #[test]
512 fn test_sparse_adam_with_weight_decay() {
513 let mut optimizer = SparseAdam::<f64>::new(0.1).with_weight_decay(0.01);
514
515 let params = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
517
518 let sparse_grad = SparseGradient::new(
520 vec![1, 3], vec![0.2, 0.5], 5, );
524
525 let mut optimizer_no_decay = SparseAdam::<f64>::new(0.1);
527
528 let with_decay = optimizer.step_sparse(¶ms, &sparse_grad).unwrap();
529 let without_decay = optimizer_no_decay
530 .step_sparse(¶ms, &sparse_grad)
531 .unwrap();
532
533 assert!(with_decay[1] != without_decay[1]);
535 assert!(with_decay[3] != without_decay[3]);
536
537 assert_abs_diff_eq!(with_decay[0], params[0]);
539 assert_abs_diff_eq!(with_decay[2], params[2]);
540 assert_abs_diff_eq!(with_decay[4], params[4]);
541 }
542
543 #[test]
544 fn test_sparse_adam_empty_gradient() {
545 let mut optimizer = SparseAdam::<f64>::new(0.1);
546
547 let params = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
549
550 let sparse_grad = SparseGradient::new(
552 vec![], vec![], 5, );
556
557 let result = optimizer.step_sparse(¶ms, &sparse_grad).unwrap();
559 assert_eq!(result, params);
560 }
561
562 #[test]
563 fn test_sparse_adam_reset() {
564 let mut optimizer = SparseAdam::<f64>::new(0.1);
565
566 let params = Array1::from_vec(vec![0.0; 5]);
568
569 let sparse_grad = SparseGradient::new(
571 vec![1, 3], vec![0.2, 0.5], 5, );
575
576 for _ in 0..10 {
578 optimizer.step_sparse(¶ms, &sparse_grad).unwrap();
579 }
580
581 optimizer.reset();
583
584 let mut new_optimizer = SparseAdam::<f64>::new(0.1);
586 let reset_result = optimizer.step_sparse(¶ms, &sparse_grad).unwrap();
587 let new_result = new_optimizer.step_sparse(¶ms, &sparse_grad).unwrap();
588
589 assert_abs_diff_eq!(reset_result[1], new_result[1], epsilon = 1e-10);
590 assert_abs_diff_eq!(reset_result[3], new_result[3], epsilon = 1e-10);
591 }
592}