Skip to main content

optirs_core/second_order/
mod.rs

1// Second-order optimization methods
2//
3// This module provides implementations of second-order optimization methods
4// that use curvature information (Hessian matrix) to improve convergence.
5
6pub mod kfac;
7pub mod newton_cg;
8
9use crate::error::{OptimError, Result};
10use scirs2_core::ndarray::{Array, Array1, Array2, Dimension, ScalarOperand};
11use scirs2_core::numeric::Float;
12use std::collections::VecDeque;
13use std::fmt::Debug;
14
15pub use self::kfac::{KFACConfig, KFACLayerState, KFACStats, LayerInfo, LayerType, KFAC};
16pub use self::newton_cg::NewtonCG;
17
18/// Trait for second-order optimization methods
19pub trait SecondOrderOptimizer<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension> {
20    /// Update parameters using second-order information
21    fn step_second_order(
22        &mut self,
23        params: &Array<A, D>,
24        gradients: &Array<A, D>,
25        hessian_info: &HessianInfo<A, D>,
26    ) -> Result<Array<A, D>>;
27
28    /// Reset optimizer state
29    fn reset(&mut self);
30}
31
32/// Hessian information for second-order methods
33#[derive(Debug, Clone)]
34pub enum HessianInfo<A: Float, D: Dimension> {
35    /// Full Hessian matrix (expensive, rarely used in practice)
36    Full(Array2<A>),
37    /// Diagonal approximation of Hessian
38    Diagonal(Array<A, D>),
39    /// L-BFGS style quasi-Newton approximation
40    QuasiNewton {
41        /// Parameter differences history
42        s_history: VecDeque<Array<A, D>>,
43        /// Gradient differences history
44        y_history: VecDeque<Array<A, D>>,
45    },
46    /// Gauss-Newton approximation for least squares problems
47    GaussNewton(Array2<A>),
48}
49
50/// Approximated Hessian computation methods
51pub mod hessian_approximation {
52    use super::*;
53
54    /// Compute diagonal Hessian approximation using finite differences (1D only)
55    pub fn diagonal_finite_difference<A, F>(
56        params: &Array1<A>,
57        gradient_fn: F,
58        epsilon: A,
59    ) -> Result<Array1<A>>
60    where
61        A: Float + ScalarOperand + Debug + Copy,
62        F: Fn(&Array1<A>) -> Result<Array1<A>>,
63    {
64        let mut hessian_diag = Array1::zeros(params.len());
65        let _original_grad = gradient_fn(params)?;
66
67        for i in 0..params.len() {
68            let mut param_plus = params.clone();
69            let mut param_minus = params.clone();
70
71            // Forward difference: f(x + h) - f(x)
72            param_plus[i] = params[i] + epsilon;
73            let grad_plus = gradient_fn(&param_plus)?;
74
75            // Backward difference: f(x) - f(x - h)
76            param_minus[i] = params[i] - epsilon;
77            let grad_minus = gradient_fn(&param_minus)?;
78
79            // Hessian diagonal: derivative of gradient using central difference
80            let second_deriv =
81                (grad_plus[i] - grad_minus[i]) / (A::from(2.0).expect("unwrap failed") * epsilon);
82            hessian_diag[i] = second_deriv;
83        }
84
85        Ok(hessian_diag)
86    }
87
88    /// Update L-BFGS Hessian approximation
89    pub fn update_lbfgs_approximation<A, D>(
90        s_history: &mut VecDeque<Array<A, D>>,
91        y_history: &mut VecDeque<Array<A, D>>,
92        param_diff: Array<A, D>,
93        grad_diff: Array<A, D>,
94        max_history: usize,
95    ) where
96        A: Float + ScalarOperand + Debug,
97        D: Dimension,
98    {
99        // Add new differences to _history
100        s_history.push_back(param_diff);
101        y_history.push_back(grad_diff);
102
103        // Maintain maximum _history size
104        if s_history.len() > max_history {
105            s_history.pop_front();
106            y_history.pop_front();
107        }
108    }
109
110    /// Apply L-BFGS two-loop recursion to approximate H^(-1) * grad
111    pub fn lbfgs_two_loop_recursion<A, D>(
112        gradient: &Array<A, D>,
113        s_history: &VecDeque<Array<A, D>>,
114        y_history: &VecDeque<Array<A, D>>,
115        initial_hessian_scale: A,
116    ) -> Result<Array<A, D>>
117    where
118        A: Float + ScalarOperand + Debug,
119        D: Dimension,
120    {
121        if s_history.len() != y_history.len() {
122            return Err(OptimError::InvalidConfig(
123                "History sizes don't match in L-BFGS".to_string(),
124            ));
125        }
126
127        let m = s_history.len();
128        if m == 0 {
129            // No history, return scaled gradient
130            return Ok(gradient * initial_hessian_scale);
131        }
132
133        let mut q = gradient.clone();
134        let mut alphas = Vec::with_capacity(m);
135
136        // First loop: compute alphas and update q
137        for i in (0..m).rev() {
138            let s_i = &s_history[i];
139            let y_i = &y_history[i];
140
141            // Compute rho_i = 1 / (y_i^T * s_i)
142            let y_dot_s = y_i
143                .iter()
144                .zip(s_i.iter())
145                .map(|(&y, &s)| y * s)
146                .fold(A::zero(), |acc, x| acc + x);
147
148            if y_dot_s.abs() < A::from(1e-12).expect("unwrap failed") {
149                alphas.push(A::zero());
150                continue;
151            }
152
153            let rho_i = A::one() / y_dot_s;
154
155            // Compute alpha_i = rho_i * s_i^T * q
156            let s_dot_q = s_i
157                .iter()
158                .zip(q.iter())
159                .map(|(&s, &q_val)| s * q_val)
160                .fold(A::zero(), |acc, x| acc + x);
161            let alpha_i = rho_i * s_dot_q;
162            alphas.push(alpha_i);
163
164            // Update q = q - alpha_i * y_i
165            for (q_val, &y_val) in q.iter_mut().zip(y_i.iter()) {
166                *q_val = *q_val - alpha_i * y_val;
167            }
168        }
169
170        // Scale by initial Hessian approximation
171        q.mapv_inplace(|x| x * initial_hessian_scale);
172
173        // Second loop: compute final result
174        alphas.reverse(); // Reverse to match forward iteration
175        for i in 0..m {
176            let s_i = &s_history[i];
177            let y_i = &y_history[i];
178
179            // Compute rho_i = 1 / (y_i^T * s_i)
180            let y_dot_s = y_i
181                .iter()
182                .zip(s_i.iter())
183                .map(|(&y, &s)| y * s)
184                .fold(A::zero(), |acc, x| acc + x);
185
186            if y_dot_s.abs() < A::from(1e-12).expect("unwrap failed") {
187                continue;
188            }
189
190            let rho_i = A::one() / y_dot_s;
191
192            // Compute beta = rho_i * y_i^T * q
193            let y_dot_q = y_i
194                .iter()
195                .zip(q.iter())
196                .map(|(&y, &q_val)| y * q_val)
197                .fold(A::zero(), |acc, x| acc + x);
198            let beta = rho_i * y_dot_q;
199
200            // Update q = q + (alpha_i - beta) * s_i
201            let alpha_i = alphas[i];
202            let coeff = alpha_i - beta;
203            for (q_val, &s_val) in q.iter_mut().zip(s_i.iter()) {
204                *q_val = *q_val + coeff * s_val;
205            }
206        }
207
208        Ok(q)
209    }
210
211    /// Gauss-Newton Hessian approximation for least squares problems
212    pub fn gauss_newton_approximation<A>(jacobian: &Array2<A>) -> Result<Array2<A>>
213    where
214        A: Float + ScalarOperand + Debug,
215    {
216        // Gauss-Newton approximation: H ≈ J^T * J
217        let j_transpose = jacobian.t();
218        let hessian_approx = j_transpose.dot(jacobian);
219        Ok(hessian_approx)
220    }
221}
222
223/// Newton's method optimizer
224#[derive(Debug, Clone)]
225pub struct Newton<A: Float> {
226    learning_rate: A,
227    regularization: A, // For numerical stability
228}
229
230impl<A: Float + ScalarOperand + Debug + Send + Sync + Send + Sync> Newton<A> {
231    /// Create a new Newton optimizer
232    pub fn new(learning_rate: A) -> Self {
233        Self {
234            learning_rate,
235            regularization: A::from(1e-6).expect("unwrap failed"),
236        }
237    }
238
239    /// Set regularization parameter for numerical stability
240    pub fn with_regularization(mut self, regularization: A) -> Self {
241        self.regularization = regularization;
242        self
243    }
244}
245
246impl<A: Float + ScalarOperand + Debug + Send + Sync + Send + Sync>
247    SecondOrderOptimizer<A, scirs2_core::ndarray::Ix1> for Newton<A>
248{
249    fn step_second_order(
250        &mut self,
251        params: &Array1<A>,
252        gradients: &Array1<A>,
253        hessian_info: &HessianInfo<A, scirs2_core::ndarray::Ix1>,
254    ) -> Result<Array1<A>> {
255        match hessian_info {
256            HessianInfo::Diagonal(hessian_diag) => {
257                if params.len() != hessian_diag.len() || params.len() != gradients.len() {
258                    return Err(OptimError::DimensionMismatch(
259                        "Parameter, gradient, and Hessian dimensions must match".to_string(),
260                    ));
261                }
262
263                let mut update = Array1::zeros(params.len());
264                for i in 0..params.len() {
265                    let h_ii = hessian_diag[i] + self.regularization;
266                    if h_ii.abs() > A::from(1e-12).expect("unwrap failed") {
267                        update[i] = gradients[i] / h_ii;
268                    } else {
269                        // Fall back to gradient descent if Hessian is singular
270                        update[i] = gradients[i];
271                    }
272                }
273
274                Ok(params - &(update * self.learning_rate))
275            }
276            HessianInfo::QuasiNewton {
277                s_history,
278                y_history,
279            } => {
280                // Use L-BFGS approximation
281                let search_direction = hessian_approximation::lbfgs_two_loop_recursion(
282                    gradients,
283                    s_history,
284                    y_history,
285                    A::one(), // Initial Hessian scale
286                )?;
287
288                Ok(params - &(search_direction * self.learning_rate))
289            }
290            _ => Err(OptimError::InvalidConfig(
291                "Unsupported Hessian information type for Newton method".to_string(),
292            )),
293        }
294    }
295
296    fn reset(&mut self) {
297        // Newton method is stateless, nothing to reset
298    }
299}
300
301/// Quasi-Newton L-BFGS optimizer
302#[derive(Debug)]
303pub struct LBFGS<A: Float, D: Dimension> {
304    learning_rate: A,
305    max_history: usize,
306    s_history: VecDeque<Array<A, D>>,
307    y_history: VecDeque<Array<A, D>>,
308    previous_params: Option<Array<A, D>>,
309    previous_grad: Option<Array<A, D>>,
310}
311
312impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> LBFGS<A, D> {
313    /// Create a new L-BFGS optimizer
314    pub fn new(learning_rate: A) -> Self {
315        Self {
316            learning_rate,
317            max_history: 10,
318            s_history: VecDeque::new(),
319            y_history: VecDeque::new(),
320            previous_params: None,
321            previous_grad: None,
322        }
323    }
324
325    /// Set maximum history size
326    pub fn with_max_history(mut self, max_history: usize) -> Self {
327        self.max_history = max_history;
328        self
329    }
330
331    /// Perform L-BFGS step
332    pub fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
333        // Update history if we have previous step information
334        if let (Some(prev_params), Some(prev_grad)) = (&self.previous_params, &self.previous_grad) {
335            let s = params - prev_params; // Parameter difference
336            let y = gradients - prev_grad; // Gradient difference
337
338            hessian_approximation::update_lbfgs_approximation(
339                &mut self.s_history,
340                &mut self.y_history,
341                s,
342                y,
343                self.max_history,
344            );
345        }
346
347        // Compute search direction using two-loop recursion
348        let search_direction = if self.s_history.is_empty() {
349            // No history, use gradient descent
350            gradients.clone()
351        } else {
352            hessian_approximation::lbfgs_two_loop_recursion(
353                gradients,
354                &self.s_history,
355                &self.y_history,
356                A::one(),
357            )?
358        };
359
360        // Update parameters
361        let new_params = params - &(search_direction * self.learning_rate);
362
363        // Store current information for next iteration
364        self.previous_params = Some(params.clone());
365        self.previous_grad = Some(gradients.clone());
366
367        Ok(new_params)
368    }
369}
370
371impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
372    SecondOrderOptimizer<A, D> for LBFGS<A, D>
373{
374    fn step_second_order(
375        &mut self,
376        params: &Array<A, D>,
377        gradients: &Array<A, D>,
378        _hessian_info: &HessianInfo<A, D>, // L-BFGS maintains its own history
379    ) -> Result<Array<A, D>> {
380        self.step(params, gradients)
381    }
382
383    fn reset(&mut self) {
384        self.s_history.clear();
385        self.y_history.clear();
386        self.previous_params = None;
387        self.previous_grad = None;
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use approx::assert_relative_eq;
395    use scirs2_core::ndarray::Array1;
396
397    #[test]
398    fn test_diagonal_hessian_approximation() {
399        // Test on a simple quadratic function: f(x) = x^2
400        let params = Array1::from_vec(vec![1.0]);
401
402        // Gradient function for quadratic: grad = 2*x
403        let gradient_fn =
404            |x: &Array1<f64>| -> Result<Array1<f64>> { Ok(Array1::from_vec(vec![2.0 * x[0]])) };
405
406        let hessian_diag =
407            hessian_approximation::diagonal_finite_difference(&params, gradient_fn, 1e-5)
408                .expect("unwrap failed");
409
410        // For quadratic function f(x) = x^2, second derivative should be 2.0
411        assert_relative_eq!(hessian_diag[0], 2.0, epsilon = 1e-1);
412    }
413
414    #[test]
415    fn test_lbfgs_two_loop_recursion() {
416        let gradient = Array1::from_vec(vec![1.0, 2.0, 3.0]);
417        let mut s_history = VecDeque::new();
418        let mut y_history = VecDeque::new();
419
420        // Add some history
421        s_history.push_back(Array1::from_vec(vec![0.1, 0.1, 0.1]));
422        y_history.push_back(Array1::from_vec(vec![0.2, 0.3, 0.4]));
423
424        let result =
425            hessian_approximation::lbfgs_two_loop_recursion(&gradient, &s_history, &y_history, 1.0)
426                .expect("unwrap failed");
427
428        // Result should be different from original gradient due to curvature information
429        assert_ne!(result, gradient);
430        assert_eq!(result.len(), gradient.len());
431    }
432
433    #[test]
434    fn test_newton_method() {
435        let mut optimizer = Newton::new(0.1);
436        let params = Array1::from_vec(vec![1.0, 2.0]);
437        let gradients = Array1::from_vec(vec![0.1, 0.2]);
438        let hessian_diag = Array1::from_vec(vec![2.0, 4.0]);
439
440        let hessian_info = HessianInfo::Diagonal(hessian_diag);
441        let new_params = optimizer
442            .step_second_order(&params, &gradients, &hessian_info)
443            .expect("unwrap failed");
444
445        // Verify parameters were updated
446        assert!(new_params[0] < params[0]);
447        assert!(new_params[1] < params[1]);
448    }
449
450    #[test]
451    fn test_lbfgs_optimizer() {
452        let mut optimizer = LBFGS::new(0.01).with_max_history(5);
453        let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
454        let gradients1 = Array1::from_vec(vec![0.1, 0.2, 0.3]);
455        let gradients2 = Array1::from_vec(vec![0.05, 0.15, 0.25]);
456
457        // First step
458        params = optimizer.step(&params, &gradients1).expect("unwrap failed");
459
460        // Second step (should use history)
461        let new_params = optimizer.step(&params, &gradients2).expect("unwrap failed");
462
463        // Verify parameters were updated
464        assert_ne!(new_params, params);
465        assert_eq!(optimizer.s_history.len(), 1);
466        assert_eq!(optimizer.y_history.len(), 1);
467    }
468
469    #[test]
470    fn test_gauss_newton_approximation() {
471        let jacobian = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
472            .expect("unwrap failed");
473        let hessian_approx =
474            hessian_approximation::gauss_newton_approximation(&jacobian).expect("unwrap failed");
475
476        // Should be a 2x2 matrix (J^T * J)
477        assert_eq!(hessian_approx.dim(), (2, 2));
478
479        // Verify it's positive semidefinite by checking diagonal elements are non-negative
480        assert!(hessian_approx[(0, 0)] >= 0.0);
481        assert!(hessian_approx[(1, 1)] >= 0.0);
482    }
483}