Skip to main content

echidna_optim/solvers/
trust_region.rs

1use num_traits::Float;
2
3use crate::convergence::{dot, norm, ConvergenceParams};
4use crate::objective::Objective;
5use crate::result::{OptimResult, TerminationReason};
6
7/// Configuration for the trust-region solver.
8#[derive(Debug, Clone)]
9pub struct TrustRegionConfig<F> {
10    /// Initial trust-region radius (default: 1.0).
11    pub initial_radius: F,
12    /// Maximum trust-region radius (default: 100.0).
13    pub max_radius: F,
14    /// Acceptance threshold for the ratio of actual to predicted reduction (default: 0.1).
15    pub eta: F,
16    /// Maximum CG iterations per trust-region subproblem (default: 2 * dim).
17    /// If 0, defaults to 2 * dim.
18    pub max_cg_iter: usize,
19    /// Convergence parameters.
20    pub convergence: ConvergenceParams<F>,
21}
22
23impl Default for TrustRegionConfig<f64> {
24    fn default() -> Self {
25        TrustRegionConfig {
26            initial_radius: 1.0,
27            max_radius: 100.0,
28            eta: 0.1,
29            max_cg_iter: 0,
30            convergence: ConvergenceParams::default(),
31        }
32    }
33}
34
35impl Default for TrustRegionConfig<f32> {
36    fn default() -> Self {
37        TrustRegionConfig {
38            initial_radius: 1.0,
39            max_radius: 100.0,
40            eta: 0.1,
41            max_cg_iter: 0,
42            convergence: ConvergenceParams::default(),
43        }
44    }
45}
46
47/// Trust-region optimization using Steihaug-Toint CG.
48///
49/// Minimizes `obj` starting from `x0`. Uses Hessian-vector products
50/// (via `obj.hvp()`) to solve the trust-region subproblem approximately
51/// with truncated conjugate gradients (Steihaug-Toint).
52pub fn trust_region<F: Float, O: Objective<F>>(
53    obj: &mut O,
54    x0: &[F],
55    config: &TrustRegionConfig<F>,
56) -> OptimResult<F> {
57    let n = x0.len();
58
59    if config.convergence.max_iter == 0
60        || config.initial_radius <= F::zero()
61        || config.max_radius <= F::zero()
62    {
63        return OptimResult {
64            x: x0.to_vec(),
65            value: F::nan(),
66            gradient: vec![F::nan(); n],
67            gradient_norm: F::nan(),
68            iterations: 0,
69            func_evals: 0,
70            termination: TerminationReason::NumericalError,
71        };
72    }
73
74    let max_cg = if config.max_cg_iter == 0 {
75        2 * n
76    } else {
77        config.max_cg_iter
78    };
79
80    let mut x = x0.to_vec();
81    let (mut f_val, mut grad) = obj.eval_grad(&x);
82    let mut func_evals = 1usize;
83    let mut grad_norm = norm(&grad);
84    let mut radius = config.initial_radius;
85
86    if grad_norm < config.convergence.grad_tol {
87        return OptimResult {
88            x,
89            value: f_val,
90            gradient: grad,
91            gradient_norm: grad_norm,
92            iterations: 0,
93            func_evals,
94            termination: TerminationReason::GradientNorm,
95        };
96    }
97
98    let two = F::one() + F::one();
99    let quarter = F::one() / (two * two);
100    let three_quarter = F::one() - quarter;
101
102    for iter in 0..config.convergence.max_iter {
103        // Solve the trust-region subproblem with Steihaug-Toint CG
104        let step = steihaug_cg(obj, &x, &grad, radius, max_cg, &mut func_evals);
105
106        // Predicted reduction: -g^T s - 0.5 * s^T H s
107        // We need H*s for the predicted reduction; get it via hvp
108        let (_, hvp_result) = obj.hvp(&x, &step);
109        func_evals += 1;
110        let gs = dot(&grad, &step);
111        let shs = dot(&step, &hvp_result);
112        let predicted = F::zero() - gs - shs / two;
113
114        // Actual reduction
115        let mut x_new = vec![F::zero(); n];
116        for i in 0..n {
117            x_new[i] = x[i] + step[i];
118        }
119        let (f_new, g_new) = obj.eval_grad(&x_new);
120        func_evals += 1;
121        let actual = f_val - f_new;
122
123        let step_norm = norm(&step);
124
125        // Ratio of actual to predicted reduction
126        let rho = if predicted.abs() < F::epsilon() {
127            if actual >= F::zero() {
128                F::one()
129            } else {
130                F::zero()
131            }
132        } else {
133            actual / predicted
134        };
135
136        // Update trust-region radius
137        if rho < quarter {
138            radius = quarter * step_norm;
139        } else if rho > three_quarter && (step_norm - radius).abs() < F::epsilon() * radius {
140            // Step was on the boundary and rho is good — expand
141            radius = (two * radius).min(config.max_radius);
142        }
143        // Otherwise keep radius unchanged
144
145        // Accept or reject step
146        if rho > config.eta {
147            let f_prev = f_val;
148            x = x_new;
149            f_val = f_new;
150            grad = g_new;
151            grad_norm = norm(&grad);
152
153            // Convergence checks
154            if grad_norm < config.convergence.grad_tol {
155                return OptimResult {
156                    x,
157                    value: f_val,
158                    gradient: grad,
159                    gradient_norm: grad_norm,
160                    iterations: iter + 1,
161                    func_evals,
162                    termination: TerminationReason::GradientNorm,
163                };
164            }
165
166            if step_norm < config.convergence.step_tol {
167                return OptimResult {
168                    x,
169                    value: f_val,
170                    gradient: grad,
171                    gradient_norm: grad_norm,
172                    iterations: iter + 1,
173                    func_evals,
174                    termination: TerminationReason::StepSize,
175                };
176            }
177
178            if config.convergence.func_tol > F::zero()
179                && (f_prev - f_val).abs() < config.convergence.func_tol
180            {
181                return OptimResult {
182                    x,
183                    value: f_val,
184                    gradient: grad,
185                    gradient_norm: grad_norm,
186                    iterations: iter + 1,
187                    func_evals,
188                    termination: TerminationReason::FunctionChange,
189                };
190            }
191        }
192        // If rejected, loop again with smaller radius
193    }
194
195    OptimResult {
196        x,
197        value: f_val,
198        gradient: grad,
199        gradient_norm: grad_norm,
200        iterations: config.convergence.max_iter,
201        func_evals,
202        termination: TerminationReason::MaxIterations,
203    }
204}
205
206/// Steihaug-Toint truncated CG for the trust-region subproblem.
207///
208/// Approximately minimizes `m(s) = g^T s + 0.5 s^T H s` subject to `||s|| <= radius`.
209/// Returns the step `s`.
210fn steihaug_cg<F: Float, O: Objective<F>>(
211    obj: &mut O,
212    x: &[F],
213    grad: &[F],
214    radius: F,
215    max_iter: usize,
216    func_evals: &mut usize,
217) -> Vec<F> {
218    let n = grad.len();
219    let mut s = vec![F::zero(); n];
220    let mut r: Vec<F> = grad.to_vec();
221    let mut d: Vec<F> = r.iter().map(|&ri| F::zero() - ri).collect();
222    let mut r_dot_r = dot(&r, &r);
223
224    if r_dot_r.sqrt() < F::epsilon() {
225        return s;
226    }
227
228    for _ in 0..max_iter {
229        // H * d via hvp
230        let (_, hd) = obj.hvp(x, &d);
231        *func_evals += 1;
232
233        let d_hd = dot(&d, &hd);
234
235        // Negative curvature: go to the boundary
236        if d_hd <= F::zero() {
237            let tau = boundary_tau(&s, &d, radius);
238            for i in 0..n {
239                s[i] = s[i] + tau * d[i];
240            }
241            return s;
242        }
243
244        let alpha = r_dot_r / d_hd;
245
246        // Check if step would leave the trust region
247        let mut s_next = vec![F::zero(); n];
248        for i in 0..n {
249            s_next[i] = s[i] + alpha * d[i];
250        }
251        if norm(&s_next) >= radius {
252            let tau = boundary_tau(&s, &d, radius);
253            for i in 0..n {
254                s[i] = s[i] + tau * d[i];
255            }
256            return s;
257        }
258
259        s = s_next;
260
261        // Update residual
262        for i in 0..n {
263            r[i] = r[i] + alpha * hd[i];
264        }
265        let r_dot_r_new = dot(&r, &r);
266
267        if r_dot_r_new.sqrt() < F::epsilon() {
268            return s;
269        }
270
271        let beta = r_dot_r_new / r_dot_r;
272        r_dot_r = r_dot_r_new;
273
274        for i in 0..n {
275            d[i] = F::zero() - r[i] + beta * d[i];
276        }
277    }
278
279    s
280}
281
282/// Find `tau > 0` such that `||s + tau * d|| = radius`.
283///
284/// Solves `||s + tau * d||^2 = radius^2` for the positive root.
285fn boundary_tau<F: Float>(s: &[F], d: &[F], radius: F) -> F {
286    let dd = dot(d, d);
287    let sd = dot(s, d);
288    let ss = dot(s, s);
289    let two = F::one() + F::one();
290
291    // tau^2 * dd + 2*tau*sd + ss = radius^2
292    // Quadratic: a*tau^2 + b*tau + c = 0
293    let a = dd;
294    let b = two * sd;
295    let c = ss - radius * radius;
296
297    let disc = b * b - (two + two) * a * c;
298    if disc < F::zero() {
299        return F::zero();
300    }
301
302    // We want the positive root
303    let sqrt_disc = disc.sqrt();
304    let tau1 = (F::zero() - b + sqrt_disc) / (two * a);
305    let tau2 = (F::zero() - b - sqrt_disc) / (two * a);
306
307    if tau1 > F::zero() {
308        if tau2 > F::zero() {
309            tau1.min(tau2)
310        } else {
311            tau1
312        }
313    } else {
314        tau2.max(F::zero())
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    struct Rosenbrock;
323
324    impl Objective<f64> for Rosenbrock {
325        fn dim(&self) -> usize {
326            2
327        }
328
329        fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
330            let a = 1.0 - x[0];
331            let b = x[1] - x[0] * x[0];
332            let f = a * a + 100.0 * b * b;
333            let g0 = -2.0 * a - 400.0 * x[0] * b;
334            let g1 = 200.0 * b;
335            (f, vec![g0, g1])
336        }
337
338        fn hvp(&mut self, x: &[f64], v: &[f64]) -> (Vec<f64>, Vec<f64>) {
339            // H = [[2 - 400*(x1 - 3*x0^2), -400*x0],
340            //       [-400*x0,                  200  ]]
341            let h00 = 2.0 - 400.0 * (x[1] - 3.0 * x[0] * x[0]);
342            let h01 = -400.0 * x[0];
343            let h11 = 200.0;
344
345            let hv0 = h00 * v[0] + h01 * v[1];
346            let hv1 = h01 * v[0] + h11 * v[1];
347
348            let g0 = -2.0 * (1.0 - x[0]) - 400.0 * x[0] * (x[1] - x[0] * x[0]);
349            let g1 = 200.0 * (x[1] - x[0] * x[0]);
350
351            (vec![g0, g1], vec![hv0, hv1])
352        }
353    }
354
355    #[test]
356    fn trust_region_rosenbrock() {
357        let mut obj = Rosenbrock;
358        let config = TrustRegionConfig {
359            convergence: ConvergenceParams {
360                max_iter: 200,
361                ..Default::default()
362            },
363            ..Default::default()
364        };
365        let result = trust_region(&mut obj, &[0.0, 0.0], &config);
366
367        assert_eq!(
368            result.termination,
369            TerminationReason::GradientNorm,
370            "terminated with {:?} after {} iterations",
371            result.termination,
372            result.iterations
373        );
374        assert!(
375            (result.x[0] - 1.0).abs() < 1e-6,
376            "x[0] = {}, expected 1.0",
377            result.x[0]
378        );
379        assert!(
380            (result.x[1] - 1.0).abs() < 1e-6,
381            "x[1] = {}, expected 1.0",
382            result.x[1]
383        );
384    }
385
386    struct Rosenbrock4D;
387
388    impl Objective<f64> for Rosenbrock4D {
389        fn dim(&self) -> usize {
390            4
391        }
392
393        fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
394            let mut f = 0.0;
395            let mut g = vec![0.0; 4];
396            for i in 0..3 {
397                let a = 1.0 - x[i];
398                let b = x[i + 1] - x[i] * x[i];
399                f += a * a + 100.0 * b * b;
400                g[i] += -2.0 * a - 400.0 * x[i] * b;
401                g[i + 1] += 200.0 * b;
402            }
403            (f, g)
404        }
405
406        fn hvp(&mut self, x: &[f64], v: &[f64]) -> (Vec<f64>, Vec<f64>) {
407            let n = 4;
408            let mut hv = vec![0.0; n];
409            let mut g = vec![0.0; n];
410
411            for i in 0..3 {
412                let a = 1.0 - x[i];
413                let b = x[i + 1] - x[i] * x[i];
414
415                g[i] += -2.0 * a - 400.0 * x[i] * b;
416                g[i + 1] += 200.0 * b;
417
418                let h_ii = 2.0 - 400.0 * (x[i + 1] - 3.0 * x[i] * x[i]);
419                let h_ij = -400.0 * x[i];
420                let h_jj = 200.0;
421
422                hv[i] += h_ii * v[i] + h_ij * v[i + 1];
423                hv[i + 1] += h_ij * v[i] + h_jj * v[i + 1];
424            }
425
426            (g, hv)
427        }
428    }
429
430    #[test]
431    fn trust_region_rosenbrock_4d() {
432        let mut obj = Rosenbrock4D;
433        let config = TrustRegionConfig {
434            convergence: ConvergenceParams {
435                max_iter: 500,
436                ..Default::default()
437            },
438            ..Default::default()
439        };
440        let result = trust_region(&mut obj, &[0.0, 0.0, 0.0, 0.0], &config);
441
442        assert_eq!(
443            result.termination,
444            TerminationReason::GradientNorm,
445            "terminated with {:?} after {} iterations, grad_norm={}",
446            result.termination,
447            result.iterations,
448            result.gradient_norm
449        );
450        for i in 0..4 {
451            assert!(
452                (result.x[i] - 1.0).abs() < 1e-5,
453                "x[{}] = {}, expected 1.0",
454                i,
455                result.x[i]
456            );
457        }
458    }
459}