Skip to main content

sklears_simd/
optimization.rs

1//! SIMD-optimized optimization algorithms
2//!
3//! This module implements high-performance optimization algorithms using SIMD instructions
4//! for machine learning applications including gradient descent, coordinate descent,
5//! and Newton-type methods.
6
7use crate::matrix::matrix_vector_multiply_f32;
8use crate::vector::{dot_product, norm_l2};
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayViewMut1};
10
11// Conditional imports for no-std compatibility
12#[cfg(feature = "no-std")]
13use alloc::string::String;
14#[cfg(not(feature = "no-std"))]
15use std::string::String;
16
17/// SIMD-optimized gradient descent optimizer
18pub struct GradientDescent {
19    learning_rate: f32,
20    momentum: f32,
21    dampening: f32,
22    weight_decay: f32,
23    nesterov: bool,
24}
25
26impl GradientDescent {
27    /// Create a new gradient descent optimizer
28    pub fn new(learning_rate: f32) -> Self {
29        Self {
30            learning_rate,
31            momentum: 0.0,
32            dampening: 0.0,
33            weight_decay: 0.0,
34            nesterov: false,
35        }
36    }
37
38    /// Set momentum for the optimizer
39    pub fn with_momentum(mut self, momentum: f32) -> Self {
40        self.momentum = momentum;
41        self
42    }
43
44    /// Set weight decay (L2 regularization)
45    pub fn with_weight_decay(mut self, weight_decay: f32) -> Self {
46        self.weight_decay = weight_decay;
47        self
48    }
49
50    /// Enable Nesterov momentum
51    pub fn with_nesterov(mut self) -> Self {
52        self.nesterov = true;
53        self
54    }
55
56    /// Perform a single optimization step
57    pub fn step(
58        &self,
59        params: &mut ArrayViewMut1<f32>,
60        gradient: &ArrayView1<f32>,
61        velocity: &mut ArrayViewMut1<f32>,
62    ) {
63        // Add weight decay to gradient if specified
64        let mut grad = gradient.to_owned();
65        if self.weight_decay != 0.0 {
66            simd_axpy(self.weight_decay, &params.view(), &mut grad.view_mut());
67        }
68
69        if self.momentum != 0.0 {
70            // Update velocity: v = momentum * v + grad
71            simd_momentum_update(self.momentum, &grad.view(), velocity);
72
73            if self.nesterov {
74                // Nesterov momentum: param = param - lr * (momentum * v + grad)
75                let mut nesterov_grad = grad.clone();
76                simd_axpy(
77                    self.momentum,
78                    &velocity.view(),
79                    &mut nesterov_grad.view_mut(),
80                );
81                simd_axpy(-self.learning_rate, &nesterov_grad.view(), params);
82            } else {
83                // Standard momentum: param = param - lr * v
84                simd_axpy(-self.learning_rate, &velocity.view(), params);
85            }
86        } else {
87            // No momentum: param = param - lr * grad
88            simd_axpy(-self.learning_rate, &grad.view(), params);
89        }
90    }
91}
92
93/// SIMD-optimized coordinate descent optimizer
94pub struct CoordinateDescent {
95    alpha: f32,
96    tolerance: f32,
97    max_iterations: usize,
98}
99
100impl CoordinateDescent {
101    /// Create a new coordinate descent optimizer
102    pub fn new(alpha: f32) -> Self {
103        Self {
104            alpha,
105            tolerance: 1e-4,
106            max_iterations: 1000,
107        }
108    }
109
110    /// Set convergence tolerance
111    pub fn with_tolerance(mut self, tolerance: f32) -> Self {
112        self.tolerance = tolerance;
113        self
114    }
115
116    /// Set maximum iterations
117    pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
118        self.max_iterations = max_iterations;
119        self
120    }
121
122    /// Optimize using coordinate descent for LASSO regression
123    pub fn optimize_lasso(
124        &self,
125        x: &Array2<f32>,
126        y: &Array1<f32>,
127        coeff: &mut Array1<f32>,
128    ) -> Result<(), String> {
129        let n_features = x.ncols();
130        let n_samples = x.nrows();
131
132        // Pre-compute X^T X diagonal for efficiency
133        let mut xtx_diag = Array1::zeros(n_features);
134        for j in 0..n_features {
135            let col = x.column(j).to_owned();
136            xtx_diag[j] = dot_product(
137                col.as_slice().expect("slice operation should succeed"),
138                col.as_slice().expect("slice operation should succeed"),
139            );
140        }
141
142        // Residuals: r = y - X * coeff
143        let mut residuals = y.clone();
144        let pred = matrix_vector_multiply_f32(x, coeff);
145        simd_axpy(-1.0, &pred.view(), &mut residuals.view_mut());
146
147        for _ in 0..self.max_iterations {
148            let mut max_change: f32 = 0.0;
149
150            for j in 0..n_features {
151                let old_coeff = coeff[j];
152
153                // Add back the contribution of feature j to residuals
154                let col = x.column(j);
155                simd_axpy(old_coeff, &col.to_owned().view(), &mut residuals.view_mut());
156
157                // Compute new coefficient
158                let col_slice = col.to_owned();
159                let rho = dot_product(
160                    col_slice
161                        .as_slice()
162                        .expect("slice operation should succeed"),
163                    residuals
164                        .as_slice()
165                        .expect("slice operation should succeed"),
166                );
167                let new_coeff = soft_threshold(rho / n_samples as f32, self.alpha)
168                    / (xtx_diag[j] / n_samples as f32);
169
170                // Update coefficient and residuals
171                coeff[j] = new_coeff;
172                let change = new_coeff - old_coeff;
173                max_change = max_change.max(change.abs());
174
175                // Subtract new contribution from residuals
176                simd_axpy(
177                    -new_coeff,
178                    &col.to_owned().view(),
179                    &mut residuals.view_mut(),
180                );
181            }
182
183            if max_change < self.tolerance {
184                return Ok(());
185            }
186        }
187
188        Ok(())
189    }
190}
191
192/// SIMD-optimized quasi-Newton optimizer (L-BFGS)
193pub struct QuasiNewton {
194    memory_size: usize,
195    tolerance: f32,
196    max_iterations: usize,
197    line_search_max_iter: usize,
198}
199
200impl Default for QuasiNewton {
201    fn default() -> Self {
202        Self::new()
203    }
204}
205
206impl QuasiNewton {
207    /// Create a new quasi-Newton optimizer
208    pub fn new() -> Self {
209        Self {
210            memory_size: 10,
211            tolerance: 1e-6,
212            max_iterations: 1000,
213            line_search_max_iter: 20,
214        }
215    }
216
217    /// Set L-BFGS memory size
218    pub fn with_memory_size(mut self, memory_size: usize) -> Self {
219        self.memory_size = memory_size;
220        self
221    }
222
223    /// Simple L-BFGS implementation for demonstration
224    pub fn optimize<F, G>(
225        &self,
226        mut x: Array1<f32>,
227        objective: F,
228        gradient: G,
229    ) -> Result<Array1<f32>, String>
230    where
231        F: Fn(&Array1<f32>) -> f32,
232        G: Fn(&Array1<f32>) -> Array1<f32>,
233    {
234        let n = x.len();
235        let mut grad = gradient(&x);
236        let h_inv = Array2::eye(n); // Initial Hessian inverse approximation
237
238        for _ in 0..self.max_iterations {
239            let grad_norm = norm_l2(grad.as_slice().expect("slice operation should succeed"));
240            if grad_norm < self.tolerance {
241                return Ok(x);
242            }
243
244            // Compute search direction: d = -H^{-1} * grad
245            let direction = matrix_vector_multiply_f32(&h_inv, &grad);
246            let mut search_dir = direction;
247            simd_scale(-1.0, &mut search_dir.view_mut());
248
249            // Line search to find step size
250            let step_size = self.line_search(&x, &search_dir, &objective, &gradient)?;
251
252            // Update parameters
253            let mut step = search_dir.clone();
254            simd_scale(step_size, &mut step.view_mut());
255            let x_new = &x + &step;
256
257            let grad_new = gradient(&x_new);
258
259            // BFGS update (simplified)
260            let s = &x_new - &x;
261            let y = &grad_new - &grad;
262
263            let sy = dot_product(
264                s.as_slice().expect("slice operation should succeed"),
265                y.as_slice().expect("slice operation should succeed"),
266            );
267            if sy > 1e-10 {
268                // Update Hessian inverse approximation (simplified rank-1 update)
269                // This is a simplified version - full L-BFGS would maintain a history
270            }
271
272            x = x_new;
273            grad = grad_new;
274        }
275
276        Ok(x)
277    }
278
279    /// Simple backtracking line search
280    fn line_search<F, G>(
281        &self,
282        x: &Array1<f32>,
283        direction: &Array1<f32>,
284        objective: &F,
285        gradient: &G,
286    ) -> Result<f32, String>
287    where
288        F: Fn(&Array1<f32>) -> f32,
289        G: Fn(&Array1<f32>) -> Array1<f32>,
290    {
291        let c1 = 1e-4;
292        let mut alpha = 1.0;
293        let f_x = objective(x);
294        let grad_x = gradient(x);
295        let grad_dot_dir = dot_product(
296            grad_x.as_slice().expect("slice operation should succeed"),
297            direction
298                .as_slice()
299                .expect("slice operation should succeed"),
300        );
301
302        for _ in 0..self.line_search_max_iter {
303            let mut x_new = x.clone();
304            let mut step = direction.clone();
305            simd_scale(alpha, &mut step.view_mut());
306            simd_axpy(1.0, &step.view(), &mut x_new.view_mut());
307
308            let f_x_new = objective(&x_new);
309
310            // Armijo condition
311            if f_x_new <= f_x + c1 * alpha * grad_dot_dir {
312                return Ok(alpha);
313            }
314
315            alpha *= 0.5;
316        }
317
318        Ok(alpha)
319    }
320}
321
322/// SIMD-optimized AXPY operation: y = alpha * x + y
323pub fn simd_axpy(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
324    assert_eq!(x.len(), y.len(), "Arrays must have the same length");
325
326    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
327    {
328        if crate::simd_feature_detected!("avx2") && crate::simd_feature_detected!("fma") {
329            unsafe { simd_axpy_avx2_fma(alpha, x, y) };
330            return;
331        } else if crate::simd_feature_detected!("avx2") {
332            unsafe { simd_axpy_avx2(alpha, x, y) };
333            return;
334        } else if crate::simd_feature_detected!("sse2") {
335            unsafe { simd_axpy_sse2(alpha, x, y) };
336            return;
337        }
338    }
339
340    // Scalar fallback
341    for i in 0..x.len() {
342        y[i] += alpha * x[i];
343    }
344}
345
346/// SIMD-optimized scaling: x = alpha * x
347pub fn simd_scale(alpha: f32, x: &mut ArrayViewMut1<f32>) {
348    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
349    {
350        if crate::simd_feature_detected!("avx2") {
351            unsafe { simd_scale_avx2(alpha, x) };
352            return;
353        } else if crate::simd_feature_detected!("sse2") {
354            unsafe { simd_scale_sse2(alpha, x) };
355            return;
356        }
357    }
358
359    // Scalar fallback
360    for val in x.iter_mut() {
361        *val *= alpha;
362    }
363}
364
365/// SIMD-optimized momentum update: v = momentum * v + grad
366pub fn simd_momentum_update(
367    momentum: f32,
368    grad: &ArrayView1<f32>,
369    velocity: &mut ArrayViewMut1<f32>,
370) {
371    assert_eq!(
372        grad.len(),
373        velocity.len(),
374        "Arrays must have the same length"
375    );
376
377    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
378    {
379        if crate::simd_feature_detected!("avx2") && crate::simd_feature_detected!("fma") {
380            unsafe { simd_momentum_update_avx2_fma(momentum, grad, velocity) };
381            return;
382        } else if crate::simd_feature_detected!("avx2") {
383            unsafe { simd_momentum_update_avx2(momentum, grad, velocity) };
384            return;
385        } else if crate::simd_feature_detected!("sse2") {
386            unsafe { simd_momentum_update_sse2(momentum, grad, velocity) };
387            return;
388        }
389    }
390
391    // Scalar fallback
392    for i in 0..grad.len() {
393        velocity[i] = momentum * velocity[i] + grad[i];
394    }
395}
396
397/// Soft thresholding function for LASSO
398fn soft_threshold(x: f32, threshold: f32) -> f32 {
399    if x > threshold {
400        x - threshold
401    } else if x < -threshold {
402        x + threshold
403    } else {
404        0.0
405    }
406}
407
408// SIMD implementations for x86/x86_64
409
410#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
411#[target_feature(enable = "sse2")]
412unsafe fn simd_axpy_sse2(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
413    use core::arch::x86_64::*;
414
415    let alpha_vec = _mm_set1_ps(alpha);
416    let len = x.len();
417    let mut i = 0;
418
419    while i + 4 <= len {
420        let x_vec = _mm_loadu_ps(&x[i]);
421        let y_vec = _mm_loadu_ps(&y[i]);
422        let result = _mm_add_ps(_mm_mul_ps(alpha_vec, x_vec), y_vec);
423        _mm_storeu_ps(&mut y[i], result);
424        i += 4;
425    }
426
427    // Handle remaining elements
428    while i < len {
429        y[i] += alpha * x[i];
430        i += 1;
431    }
432}
433
434#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
435#[target_feature(enable = "avx2")]
436unsafe fn simd_axpy_avx2(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
437    use core::arch::x86_64::*;
438
439    let alpha_vec = _mm256_set1_ps(alpha);
440    let len = x.len();
441    let mut i = 0;
442
443    while i + 8 <= len {
444        let x_vec = _mm256_loadu_ps(&x[i]);
445        let y_vec = _mm256_loadu_ps(&y[i]);
446        let result = _mm256_add_ps(_mm256_mul_ps(alpha_vec, x_vec), y_vec);
447        _mm256_storeu_ps(&mut y[i], result);
448        i += 8;
449    }
450
451    // Handle remaining elements
452    while i < len {
453        y[i] += alpha * x[i];
454        i += 1;
455    }
456}
457
458#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
459#[target_feature(enable = "avx2", enable = "fma")]
460unsafe fn simd_axpy_avx2_fma(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
461    use core::arch::x86_64::*;
462
463    let alpha_vec = _mm256_set1_ps(alpha);
464    let len = x.len();
465    let mut i = 0;
466
467    while i + 8 <= len {
468        let x_vec = _mm256_loadu_ps(&x[i]);
469        let y_vec = _mm256_loadu_ps(&y[i]);
470        let result = _mm256_fmadd_ps(alpha_vec, x_vec, y_vec);
471        _mm256_storeu_ps(&mut y[i], result);
472        i += 8;
473    }
474
475    // Handle remaining elements
476    while i < len {
477        y[i] += alpha * x[i];
478        i += 1;
479    }
480}
481
482#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
483#[target_feature(enable = "sse2")]
484unsafe fn simd_scale_sse2(alpha: f32, x: &mut ArrayViewMut1<f32>) {
485    use core::arch::x86_64::*;
486
487    let alpha_vec = _mm_set1_ps(alpha);
488    let len = x.len();
489    let mut i = 0;
490
491    while i + 4 <= len {
492        let x_vec = _mm_loadu_ps(&x[i]);
493        let result = _mm_mul_ps(alpha_vec, x_vec);
494        _mm_storeu_ps(&mut x[i], result);
495        i += 4;
496    }
497
498    // Handle remaining elements
499    while i < len {
500        x[i] *= alpha;
501        i += 1;
502    }
503}
504
505#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
506#[target_feature(enable = "avx2")]
507unsafe fn simd_scale_avx2(alpha: f32, x: &mut ArrayViewMut1<f32>) {
508    use core::arch::x86_64::*;
509
510    let alpha_vec = _mm256_set1_ps(alpha);
511    let len = x.len();
512    let mut i = 0;
513
514    while i + 8 <= len {
515        let x_vec = _mm256_loadu_ps(&x[i]);
516        let result = _mm256_mul_ps(alpha_vec, x_vec);
517        _mm256_storeu_ps(&mut x[i], result);
518        i += 8;
519    }
520
521    // Handle remaining elements
522    while i < len {
523        x[i] *= alpha;
524        i += 1;
525    }
526}
527
528#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
529#[target_feature(enable = "sse2")]
530unsafe fn simd_momentum_update_sse2(
531    momentum: f32,
532    grad: &ArrayView1<f32>,
533    velocity: &mut ArrayViewMut1<f32>,
534) {
535    use core::arch::x86_64::*;
536
537    let momentum_vec = _mm_set1_ps(momentum);
538    let len = grad.len();
539    let mut i = 0;
540
541    while i + 4 <= len {
542        let grad_vec = _mm_loadu_ps(&grad[i]);
543        let vel_vec = _mm_loadu_ps(&velocity[i]);
544        let result = _mm_add_ps(_mm_mul_ps(momentum_vec, vel_vec), grad_vec);
545        _mm_storeu_ps(&mut velocity[i], result);
546        i += 4;
547    }
548
549    // Handle remaining elements
550    while i < len {
551        velocity[i] = momentum * velocity[i] + grad[i];
552        i += 1;
553    }
554}
555
556#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
557#[target_feature(enable = "avx2")]
558unsafe fn simd_momentum_update_avx2(
559    momentum: f32,
560    grad: &ArrayView1<f32>,
561    velocity: &mut ArrayViewMut1<f32>,
562) {
563    use core::arch::x86_64::*;
564
565    let momentum_vec = _mm256_set1_ps(momentum);
566    let len = grad.len();
567    let mut i = 0;
568
569    while i + 8 <= len {
570        let grad_vec = _mm256_loadu_ps(&grad[i]);
571        let vel_vec = _mm256_loadu_ps(&velocity[i]);
572        let result = _mm256_add_ps(_mm256_mul_ps(momentum_vec, vel_vec), grad_vec);
573        _mm256_storeu_ps(&mut velocity[i], result);
574        i += 8;
575    }
576
577    // Handle remaining elements
578    while i < len {
579        velocity[i] = momentum * velocity[i] + grad[i];
580        i += 1;
581    }
582}
583
584#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
585#[target_feature(enable = "avx2", enable = "fma")]
586unsafe fn simd_momentum_update_avx2_fma(
587    momentum: f32,
588    grad: &ArrayView1<f32>,
589    velocity: &mut ArrayViewMut1<f32>,
590) {
591    use core::arch::x86_64::*;
592
593    let momentum_vec = _mm256_set1_ps(momentum);
594    let len = grad.len();
595    let mut i = 0;
596
597    while i + 8 <= len {
598        let grad_vec = _mm256_loadu_ps(&grad[i]);
599        let vel_vec = _mm256_loadu_ps(&velocity[i]);
600        let result = _mm256_fmadd_ps(momentum_vec, vel_vec, grad_vec);
601        _mm256_storeu_ps(&mut velocity[i], result);
602        i += 8;
603    }
604
605    // Handle remaining elements
606    while i < len {
607        velocity[i] = momentum * velocity[i] + grad[i];
608        i += 1;
609    }
610}
611
612#[allow(non_snake_case)]
613#[cfg(all(test, not(feature = "no-std")))]
614mod tests {
615    use super::*;
616    use approx::assert_relative_eq;
617
618    #[cfg(feature = "no-std")]
619    use alloc::{vec, vec::Vec};
620
621    #[test]
622    fn test_gradient_descent() {
623        let optimizer = GradientDescent::new(0.1).with_momentum(0.9);
624
625        let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
626        let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
627        let mut velocity = Array1::zeros(3);
628
629        let params_before = params.clone();
630        optimizer.step(
631            &mut params.view_mut(),
632            &gradient.view(),
633            &mut velocity.view_mut(),
634        );
635
636        // Parameters should have moved in the opposite direction of the gradient
637        for i in 0..params.len() {
638            assert!(params[i] < params_before[i]);
639        }
640    }
641
642    #[test]
643    fn test_coordinate_descent() {
644        let optimizer = CoordinateDescent::new(0.1);
645
646        // Simple 2D problem
647        let x = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
648            .expect("shape and data length should match");
649        let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
650        let mut coeff = Array1::zeros(2);
651
652        let result = optimizer.optimize_lasso(&x, &y, &mut coeff);
653        assert!(result.is_ok());
654    }
655
656    #[test]
657    fn test_simd_axpy() {
658        let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
659        let mut y = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0]);
660        let alpha = 2.0;
661
662        let expected = &y + &(&x * alpha);
663        simd_axpy(alpha, &x.view(), &mut y.view_mut());
664
665        for i in 0..x.len() {
666            assert_relative_eq!(y[i], expected[i], epsilon = 1e-6);
667        }
668    }
669
670    #[test]
671    fn test_simd_scale() {
672        let mut x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
673        let alpha = 2.5;
674
675        let expected = &x * alpha;
676        simd_scale(alpha, &mut x.view_mut());
677
678        for i in 0..x.len() {
679            assert_relative_eq!(x[i], expected[i], epsilon = 1e-6);
680        }
681    }
682
683    #[test]
684    fn test_momentum_update() {
685        let grad = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
686        let mut velocity = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
687        let momentum = 0.9;
688
689        let expected = &velocity * momentum + &grad;
690        simd_momentum_update(momentum, &grad.view(), &mut velocity.view_mut());
691
692        for i in 0..grad.len() {
693            assert_relative_eq!(velocity[i], expected[i], epsilon = 1e-6);
694        }
695    }
696
697    #[test]
698    fn test_soft_threshold() {
699        assert_eq!(soft_threshold(2.0, 1.0), 1.0);
700        assert_eq!(soft_threshold(-2.0, 1.0), -1.0);
701        assert_eq!(soft_threshold(0.5, 1.0), 0.0);
702        assert_eq!(soft_threshold(-0.5, 1.0), 0.0);
703    }
704}