sklears_linear/
optimizer.rs

1//! Optimization algorithms for linear models
2
3use scirs2_core::ndarray::{Array1, Array2};
4use scirs2_core::random::rngs::StdRng;
5use scirs2_core::random::SeedableRng;
6use sklears_core::types::Float;
7
8/// L-BFGS optimizer for unconstrained optimization
9pub struct LbfgsOptimizer {
10    /// Maximum number of iterations
11    pub max_iter: usize,
12    /// Convergence tolerance
13    pub tol: Float,
14    /// Number of corrections to approximate the Hessian
15    pub m: usize,
16    /// Line search max iterations
17    pub max_ls: usize,
18}
19
20impl Default for LbfgsOptimizer {
21    fn default() -> Self {
22        Self {
23            max_iter: 100,
24            tol: 1e-4,
25            m: 10,
26            max_ls: 20,
27        }
28    }
29}
30
31impl LbfgsOptimizer {
32    /// Minimize a differentiable function using L-BFGS
33    pub fn minimize<F, G>(
34        &self,
35        f: F,
36        grad_f: G,
37        x0: Array1<Float>,
38    ) -> Result<Array1<Float>, String>
39    where
40        F: Fn(&Array1<Float>) -> Float,
41        G: Fn(&Array1<Float>) -> Array1<Float>,
42    {
43        let _n = x0.len();
44        let mut x = x0;
45        let mut f_k = f(&x);
46        let mut g_k = grad_f(&x);
47
48        // History for L-BFGS
49        let mut s_list: Vec<Array1<Float>> = Vec::with_capacity(self.m);
50        let mut y_list: Vec<Array1<Float>> = Vec::with_capacity(self.m);
51        let mut rho_list: Vec<Float> = Vec::with_capacity(self.m);
52
53        for _iter in 0..self.max_iter {
54            // Check convergence
55            let g_norm = g_k.mapv(Float::abs).sum();
56            if g_norm < self.tol {
57                return Ok(x);
58            }
59
60            // Compute search direction using L-BFGS
61            let p = self.compute_direction(&g_k, &s_list, &y_list, &rho_list);
62
63            // Line search
64            let alpha = self.line_search(&f, &x, &p, f_k, &g_k)?;
65
66            // Update x
67            let x_new = &x + alpha * &p;
68            let f_new = f(&x_new);
69            let g_new = grad_f(&x_new);
70
71            // Update history
72            let s = &x_new - &x;
73            let y = &g_new - &g_k;
74
75            let rho = 1.0 / s.dot(&y).max(1e-10);
76
77            // Maintain history size
78            if s_list.len() == self.m {
79                s_list.remove(0);
80                y_list.remove(0);
81                rho_list.remove(0);
82            }
83
84            s_list.push(s);
85            y_list.push(y);
86            rho_list.push(rho);
87
88            // Update for next iteration
89            x = x_new;
90            f_k = f_new;
91            g_k = g_new;
92        }
93
94        Ok(x)
95    }
96
97    /// Compute L-BFGS search direction
98    fn compute_direction(
99        &self,
100        g: &Array1<Float>,
101        s_list: &[Array1<Float>],
102        y_list: &[Array1<Float>],
103        rho_list: &[Float],
104    ) -> Array1<Float> {
105        let mut q = g.clone();
106        let k = s_list.len();
107        let mut alpha = vec![0.0; k];
108
109        // First loop
110        for i in (0..k).rev() {
111            alpha[i] = rho_list[i] * s_list[i].dot(&q);
112            q = q - alpha[i] * &y_list[i];
113        }
114
115        // Scale initial direction
116        let mut r = if k > 0 {
117            let gamma = s_list[k - 1].dot(&y_list[k - 1]) / y_list[k - 1].dot(&y_list[k - 1]);
118            gamma * q
119        } else {
120            q
121        };
122
123        // Second loop
124        for i in 0..k {
125            let beta = rho_list[i] * y_list[i].dot(&r);
126            r = r + (alpha[i] - beta) * &s_list[i];
127        }
128
129        -r
130    }
131
132    /// Backtracking line search
133    fn line_search<F>(
134        &self,
135        f: &F,
136        x: &Array1<Float>,
137        p: &Array1<Float>,
138        f_k: Float,
139        g_k: &Array1<Float>,
140    ) -> Result<Float, String>
141    where
142        F: Fn(&Array1<Float>) -> Float,
143    {
144        let c1 = 1e-4;
145        let rho = 0.5;
146        let mut alpha = 1.0;
147
148        let gp = g_k.dot(p);
149
150        for _ in 0..self.max_ls {
151            let x_new = x + alpha * p;
152            let f_new = f(&x_new);
153
154            // Armijo condition
155            if f_new <= f_k + c1 * alpha * gp {
156                return Ok(alpha);
157            }
158
159            alpha *= rho;
160        }
161
162        Err("Line search failed to find suitable step size".to_string())
163    }
164}
165
166/// Stochastic Average Gradient (SAG) optimizer
167pub struct SagOptimizer {
168    /// Maximum number of epochs
169    pub max_epochs: usize,
170    /// Convergence tolerance
171    pub tol: Float,
172    /// Learning rate
173    pub learning_rate: Option<Float>,
174    /// Random seed
175    pub random_state: Option<u64>,
176}
177
178impl Default for SagOptimizer {
179    fn default() -> Self {
180        Self {
181            max_epochs: 100,
182            tol: 1e-4,
183            learning_rate: None,
184            random_state: None,
185        }
186    }
187}
188
189impl SagOptimizer {
190    /// Minimize using SAG for finite-sum problems
191    /// f_i: individual loss function for sample i
192    /// grad_f_i: gradient of individual loss function
193    pub fn minimize<F, G>(
194        &self,
195        _f_i: F,
196        grad_f_i: G,
197        x0: Array1<Float>,
198        n_samples: usize,
199    ) -> Result<Array1<Float>, String>
200    where
201        F: Fn(&Array1<Float>, usize) -> Float,
202        G: Fn(&Array1<Float>, usize) -> Array1<Float>,
203    {
204        let n_features = x0.len();
205        let mut x = x0;
206
207        // Initialize gradient memory
208        let mut gradient_memory = Array2::zeros((n_samples, n_features));
209        let mut gradient_sum = Array1::zeros(n_features);
210        let mut seen = vec![false; n_samples];
211
212        // Learning rate (can be auto-tuned based on Lipschitz constant)
213        let alpha = self.learning_rate.unwrap_or(0.01);
214
215        // Random number generator
216        let mut rng = match self.random_state {
217            Some(seed) => StdRng::seed_from_u64(seed),
218            None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
219        };
220
221        for _epoch in 0..self.max_epochs {
222            // Random permutation for this epoch
223            let mut indices: Vec<usize> = (0..n_samples).collect();
224            use scirs2_core::random::SliceRandomExt;
225            indices.shuffle(&mut rng);
226
227            for &i in &indices {
228                // Compute gradient for sample i
229                let grad_i = grad_f_i(&x, i);
230
231                // Update gradient sum
232                if seen[i] {
233                    gradient_sum = &gradient_sum - &gradient_memory.row(i) + &grad_i;
234                } else {
235                    gradient_sum = &gradient_sum + &grad_i;
236                    seen[i] = true;
237                }
238
239                // Store gradient in memory
240                gradient_memory.row_mut(i).assign(&grad_i);
241
242                // Update parameters
243                x = &x - alpha * &gradient_sum / n_samples as Float;
244            }
245
246            // Check convergence
247            let avg_grad_norm = gradient_sum.mapv(Float::abs).sum() / n_samples as Float;
248            if avg_grad_norm < self.tol {
249                return Ok(x);
250            }
251        }
252
253        Ok(x)
254    }
255}
256
257/// SAGA optimizer (improved SAG with support for non-smooth penalties)
258pub struct SagaOptimizer {
259    /// Maximum number of epochs
260    pub max_epochs: usize,
261    /// Convergence tolerance
262    pub tol: Float,
263    /// Learning rate
264    pub learning_rate: Option<Float>,
265    /// Random seed
266    pub random_state: Option<u64>,
267}
268
269impl Default for SagaOptimizer {
270    fn default() -> Self {
271        Self {
272            max_epochs: 100,
273            tol: 1e-4,
274            learning_rate: None,
275            random_state: None,
276        }
277    }
278}
279
280impl SagaOptimizer {
281    /// Minimize using SAGA for composite objectives: f(x) + g(x)
282    /// where f is smooth (finite-sum) and g is possibly non-smooth (e.g., L1 penalty)
283    pub fn minimize<F, GradF, ProxG>(
284        &self,
285        _f_i: F,
286        grad_f_i: GradF,
287        prox_g: ProxG,
288        x0: Array1<Float>,
289        n_samples: usize,
290    ) -> Result<Array1<Float>, String>
291    where
292        F: Fn(&Array1<Float>, usize) -> Float,
293        GradF: Fn(&Array1<Float>, usize) -> Array1<Float>,
294        ProxG: Fn(&Array1<Float>, Float) -> Array1<Float>,
295    {
296        let n_features = x0.len();
297        let mut x = x0;
298
299        // Initialize gradient memory and average
300        let mut gradient_memory = Array2::zeros((n_samples, n_features));
301        let mut gradient_avg = Array1::zeros(n_features);
302
303        // Initialize all gradients
304        for i in 0..n_samples {
305            let grad_i = grad_f_i(&x, i);
306            gradient_memory.row_mut(i).assign(&grad_i);
307            gradient_avg = &gradient_avg + &grad_i / n_samples as Float;
308        }
309
310        // Learning rate
311        let alpha = self.learning_rate.unwrap_or(0.01);
312
313        // Random number generator
314        let mut rng = match self.random_state {
315            Some(seed) => StdRng::seed_from_u64(seed),
316            None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
317        };
318
319        for _epoch in 0..self.max_epochs {
320            // Random permutation
321            let mut indices: Vec<usize> = (0..n_samples).collect();
322            use scirs2_core::random::SliceRandomExt;
323            indices.shuffle(&mut rng);
324
325            for &i in &indices {
326                // Store old gradient
327                let grad_old = gradient_memory.row(i).to_owned();
328
329                // Compute new gradient
330                let grad_new = grad_f_i(&x, i);
331
332                // Update gradient memory and average
333                gradient_memory.row_mut(i).assign(&grad_new);
334                gradient_avg = &gradient_avg + (&grad_new - &grad_old) / n_samples as Float;
335
336                // Gradient step
337                let v = &grad_new - &grad_old + &gradient_avg;
338                let x_intermediate = &x - alpha * &v;
339
340                // Proximal step (handles non-smooth penalty)
341                x = prox_g(&x_intermediate, alpha);
342            }
343
344            // Check convergence
345            let grad_norm = gradient_avg.mapv(Float::abs).sum();
346            if grad_norm < self.tol {
347                return Ok(x);
348            }
349        }
350
351        Ok(x)
352    }
353}
354
355/// Proximal Gradient Method optimizer for composite objectives
356pub struct ProximalGradientOptimizer {
357    /// Maximum number of iterations
358    pub max_iter: usize,
359    /// Convergence tolerance
360    pub tol: Float,
361    /// Initial step size
362    pub step_size: Option<Float>,
363    /// Whether to use line search for step size
364    pub use_line_search: bool,
365    /// Whether to use acceleration (FISTA)
366    pub accelerated: bool,
367    /// Backtracking line search parameters
368    pub beta: Float,
369    pub sigma: Float,
370}
371
372impl Default for ProximalGradientOptimizer {
373    fn default() -> Self {
374        Self {
375            max_iter: 1000,
376            tol: 1e-6,
377            step_size: None,
378            use_line_search: true,
379            accelerated: false,
380            beta: 0.5,
381            sigma: 0.01,
382        }
383    }
384}
385
386impl ProximalGradientOptimizer {
387    /// Create accelerated version (FISTA)
388    pub fn accelerated() -> Self {
389        Self {
390            accelerated: true,
391            ..Default::default()
392        }
393    }
394
395    /// Minimize composite objective f(x) + g(x) using proximal gradient method
396    /// where f is smooth and g is convex (possibly non-smooth)
397    pub fn minimize<F, GradF, ProxG>(
398        &self,
399        f: F,
400        grad_f: GradF,
401        prox_g: ProxG,
402        x0: Array1<Float>,
403    ) -> Result<Array1<Float>, String>
404    where
405        F: Fn(&Array1<Float>) -> Float,
406        GradF: Fn(&Array1<Float>) -> Array1<Float>,
407        ProxG: Fn(&Array1<Float>, Float) -> Array1<Float>,
408    {
409        let mut x = x0.clone();
410        let mut y = x0; // For acceleration
411        let mut t = 1.0; // Acceleration parameter
412
413        // Initial step size
414        let mut step_size = self.step_size.unwrap_or(1.0);
415
416        let mut f_prev = f(&x);
417
418        for iter in 0..self.max_iter {
419            // Use y for accelerated version, x for regular version
420            let current_point = if self.accelerated { &y } else { &x };
421
422            // Compute gradient at current point
423            let grad = grad_f(current_point);
424
425            // Line search for step size if enabled
426            if self.use_line_search {
427                step_size = self.backtracking_line_search(
428                    &f,
429                    &grad_f,
430                    &prox_g,
431                    current_point,
432                    &grad,
433                    step_size,
434                )?;
435            }
436
437            // Proximal gradient step
438            let x_new = prox_g(&(current_point - step_size * &grad), step_size);
439
440            // Check convergence
441            let diff_norm = (&x_new - &x).mapv(Float::abs).sum();
442            if diff_norm < self.tol {
443                return Ok(x_new);
444            }
445
446            // Update for acceleration
447            if self.accelerated {
448                let t_new: Float = (1.0_f64 + (1.0_f64 + 4.0_f64 * t * t).sqrt()) / 2.0_f64;
449                let beta = (t - 1.0) / t_new;
450                y = &x_new + beta * (&x_new - &x);
451                t = t_new;
452            }
453
454            x = x_new;
455
456            // Optional: track objective value for convergence
457            let f_curr = f(&x);
458            if (f_curr - f_prev).abs() < self.tol * f_prev.abs().max(1.0) && iter > 10 {
459                // Avoid early termination
460                return Ok(x);
461            }
462            f_prev = f_curr;
463        }
464
465        Ok(x)
466    }
467
468    /// Backtracking line search for proximal gradient
469    fn backtracking_line_search<F, GradF, ProxG>(
470        &self,
471        f: &F,
472        _grad_f: &GradF,
473        prox_g: &ProxG,
474        x: &Array1<Float>,
475        grad: &Array1<Float>,
476        mut step_size: Float,
477    ) -> Result<Float, String>
478    where
479        F: Fn(&Array1<Float>) -> Float,
480        GradF: Fn(&Array1<Float>) -> Array1<Float>,
481        ProxG: Fn(&Array1<Float>, Float) -> Array1<Float>,
482    {
483        let f_x = f(x);
484        let max_iter = 50;
485
486        for _ in 0..max_iter {
487            // Proximal step
488            let x_new = prox_g(&(x - step_size * grad), step_size);
489            let f_new = f(&x_new);
490
491            // Quadratic approximation condition
492            let diff = &x_new - x;
493            let quad_approx = f_x + grad.dot(&diff) + diff.dot(&diff) / (2.0 * step_size);
494
495            if f_new <= quad_approx + self.sigma * step_size * grad.dot(grad) {
496                return Ok(step_size);
497            }
498
499            step_size *= self.beta;
500
501            if step_size < 1e-16 {
502                return Err("Step size became too small in line search".to_string());
503            }
504        }
505
506        Ok(step_size) // Return current step size if line search doesn't converge
507    }
508}
509
510/// Accelerated Proximal Gradient Method (FISTA)
511pub type FistaOptimizer = ProximalGradientOptimizer;
512
513/// Accelerated Gradient Descent with Nesterov momentum
514pub struct NesterovAcceleratedGradient {
515    /// Maximum number of iterations
516    pub max_iter: usize,
517    /// Convergence tolerance
518    pub tol: Float,
519    /// Learning rate
520    pub learning_rate: Float,
521    /// Whether to use adaptive learning rate
522    pub adaptive_lr: bool,
523    /// Momentum parameter (typically 0.9)
524    pub momentum: Float,
525    /// Learning rate decay factor
526    pub lr_decay: Float,
527}
528
529impl Default for NesterovAcceleratedGradient {
530    fn default() -> Self {
531        Self {
532            max_iter: 1000,
533            tol: 1e-6,
534            learning_rate: 0.01,
535            adaptive_lr: false,
536            momentum: 0.9,
537            lr_decay: 0.999,
538        }
539    }
540}
541
542impl NesterovAcceleratedGradient {
543    /// Create optimizer with adaptive learning rate
544    pub fn adaptive() -> Self {
545        Self {
546            adaptive_lr: true,
547            ..Default::default()
548        }
549    }
550
551    /// Minimize smooth objective using Nesterov accelerated gradient descent
552    pub fn minimize<F, GradF>(
553        &self,
554        f: F,
555        grad_f: GradF,
556        x0: Array1<Float>,
557    ) -> Result<Array1<Float>, String>
558    where
559        F: Fn(&Array1<Float>) -> Float,
560        GradF: Fn(&Array1<Float>) -> Array1<Float>,
561    {
562        let mut x = x0.clone();
563        let mut v = Array1::zeros(x0.len()); // velocity
564        let mut lr = self.learning_rate;
565
566        let mut f_prev = f(&x);
567
568        for iter in 0..self.max_iter {
569            // Nesterov look-ahead point
570            let y = &x + self.momentum * &v;
571
572            // Compute gradient at look-ahead point
573            let grad = grad_f(&y);
574
575            // Check convergence
576            let grad_norm = grad.mapv(Float::abs).sum();
577            if grad_norm < self.tol {
578                return Ok(x);
579            }
580
581            // Update velocity and position
582            v = self.momentum * &v - lr * &grad;
583            x = &x + &v;
584
585            // Adaptive learning rate
586            if self.adaptive_lr {
587                let f_curr = f(&x);
588                if f_curr > f_prev + 1e-10 {
589                    // If objective increased, reduce learning rate and reset
590                    lr *= 0.9;
591                    x = &x - &v; // undo the step
592                    v = Array1::zeros(x0.len()); // reset velocity for stability
593                    continue;
594                } else if f_curr < f_prev - 0.01 * lr * grad_norm {
595                    // If we're making good progress, increase learning rate slightly
596                    lr = (lr * 1.005).min(self.learning_rate * 1.5); // Very conservative increase
597                }
598
599                // Check objective-based convergence every iteration for adaptive mode
600                if (f_curr - f_prev).abs() < self.tol * f_prev.abs().max(1.0) {
601                    return Ok(x);
602                }
603                f_prev = f_curr;
604            } else {
605                // Fixed decay schedule
606                lr *= self.lr_decay;
607
608                // Check objective-based convergence every 10 iterations for non-adaptive mode
609                if iter % 10 == 0 {
610                    let f_curr = f(&x);
611                    if (f_curr - f_prev).abs() < self.tol * f_prev.abs().max(1.0) {
612                        return Ok(x);
613                    }
614                    f_prev = f_curr;
615                }
616            }
617        }
618
619        Ok(x)
620    }
621}
622
623/// Proximal operators for common penalties
624pub mod proximal {
625    use super::*;
626
627    /// Proximal operator for L1 penalty: prox_{α*λ*||.||_1}
628    pub fn prox_l1(x: &Array1<Float>, alpha_lambda: Float) -> Array1<Float> {
629        x.mapv(|xi| soft_threshold(xi, alpha_lambda))
630    }
631
632    /// Proximal operator for L2 penalty: prox_{α*λ/2*||.||²}
633    pub fn prox_l2(x: &Array1<Float>, alpha_lambda: Float) -> Array1<Float> {
634        x / (1.0 + alpha_lambda)
635    }
636
637    /// Proximal operator for Elastic Net
638    pub fn prox_elastic_net(x: &Array1<Float>, alpha: Float, l1_ratio: Float) -> Array1<Float> {
639        let l1_prox = prox_l1(x, alpha * l1_ratio);
640        prox_l2(&l1_prox, alpha * (1.0 - l1_ratio))
641    }
642
643    /// Soft thresholding function
644    #[inline]
645    fn soft_threshold(x: Float, lambda: Float) -> Float {
646        if x > lambda {
647            x - lambda
648        } else if x < -lambda {
649            x + lambda
650        } else {
651            0.0
652        }
653    }
654}
655
656#[allow(non_snake_case)]
657#[cfg(test)]
658mod tests {
659    use super::*;
660    use approx::assert_abs_diff_eq;
661    use scirs2_core::ndarray::array;
662
663    #[test]
664    fn test_lbfgs_quadratic() {
665        // Minimize f(x) = 0.5 * x^T * A * x - b^T * x
666        // where A = [[2, 0], [0, 4]], b = [1, 2]
667        // Solution: x* = [0.5, 0.5]
668
669        let f = |x: &Array1<Float>| -> Float {
670            0.5 * (2.0 * x[0] * x[0] + 4.0 * x[1] * x[1]) - x[0] - 2.0 * x[1]
671        };
672
673        let grad_f =
674            |x: &Array1<Float>| -> Array1<Float> { array![2.0 * x[0] - 1.0, 4.0 * x[1] - 2.0] };
675
676        let optimizer = LbfgsOptimizer::default();
677        let x0 = array![0.0, 0.0];
678
679        let result = optimizer.minimize(f, grad_f, x0).unwrap();
680
681        assert_abs_diff_eq!(result[0], 0.5, epsilon = 1e-6);
682        assert_abs_diff_eq!(result[1], 0.5, epsilon = 1e-6);
683    }
684
685    #[test]
686    fn test_proximal_operators() {
687        use proximal::*;
688
689        let x = array![2.0, -1.5, 0.5, -0.3];
690
691        // Test L1 proximal
692        let prox_x = prox_l1(&x, 0.5);
693        assert_abs_diff_eq!(prox_x[0], 1.5);
694        assert_abs_diff_eq!(prox_x[1], -1.0);
695        assert_abs_diff_eq!(prox_x[2], 0.0);
696        assert_abs_diff_eq!(prox_x[3], 0.0);
697
698        // Test L2 proximal
699        let prox_x = prox_l2(&x, 0.5);
700        assert_abs_diff_eq!(prox_x[0], 2.0 / 1.5);
701        assert_abs_diff_eq!(prox_x[1], -1.5 / 1.5);
702    }
703
704    #[test]
705    fn test_proximal_gradient_lasso() {
706        use proximal::*;
707
708        // Test Lasso problem: min 0.5 * ||Ax - b||^2 + lambda * ||x||_1
709        // A = [[1, 0], [0, 1]], b = [2, 1], lambda = 0.5
710        // Solution should be approximately [1.5, 0.5]
711
712        let a = array![[1.0, 0.0], [0.0, 1.0]];
713        let b = array![2.0, 1.0];
714        let lambda = 0.5;
715
716        let f = |x: &Array1<Float>| -> Float {
717            let residual = a.dot(x) - &b;
718            0.5 * residual.dot(&residual)
719        };
720
721        let grad_f = |x: &Array1<Float>| -> Array1<Float> {
722            let residual = a.dot(x) - &b;
723            a.t().dot(&residual)
724        };
725
726        let prox_g = |x: &Array1<Float>, t: Float| -> Array1<Float> { prox_l1(x, lambda * t) };
727
728        let optimizer = ProximalGradientOptimizer::default();
729        let x0 = array![0.0, 0.0];
730
731        let result = optimizer.minimize(f, grad_f, prox_g, x0).unwrap();
732
733        // Check that solution is reasonable
734        assert!(result[0] > 1.0 && result[0] < 2.0);
735        assert!(result[1] > 0.0 && result[1] < 1.0);
736    }
737
738    #[test]
739    fn test_fista_accelerated() {
740        use proximal::*;
741
742        // Same problem as above but with FISTA acceleration
743        let a = array![[1.0, 0.0], [0.0, 1.0]];
744        let b = array![2.0, 1.0];
745        let lambda = 0.5;
746
747        let f = |x: &Array1<Float>| -> Float {
748            let residual = a.dot(x) - &b;
749            0.5 * residual.dot(&residual)
750        };
751
752        let grad_f = |x: &Array1<Float>| -> Array1<Float> {
753            let residual = a.dot(x) - &b;
754            a.t().dot(&residual)
755        };
756
757        let prox_g = |x: &Array1<Float>, t: Float| -> Array1<Float> { prox_l1(x, lambda * t) };
758
759        let optimizer = ProximalGradientOptimizer::accelerated();
760        let x0 = array![0.0, 0.0];
761
762        let result = optimizer.minimize(f, grad_f, prox_g, x0).unwrap();
763
764        // Check that solution is reasonable
765        assert!(result[0] > 1.0 && result[0] < 2.0);
766        assert!(result[1] > 0.0 && result[1] < 1.0);
767    }
768
769    #[test]
770    fn test_nesterov_accelerated_gradient() {
771        // Test Nesterov AGD on quadratic function: f(x) = 0.5 * x^T * A * x - b^T * x
772        // where A = [[4, 0], [0, 1]], b = [2, 1]
773        // Solution: x* = A^{-1} * b = [0.5, 1.0]
774
775        let f = |x: &Array1<Float>| -> Float {
776            0.5 * (4.0 * x[0] * x[0] + x[1] * x[1]) - 2.0 * x[0] - x[1]
777        };
778
779        let grad_f = |x: &Array1<Float>| -> Array1<Float> { array![4.0 * x[0] - 2.0, x[1] - 1.0] };
780
781        let optimizer = NesterovAcceleratedGradient {
782            learning_rate: 0.1,
783            max_iter: 100,
784            tol: 1e-8,
785            ..Default::default()
786        };
787        let x0 = array![0.0, 0.0];
788
789        let result = optimizer.minimize(f, grad_f, x0).unwrap();
790
791        assert_abs_diff_eq!(result[0], 0.5, epsilon = 1e-4);
792        assert_abs_diff_eq!(result[1], 1.0, epsilon = 1e-4);
793    }
794
795    #[test]
796    fn test_nesterov_adaptive() {
797        // Use a simpler quadratic function: f(x) = x₀² + x₁²
798        // Minimum at (0, 0), which is easier to optimize
799        let f = |x: &Array1<Float>| -> Float { x[0] * x[0] + x[1] * x[1] };
800        let grad_f = |x: &Array1<Float>| -> Array1<Float> { array![2.0 * x[0], 2.0 * x[1]] };
801
802        let mut optimizer = NesterovAcceleratedGradient::adaptive();
803        optimizer.max_iter = 500; // Reasonable number of iterations
804        optimizer.tol = 1e-3; // Practical convergence tolerance
805        optimizer.learning_rate = 0.01; // Balanced learning rate
806        optimizer.momentum = 0.9; // Standard momentum
807        let x0 = array![1.0, 1.0]; // Start point
808
809        let result = optimizer.minimize(f, grad_f, x0).unwrap();
810
811        // Should converge to (0, 0) with reasonable tolerance
812        assert_abs_diff_eq!(result[0], 0.0, epsilon = 0.1);
813        assert_abs_diff_eq!(result[1], 0.0, epsilon = 0.1);
814    }
815}