1use num_traits::Float;
2
3use crate::convergence::{dot, norm, ConvergenceParams};
4use crate::objective::Objective;
5use crate::result::{OptimResult, TerminationReason};
6
7#[derive(Debug, Clone)]
9pub struct TrustRegionConfig<F> {
10 pub initial_radius: F,
12 pub max_radius: F,
14 pub eta: F,
16 pub max_cg_iter: usize,
19 pub convergence: ConvergenceParams<F>,
21}
22
23impl Default for TrustRegionConfig<f64> {
24 fn default() -> Self {
25 TrustRegionConfig {
26 initial_radius: 1.0,
27 max_radius: 100.0,
28 eta: 0.1,
29 max_cg_iter: 0,
30 convergence: ConvergenceParams::default(),
31 }
32 }
33}
34
35impl Default for TrustRegionConfig<f32> {
36 fn default() -> Self {
37 TrustRegionConfig {
38 initial_radius: 1.0,
39 max_radius: 100.0,
40 eta: 0.1,
41 max_cg_iter: 0,
42 convergence: ConvergenceParams::default(),
43 }
44 }
45}
46
47pub fn trust_region<F: Float, O: Objective<F>>(
53 obj: &mut O,
54 x0: &[F],
55 config: &TrustRegionConfig<F>,
56) -> OptimResult<F> {
57 let n = x0.len();
58
59 if config.convergence.max_iter == 0
60 || config.initial_radius <= F::zero()
61 || config.max_radius <= F::zero()
62 {
63 return OptimResult {
64 x: x0.to_vec(),
65 value: F::nan(),
66 gradient: vec![F::nan(); n],
67 gradient_norm: F::nan(),
68 iterations: 0,
69 func_evals: 0,
70 termination: TerminationReason::NumericalError,
71 };
72 }
73
74 let max_cg = if config.max_cg_iter == 0 {
75 2 * n
76 } else {
77 config.max_cg_iter
78 };
79
80 let mut x = x0.to_vec();
81 let (mut f_val, mut grad) = obj.eval_grad(&x);
82 let mut func_evals = 1usize;
83 let mut grad_norm = norm(&grad);
84 let mut radius = config.initial_radius;
85
86 if grad_norm < config.convergence.grad_tol {
87 return OptimResult {
88 x,
89 value: f_val,
90 gradient: grad,
91 gradient_norm: grad_norm,
92 iterations: 0,
93 func_evals,
94 termination: TerminationReason::GradientNorm,
95 };
96 }
97
98 let two = F::one() + F::one();
99 let quarter = F::one() / (two * two);
100 let three_quarter = F::one() - quarter;
101
102 for iter in 0..config.convergence.max_iter {
103 let step = steihaug_cg(obj, &x, &grad, radius, max_cg, &mut func_evals);
105
106 let (_, hvp_result) = obj.hvp(&x, &step);
109 func_evals += 1;
110 let gs = dot(&grad, &step);
111 let shs = dot(&step, &hvp_result);
112 let predicted = F::zero() - gs - shs / two;
113
114 let mut x_new = vec![F::zero(); n];
116 for i in 0..n {
117 x_new[i] = x[i] + step[i];
118 }
119 let (f_new, g_new) = obj.eval_grad(&x_new);
120 func_evals += 1;
121 let actual = f_val - f_new;
122
123 let step_norm = norm(&step);
124
125 let rho = if predicted.abs() < F::epsilon() {
127 if actual >= F::zero() {
128 F::one()
129 } else {
130 F::zero()
131 }
132 } else {
133 actual / predicted
134 };
135
136 if rho < quarter {
138 radius = quarter * step_norm;
139 } else if rho > three_quarter && (step_norm - radius).abs() < F::epsilon() * radius {
140 radius = (two * radius).min(config.max_radius);
142 }
143 if rho > config.eta {
147 let f_prev = f_val;
148 x = x_new;
149 f_val = f_new;
150 grad = g_new;
151 grad_norm = norm(&grad);
152
153 if grad_norm < config.convergence.grad_tol {
155 return OptimResult {
156 x,
157 value: f_val,
158 gradient: grad,
159 gradient_norm: grad_norm,
160 iterations: iter + 1,
161 func_evals,
162 termination: TerminationReason::GradientNorm,
163 };
164 }
165
166 if step_norm < config.convergence.step_tol {
167 return OptimResult {
168 x,
169 value: f_val,
170 gradient: grad,
171 gradient_norm: grad_norm,
172 iterations: iter + 1,
173 func_evals,
174 termination: TerminationReason::StepSize,
175 };
176 }
177
178 if config.convergence.func_tol > F::zero()
179 && (f_prev - f_val).abs() < config.convergence.func_tol
180 {
181 return OptimResult {
182 x,
183 value: f_val,
184 gradient: grad,
185 gradient_norm: grad_norm,
186 iterations: iter + 1,
187 func_evals,
188 termination: TerminationReason::FunctionChange,
189 };
190 }
191 }
192 }
194
195 OptimResult {
196 x,
197 value: f_val,
198 gradient: grad,
199 gradient_norm: grad_norm,
200 iterations: config.convergence.max_iter,
201 func_evals,
202 termination: TerminationReason::MaxIterations,
203 }
204}
205
206fn steihaug_cg<F: Float, O: Objective<F>>(
211 obj: &mut O,
212 x: &[F],
213 grad: &[F],
214 radius: F,
215 max_iter: usize,
216 func_evals: &mut usize,
217) -> Vec<F> {
218 let n = grad.len();
219 let mut s = vec![F::zero(); n];
220 let mut r: Vec<F> = grad.to_vec();
221 let mut d: Vec<F> = r.iter().map(|&ri| F::zero() - ri).collect();
222 let mut r_dot_r = dot(&r, &r);
223
224 if r_dot_r.sqrt() < F::epsilon() {
225 return s;
226 }
227
228 for _ in 0..max_iter {
229 let (_, hd) = obj.hvp(x, &d);
231 *func_evals += 1;
232
233 let d_hd = dot(&d, &hd);
234
235 if d_hd <= F::zero() {
237 let tau = boundary_tau(&s, &d, radius);
238 for i in 0..n {
239 s[i] = s[i] + tau * d[i];
240 }
241 return s;
242 }
243
244 let alpha = r_dot_r / d_hd;
245
246 let mut s_next = vec![F::zero(); n];
248 for i in 0..n {
249 s_next[i] = s[i] + alpha * d[i];
250 }
251 if norm(&s_next) >= radius {
252 let tau = boundary_tau(&s, &d, radius);
253 for i in 0..n {
254 s[i] = s[i] + tau * d[i];
255 }
256 return s;
257 }
258
259 s = s_next;
260
261 for i in 0..n {
263 r[i] = r[i] + alpha * hd[i];
264 }
265 let r_dot_r_new = dot(&r, &r);
266
267 if r_dot_r_new.sqrt() < F::epsilon() {
268 return s;
269 }
270
271 let beta = r_dot_r_new / r_dot_r;
272 r_dot_r = r_dot_r_new;
273
274 for i in 0..n {
275 d[i] = F::zero() - r[i] + beta * d[i];
276 }
277 }
278
279 s
280}
281
282fn boundary_tau<F: Float>(s: &[F], d: &[F], radius: F) -> F {
286 let dd = dot(d, d);
287 let sd = dot(s, d);
288 let ss = dot(s, s);
289 let two = F::one() + F::one();
290
291 let a = dd;
294 let b = two * sd;
295 let c = ss - radius * radius;
296
297 let disc = b * b - (two + two) * a * c;
298 if disc < F::zero() {
299 return F::zero();
300 }
301
302 let sqrt_disc = disc.sqrt();
304 let tau1 = (F::zero() - b + sqrt_disc) / (two * a);
305 let tau2 = (F::zero() - b - sqrt_disc) / (two * a);
306
307 if tau1 > F::zero() {
308 if tau2 > F::zero() {
309 tau1.min(tau2)
310 } else {
311 tau1
312 }
313 } else {
314 tau2.max(F::zero())
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 struct Rosenbrock;
323
324 impl Objective<f64> for Rosenbrock {
325 fn dim(&self) -> usize {
326 2
327 }
328
329 fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
330 let a = 1.0 - x[0];
331 let b = x[1] - x[0] * x[0];
332 let f = a * a + 100.0 * b * b;
333 let g0 = -2.0 * a - 400.0 * x[0] * b;
334 let g1 = 200.0 * b;
335 (f, vec![g0, g1])
336 }
337
338 fn hvp(&mut self, x: &[f64], v: &[f64]) -> (Vec<f64>, Vec<f64>) {
339 let h00 = 2.0 - 400.0 * (x[1] - 3.0 * x[0] * x[0]);
342 let h01 = -400.0 * x[0];
343 let h11 = 200.0;
344
345 let hv0 = h00 * v[0] + h01 * v[1];
346 let hv1 = h01 * v[0] + h11 * v[1];
347
348 let g0 = -2.0 * (1.0 - x[0]) - 400.0 * x[0] * (x[1] - x[0] * x[0]);
349 let g1 = 200.0 * (x[1] - x[0] * x[0]);
350
351 (vec![g0, g1], vec![hv0, hv1])
352 }
353 }
354
355 #[test]
356 fn trust_region_rosenbrock() {
357 let mut obj = Rosenbrock;
358 let config = TrustRegionConfig {
359 convergence: ConvergenceParams {
360 max_iter: 200,
361 ..Default::default()
362 },
363 ..Default::default()
364 };
365 let result = trust_region(&mut obj, &[0.0, 0.0], &config);
366
367 assert_eq!(
368 result.termination,
369 TerminationReason::GradientNorm,
370 "terminated with {:?} after {} iterations",
371 result.termination,
372 result.iterations
373 );
374 assert!(
375 (result.x[0] - 1.0).abs() < 1e-6,
376 "x[0] = {}, expected 1.0",
377 result.x[0]
378 );
379 assert!(
380 (result.x[1] - 1.0).abs() < 1e-6,
381 "x[1] = {}, expected 1.0",
382 result.x[1]
383 );
384 }
385
386 struct Rosenbrock4D;
387
388 impl Objective<f64> for Rosenbrock4D {
389 fn dim(&self) -> usize {
390 4
391 }
392
393 fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
394 let mut f = 0.0;
395 let mut g = vec![0.0; 4];
396 for i in 0..3 {
397 let a = 1.0 - x[i];
398 let b = x[i + 1] - x[i] * x[i];
399 f += a * a + 100.0 * b * b;
400 g[i] += -2.0 * a - 400.0 * x[i] * b;
401 g[i + 1] += 200.0 * b;
402 }
403 (f, g)
404 }
405
406 fn hvp(&mut self, x: &[f64], v: &[f64]) -> (Vec<f64>, Vec<f64>) {
407 let n = 4;
408 let mut hv = vec![0.0; n];
409 let mut g = vec![0.0; n];
410
411 for i in 0..3 {
412 let a = 1.0 - x[i];
413 let b = x[i + 1] - x[i] * x[i];
414
415 g[i] += -2.0 * a - 400.0 * x[i] * b;
416 g[i + 1] += 200.0 * b;
417
418 let h_ii = 2.0 - 400.0 * (x[i + 1] - 3.0 * x[i] * x[i]);
419 let h_ij = -400.0 * x[i];
420 let h_jj = 200.0;
421
422 hv[i] += h_ii * v[i] + h_ij * v[i + 1];
423 hv[i + 1] += h_ij * v[i] + h_jj * v[i + 1];
424 }
425
426 (g, hv)
427 }
428 }
429
430 #[test]
431 fn trust_region_rosenbrock_4d() {
432 let mut obj = Rosenbrock4D;
433 let config = TrustRegionConfig {
434 convergence: ConvergenceParams {
435 max_iter: 500,
436 ..Default::default()
437 },
438 ..Default::default()
439 };
440 let result = trust_region(&mut obj, &[0.0, 0.0, 0.0, 0.0], &config);
441
442 assert_eq!(
443 result.termination,
444 TerminationReason::GradientNorm,
445 "terminated with {:?} after {} iterations, grad_norm={}",
446 result.termination,
447 result.iterations,
448 result.gradient_norm
449 );
450 for i in 0..4 {
451 assert!(
452 (result.x[i] - 1.0).abs() < 1e-5,
453 "x[{}] = {}, expected 1.0",
454 i,
455 result.x[i]
456 );
457 }
458 }
459}