opensrdk_optimization/
line_search.rs

1use opensrdk_linear_algebra::*;
2
3/// Configurations for line search.
4/// If you are indecisive for values, it is sufficient to use `default()`.
5#[derive(Clone, Debug)]
6pub struct LineSearch {
7    initial_step_size: f64,
8    step_update_rate: f64,
9    armijo_param: f64,
10    curvature_param: f64,
11}
12
13impl Default for LineSearch {
14    fn default() -> Self {
15        Self {
16            initial_step_size: 1.0,
17            step_update_rate: 0.1,
18            armijo_param: 0.1,
19            curvature_param: 0.9,
20        }
21    }
22}
23
24impl LineSearch {
25    pub fn with_initial_step_size(mut self, initial_step_size: f64) -> Self {
26        self.initial_step_size = initial_step_size;
27
28        self
29    }
30
31    pub fn with_step_update_rate(mut self, step_update_rate: f64) -> Self {
32        self.step_update_rate = step_update_rate;
33
34        self
35    }
36
37    pub fn with_armijo_param(mut self, armijo_param: f64) -> Self {
38        self.armijo_param = armijo_param;
39
40        self
41    }
42
43    pub fn with_curvature_param(mut self, curvature_param: f64) -> Self {
44        self.curvature_param = curvature_param;
45
46        self
47    }
48
49    /// - `fx_gfx`: function to optimize. It must also return the gradients of each inputs.
50    /// - `x`: current input value
51    /// - `direction`: direction to search
52    pub fn search(
53        &self,
54        fx_gfx: &dyn Fn(&[f64]) -> (f64, Vec<f64>),
55        x: &[f64],
56        direction: &[f64],
57    ) -> f64 {
58        let mut step_size = self.initial_step_size;
59        let x = x.to_vec().col_mat();
60        let d = direction.to_vec().col_mat();
61
62        loop {
63            let xad = x.clone() + step_size * d.clone();
64
65            let (fx, dfx_dx) = fx_gfx(x.slice());
66            let (fxad, dfxad_dx) = fx_gfx(xad.slice());
67
68            let dfx_dx_d = (dfx_dx.row_mat() * &d)[0][0];
69
70            // Armijo condition
71            let armijo_left = fxad;
72            let armijo_right = fx + self.armijo_param * step_size * dfx_dx_d;
73
74            if armijo_left > armijo_right {
75                step_size *= 1.0 - self.step_update_rate;
76
77                continue;
78            }
79
80            // Curvature condition
81            let curvature_left = self.curvature_param * dfx_dx_d;
82            let curvature_right = (dfxad_dx.row_mat() * &d)[0][0];
83
84            if curvature_left > curvature_right {
85                step_size *= 1.0 + self.step_update_rate;
86
87                continue;
88            }
89
90            break;
91        }
92
93        step_size
94    }
95}