opensrdk_optimization/
line_search.rs1use opensrdk_linear_algebra::*;
2
3#[derive(Clone, Debug)]
6pub struct LineSearch {
7 initial_step_size: f64,
8 step_update_rate: f64,
9 armijo_param: f64,
10 curvature_param: f64,
11}
12
13impl Default for LineSearch {
14 fn default() -> Self {
15 Self {
16 initial_step_size: 1.0,
17 step_update_rate: 0.1,
18 armijo_param: 0.1,
19 curvature_param: 0.9,
20 }
21 }
22}
23
24impl LineSearch {
25 pub fn with_initial_step_size(mut self, initial_step_size: f64) -> Self {
26 self.initial_step_size = initial_step_size;
27
28 self
29 }
30
31 pub fn with_step_update_rate(mut self, step_update_rate: f64) -> Self {
32 self.step_update_rate = step_update_rate;
33
34 self
35 }
36
37 pub fn with_armijo_param(mut self, armijo_param: f64) -> Self {
38 self.armijo_param = armijo_param;
39
40 self
41 }
42
43 pub fn with_curvature_param(mut self, curvature_param: f64) -> Self {
44 self.curvature_param = curvature_param;
45
46 self
47 }
48
49 pub fn search(
53 &self,
54 fx_gfx: &dyn Fn(&[f64]) -> (f64, Vec<f64>),
55 x: &[f64],
56 direction: &[f64],
57 ) -> f64 {
58 let mut step_size = self.initial_step_size;
59 let x = x.to_vec().col_mat();
60 let d = direction.to_vec().col_mat();
61
62 loop {
63 let xad = x.clone() + step_size * d.clone();
64
65 let (fx, dfx_dx) = fx_gfx(x.slice());
66 let (fxad, dfxad_dx) = fx_gfx(xad.slice());
67
68 let dfx_dx_d = (dfx_dx.row_mat() * &d)[0][0];
69
70 let armijo_left = fxad;
72 let armijo_right = fx + self.armijo_param * step_size * dfx_dx_d;
73
74 if armijo_left > armijo_right {
75 step_size *= 1.0 - self.step_update_rate;
76
77 continue;
78 }
79
80 let curvature_left = self.curvature_param * dfx_dx_d;
82 let curvature_right = (dfxad_dx.row_mat() * &d)[0][0];
83
84 if curvature_left > curvature_right {
85 step_size *= 1.0 + self.step_update_rate;
86
87 continue;
88 }
89
90 break;
91 }
92
93 step_size
94 }
95}