optirs_core/
simd_optimizer.rs

1//! SIMD-accelerated optimizer operations
2//!
3//! This module provides SIMD-optimized implementations of common optimizer
4//! operations using scirs2_core's SimdUnifiedOps infrastructure.
5//!
6//! The module automatically selects the best SIMD backend available on the
7//! target platform (AVX2, SSE, NEON, or scalar fallback).
8
9use scirs2_core::ndarray::{Array, Array1, ArrayView, ArrayView1, Dimension, Ix1};
10use scirs2_core::numeric::Float;
11use scirs2_core::simd_ops::SimdUnifiedOps;
12use std::fmt::Debug;
13
14/// Trait for SIMD-accelerated optimizer operations
15///
16/// This trait provides high-performance implementations of common
17/// operations found in optimization algorithms.
18pub trait SimdOptimizer<T: Float> {
19    /// SIMD-accelerated parameter update: params - learning_rate * gradient
20    ///
21    /// # Arguments
22    ///
23    /// * `params` - Parameter array
24    /// * `gradients` - Gradient array
25    /// * `learning_rate` - Learning rate scalar
26    ///
27    /// # Returns
28    ///
29    /// Updated parameters
30    fn simd_sgd_update(
31        params: &ArrayView1<T>,
32        gradients: &ArrayView1<T>,
33        learning_rate: T,
34    ) -> Array1<T>;
35
36    /// SIMD-accelerated momentum update
37    ///
38    /// velocity = momentum * velocity + learning_rate * gradient
39    /// params = params - velocity
40    ///
41    /// # Arguments
42    ///
43    /// * `params` - Parameter array
44    /// * `gradients` - Gradient array
45    /// * `velocity` - Velocity array (momentum state)
46    /// * `learning_rate` - Learning rate scalar
47    /// * `momentum` - Momentum coefficient
48    ///
49    /// # Returns
50    ///
51    /// Tuple of (updated_params, updated_velocity)
52    fn simd_momentum_update(
53        params: &ArrayView1<T>,
54        gradients: &ArrayView1<T>,
55        velocity: &ArrayView1<T>,
56        learning_rate: T,
57        momentum: T,
58    ) -> (Array1<T>, Array1<T>);
59
60    /// SIMD-accelerated Adam first moment update
61    ///
62    /// m = beta1 * m + (1 - beta1) * gradient
63    ///
64    /// # Arguments
65    ///
66    /// * `m` - First moment array
67    /// * `gradients` - Gradient array
68    /// * `beta1` - Exponential decay rate for first moment
69    ///
70    /// # Returns
71    ///
72    /// Updated first moment
73    fn simd_adam_first_moment(m: &ArrayView1<T>, gradients: &ArrayView1<T>, beta1: T) -> Array1<T>;
74
75    /// SIMD-accelerated Adam second moment update
76    ///
77    /// v = beta2 * v + (1 - beta2) * gradient^2
78    ///
79    /// # Arguments
80    ///
81    /// * `v` - Second moment array
82    /// * `gradients` - Gradient array
83    /// * `beta2` - Exponential decay rate for second moment
84    ///
85    /// # Returns
86    ///
87    /// Updated second moment
88    fn simd_adam_second_moment(v: &ArrayView1<T>, gradients: &ArrayView1<T>, beta2: T)
89        -> Array1<T>;
90
91    /// SIMD-accelerated Adam parameter update
92    ///
93    /// params = params - learning_rate * m_hat / (sqrt(v_hat) + epsilon)
94    ///
95    /// # Arguments
96    ///
97    /// * `params` - Parameter array
98    /// * `m_hat` - Bias-corrected first moment
99    /// * `v_hat` - Bias-corrected second moment
100    /// * `learning_rate` - Learning rate scalar
101    /// * `epsilon` - Small constant for numerical stability
102    ///
103    /// # Returns
104    ///
105    /// Updated parameters
106    fn simd_adam_update(
107        params: &ArrayView1<T>,
108        m_hat: &ArrayView1<T>,
109        v_hat: &ArrayView1<T>,
110        learning_rate: T,
111        epsilon: T,
112    ) -> Array1<T>;
113
114    /// SIMD-accelerated weight decay application
115    ///
116    /// gradients = gradients + weight_decay * params
117    ///
118    /// # Arguments
119    ///
120    /// * `gradients` - Gradient array
121    /// * `params` - Parameter array
122    /// * `weight_decay` - Weight decay coefficient
123    ///
124    /// # Returns
125    ///
126    /// Gradients with weight decay applied
127    fn simd_weight_decay(
128        gradients: &ArrayView1<T>,
129        params: &ArrayView1<T>,
130        weight_decay: T,
131    ) -> Array1<T>;
132
133    /// SIMD-accelerated gradient norm computation
134    ///
135    /// # Arguments
136    ///
137    /// * `gradients` - Gradient array
138    ///
139    /// # Returns
140    ///
141    /// L2 norm of gradients
142    fn simd_gradient_norm(gradients: &ArrayView1<T>) -> T;
143}
144
145/// Implementation of SIMD optimizer operations for f32
146impl SimdOptimizer<f32> for f32 {
147    fn simd_sgd_update(
148        params: &ArrayView1<f32>,
149        gradients: &ArrayView1<f32>,
150        learning_rate: f32,
151    ) -> Array1<f32> {
152        // Use SIMD for large arrays, scalar for small ones
153        if params.len() >= 16 {
154            // SIMD path: params - learning_rate * gradients
155            let scaled_grads = f32::simd_scalar_mul(gradients, learning_rate);
156            f32::simd_sub(params, &scaled_grads.view())
157        } else {
158            // Scalar path for small arrays
159            params
160                .iter()
161                .zip(gradients.iter())
162                .map(|(&p, &g)| p - learning_rate * g)
163                .collect()
164        }
165    }
166
167    fn simd_momentum_update(
168        params: &ArrayView1<f32>,
169        gradients: &ArrayView1<f32>,
170        velocity: &ArrayView1<f32>,
171        learning_rate: f32,
172        momentum: f32,
173    ) -> (Array1<f32>, Array1<f32>) {
174        if params.len() >= 16 {
175            // SIMD path
176            // velocity = momentum * velocity + learning_rate * gradient
177            let scaled_velocity = f32::simd_scalar_mul(velocity, momentum);
178            let scaled_gradients = f32::simd_scalar_mul(gradients, learning_rate);
179            let new_velocity = f32::simd_add(&scaled_velocity.view(), &scaled_gradients.view());
180
181            // params = params - velocity
182            let new_params = f32::simd_sub(params, &new_velocity.view());
183
184            (new_params, new_velocity)
185        } else {
186            // Scalar path
187            let new_velocity: Array1<f32> = velocity
188                .iter()
189                .zip(gradients.iter())
190                .map(|(&v, &g)| momentum * v + learning_rate * g)
191                .collect();
192
193            let new_params: Array1<f32> = params
194                .iter()
195                .zip(new_velocity.iter())
196                .map(|(&p, &v)| p - v)
197                .collect();
198
199            (new_params, new_velocity)
200        }
201    }
202
203    fn simd_adam_first_moment(
204        m: &ArrayView1<f32>,
205        gradients: &ArrayView1<f32>,
206        beta1: f32,
207    ) -> Array1<f32> {
208        if m.len() >= 16 {
209            // SIMD path: m = beta1 * m + (1 - beta1) * gradient
210            let scaled_m = f32::simd_scalar_mul(m, beta1);
211            let scaled_grads = f32::simd_scalar_mul(gradients, 1.0 - beta1);
212            f32::simd_add(&scaled_m.view(), &scaled_grads.view())
213        } else {
214            // Scalar path
215            m.iter()
216                .zip(gradients.iter())
217                .map(|(&m_val, &g)| beta1 * m_val + (1.0 - beta1) * g)
218                .collect()
219        }
220    }
221
222    fn simd_adam_second_moment(
223        v: &ArrayView1<f32>,
224        gradients: &ArrayView1<f32>,
225        beta2: f32,
226    ) -> Array1<f32> {
227        if v.len() >= 16 {
228            // SIMD path: v = beta2 * v + (1 - beta2) * gradient^2
229            let scaled_v = f32::simd_scalar_mul(v, beta2);
230            let grad_squared = f32::simd_mul(gradients, gradients);
231            let scaled_grad_squared = f32::simd_scalar_mul(&grad_squared.view(), 1.0 - beta2);
232            f32::simd_add(&scaled_v.view(), &scaled_grad_squared.view())
233        } else {
234            // Scalar path
235            v.iter()
236                .zip(gradients.iter())
237                .map(|(&v_val, &g)| beta2 * v_val + (1.0 - beta2) * g * g)
238                .collect()
239        }
240    }
241
242    fn simd_adam_update(
243        params: &ArrayView1<f32>,
244        m_hat: &ArrayView1<f32>,
245        v_hat: &ArrayView1<f32>,
246        learning_rate: f32,
247        epsilon: f32,
248    ) -> Array1<f32> {
249        if params.len() >= 16 {
250            // SIMD path: params - learning_rate * m_hat / (sqrt(v_hat) + epsilon)
251            // Compute sqrt(v_hat) + epsilon
252            let v_hat_sqrt: Array1<f32> = v_hat.iter().map(|&v| v.sqrt() + epsilon).collect();
253
254            // Compute m_hat / (sqrt(v_hat) + epsilon)
255            let step = f32::simd_div(m_hat, &v_hat_sqrt.view());
256
257            // Scale by learning rate
258            let scaled_step = f32::simd_scalar_mul(&step.view(), learning_rate);
259
260            // Update parameters
261            f32::simd_sub(params, &scaled_step.view())
262        } else {
263            // Scalar path
264            params
265                .iter()
266                .zip(m_hat.iter().zip(v_hat.iter()))
267                .map(|(&p, (&m, &v))| p - learning_rate * m / (v.sqrt() + epsilon))
268                .collect()
269        }
270    }
271
272    fn simd_weight_decay(
273        gradients: &ArrayView1<f32>,
274        params: &ArrayView1<f32>,
275        weight_decay: f32,
276    ) -> Array1<f32> {
277        if gradients.len() >= 16 {
278            // SIMD path: gradients + weight_decay * params
279            let scaled_params = f32::simd_scalar_mul(params, weight_decay);
280            f32::simd_add(gradients, &scaled_params.view())
281        } else {
282            // Scalar path
283            gradients
284                .iter()
285                .zip(params.iter())
286                .map(|(&g, &p)| g + weight_decay * p)
287                .collect()
288        }
289    }
290
291    fn simd_gradient_norm(gradients: &ArrayView1<f32>) -> f32 {
292        if gradients.len() >= 16 {
293            // SIMD path using optimized dot product
294            f32::simd_dot(gradients, gradients).sqrt()
295        } else {
296            // Scalar path
297            gradients.iter().map(|&x| x * x).sum::<f32>().sqrt()
298        }
299    }
300}
301
302/// Implementation of SIMD optimizer operations for f64
303impl SimdOptimizer<f64> for f64 {
304    fn simd_sgd_update(
305        params: &ArrayView1<f64>,
306        gradients: &ArrayView1<f64>,
307        learning_rate: f64,
308    ) -> Array1<f64> {
309        if params.len() >= 8 {
310            // SIMD path
311            let scaled_grads = f64::simd_scalar_mul(gradients, learning_rate);
312            f64::simd_sub(params, &scaled_grads.view())
313        } else {
314            // Scalar path
315            params
316                .iter()
317                .zip(gradients.iter())
318                .map(|(&p, &g)| p - learning_rate * g)
319                .collect()
320        }
321    }
322
323    fn simd_momentum_update(
324        params: &ArrayView1<f64>,
325        gradients: &ArrayView1<f64>,
326        velocity: &ArrayView1<f64>,
327        learning_rate: f64,
328        momentum: f64,
329    ) -> (Array1<f64>, Array1<f64>) {
330        if params.len() >= 8 {
331            // SIMD path
332            let scaled_velocity = f64::simd_scalar_mul(velocity, momentum);
333            let scaled_gradients = f64::simd_scalar_mul(gradients, learning_rate);
334            let new_velocity = f64::simd_add(&scaled_velocity.view(), &scaled_gradients.view());
335            let new_params = f64::simd_sub(params, &new_velocity.view());
336            (new_params, new_velocity)
337        } else {
338            // Scalar path
339            let new_velocity: Array1<f64> = velocity
340                .iter()
341                .zip(gradients.iter())
342                .map(|(&v, &g)| momentum * v + learning_rate * g)
343                .collect();
344            let new_params: Array1<f64> = params
345                .iter()
346                .zip(new_velocity.iter())
347                .map(|(&p, &v)| p - v)
348                .collect();
349            (new_params, new_velocity)
350        }
351    }
352
353    fn simd_adam_first_moment(
354        m: &ArrayView1<f64>,
355        gradients: &ArrayView1<f64>,
356        beta1: f64,
357    ) -> Array1<f64> {
358        if m.len() >= 8 {
359            // SIMD path
360            let scaled_m = f64::simd_scalar_mul(m, beta1);
361            let scaled_grads = f64::simd_scalar_mul(gradients, 1.0 - beta1);
362            f64::simd_add(&scaled_m.view(), &scaled_grads.view())
363        } else {
364            // Scalar path
365            m.iter()
366                .zip(gradients.iter())
367                .map(|(&m_val, &g)| beta1 * m_val + (1.0 - beta1) * g)
368                .collect()
369        }
370    }
371
372    fn simd_adam_second_moment(
373        v: &ArrayView1<f64>,
374        gradients: &ArrayView1<f64>,
375        beta2: f64,
376    ) -> Array1<f64> {
377        if v.len() >= 8 {
378            // SIMD path
379            let scaled_v = f64::simd_scalar_mul(v, beta2);
380            let grad_squared = f64::simd_mul(gradients, gradients);
381            let scaled_grad_squared = f64::simd_scalar_mul(&grad_squared.view(), 1.0 - beta2);
382            f64::simd_add(&scaled_v.view(), &scaled_grad_squared.view())
383        } else {
384            // Scalar path
385            v.iter()
386                .zip(gradients.iter())
387                .map(|(&v_val, &g)| beta2 * v_val + (1.0 - beta2) * g * g)
388                .collect()
389        }
390    }
391
392    fn simd_adam_update(
393        params: &ArrayView1<f64>,
394        m_hat: &ArrayView1<f64>,
395        v_hat: &ArrayView1<f64>,
396        learning_rate: f64,
397        epsilon: f64,
398    ) -> Array1<f64> {
399        if params.len() >= 8 {
400            // SIMD path
401            let v_hat_sqrt: Array1<f64> = v_hat.iter().map(|&v| v.sqrt() + epsilon).collect();
402            let step = f64::simd_div(m_hat, &v_hat_sqrt.view());
403            let scaled_step = f64::simd_scalar_mul(&step.view(), learning_rate);
404            f64::simd_sub(params, &scaled_step.view())
405        } else {
406            // Scalar path
407            params
408                .iter()
409                .zip(m_hat.iter().zip(v_hat.iter()))
410                .map(|(&p, (&m, &v))| p - learning_rate * m / (v.sqrt() + epsilon))
411                .collect()
412        }
413    }
414
415    fn simd_weight_decay(
416        gradients: &ArrayView1<f64>,
417        params: &ArrayView1<f64>,
418        weight_decay: f64,
419    ) -> Array1<f64> {
420        if gradients.len() >= 8 {
421            // SIMD path
422            let scaled_params = f64::simd_scalar_mul(params, weight_decay);
423            f64::simd_add(gradients, &scaled_params.view())
424        } else {
425            // Scalar path
426            gradients
427                .iter()
428                .zip(params.iter())
429                .map(|(&g, &p)| g + weight_decay * p)
430                .collect()
431        }
432    }
433
434    fn simd_gradient_norm(gradients: &ArrayView1<f64>) -> f64 {
435        if gradients.len() >= 8 {
436            // SIMD path
437            f64::simd_dot(gradients, gradients).sqrt()
438        } else {
439            // Scalar path
440            gradients.iter().map(|&x| x * x).sum::<f64>().sqrt()
441        }
442    }
443}
444
445/// Helper function to determine if SIMD should be used based on array size
446///
447/// # Arguments
448///
449/// * `size` - Size of the array
450/// * `dtype_size` - Size of the data type in bytes (4 for f32, 8 for f64)
451///
452/// # Returns
453///
454/// True if SIMD should be used, false otherwise
455pub fn should_use_simd(size: usize, dtype_size: usize) -> bool {
456    // Use SIMD for arrays with at least 16 f32 elements or 8 f64 elements
457    let min_simd_size = match dtype_size {
458        4 => 16,         // f32
459        8 => 8,          // f64
460        _ => usize::MAX, // Unknown type, don't use SIMD
461    };
462
463    size >= min_simd_size
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use approx::assert_relative_eq;
470
471    #[test]
472    fn test_simd_sgd_update_f32() {
473        let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
474        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
475        let learning_rate = 0.1;
476
477        let result = f32::simd_sgd_update(&params.view(), &gradients.view(), learning_rate);
478
479        assert_relative_eq!(result[0], 0.99, epsilon = 1e-6);
480        assert_relative_eq!(result[1], 1.98, epsilon = 1e-6);
481        assert_relative_eq!(result[2], 2.97, epsilon = 1e-6);
482        assert_relative_eq!(result[3], 3.96, epsilon = 1e-6);
483    }
484
485    #[test]
486    fn test_simd_sgd_update_f64() {
487        let params = Array1::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
488        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
489        let learning_rate = 0.1;
490
491        let result = f64::simd_sgd_update(&params.view(), &gradients.view(), learning_rate);
492
493        assert_relative_eq!(result[0], 0.99, epsilon = 1e-10);
494        assert_relative_eq!(result[1], 1.98, epsilon = 1e-10);
495        assert_relative_eq!(result[2], 2.97, epsilon = 1e-10);
496        assert_relative_eq!(result[3], 3.96, epsilon = 1e-10);
497    }
498
499    #[test]
500    fn test_simd_momentum_update() {
501        let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
502        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
503        let velocity = Array1::from_vec(vec![0.01, 0.02, 0.03, 0.04]);
504        let learning_rate = 0.1;
505        let momentum = 0.9;
506
507        let (new_params, new_velocity) = f32::simd_momentum_update(
508            &params.view(),
509            &gradients.view(),
510            &velocity.view(),
511            learning_rate,
512            momentum,
513        );
514
515        // Check velocity: 0.9 * old_velocity + 0.1 * gradient
516        assert_relative_eq!(new_velocity[0], 0.9 * 0.01 + 0.1 * 0.1, epsilon = 1e-6);
517
518        // Check params: old_params - new_velocity
519        assert_relative_eq!(new_params[0], 1.0 - new_velocity[0], epsilon = 1e-6);
520    }
521
522    #[test]
523    fn test_simd_adam_first_moment() {
524        let m = Array1::from_vec(vec![0.01f32, 0.02, 0.03, 0.04]);
525        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
526        let beta1 = 0.9;
527
528        let result = f32::simd_adam_first_moment(&m.view(), &gradients.view(), beta1);
529
530        assert_relative_eq!(result[0], 0.9 * 0.01 + 0.1 * 0.1, epsilon = 1e-6);
531        assert_relative_eq!(result[1], 0.9 * 0.02 + 0.1 * 0.2, epsilon = 1e-6);
532    }
533
534    #[test]
535    fn test_simd_adam_second_moment() {
536        let v = Array1::from_vec(vec![0.001f32, 0.002, 0.003, 0.004]);
537        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
538        let beta2 = 0.999;
539
540        let result = f32::simd_adam_second_moment(&v.view(), &gradients.view(), beta2);
541
542        assert_relative_eq!(result[0], 0.999 * 0.001 + 0.001 * 0.1 * 0.1, epsilon = 1e-6);
543    }
544
545    #[test]
546    fn test_simd_weight_decay() {
547        let gradients = Array1::from_vec(vec![0.1f32, 0.2, 0.3, 0.4]);
548        let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
549        let weight_decay = 0.01;
550
551        let result = f32::simd_weight_decay(&gradients.view(), &params.view(), weight_decay);
552
553        assert_relative_eq!(result[0], 0.1 + 0.01 * 1.0, epsilon = 1e-6);
554        assert_relative_eq!(result[1], 0.2 + 0.01 * 2.0, epsilon = 1e-6);
555    }
556
557    #[test]
558    fn test_simd_gradient_norm() {
559        let gradients = Array1::from_vec(vec![3.0f32, 4.0]);
560        let norm = f32::simd_gradient_norm(&gradients.view());
561        assert_relative_eq!(norm, 5.0, epsilon = 1e-6);
562
563        let gradients_f64 = Array1::from_vec(vec![3.0f64, 4.0]);
564        let norm_f64 = f64::simd_gradient_norm(&gradients_f64.view());
565        assert_relative_eq!(norm_f64, 5.0, epsilon = 1e-10);
566    }
567
568    #[test]
569    fn test_should_use_simd() {
570        // f32 tests
571        assert!(!should_use_simd(8, 4)); // Too small for f32 SIMD
572        assert!(should_use_simd(16, 4)); // Exactly at threshold
573        assert!(should_use_simd(100, 4)); // Large enough
574
575        // f64 tests
576        assert!(!should_use_simd(4, 8)); // Too small for f64 SIMD
577        assert!(should_use_simd(8, 8)); // Exactly at threshold
578        assert!(should_use_simd(100, 8)); // Large enough
579    }
580
581    #[test]
582    fn test_simd_large_array() {
583        // Test with a large array to ensure SIMD path is taken
584        let size = 1000;
585        let params: Array1<f32> = Array1::from_vec((0..size).map(|i| i as f32).collect());
586        let gradients: Array1<f32> = Array1::from_vec(vec![0.1; size]);
587        let learning_rate = 0.01;
588
589        let result = f32::simd_sgd_update(&params.view(), &gradients.view(), learning_rate);
590
591        for i in 0..size {
592            assert_relative_eq!(result[i], (i as f32) - learning_rate * 0.1, epsilon = 1e-6);
593        }
594    }
595}