optirs_core/second_order/
mod.rs

1// Second-order optimization methods
2//
3// This module provides implementations of second-order optimization methods
4// that use curvature information (Hessian matrix) to improve convergence.
5
6pub mod kfac;
7pub mod newton_cg;
8
9use crate::error::{OptimError, Result};
10use scirs2_core::ndarray::{Array, Array1, Array2, Dimension, ScalarOperand};
11use scirs2_core::numeric::Float;
12use std::collections::VecDeque;
13use std::fmt::Debug;
14
15pub use self::kfac::{KFACConfig, KFACLayerState, KFACStats, LayerInfo, LayerType, KFAC};
16pub use self::newton_cg::NewtonCG;
17
18/// Trait for second-order optimization methods
19pub trait SecondOrderOptimizer<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension> {
20    /// Update parameters using second-order information
21    fn step_second_order(
22        &mut self,
23        params: &Array<A, D>,
24        gradients: &Array<A, D>,
25        hessian_info: &HessianInfo<A, D>,
26    ) -> Result<Array<A, D>>;
27
28    /// Reset optimizer state
29    fn reset(&mut self);
30}
31
32/// Hessian information for second-order methods
33#[derive(Debug, Clone)]
34pub enum HessianInfo<A: Float, D: Dimension> {
35    /// Full Hessian matrix (expensive, rarely used in practice)
36    Full(Array2<A>),
37    /// Diagonal approximation of Hessian
38    Diagonal(Array<A, D>),
39    /// L-BFGS style quasi-Newton approximation
40    QuasiNewton {
41        /// Parameter differences history
42        s_history: VecDeque<Array<A, D>>,
43        /// Gradient differences history
44        y_history: VecDeque<Array<A, D>>,
45    },
46    /// Gauss-Newton approximation for least squares problems
47    GaussNewton(Array2<A>),
48}
49
50/// Approximated Hessian computation methods
51pub mod hessian_approximation {
52    use super::*;
53
54    /// Compute diagonal Hessian approximation using finite differences (1D only)
55    pub fn diagonal_finite_difference<A, F>(
56        params: &Array1<A>,
57        gradient_fn: F,
58        epsilon: A,
59    ) -> Result<Array1<A>>
60    where
61        A: Float + ScalarOperand + Debug + Copy,
62        F: Fn(&Array1<A>) -> Result<Array1<A>>,
63    {
64        let mut hessian_diag = Array1::zeros(params.len());
65        let _original_grad = gradient_fn(params)?;
66
67        for i in 0..params.len() {
68            let mut param_plus = params.clone();
69            let mut param_minus = params.clone();
70
71            // Forward difference: f(x + h) - f(x)
72            param_plus[i] = params[i] + epsilon;
73            let grad_plus = gradient_fn(&param_plus)?;
74
75            // Backward difference: f(x) - f(x - h)
76            param_minus[i] = params[i] - epsilon;
77            let grad_minus = gradient_fn(&param_minus)?;
78
79            // Hessian diagonal: derivative of gradient using central difference
80            let second_deriv = (grad_plus[i] - grad_minus[i]) / (A::from(2.0).unwrap() * epsilon);
81            hessian_diag[i] = second_deriv;
82        }
83
84        Ok(hessian_diag)
85    }
86
87    /// Update L-BFGS Hessian approximation
88    pub fn update_lbfgs_approximation<A, D>(
89        s_history: &mut VecDeque<Array<A, D>>,
90        y_history: &mut VecDeque<Array<A, D>>,
91        param_diff: Array<A, D>,
92        grad_diff: Array<A, D>,
93        max_history: usize,
94    ) where
95        A: Float + ScalarOperand + Debug,
96        D: Dimension,
97    {
98        // Add new differences to _history
99        s_history.push_back(param_diff);
100        y_history.push_back(grad_diff);
101
102        // Maintain maximum _history size
103        if s_history.len() > max_history {
104            s_history.pop_front();
105            y_history.pop_front();
106        }
107    }
108
109    /// Apply L-BFGS two-loop recursion to approximate H^(-1) * grad
110    pub fn lbfgs_two_loop_recursion<A, D>(
111        gradient: &Array<A, D>,
112        s_history: &VecDeque<Array<A, D>>,
113        y_history: &VecDeque<Array<A, D>>,
114        initial_hessian_scale: A,
115    ) -> Result<Array<A, D>>
116    where
117        A: Float + ScalarOperand + Debug,
118        D: Dimension,
119    {
120        if s_history.len() != y_history.len() {
121            return Err(OptimError::InvalidConfig(
122                "History sizes don't match in L-BFGS".to_string(),
123            ));
124        }
125
126        let m = s_history.len();
127        if m == 0 {
128            // No history, return scaled gradient
129            return Ok(gradient * initial_hessian_scale);
130        }
131
132        let mut q = gradient.clone();
133        let mut alphas = Vec::with_capacity(m);
134
135        // First loop: compute alphas and update q
136        for i in (0..m).rev() {
137            let s_i = &s_history[i];
138            let y_i = &y_history[i];
139
140            // Compute rho_i = 1 / (y_i^T * s_i)
141            let y_dot_s = y_i
142                .iter()
143                .zip(s_i.iter())
144                .map(|(&y, &s)| y * s)
145                .fold(A::zero(), |acc, x| acc + x);
146
147            if y_dot_s.abs() < A::from(1e-12).unwrap() {
148                alphas.push(A::zero());
149                continue;
150            }
151
152            let rho_i = A::one() / y_dot_s;
153
154            // Compute alpha_i = rho_i * s_i^T * q
155            let s_dot_q = s_i
156                .iter()
157                .zip(q.iter())
158                .map(|(&s, &q_val)| s * q_val)
159                .fold(A::zero(), |acc, x| acc + x);
160            let alpha_i = rho_i * s_dot_q;
161            alphas.push(alpha_i);
162
163            // Update q = q - alpha_i * y_i
164            for (q_val, &y_val) in q.iter_mut().zip(y_i.iter()) {
165                *q_val = *q_val - alpha_i * y_val;
166            }
167        }
168
169        // Scale by initial Hessian approximation
170        q.mapv_inplace(|x| x * initial_hessian_scale);
171
172        // Second loop: compute final result
173        alphas.reverse(); // Reverse to match forward iteration
174        for i in 0..m {
175            let s_i = &s_history[i];
176            let y_i = &y_history[i];
177
178            // Compute rho_i = 1 / (y_i^T * s_i)
179            let y_dot_s = y_i
180                .iter()
181                .zip(s_i.iter())
182                .map(|(&y, &s)| y * s)
183                .fold(A::zero(), |acc, x| acc + x);
184
185            if y_dot_s.abs() < A::from(1e-12).unwrap() {
186                continue;
187            }
188
189            let rho_i = A::one() / y_dot_s;
190
191            // Compute beta = rho_i * y_i^T * q
192            let y_dot_q = y_i
193                .iter()
194                .zip(q.iter())
195                .map(|(&y, &q_val)| y * q_val)
196                .fold(A::zero(), |acc, x| acc + x);
197            let beta = rho_i * y_dot_q;
198
199            // Update q = q + (alpha_i - beta) * s_i
200            let alpha_i = alphas[i];
201            let coeff = alpha_i - beta;
202            for (q_val, &s_val) in q.iter_mut().zip(s_i.iter()) {
203                *q_val = *q_val + coeff * s_val;
204            }
205        }
206
207        Ok(q)
208    }
209
210    /// Gauss-Newton Hessian approximation for least squares problems
211    pub fn gauss_newton_approximation<A>(jacobian: &Array2<A>) -> Result<Array2<A>>
212    where
213        A: Float + ScalarOperand + Debug,
214    {
215        // Gauss-Newton approximation: H ≈ J^T * J
216        let j_transpose = jacobian.t();
217        let hessian_approx = j_transpose.dot(jacobian);
218        Ok(hessian_approx)
219    }
220}
221
222/// Newton's method optimizer
223#[derive(Debug, Clone)]
224pub struct Newton<A: Float> {
225    learning_rate: A,
226    regularization: A, // For numerical stability
227}
228
229impl<A: Float + ScalarOperand + Debug + Send + Sync + Send + Sync> Newton<A> {
230    /// Create a new Newton optimizer
231    pub fn new(learning_rate: A) -> Self {
232        Self {
233            learning_rate,
234            regularization: A::from(1e-6).unwrap(),
235        }
236    }
237
238    /// Set regularization parameter for numerical stability
239    pub fn with_regularization(mut self, regularization: A) -> Self {
240        self.regularization = regularization;
241        self
242    }
243}
244
245impl<A: Float + ScalarOperand + Debug + Send + Sync + Send + Sync>
246    SecondOrderOptimizer<A, scirs2_core::ndarray::Ix1> for Newton<A>
247{
248    fn step_second_order(
249        &mut self,
250        params: &Array1<A>,
251        gradients: &Array1<A>,
252        hessian_info: &HessianInfo<A, scirs2_core::ndarray::Ix1>,
253    ) -> Result<Array1<A>> {
254        match hessian_info {
255            HessianInfo::Diagonal(hessian_diag) => {
256                if params.len() != hessian_diag.len() || params.len() != gradients.len() {
257                    return Err(OptimError::DimensionMismatch(
258                        "Parameter, gradient, and Hessian dimensions must match".to_string(),
259                    ));
260                }
261
262                let mut update = Array1::zeros(params.len());
263                for i in 0..params.len() {
264                    let h_ii = hessian_diag[i] + self.regularization;
265                    if h_ii.abs() > A::from(1e-12).unwrap() {
266                        update[i] = gradients[i] / h_ii;
267                    } else {
268                        // Fall back to gradient descent if Hessian is singular
269                        update[i] = gradients[i];
270                    }
271                }
272
273                Ok(params - &(update * self.learning_rate))
274            }
275            HessianInfo::QuasiNewton {
276                s_history,
277                y_history,
278            } => {
279                // Use L-BFGS approximation
280                let search_direction = hessian_approximation::lbfgs_two_loop_recursion(
281                    gradients,
282                    s_history,
283                    y_history,
284                    A::one(), // Initial Hessian scale
285                )?;
286
287                Ok(params - &(search_direction * self.learning_rate))
288            }
289            _ => Err(OptimError::InvalidConfig(
290                "Unsupported Hessian information type for Newton method".to_string(),
291            )),
292        }
293    }
294
295    fn reset(&mut self) {
296        // Newton method is stateless, nothing to reset
297    }
298}
299
300/// Quasi-Newton L-BFGS optimizer
301#[derive(Debug)]
302pub struct LBFGS<A: Float, D: Dimension> {
303    learning_rate: A,
304    max_history: usize,
305    s_history: VecDeque<Array<A, D>>,
306    y_history: VecDeque<Array<A, D>>,
307    previous_params: Option<Array<A, D>>,
308    previous_grad: Option<Array<A, D>>,
309}
310
311impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> LBFGS<A, D> {
312    /// Create a new L-BFGS optimizer
313    pub fn new(learning_rate: A) -> Self {
314        Self {
315            learning_rate,
316            max_history: 10,
317            s_history: VecDeque::new(),
318            y_history: VecDeque::new(),
319            previous_params: None,
320            previous_grad: None,
321        }
322    }
323
324    /// Set maximum history size
325    pub fn with_max_history(mut self, max_history: usize) -> Self {
326        self.max_history = max_history;
327        self
328    }
329
330    /// Perform L-BFGS step
331    pub fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
332        // Update history if we have previous step information
333        if let (Some(prev_params), Some(prev_grad)) = (&self.previous_params, &self.previous_grad) {
334            let s = params - prev_params; // Parameter difference
335            let y = gradients - prev_grad; // Gradient difference
336
337            hessian_approximation::update_lbfgs_approximation(
338                &mut self.s_history,
339                &mut self.y_history,
340                s,
341                y,
342                self.max_history,
343            );
344        }
345
346        // Compute search direction using two-loop recursion
347        let search_direction = if self.s_history.is_empty() {
348            // No history, use gradient descent
349            gradients.clone()
350        } else {
351            hessian_approximation::lbfgs_two_loop_recursion(
352                gradients,
353                &self.s_history,
354                &self.y_history,
355                A::one(),
356            )?
357        };
358
359        // Update parameters
360        let new_params = params - &(search_direction * self.learning_rate);
361
362        // Store current information for next iteration
363        self.previous_params = Some(params.clone());
364        self.previous_grad = Some(gradients.clone());
365
366        Ok(new_params)
367    }
368}
369
370impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
371    SecondOrderOptimizer<A, D> for LBFGS<A, D>
372{
373    fn step_second_order(
374        &mut self,
375        params: &Array<A, D>,
376        gradients: &Array<A, D>,
377        _hessian_info: &HessianInfo<A, D>, // L-BFGS maintains its own history
378    ) -> Result<Array<A, D>> {
379        self.step(params, gradients)
380    }
381
382    fn reset(&mut self) {
383        self.s_history.clear();
384        self.y_history.clear();
385        self.previous_params = None;
386        self.previous_grad = None;
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use approx::assert_relative_eq;
394    use scirs2_core::ndarray::Array1;
395
396    #[test]
397    fn test_diagonal_hessian_approximation() {
398        // Test on a simple quadratic function: f(x) = x^2
399        let params = Array1::from_vec(vec![1.0]);
400
401        // Gradient function for quadratic: grad = 2*x
402        let gradient_fn =
403            |x: &Array1<f64>| -> Result<Array1<f64>> { Ok(Array1::from_vec(vec![2.0 * x[0]])) };
404
405        let hessian_diag =
406            hessian_approximation::diagonal_finite_difference(&params, gradient_fn, 1e-5).unwrap();
407
408        // For quadratic function f(x) = x^2, second derivative should be 2.0
409        assert_relative_eq!(hessian_diag[0], 2.0, epsilon = 1e-1);
410    }
411
412    #[test]
413    fn test_lbfgs_two_loop_recursion() {
414        let gradient = Array1::from_vec(vec![1.0, 2.0, 3.0]);
415        let mut s_history = VecDeque::new();
416        let mut y_history = VecDeque::new();
417
418        // Add some history
419        s_history.push_back(Array1::from_vec(vec![0.1, 0.1, 0.1]));
420        y_history.push_back(Array1::from_vec(vec![0.2, 0.3, 0.4]));
421
422        let result =
423            hessian_approximation::lbfgs_two_loop_recursion(&gradient, &s_history, &y_history, 1.0)
424                .unwrap();
425
426        // Result should be different from original gradient due to curvature information
427        assert_ne!(result, gradient);
428        assert_eq!(result.len(), gradient.len());
429    }
430
431    #[test]
432    fn test_newton_method() {
433        let mut optimizer = Newton::new(0.1);
434        let params = Array1::from_vec(vec![1.0, 2.0]);
435        let gradients = Array1::from_vec(vec![0.1, 0.2]);
436        let hessian_diag = Array1::from_vec(vec![2.0, 4.0]);
437
438        let hessian_info = HessianInfo::Diagonal(hessian_diag);
439        let new_params = optimizer
440            .step_second_order(&params, &gradients, &hessian_info)
441            .unwrap();
442
443        // Verify parameters were updated
444        assert!(new_params[0] < params[0]);
445        assert!(new_params[1] < params[1]);
446    }
447
448    #[test]
449    fn test_lbfgs_optimizer() {
450        let mut optimizer = LBFGS::new(0.01).with_max_history(5);
451        let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
452        let gradients1 = Array1::from_vec(vec![0.1, 0.2, 0.3]);
453        let gradients2 = Array1::from_vec(vec![0.05, 0.15, 0.25]);
454
455        // First step
456        params = optimizer.step(&params, &gradients1).unwrap();
457
458        // Second step (should use history)
459        let new_params = optimizer.step(&params, &gradients2).unwrap();
460
461        // Verify parameters were updated
462        assert_ne!(new_params, params);
463        assert_eq!(optimizer.s_history.len(), 1);
464        assert_eq!(optimizer.y_history.len(), 1);
465    }
466
467    #[test]
468    fn test_gauss_newton_approximation() {
469        let jacobian = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
470        let hessian_approx = hessian_approximation::gauss_newton_approximation(&jacobian).unwrap();
471
472        // Should be a 2x2 matrix (J^T * J)
473        assert_eq!(hessian_approx.dim(), (2, 2));
474
475        // Verify it's positive semidefinite by checking diagonal elements are non-negative
476        assert!(hessian_approx[(0, 0)] >= 0.0);
477        assert!(hessian_approx[(1, 1)] >= 0.0);
478    }
479}