Skip to main content

echidna_optim/solvers/
newton.rs

1use num_traits::Float;
2
3use crate::convergence::{norm, ConvergenceParams};
4use crate::linalg::lu_solve;
5use crate::line_search::{backtracking_armijo, ArmijoParams};
6use crate::objective::Objective;
7use crate::result::{OptimResult, TerminationReason};
8
9/// Configuration for the Newton solver.
10#[derive(Debug, Clone)]
11pub struct NewtonConfig<F> {
12    /// Convergence parameters.
13    pub convergence: ConvergenceParams<F>,
14    /// Line search parameters.
15    pub line_search: ArmijoParams<F>,
16}
17
18impl Default for NewtonConfig<f64> {
19    fn default() -> Self {
20        NewtonConfig {
21            convergence: ConvergenceParams::default(),
22            line_search: ArmijoParams::default(),
23        }
24    }
25}
26
27impl Default for NewtonConfig<f32> {
28    fn default() -> Self {
29        NewtonConfig {
30            convergence: ConvergenceParams::default(),
31            line_search: ArmijoParams::default(),
32        }
33    }
34}
35
36/// Newton's method with LU-based Hessian solve and Armijo line search.
37///
38/// Minimizes `obj` starting from `x0`. At each iteration, solves `H * delta = -g`
39/// via LU factorization, then performs a backtracking line search along `delta`.
40///
41/// Requires `obj` to implement `eval_hessian`.
42pub fn newton<F: Float, O: Objective<F>>(
43    obj: &mut O,
44    x0: &[F],
45    config: &NewtonConfig<F>,
46) -> OptimResult<F> {
47    let n = x0.len();
48
49    if config.convergence.max_iter == 0 {
50        return OptimResult {
51            x: x0.to_vec(),
52            value: F::nan(),
53            gradient: vec![F::nan(); n],
54            gradient_norm: F::nan(),
55            iterations: 0,
56            func_evals: 0,
57            termination: TerminationReason::NumericalError,
58        };
59    }
60
61    let mut x = x0.to_vec();
62    let (mut f_val, mut grad, mut hess) = obj.eval_hessian(&x);
63    let mut func_evals = 1usize;
64    let mut grad_norm = norm(&grad);
65
66    if grad_norm < config.convergence.grad_tol {
67        return OptimResult {
68            x,
69            value: f_val,
70            gradient: grad,
71            gradient_norm: grad_norm,
72            iterations: 0,
73            func_evals,
74            termination: TerminationReason::GradientNorm,
75        };
76    }
77
78    for iter in 0..config.convergence.max_iter {
79        // Solve H * delta = -g
80        let neg_grad: Vec<F> = grad.iter().map(|&g| F::zero() - g).collect();
81        let delta = match lu_solve(&hess, &neg_grad) {
82            Some(d) => d,
83            None => {
84                return OptimResult {
85                    x,
86                    value: f_val,
87                    gradient: grad,
88                    gradient_norm: grad_norm,
89                    iterations: iter,
90                    func_evals,
91                    termination: TerminationReason::NumericalError,
92                };
93            }
94        };
95
96        // Line search along Newton direction
97        let ls = match backtracking_armijo(obj, &x, &delta, f_val, &grad, &config.line_search) {
98            Some(ls) => ls,
99            None => {
100                return OptimResult {
101                    x,
102                    value: f_val,
103                    gradient: grad,
104                    gradient_norm: grad_norm,
105                    iterations: iter,
106                    func_evals,
107                    termination: TerminationReason::LineSearchFailed,
108                };
109            }
110        };
111        func_evals += ls.evals;
112
113        // Update x
114        let mut step_norm_sq = F::zero();
115        for i in 0..n {
116            let step = ls.alpha * delta[i];
117            step_norm_sq = step_norm_sq + step * step;
118            x[i] = x[i] + step;
119        }
120
121        let f_prev = f_val;
122
123        // Re-evaluate with Hessian at new point
124        let result = obj.eval_hessian(&x);
125        func_evals += 1;
126        f_val = result.0;
127        grad = result.1;
128        hess = result.2;
129        grad_norm = norm(&grad);
130
131        // Convergence checks
132        if grad_norm < config.convergence.grad_tol {
133            return OptimResult {
134                x,
135                value: f_val,
136                gradient: grad,
137                gradient_norm: grad_norm,
138                iterations: iter + 1,
139                func_evals,
140                termination: TerminationReason::GradientNorm,
141            };
142        }
143
144        if step_norm_sq.sqrt() < config.convergence.step_tol {
145            return OptimResult {
146                x,
147                value: f_val,
148                gradient: grad,
149                gradient_norm: grad_norm,
150                iterations: iter + 1,
151                func_evals,
152                termination: TerminationReason::StepSize,
153            };
154        }
155
156        if config.convergence.func_tol > F::zero()
157            && (f_prev - f_val).abs() < config.convergence.func_tol
158        {
159            return OptimResult {
160                x,
161                value: f_val,
162                gradient: grad,
163                gradient_norm: grad_norm,
164                iterations: iter + 1,
165                func_evals,
166                termination: TerminationReason::FunctionChange,
167            };
168        }
169    }
170
171    OptimResult {
172        x,
173        value: f_val,
174        gradient: grad,
175        gradient_norm: grad_norm,
176        iterations: config.convergence.max_iter,
177        func_evals,
178        termination: TerminationReason::MaxIterations,
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    struct Rosenbrock;
187
188    impl Objective<f64> for Rosenbrock {
189        fn dim(&self) -> usize {
190            2
191        }
192
193        fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
194            let a = 1.0 - x[0];
195            let b = x[1] - x[0] * x[0];
196            let f = a * a + 100.0 * b * b;
197            let g0 = -2.0 * a - 400.0 * x[0] * b;
198            let g1 = 200.0 * b;
199            (f, vec![g0, g1])
200        }
201
202        fn eval_hessian(&mut self, x: &[f64]) -> (f64, Vec<f64>, Vec<Vec<f64>>) {
203            let a = 1.0 - x[0];
204            let b = x[1] - x[0] * x[0];
205            let f = a * a + 100.0 * b * b;
206            let g0 = -2.0 * a - 400.0 * x[0] * b;
207            let g1 = 200.0 * b;
208
209            let h00 = 2.0 - 400.0 * (x[1] - 3.0 * x[0] * x[0]);
210            let h01 = -400.0 * x[0];
211            let h11 = 200.0;
212
213            (f, vec![g0, g1], vec![vec![h00, h01], vec![h01, h11]])
214        }
215    }
216
217    #[test]
218    fn newton_rosenbrock() {
219        let mut obj = Rosenbrock;
220        let config = NewtonConfig::default();
221        let result = newton(&mut obj, &[0.0, 0.0], &config);
222
223        assert_eq!(result.termination, TerminationReason::GradientNorm);
224        assert!(
225            (result.x[0] - 1.0).abs() < 1e-6,
226            "x[0] = {}, expected 1.0",
227            result.x[0]
228        );
229        assert!(
230            (result.x[1] - 1.0).abs() < 1e-6,
231            "x[1] = {}, expected 1.0",
232            result.x[1]
233        );
234        assert!(result.gradient_norm < 1e-8);
235    }
236
237    #[test]
238    fn newton_singular_hessian() {
239        struct SingularAtOrigin;
240
241        impl Objective<f64> for SingularAtOrigin {
242            fn dim(&self) -> usize {
243                2
244            }
245
246            fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
247                let f = x[0] * x[0] + x[1] * x[1];
248                (f, vec![2.0 * x[0], 2.0 * x[1]])
249            }
250
251            fn eval_hessian(&mut self, _x: &[f64]) -> (f64, Vec<f64>, Vec<Vec<f64>>) {
252                // Return a singular Hessian
253                (1.0, vec![1.0, 1.0], vec![vec![1.0, 1.0], vec![1.0, 1.0]])
254            }
255        }
256
257        let mut obj = SingularAtOrigin;
258        let config = NewtonConfig::default();
259        let result = newton(&mut obj, &[2.0, 3.0], &config);
260
261        assert_eq!(result.termination, TerminationReason::NumericalError);
262    }
263}