optirs_core/optimizers/
sgd_simd.rs

1//! SIMD-accelerated SGD optimizer
2//!
3//! This module provides a SIMD-optimized implementation of Stochastic Gradient Descent
4//! for 1D parameter arrays using scirs2_core's SimdUnifiedOps.
5
6use scirs2_core::ndarray::{Array1, ArrayView1};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::optimizers::Optimizer;
12use crate::simd_optimizer::SimdOptimizer;
13
14/// SIMD-accelerated Stochastic Gradient Descent optimizer
15///
16/// This is a specialized version of SGD optimized for 1D arrays using SIMD operations.
17/// For maximum performance, use this when working with flattened parameter vectors.
18///
19/// Formula:
20/// v_t = momentum * v_{t-1} + learning_rate * (gradient + weight_decay * param)
21/// param_t = param_{t-1} - v_t
22///
23/// # Performance
24///
25/// This implementation uses SIMD instructions (AVX2/SSE/NEON) for:
26/// - Parameter updates
27/// - Momentum computation
28/// - Weight decay application
29///
30/// Expected speedup: 2-4x over scalar implementation for large parameter arrays
31///
32/// # Examples
33///
34/// ```
35/// use scirs2_core::ndarray::Array1;
36/// use optirs_core::optimizers::{SimdSGD, Optimizer};
37///
38/// // Initialize parameters and gradients
39/// let params = Array1::zeros(1000);
40/// let gradients = Array1::from_elem(1000, 0.1);
41///
42/// // Create SIMD-accelerated SGD optimizer
43/// let mut optimizer = SimdSGD::new(0.01);
44/// optimizer.set_momentum(0.9);
45///
46/// // Update parameters with SIMD acceleration
47/// let new_params = optimizer.step(&params, &gradients).unwrap();
48/// ```
49#[derive(Debug, Clone)]
50pub struct SimdSGD<A: Float> {
51    /// Learning rate
52    learning_rate: A,
53    /// Momentum factor (0.0 means no momentum)
54    momentum: A,
55    /// Weight decay factor (L2 regularization)
56    weight_decay: A,
57    /// Velocity (momentum state)
58    velocity: Option<Array1<A>>,
59}
60
61impl<A: Float> SimdSGD<A> {
62    /// Creates a new SIMD-accelerated SGD optimizer
63    ///
64    /// # Arguments
65    ///
66    /// * `learning_rate` - The learning rate for parameter updates
67    pub fn new(learning_rate: A) -> Self {
68        Self {
69            learning_rate,
70            momentum: A::zero(),
71            weight_decay: A::zero(),
72            velocity: None,
73        }
74    }
75
76    /// Creates a new SIMD SGD optimizer with full configuration
77    ///
78    /// # Arguments
79    ///
80    /// * `learning_rate` - The learning rate for parameter updates
81    /// * `momentum` - The momentum factor (0.0 means no momentum)
82    /// * `weight_decay` - The weight decay factor (L2 regularization)
83    pub fn new_with_config(learning_rate: A, momentum: A, weight_decay: A) -> Self {
84        Self {
85            learning_rate,
86            momentum,
87            weight_decay,
88            velocity: None,
89        }
90    }
91
92    /// Sets the momentum factor
93    pub fn set_momentum(&mut self, momentum: A) -> &mut Self {
94        self.momentum = momentum;
95        self
96    }
97
98    /// Builder method to set momentum and return self
99    pub fn with_momentum(mut self, momentum: A) -> Self {
100        self.momentum = momentum;
101        self
102    }
103
104    /// Gets the current momentum factor
105    pub fn get_momentum(&self) -> A {
106        self.momentum
107    }
108
109    /// Gets the current learning rate
110    pub fn learning_rate(&self) -> A {
111        self.learning_rate
112    }
113
114    /// Sets the weight decay factor
115    pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
116        self.weight_decay = weight_decay;
117        self
118    }
119
120    /// Builder method to set weight decay and return self
121    pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
122        self.weight_decay = weight_decay;
123        self
124    }
125
126    /// Gets the current weight decay factor
127    pub fn get_weight_decay(&self) -> A {
128        self.weight_decay
129    }
130
131    /// Resets the optimizer state
132    pub fn reset(&mut self) {
133        self.velocity = None;
134    }
135}
136
137// Specialized SIMD implementation for f32
138impl Optimizer<f32, scirs2_core::ndarray::Ix1> for SimdSGD<f32> {
139    fn step(&mut self, params: &Array1<f32>, gradients: &Array1<f32>) -> Result<Array1<f32>> {
140        // Validate shapes
141        if params.shape() != gradients.shape() {
142            return Err(crate::error::OptimError::DimensionMismatch(format!(
143                "Incompatible shapes: parameters have shape {:?}, gradients have shape {:?}",
144                params.shape(),
145                gradients.shape()
146            )));
147        }
148
149        let params_view = params.view();
150        let gradients_view = gradients.view();
151
152        // Apply weight decay if needed
153        let adjusted_gradients = if self.weight_decay > 0.0 {
154            f32::simd_weight_decay(&gradients_view, &params_view, self.weight_decay)
155        } else {
156            gradients.to_owned()
157        };
158
159        // Initialize velocity if this is the first step
160        if self.velocity.is_none() {
161            self.velocity = Some(Array1::zeros(params.len()));
162        }
163
164        let velocity = self.velocity.as_mut().unwrap();
165
166        // Ensure velocity has correct dimensions
167        if velocity.len() != params.len() {
168            *velocity = Array1::zeros(params.len());
169        }
170
171        // Compute update using SIMD operations
172        let new_params = if self.momentum > 0.0 {
173            // SIMD-accelerated momentum update
174            let (updated_params, updated_velocity) = f32::simd_momentum_update(
175                &params_view,
176                &adjusted_gradients.view(),
177                &velocity.view(),
178                self.learning_rate,
179                self.momentum,
180            );
181            *velocity = updated_velocity;
182            updated_params
183        } else {
184            // SIMD-accelerated vanilla SGD
185            f32::simd_sgd_update(&params_view, &adjusted_gradients.view(), self.learning_rate)
186        };
187
188        Ok(new_params)
189    }
190
191    fn get_learning_rate(&self) -> f32 {
192        self.learning_rate
193    }
194
195    fn set_learning_rate(&mut self, learning_rate: f32) {
196        self.learning_rate = learning_rate;
197    }
198}
199
200// Specialized SIMD implementation for f64
201impl Optimizer<f64, scirs2_core::ndarray::Ix1> for SimdSGD<f64> {
202    fn step(&mut self, params: &Array1<f64>, gradients: &Array1<f64>) -> Result<Array1<f64>> {
203        // Validate shapes
204        if params.shape() != gradients.shape() {
205            return Err(crate::error::OptimError::DimensionMismatch(format!(
206                "Incompatible shapes: parameters have shape {:?}, gradients have shape {:?}",
207                params.shape(),
208                gradients.shape()
209            )));
210        }
211
212        let params_view = params.view();
213        let gradients_view = gradients.view();
214
215        // Apply weight decay if needed
216        let adjusted_gradients = if self.weight_decay > 0.0 {
217            f64::simd_weight_decay(&gradients_view, &params_view, self.weight_decay)
218        } else {
219            gradients.to_owned()
220        };
221
222        // Initialize velocity if this is the first step
223        if self.velocity.is_none() {
224            self.velocity = Some(Array1::zeros(params.len()));
225        }
226
227        let velocity = self.velocity.as_mut().unwrap();
228
229        // Ensure velocity has correct dimensions
230        if velocity.len() != params.len() {
231            *velocity = Array1::zeros(params.len());
232        }
233
234        // Compute update using SIMD operations
235        let new_params = if self.momentum > 0.0 {
236            // SIMD-accelerated momentum update
237            let (updated_params, updated_velocity) = f64::simd_momentum_update(
238                &params_view,
239                &adjusted_gradients.view(),
240                &velocity.view(),
241                self.learning_rate,
242                self.momentum,
243            );
244            *velocity = updated_velocity;
245            updated_params
246        } else {
247            // SIMD-accelerated vanilla SGD
248            f64::simd_sgd_update(&params_view, &adjusted_gradients.view(), self.learning_rate)
249        };
250
251        Ok(new_params)
252    }
253
254    fn get_learning_rate(&self) -> f64 {
255        self.learning_rate
256    }
257
258    fn set_learning_rate(&mut self, learning_rate: f64) {
259        self.learning_rate = learning_rate;
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use approx::assert_relative_eq;
267
268    #[test]
269    fn test_simd_sgd_basic() {
270        let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
271        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
272
273        let mut optimizer = SimdSGD::new(0.1);
274        let result = optimizer.step(&params, &gradients).unwrap();
275
276        assert_relative_eq!(result[0], 0.99, epsilon = 1e-6);
277        assert_relative_eq!(result[1], 1.98, epsilon = 1e-6);
278        assert_relative_eq!(result[2], 2.97, epsilon = 1e-6);
279        assert_relative_eq!(result[3], 3.96, epsilon = 1e-6);
280    }
281
282    #[test]
283    fn test_simd_sgd_momentum() {
284        let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
285        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
286
287        let mut optimizer = SimdSGD::new_with_config(0.1, 0.9, 0.0);
288
289        // First step
290        let result1 = optimizer.step(&params, &gradients).unwrap();
291
292        // Second step - should show momentum effect
293        let result2 = optimizer.step(&result1, &gradients).unwrap();
294
295        // With momentum, the second step should move further
296        assert!(result2[0] < result1[0]);
297    }
298
299    #[test]
300    fn test_simd_sgd_weight_decay() {
301        let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
302        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
303
304        let mut optimizer = SimdSGD::new_with_config(0.1, 0.0, 0.01);
305        let result = optimizer.step(&params, &gradients).unwrap();
306
307        // Weight decay should reduce parameters more than vanilla SGD
308        let expected_grad = 0.1 + 0.01 * 1.0;
309        assert_relative_eq!(result[0], 1.0 - 0.1 * expected_grad, epsilon = 1e-6);
310    }
311
312    #[test]
313    fn test_simd_sgd_large_array() {
314        // Test with large array to ensure SIMD path is taken
315        let size = 1000;
316        let params: Array1<f32> = Array1::from_vec((0..size).map(|i| i as f32).collect());
317        let gradients: Array1<f32> = Array1::from_elem(size, 0.1);
318
319        let mut optimizer = SimdSGD::new(0.01);
320        let result = optimizer.step(&params, &gradients).unwrap();
321
322        for i in 0..size {
323            assert_relative_eq!(result[i], (i as f32) - 0.01 * 0.1, epsilon = 1e-6);
324        }
325    }
326
327    #[test]
328    fn test_simd_sgd_f64() {
329        let params = Array1::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
330        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
331
332        let mut optimizer = SimdSGD::new(0.1);
333        let result = optimizer.step(&params, &gradients).unwrap();
334
335        assert_relative_eq!(result[0], 0.99, epsilon = 1e-10);
336        assert_relative_eq!(result[1], 1.98, epsilon = 1e-10);
337        assert_relative_eq!(result[2], 2.97, epsilon = 1e-10);
338        assert_relative_eq!(result[3], 3.96, epsilon = 1e-10);
339    }
340
341    #[test]
342    fn test_simd_sgd_reset() {
343        let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
344        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
345
346        let mut optimizer = SimdSGD::new_with_config(0.1, 0.9, 0.0);
347
348        // Take a step to initialize velocity
349        let _ = optimizer.step(&params, &gradients).unwrap();
350        assert!(optimizer.velocity.is_some());
351
352        // Reset should clear velocity
353        optimizer.reset();
354        assert!(optimizer.velocity.is_none());
355    }
356}