Skip to main content

optirs_core/optimizers/
sparse_adam.rs

1// SparseAdam optimizer implementation for sparse gradients
2//
3// This module implements a variant of the Adam optimizer that efficiently
4// handles sparse gradients by only updating the parameters and moments
5// for indices that have non-zero gradients.
6
7use scirs2_core::ndarray::{Array, Ix1, ScalarOperand};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12use crate::error::{OptimError, Result};
13use crate::optimizers::Optimizer;
14
15/// A struct representing a sparse gradient with indices and values
16///
17/// This provides a convenient interface for working with sparse gradients
18/// where most elements are zero.
19pub struct SparseGradient<A: Float + ScalarOperand + Debug> {
20    /// The indices of non-zero elements
21    pub indices: Vec<usize>,
22    /// The values at the non-zero elements
23    pub values: Vec<A>,
24    /// The total dimension of the gradient (including zero elements)
25    pub dim: usize,
26}
27
28impl<A: Float + ScalarOperand + Debug + Send + Sync> SparseGradient<A> {
29    /// Create a new sparse gradient from indices and values
30    pub fn new(indices: Vec<usize>, values: Vec<A>, dim: usize) -> Self {
31        assert_eq!(
32            indices.len(),
33            values.len(),
34            "Indices and values must have the same length"
35        );
36        // Ensure no index is out of bounds
37        if let Some(&max_idx) = indices.iter().max() {
38            assert!(
39                max_idx < dim,
40                "Index {} is out of bounds for dimension {}",
41                max_idx,
42                dim
43            );
44        }
45        Self {
46            indices,
47            values,
48            dim,
49        }
50    }
51
52    /// Create a sparse gradient from a dense array, keeping only non-zero entries
53    pub fn from_array(array: &Array<A, Ix1>) -> Self {
54        let mut indices = Vec::new();
55        let mut values = Vec::new();
56
57        for (idx, &val) in array.iter().enumerate() {
58            if !val.is_zero() {
59                indices.push(idx);
60                values.push(val);
61            }
62        }
63
64        Self {
65            indices,
66            values,
67            dim: array.len(),
68        }
69    }
70
71    /// Convert the sparse gradient to a dense array
72    pub fn to_array(&self) -> Array<A, Ix1> {
73        let mut array = Array::zeros(self.dim);
74        for (&idx, &val) in self.indices.iter().zip(&self.values) {
75            array[idx] = val;
76        }
77        array
78    }
79
80    /// Check if this sparse gradient is empty (all zeros)
81    pub fn is_empty(&self) -> bool {
82        self.indices.is_empty()
83    }
84}
85
86/// SparseAdam optimizer for sparse gradients
87///
88/// Implements a variant of the Adam optimization algorithm that's optimized for
89/// sparse gradients. It only updates the parameters and momentum vectors
90/// for indices that have non-zero gradients, saving computation and memory.
91///
92/// This optimizer is particularly useful for large embedding layers or
93/// models with sparse input features.
94///
95/// Formula (for non-zero gradient indices):
96/// m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
97/// v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
98/// m_hat_t = m_t / (1 - beta1^t)
99/// v_hat_t = v_t / (1 - beta2^t)
100/// theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
101///
102/// # Examples
103///
104/// ```
105/// use scirs2_core::ndarray::Array1;
106/// use optirs_core::optimizers::{SparseAdam, SparseGradient, Optimizer};
107///
108/// // Initialize parameters
109/// let params = Array1::zeros(5);
110///
111/// // Create sparse gradient with non-zero values at indices 1 and 3
112/// let sparse_grad = SparseGradient::new(
113///     vec![1, 3],             // Indices
114///     vec![0.2, 0.5],         // Values
115///     5                       // Total dimension
116/// );
117///
118/// // Create a SparseAdam optimizer
119/// let mut optimizer = SparseAdam::new(0.001);
120///
121/// // Update parameters with sparse gradient
122/// let new_params = optimizer.step_sparse(&params, &sparse_grad).expect("unwrap failed");
123/// ```
124#[derive(Debug, Clone)]
125pub struct SparseAdam<A: Float + ScalarOperand + Debug> {
126    /// Learning rate
127    learning_rate: A,
128    /// Exponential decay rate for the first moment estimates
129    beta1: A,
130    /// Exponential decay rate for the second moment estimates
131    beta2: A,
132    /// Small constant for numerical stability
133    epsilon: A,
134    /// Weight decay factor (L2 regularization)
135    weight_decay: A,
136    /// First moment vector stored as a hash map for sparse updates
137    m: HashMap<usize, A>,
138    /// Second moment vector stored as a hash map for sparse updates
139    v: HashMap<usize, A>,
140    /// Current timestep
141    t: usize,
142}
143
144impl<A: Float + ScalarOperand + Debug + Send + Sync> SparseAdam<A> {
145    /// Creates a new SparseAdam optimizer with the given learning rate and default settings
146    ///
147    /// # Arguments
148    ///
149    /// * `learning_rate` - The learning rate for parameter updates
150    pub fn new(learning_rate: A) -> Self {
151        Self {
152            learning_rate,
153            beta1: A::from(0.9).expect("unwrap failed"),
154            beta2: A::from(0.999).expect("unwrap failed"),
155            epsilon: A::from(1e-8).expect("unwrap failed"),
156            weight_decay: A::zero(),
157            m: HashMap::new(),
158            v: HashMap::new(),
159            t: 0,
160        }
161    }
162
163    /// Creates a new SparseAdam optimizer with the full configuration
164    ///
165    /// # Arguments
166    ///
167    /// * `learning_rate` - The learning rate for parameter updates
168    /// * `beta1` - Exponential decay rate for the first moment estimates (default: 0.9)
169    /// * `beta2` - Exponential decay rate for the second moment estimates (default: 0.999)
170    /// * `epsilon` - Small constant for numerical stability (default: 1e-8)
171    /// * `weight_decay` - Weight decay factor for L2 regularization (default: 0.0)
172    pub fn new_with_config(
173        learning_rate: A,
174        beta1: A,
175        beta2: A,
176        epsilon: A,
177        weight_decay: A,
178    ) -> Self {
179        Self {
180            learning_rate,
181            beta1,
182            beta2,
183            epsilon,
184            weight_decay,
185            m: HashMap::new(),
186            v: HashMap::new(),
187            t: 0,
188        }
189    }
190
191    /// Sets the beta1 parameter
192    pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
193        self.beta1 = beta1;
194        self
195    }
196
197    /// Builder method to set beta1 and return self
198    pub fn with_beta1(mut self, beta1: A) -> Self {
199        self.beta1 = beta1;
200        self
201    }
202
203    /// Gets the beta1 parameter
204    pub fn get_beta1(&self) -> A {
205        self.beta1
206    }
207
208    /// Sets the beta2 parameter
209    pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
210        self.beta2 = beta2;
211        self
212    }
213
214    /// Builder method to set beta2 and return self
215    pub fn with_beta2(mut self, beta2: A) -> Self {
216        self.beta2 = beta2;
217        self
218    }
219
220    /// Gets the beta2 parameter
221    pub fn get_beta2(&self) -> A {
222        self.beta2
223    }
224
225    /// Sets the epsilon parameter
226    pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
227        self.epsilon = epsilon;
228        self
229    }
230
231    /// Builder method to set epsilon and return self
232    pub fn with_epsilon(mut self, epsilon: A) -> Self {
233        self.epsilon = epsilon;
234        self
235    }
236
237    /// Gets the epsilon parameter
238    pub fn get_epsilon(&self) -> A {
239        self.epsilon
240    }
241
242    /// Sets the weight decay parameter
243    pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
244        self.weight_decay = weight_decay;
245        self
246    }
247
248    /// Builder method to set weight decay and return self
249    pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
250        self.weight_decay = weight_decay;
251        self
252    }
253
254    /// Gets the weight decay parameter
255    pub fn get_weight_decay(&self) -> A {
256        self.weight_decay
257    }
258
259    /// Updates parameters using sparse gradients
260    ///
261    /// This method efficiently updates only the parameters corresponding to
262    /// non-zero gradient entries, saving computation and memory.
263    ///
264    /// # Arguments
265    ///
266    /// * `params` - The parameters to update
267    /// * `gradient` - The sparse gradient
268    ///
269    /// # Returns
270    ///
271    /// The updated parameters
272    pub fn step_sparse(
273        &mut self,
274        params: &Array<A, Ix1>,
275        gradient: &SparseGradient<A>,
276    ) -> Result<Array<A, Ix1>> {
277        // Verify dimensions match
278        if params.len() != gradient.dim {
279            return Err(OptimError::InvalidConfig(format!(
280                "Parameter dimension ({}) doesn't match gradient dimension ({})",
281                params.len(),
282                gradient.dim
283            )));
284        }
285
286        // If gradient is empty, just return the parameters unchanged
287        if gradient.is_empty() {
288            return Ok(params.clone());
289        }
290
291        // Increment timestep
292        self.t += 1;
293
294        // Compute the bias correction terms
295        let bias_correction1 = A::one() - self.beta1.powi(self.t as i32);
296        let bias_correction2 = A::one() - self.beta2.powi(self.t as i32);
297
298        // Create a copy of the parameters that we'll update
299        let mut updated_params = params.clone();
300
301        // Update only the parameters for which we have non-zero gradients
302        for (&idx, &grad_val) in gradient.indices.iter().zip(&gradient.values) {
303            // Apply weight decay if needed
304            let adjusted_grad = if self.weight_decay > A::zero() {
305                grad_val + params[idx] * self.weight_decay
306            } else {
307                grad_val
308            };
309
310            // Update first moment (m)
311            let m_prev = *self.m.get(&idx).unwrap_or(&A::zero());
312            let m_t = self.beta1 * m_prev + (A::one() - self.beta1) * adjusted_grad;
313            self.m.insert(idx, m_t);
314
315            // Update second moment (v)
316            let v_prev = *self.v.get(&idx).unwrap_or(&A::zero());
317            let v_t = self.beta2 * v_prev + (A::one() - self.beta2) * adjusted_grad * adjusted_grad;
318            self.v.insert(idx, v_t);
319
320            // Bias-corrected first and second moment estimates
321            let m_hat = m_t / bias_correction1;
322            let v_hat = v_t / bias_correction2;
323
324            // Update parameter
325            let step = self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
326            updated_params[idx] = params[idx] - step;
327        }
328
329        Ok(updated_params)
330    }
331
332    /// Resets the internal state of the optimizer
333    pub fn reset(&mut self) {
334        self.m.clear();
335        self.v.clear();
336        self.t = 0;
337    }
338}
339
340impl<A> Optimizer<A, Ix1> for SparseAdam<A>
341where
342    A: Float + ScalarOperand + Debug + Send + Sync,
343{
344    fn step(&mut self, params: &Array<A, Ix1>, gradients: &Array<A, Ix1>) -> Result<Array<A, Ix1>> {
345        // Convert dense gradient to sparse
346        let sparse_gradient = SparseGradient::from_array(gradients);
347
348        // Call step_sparse with the converted gradient
349        self.step_sparse(params, &sparse_gradient)
350    }
351
352    fn get_learning_rate(&self) -> A {
353        self.learning_rate
354    }
355
356    fn set_learning_rate(&mut self, learning_rate: A) {
357        self.learning_rate = learning_rate;
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use approx::assert_abs_diff_eq;
365    use scirs2_core::ndarray::Array1;
366
367    #[test]
368    fn test_sparse_gradient_creation() {
369        let indices = vec![0, 2, 4];
370        let values = vec![1.0, 2.0, 3.0];
371        let dim = 5;
372
373        let sparse_grad = SparseGradient::new(indices, values, dim);
374
375        assert_eq!(sparse_grad.indices, vec![0, 2, 4]);
376        assert_eq!(sparse_grad.values, vec![1.0, 2.0, 3.0]);
377        assert_eq!(sparse_grad.dim, 5);
378    }
379
380    #[test]
381    fn test_sparse_gradient_from_array() {
382        let dense = Array1::from_vec(vec![1.0, 0.0, 2.0, 0.0, 3.0]);
383        let sparse_grad = SparseGradient::from_array(&dense);
384
385        assert_eq!(sparse_grad.indices, vec![0, 2, 4]);
386        assert_eq!(sparse_grad.values, vec![1.0, 2.0, 3.0]);
387        assert_eq!(sparse_grad.dim, 5);
388    }
389
390    #[test]
391    fn test_sparse_gradient_to_array() {
392        let indices = vec![0, 2, 4];
393        let values = vec![1.0, 2.0, 3.0];
394        let dim = 5;
395
396        let sparse_grad = SparseGradient::new(indices, values, dim);
397        let dense = sparse_grad.to_array();
398
399        let expected = Array1::from_vec(vec![1.0, 0.0, 2.0, 0.0, 3.0]);
400        assert_eq!(dense, expected);
401    }
402
403    #[test]
404    fn test_sparse_adam_creation() {
405        let optimizer = SparseAdam::<f64>::new(0.001);
406
407        assert_eq!(optimizer.get_learning_rate(), 0.001);
408        assert_eq!(optimizer.get_beta1(), 0.9);
409        assert_eq!(optimizer.get_beta2(), 0.999);
410        assert_eq!(optimizer.get_epsilon(), 1e-8);
411        assert_eq!(optimizer.get_weight_decay(), 0.0);
412    }
413
414    #[test]
415    fn test_sparse_adam_step() {
416        let mut optimizer = SparseAdam::<f64>::new(0.1);
417
418        // Initialize parameters
419        let params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
420
421        // Create sparse gradient with non-zero values at indices 1 and 3
422        let sparse_grad = SparseGradient::new(
423            vec![1, 3],     // Indices
424            vec![0.2, 0.5], // Values
425            5,              // Total dimension
426        );
427
428        // First update
429        let updated_params = optimizer
430            .step_sparse(&params, &sparse_grad)
431            .expect("unwrap failed");
432
433        // Only the parameters at indices 1 and 3 should be updated
434        assert_abs_diff_eq!(updated_params[0], 0.0);
435        assert!(updated_params[1] < 0.0); // Should be negative due to gradient descent
436        assert_abs_diff_eq!(updated_params[2], 0.0);
437        assert!(updated_params[3] < 0.0); // Should be negative due to gradient descent
438        assert_abs_diff_eq!(updated_params[4], 0.0);
439
440        // The parameter at index 3 should have a larger update due to larger gradient
441        assert!(updated_params[3].abs() > updated_params[1].abs());
442    }
443
444    #[test]
445    fn test_sparse_adam_vs_dense_adam() {
446        let mut sparse_optimizer = SparseAdam::<f64>::new(0.1);
447        let mut dense_optimizer = crate::optimizers::adam::Adam::<f64>::new(0.1);
448
449        // Initialize parameters
450        let params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
451
452        // Create a dense gradient with only some non-zero values
453        let dense_grad = Array1::from_vec(vec![0.0, 0.2, 0.0, 0.5, 0.0]);
454
455        // Create equivalent sparse gradient
456        let sparse_grad = SparseGradient::from_array(&dense_grad);
457
458        // Update with both optimizers
459        let sparse_result = sparse_optimizer
460            .step_sparse(&params, &sparse_grad)
461            .expect("unwrap failed");
462        let dense_result = dense_optimizer
463            .step(&params, &dense_grad)
464            .expect("unwrap failed");
465
466        // Results should be nearly identical
467        assert_abs_diff_eq!(sparse_result[0], dense_result[0]);
468        assert_abs_diff_eq!(sparse_result[1], dense_result[1], epsilon = 1e-10);
469        assert_abs_diff_eq!(sparse_result[2], dense_result[2]);
470        assert_abs_diff_eq!(sparse_result[3], dense_result[3], epsilon = 1e-10);
471        assert_abs_diff_eq!(sparse_result[4], dense_result[4]);
472    }
473
474    #[test]
475    fn test_sparse_adam_multiple_steps() {
476        let mut optimizer = SparseAdam::<f64>::new(0.1);
477        let mut params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
478
479        // First step - update indices 1 and 3
480        let sparse_grad1 = SparseGradient::new(
481            vec![1, 3],     // Indices
482            vec![0.2, 0.5], // Values
483            5,              // Total dimension
484        );
485
486        params = optimizer
487            .step_sparse(&params, &sparse_grad1)
488            .expect("unwrap failed");
489
490        // Second step - update indices 0 and 2
491        let sparse_grad2 = SparseGradient::new(
492            vec![0, 2],     // Indices
493            vec![0.3, 0.4], // Values
494            5,              // Total dimension
495        );
496
497        params = optimizer
498            .step_sparse(&params, &sparse_grad2)
499            .expect("unwrap failed");
500
501        // All parameters except index 4 should now be updated
502        assert!(params[0] < 0.0);
503        assert!(params[1] < 0.0);
504        assert!(params[2] < 0.0);
505        assert!(params[3] < 0.0);
506        assert_abs_diff_eq!(params[4], 0.0);
507
508        // Third step - update the same indices again (accumulates momentum)
509        params = optimizer
510            .step_sparse(&params, &sparse_grad2)
511            .expect("unwrap failed");
512
513        // Parameters at indices 0 and 2 should have larger updates now
514        let prev_param0 = params[0];
515        let prev_param2 = params[2];
516
517        params = optimizer
518            .step_sparse(&params, &sparse_grad2)
519            .expect("unwrap failed");
520
521        assert!(params[0].abs() > prev_param0.abs());
522        assert!(params[2].abs() > prev_param2.abs());
523    }
524
525    #[test]
526    fn test_sparse_adam_with_weight_decay() {
527        let mut optimizer = SparseAdam::<f64>::new(0.1).with_weight_decay(0.01);
528
529        // Initialize parameters with non-zero values
530        let params = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
531
532        // Create sparse gradient
533        let sparse_grad = SparseGradient::new(
534            vec![1, 3],     // Indices
535            vec![0.2, 0.5], // Values
536            5,              // Total dimension
537        );
538
539        // Create a version without weight decay for comparison
540        let mut optimizer_no_decay = SparseAdam::<f64>::new(0.1);
541
542        let with_decay = optimizer
543            .step_sparse(&params, &sparse_grad)
544            .expect("unwrap failed");
545        let without_decay = optimizer_no_decay
546            .step_sparse(&params, &sparse_grad)
547            .expect("unwrap failed");
548
549        // Parameters with non-zero gradients should be different when weight decay is applied
550        assert!(with_decay[1] != without_decay[1]);
551        assert!(with_decay[3] != without_decay[3]);
552
553        // Parameters without gradients should remain the same
554        assert_abs_diff_eq!(with_decay[0], params[0]);
555        assert_abs_diff_eq!(with_decay[2], params[2]);
556        assert_abs_diff_eq!(with_decay[4], params[4]);
557    }
558
559    #[test]
560    fn test_sparse_adam_empty_gradient() {
561        let mut optimizer = SparseAdam::<f64>::new(0.1);
562
563        // Initialize parameters
564        let params = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
565
566        // Create an empty sparse gradient
567        let sparse_grad = SparseGradient::new(
568            vec![], // Empty indices
569            vec![], // Empty values
570            5,      // Total dimension
571        );
572
573        // No parameters should change
574        let result = optimizer
575            .step_sparse(&params, &sparse_grad)
576            .expect("unwrap failed");
577        assert_eq!(result, params);
578    }
579
580    #[test]
581    fn test_sparse_adam_reset() {
582        let mut optimizer = SparseAdam::<f64>::new(0.1);
583
584        // Initialize parameters
585        let params = Array1::from_vec(vec![0.0; 5]);
586
587        // Create sparse gradient
588        let sparse_grad = SparseGradient::new(
589            vec![1, 3],     // Indices
590            vec![0.2, 0.5], // Values
591            5,              // Total dimension
592        );
593
594        // Do several steps to build up momentum
595        for _ in 0..10 {
596            optimizer
597                .step_sparse(&params, &sparse_grad)
598                .expect("unwrap failed");
599        }
600
601        // Reset optimizer
602        optimizer.reset();
603
604        // The next step should be the same as the first step with a new optimizer
605        let mut new_optimizer = SparseAdam::<f64>::new(0.1);
606        let reset_result = optimizer
607            .step_sparse(&params, &sparse_grad)
608            .expect("unwrap failed");
609        let new_result = new_optimizer
610            .step_sparse(&params, &sparse_grad)
611            .expect("unwrap failed");
612
613        assert_abs_diff_eq!(reset_result[1], new_result[1], epsilon = 1e-10);
614        assert_abs_diff_eq!(reset_result[3], new_result[3], epsilon = 1e-10);
615    }
616}