optirs_core/optimizers/
sparse_adam.rs

1// SparseAdam optimizer implementation for sparse gradients
2//
3// This module implements a variant of the Adam optimizer that efficiently
4// handles sparse gradients by only updating the parameters and moments
5// for indices that have non-zero gradients.
6
7use scirs2_core::ndarray::{Array, Ix1, ScalarOperand};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12use crate::error::{OptimError, Result};
13use crate::optimizers::Optimizer;
14
15/// A struct representing a sparse gradient with indices and values
16///
17/// This provides a convenient interface for working with sparse gradients
18/// where most elements are zero.
19pub struct SparseGradient<A: Float + ScalarOperand + Debug> {
20    /// The indices of non-zero elements
21    pub indices: Vec<usize>,
22    /// The values at the non-zero elements
23    pub values: Vec<A>,
24    /// The total dimension of the gradient (including zero elements)
25    pub dim: usize,
26}
27
28impl<A: Float + ScalarOperand + Debug + Send + Sync> SparseGradient<A> {
29    /// Create a new sparse gradient from indices and values
30    pub fn new(indices: Vec<usize>, values: Vec<A>, dim: usize) -> Self {
31        assert_eq!(
32            indices.len(),
33            values.len(),
34            "Indices and values must have the same length"
35        );
36        // Ensure no index is out of bounds
37        if let Some(&max_idx) = indices.iter().max() {
38            assert!(
39                max_idx < dim,
40                "Index {} is out of bounds for dimension {}",
41                max_idx,
42                dim
43            );
44        }
45        Self {
46            indices,
47            values,
48            dim,
49        }
50    }
51
52    /// Create a sparse gradient from a dense array, keeping only non-zero entries
53    pub fn from_array(array: &Array<A, Ix1>) -> Self {
54        let mut indices = Vec::new();
55        let mut values = Vec::new();
56
57        for (idx, &val) in array.iter().enumerate() {
58            if !val.is_zero() {
59                indices.push(idx);
60                values.push(val);
61            }
62        }
63
64        Self {
65            indices,
66            values,
67            dim: array.len(),
68        }
69    }
70
71    /// Convert the sparse gradient to a dense array
72    pub fn to_array(&self) -> Array<A, Ix1> {
73        let mut array = Array::zeros(self.dim);
74        for (&idx, &val) in self.indices.iter().zip(&self.values) {
75            array[idx] = val;
76        }
77        array
78    }
79
80    /// Check if this sparse gradient is empty (all zeros)
81    pub fn is_empty(&self) -> bool {
82        self.indices.is_empty()
83    }
84}
85
86/// SparseAdam optimizer for sparse gradients
87///
88/// Implements a variant of the Adam optimization algorithm that's optimized for
89/// sparse gradients. It only updates the parameters and momentum vectors
90/// for indices that have non-zero gradients, saving computation and memory.
91///
92/// This optimizer is particularly useful for large embedding layers or
93/// models with sparse input features.
94///
95/// Formula (for non-zero gradient indices):
96/// m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
97/// v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
98/// m_hat_t = m_t / (1 - beta1^t)
99/// v_hat_t = v_t / (1 - beta2^t)
100/// theta_t = theta_{t-1} - alpha * m_hat_t / (sqrt(v_hat_t) + epsilon)
101///
102/// # Examples
103///
104/// ```
105/// use scirs2_core::ndarray::Array1;
106/// use optirs_core::optimizers::{SparseAdam, SparseGradient, Optimizer};
107///
108/// // Initialize parameters
109/// let params = Array1::zeros(5);
110///
111/// // Create sparse gradient with non-zero values at indices 1 and 3
112/// let sparse_grad = SparseGradient::new(
113///     vec![1, 3],             // Indices
114///     vec![0.2, 0.5],         // Values
115///     5                       // Total dimension
116/// );
117///
118/// // Create a SparseAdam optimizer
119/// let mut optimizer = SparseAdam::new(0.001);
120///
121/// // Update parameters with sparse gradient
122/// let new_params = optimizer.step_sparse(&params, &sparse_grad).unwrap();
123/// ```
124#[derive(Debug, Clone)]
125pub struct SparseAdam<A: Float + ScalarOperand + Debug> {
126    /// Learning rate
127    learning_rate: A,
128    /// Exponential decay rate for the first moment estimates
129    beta1: A,
130    /// Exponential decay rate for the second moment estimates
131    beta2: A,
132    /// Small constant for numerical stability
133    epsilon: A,
134    /// Weight decay factor (L2 regularization)
135    weight_decay: A,
136    /// First moment vector stored as a hash map for sparse updates
137    m: HashMap<usize, A>,
138    /// Second moment vector stored as a hash map for sparse updates
139    v: HashMap<usize, A>,
140    /// Current timestep
141    t: usize,
142}
143
144impl<A: Float + ScalarOperand + Debug + Send + Sync> SparseAdam<A> {
145    /// Creates a new SparseAdam optimizer with the given learning rate and default settings
146    ///
147    /// # Arguments
148    ///
149    /// * `learning_rate` - The learning rate for parameter updates
150    pub fn new(learning_rate: A) -> Self {
151        Self {
152            learning_rate,
153            beta1: A::from(0.9).unwrap(),
154            beta2: A::from(0.999).unwrap(),
155            epsilon: A::from(1e-8).unwrap(),
156            weight_decay: A::zero(),
157            m: HashMap::new(),
158            v: HashMap::new(),
159            t: 0,
160        }
161    }
162
163    /// Creates a new SparseAdam optimizer with the full configuration
164    ///
165    /// # Arguments
166    ///
167    /// * `learning_rate` - The learning rate for parameter updates
168    /// * `beta1` - Exponential decay rate for the first moment estimates (default: 0.9)
169    /// * `beta2` - Exponential decay rate for the second moment estimates (default: 0.999)
170    /// * `epsilon` - Small constant for numerical stability (default: 1e-8)
171    /// * `weight_decay` - Weight decay factor for L2 regularization (default: 0.0)
172    pub fn new_with_config(
173        learning_rate: A,
174        beta1: A,
175        beta2: A,
176        epsilon: A,
177        weight_decay: A,
178    ) -> Self {
179        Self {
180            learning_rate,
181            beta1,
182            beta2,
183            epsilon,
184            weight_decay,
185            m: HashMap::new(),
186            v: HashMap::new(),
187            t: 0,
188        }
189    }
190
191    /// Sets the beta1 parameter
192    pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
193        self.beta1 = beta1;
194        self
195    }
196
197    /// Builder method to set beta1 and return self
198    pub fn with_beta1(mut self, beta1: A) -> Self {
199        self.beta1 = beta1;
200        self
201    }
202
203    /// Gets the beta1 parameter
204    pub fn get_beta1(&self) -> A {
205        self.beta1
206    }
207
208    /// Sets the beta2 parameter
209    pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
210        self.beta2 = beta2;
211        self
212    }
213
214    /// Builder method to set beta2 and return self
215    pub fn with_beta2(mut self, beta2: A) -> Self {
216        self.beta2 = beta2;
217        self
218    }
219
220    /// Gets the beta2 parameter
221    pub fn get_beta2(&self) -> A {
222        self.beta2
223    }
224
225    /// Sets the epsilon parameter
226    pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
227        self.epsilon = epsilon;
228        self
229    }
230
231    /// Builder method to set epsilon and return self
232    pub fn with_epsilon(mut self, epsilon: A) -> Self {
233        self.epsilon = epsilon;
234        self
235    }
236
237    /// Gets the epsilon parameter
238    pub fn get_epsilon(&self) -> A {
239        self.epsilon
240    }
241
242    /// Sets the weight decay parameter
243    pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
244        self.weight_decay = weight_decay;
245        self
246    }
247
248    /// Builder method to set weight decay and return self
249    pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
250        self.weight_decay = weight_decay;
251        self
252    }
253
254    /// Gets the weight decay parameter
255    pub fn get_weight_decay(&self) -> A {
256        self.weight_decay
257    }
258
259    /// Updates parameters using sparse gradients
260    ///
261    /// This method efficiently updates only the parameters corresponding to
262    /// non-zero gradient entries, saving computation and memory.
263    ///
264    /// # Arguments
265    ///
266    /// * `params` - The parameters to update
267    /// * `gradient` - The sparse gradient
268    ///
269    /// # Returns
270    ///
271    /// The updated parameters
272    pub fn step_sparse(
273        &mut self,
274        params: &Array<A, Ix1>,
275        gradient: &SparseGradient<A>,
276    ) -> Result<Array<A, Ix1>> {
277        // Verify dimensions match
278        if params.len() != gradient.dim {
279            return Err(OptimError::InvalidConfig(format!(
280                "Parameter dimension ({}) doesn't match gradient dimension ({})",
281                params.len(),
282                gradient.dim
283            )));
284        }
285
286        // If gradient is empty, just return the parameters unchanged
287        if gradient.is_empty() {
288            return Ok(params.clone());
289        }
290
291        // Increment timestep
292        self.t += 1;
293
294        // Compute the bias correction terms
295        let bias_correction1 = A::one() - self.beta1.powi(self.t as i32);
296        let bias_correction2 = A::one() - self.beta2.powi(self.t as i32);
297
298        // Create a copy of the parameters that we'll update
299        let mut updated_params = params.clone();
300
301        // Update only the parameters for which we have non-zero gradients
302        for (&idx, &grad_val) in gradient.indices.iter().zip(&gradient.values) {
303            // Apply weight decay if needed
304            let adjusted_grad = if self.weight_decay > A::zero() {
305                grad_val + params[idx] * self.weight_decay
306            } else {
307                grad_val
308            };
309
310            // Update first moment (m)
311            let m_prev = *self.m.get(&idx).unwrap_or(&A::zero());
312            let m_t = self.beta1 * m_prev + (A::one() - self.beta1) * adjusted_grad;
313            self.m.insert(idx, m_t);
314
315            // Update second moment (v)
316            let v_prev = *self.v.get(&idx).unwrap_or(&A::zero());
317            let v_t = self.beta2 * v_prev + (A::one() - self.beta2) * adjusted_grad * adjusted_grad;
318            self.v.insert(idx, v_t);
319
320            // Bias-corrected first and second moment estimates
321            let m_hat = m_t / bias_correction1;
322            let v_hat = v_t / bias_correction2;
323
324            // Update parameter
325            let step = self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
326            updated_params[idx] = params[idx] - step;
327        }
328
329        Ok(updated_params)
330    }
331
332    /// Resets the internal state of the optimizer
333    pub fn reset(&mut self) {
334        self.m.clear();
335        self.v.clear();
336        self.t = 0;
337    }
338}
339
340impl<A> Optimizer<A, Ix1> for SparseAdam<A>
341where
342    A: Float + ScalarOperand + Debug + Send + Sync,
343{
344    fn step(&mut self, params: &Array<A, Ix1>, gradients: &Array<A, Ix1>) -> Result<Array<A, Ix1>> {
345        // Convert dense gradient to sparse
346        let sparse_gradient = SparseGradient::from_array(gradients);
347
348        // Call step_sparse with the converted gradient
349        self.step_sparse(params, &sparse_gradient)
350    }
351
352    fn get_learning_rate(&self) -> A {
353        self.learning_rate
354    }
355
356    fn set_learning_rate(&mut self, learning_rate: A) {
357        self.learning_rate = learning_rate;
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use approx::assert_abs_diff_eq;
365    use scirs2_core::ndarray::Array1;
366
367    #[test]
368    fn test_sparse_gradient_creation() {
369        let indices = vec![0, 2, 4];
370        let values = vec![1.0, 2.0, 3.0];
371        let dim = 5;
372
373        let sparse_grad = SparseGradient::new(indices, values, dim);
374
375        assert_eq!(sparse_grad.indices, vec![0, 2, 4]);
376        assert_eq!(sparse_grad.values, vec![1.0, 2.0, 3.0]);
377        assert_eq!(sparse_grad.dim, 5);
378    }
379
380    #[test]
381    fn test_sparse_gradient_from_array() {
382        let dense = Array1::from_vec(vec![1.0, 0.0, 2.0, 0.0, 3.0]);
383        let sparse_grad = SparseGradient::from_array(&dense);
384
385        assert_eq!(sparse_grad.indices, vec![0, 2, 4]);
386        assert_eq!(sparse_grad.values, vec![1.0, 2.0, 3.0]);
387        assert_eq!(sparse_grad.dim, 5);
388    }
389
390    #[test]
391    fn test_sparse_gradient_to_array() {
392        let indices = vec![0, 2, 4];
393        let values = vec![1.0, 2.0, 3.0];
394        let dim = 5;
395
396        let sparse_grad = SparseGradient::new(indices, values, dim);
397        let dense = sparse_grad.to_array();
398
399        let expected = Array1::from_vec(vec![1.0, 0.0, 2.0, 0.0, 3.0]);
400        assert_eq!(dense, expected);
401    }
402
403    #[test]
404    fn test_sparse_adam_creation() {
405        let optimizer = SparseAdam::<f64>::new(0.001);
406
407        assert_eq!(optimizer.get_learning_rate(), 0.001);
408        assert_eq!(optimizer.get_beta1(), 0.9);
409        assert_eq!(optimizer.get_beta2(), 0.999);
410        assert_eq!(optimizer.get_epsilon(), 1e-8);
411        assert_eq!(optimizer.get_weight_decay(), 0.0);
412    }
413
414    #[test]
415    fn test_sparse_adam_step() {
416        let mut optimizer = SparseAdam::<f64>::new(0.1);
417
418        // Initialize parameters
419        let params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
420
421        // Create sparse gradient with non-zero values at indices 1 and 3
422        let sparse_grad = SparseGradient::new(
423            vec![1, 3],     // Indices
424            vec![0.2, 0.5], // Values
425            5,              // Total dimension
426        );
427
428        // First update
429        let updated_params = optimizer.step_sparse(&params, &sparse_grad).unwrap();
430
431        // Only the parameters at indices 1 and 3 should be updated
432        assert_abs_diff_eq!(updated_params[0], 0.0);
433        assert!(updated_params[1] < 0.0); // Should be negative due to gradient descent
434        assert_abs_diff_eq!(updated_params[2], 0.0);
435        assert!(updated_params[3] < 0.0); // Should be negative due to gradient descent
436        assert_abs_diff_eq!(updated_params[4], 0.0);
437
438        // The parameter at index 3 should have a larger update due to larger gradient
439        assert!(updated_params[3].abs() > updated_params[1].abs());
440    }
441
442    #[test]
443    fn test_sparse_adam_vs_dense_adam() {
444        let mut sparse_optimizer = SparseAdam::<f64>::new(0.1);
445        let mut dense_optimizer = crate::optimizers::adam::Adam::<f64>::new(0.1);
446
447        // Initialize parameters
448        let params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
449
450        // Create a dense gradient with only some non-zero values
451        let dense_grad = Array1::from_vec(vec![0.0, 0.2, 0.0, 0.5, 0.0]);
452
453        // Create equivalent sparse gradient
454        let sparse_grad = SparseGradient::from_array(&dense_grad);
455
456        // Update with both optimizers
457        let sparse_result = sparse_optimizer.step_sparse(&params, &sparse_grad).unwrap();
458        let dense_result = dense_optimizer.step(&params, &dense_grad).unwrap();
459
460        // Results should be nearly identical
461        assert_abs_diff_eq!(sparse_result[0], dense_result[0]);
462        assert_abs_diff_eq!(sparse_result[1], dense_result[1], epsilon = 1e-10);
463        assert_abs_diff_eq!(sparse_result[2], dense_result[2]);
464        assert_abs_diff_eq!(sparse_result[3], dense_result[3], epsilon = 1e-10);
465        assert_abs_diff_eq!(sparse_result[4], dense_result[4]);
466    }
467
468    #[test]
469    fn test_sparse_adam_multiple_steps() {
470        let mut optimizer = SparseAdam::<f64>::new(0.1);
471        let mut params = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
472
473        // First step - update indices 1 and 3
474        let sparse_grad1 = SparseGradient::new(
475            vec![1, 3],     // Indices
476            vec![0.2, 0.5], // Values
477            5,              // Total dimension
478        );
479
480        params = optimizer.step_sparse(&params, &sparse_grad1).unwrap();
481
482        // Second step - update indices 0 and 2
483        let sparse_grad2 = SparseGradient::new(
484            vec![0, 2],     // Indices
485            vec![0.3, 0.4], // Values
486            5,              // Total dimension
487        );
488
489        params = optimizer.step_sparse(&params, &sparse_grad2).unwrap();
490
491        // All parameters except index 4 should now be updated
492        assert!(params[0] < 0.0);
493        assert!(params[1] < 0.0);
494        assert!(params[2] < 0.0);
495        assert!(params[3] < 0.0);
496        assert_abs_diff_eq!(params[4], 0.0);
497
498        // Third step - update the same indices again (accumulates momentum)
499        params = optimizer.step_sparse(&params, &sparse_grad2).unwrap();
500
501        // Parameters at indices 0 and 2 should have larger updates now
502        let prev_param0 = params[0];
503        let prev_param2 = params[2];
504
505        params = optimizer.step_sparse(&params, &sparse_grad2).unwrap();
506
507        assert!(params[0].abs() > prev_param0.abs());
508        assert!(params[2].abs() > prev_param2.abs());
509    }
510
511    #[test]
512    fn test_sparse_adam_with_weight_decay() {
513        let mut optimizer = SparseAdam::<f64>::new(0.1).with_weight_decay(0.01);
514
515        // Initialize parameters with non-zero values
516        let params = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
517
518        // Create sparse gradient
519        let sparse_grad = SparseGradient::new(
520            vec![1, 3],     // Indices
521            vec![0.2, 0.5], // Values
522            5,              // Total dimension
523        );
524
525        // Create a version without weight decay for comparison
526        let mut optimizer_no_decay = SparseAdam::<f64>::new(0.1);
527
528        let with_decay = optimizer.step_sparse(&params, &sparse_grad).unwrap();
529        let without_decay = optimizer_no_decay
530            .step_sparse(&params, &sparse_grad)
531            .unwrap();
532
533        // Parameters with non-zero gradients should be different when weight decay is applied
534        assert!(with_decay[1] != without_decay[1]);
535        assert!(with_decay[3] != without_decay[3]);
536
537        // Parameters without gradients should remain the same
538        assert_abs_diff_eq!(with_decay[0], params[0]);
539        assert_abs_diff_eq!(with_decay[2], params[2]);
540        assert_abs_diff_eq!(with_decay[4], params[4]);
541    }
542
543    #[test]
544    fn test_sparse_adam_empty_gradient() {
545        let mut optimizer = SparseAdam::<f64>::new(0.1);
546
547        // Initialize parameters
548        let params = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
549
550        // Create an empty sparse gradient
551        let sparse_grad = SparseGradient::new(
552            vec![], // Empty indices
553            vec![], // Empty values
554            5,      // Total dimension
555        );
556
557        // No parameters should change
558        let result = optimizer.step_sparse(&params, &sparse_grad).unwrap();
559        assert_eq!(result, params);
560    }
561
562    #[test]
563    fn test_sparse_adam_reset() {
564        let mut optimizer = SparseAdam::<f64>::new(0.1);
565
566        // Initialize parameters
567        let params = Array1::from_vec(vec![0.0; 5]);
568
569        // Create sparse gradient
570        let sparse_grad = SparseGradient::new(
571            vec![1, 3],     // Indices
572            vec![0.2, 0.5], // Values
573            5,              // Total dimension
574        );
575
576        // Do several steps to build up momentum
577        for _ in 0..10 {
578            optimizer.step_sparse(&params, &sparse_grad).unwrap();
579        }
580
581        // Reset optimizer
582        optimizer.reset();
583
584        // The next step should be the same as the first step with a new optimizer
585        let mut new_optimizer = SparseAdam::<f64>::new(0.1);
586        let reset_result = optimizer.step_sparse(&params, &sparse_grad).unwrap();
587        let new_result = new_optimizer.step_sparse(&params, &sparse_grad).unwrap();
588
589        assert_abs_diff_eq!(reset_result[1], new_result[1], epsilon = 1e-10);
590        assert_abs_diff_eq!(reset_result[3], new_result[3], epsilon = 1e-10);
591    }
592}