optimization_solvers/steepest_descent/
coordinate_descent.rs1use super::*;
2
3#[derive(derive_getters::Getters)]
8pub struct CoordinateDescent {
9 pub grad_tol: Floating,
10 pub x: DVector<Floating>,
11 pub k: usize,
12}
13
14impl CoordinateDescent {
15 pub fn new(grad_tol: Floating, x0: DVector<Floating>) -> Self {
16 Self {
17 grad_tol,
18 x: x0,
19 k: 0,
20 }
21 }
22}
23
24impl ComputeDirection for CoordinateDescent {
25 fn compute_direction(
26 &mut self,
27 eval: &FuncEvalMultivariate,
28 ) -> Result<DVector<Floating>, SolverError> {
29 let grad_k = eval.g();
31 let (position, max_value) =
32 grad_k
33 .iter()
34 .enumerate()
35 .fold((0, 0.0), |(idx, max), (i, g)| {
36 if g.abs() > max {
37 (i, g.abs())
38 } else {
39 (idx, max)
40 }
41 });
42 let mut direction_k = DVector::zeros(grad_k.len());
43 direction_k[position] = -max_value.signum();
44 Ok(direction_k)
45 }
46}
47
48impl LineSearchSolver for CoordinateDescent {
49 fn xk(&self) -> &DVector<Floating> {
50 &self.x
51 }
52 fn xk_mut(&mut self) -> &mut DVector<Floating> {
53 &mut self.x
54 }
55 fn k(&self) -> &usize {
56 &self.k
57 }
58 fn k_mut(&mut self) -> &mut usize {
59 &mut self.k
60 }
61 fn has_converged(&self, eval: &FuncEvalMultivariate) -> bool {
62 let grad = eval.g();
64 grad.iter()
66 .fold(Floating::NEG_INFINITY, |acc, x| x.abs().max(acc))
67 < self.grad_tol
68 }
69
70 fn update_next_iterate<LS: LineSearch>(
71 &mut self,
72 line_search: &mut LS,
73 eval_x_k: &FuncEvalMultivariate, oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
75 direction: &DVector<Floating>,
76 max_iter_line_search: usize,
77 ) -> Result<(), SolverError> {
78 let step = line_search.compute_step_len(
79 self.xk(),
80 eval_x_k,
81 direction,
82 oracle,
83 max_iter_line_search,
84 );
85
86 debug!(target: "coordinate_descent", "ITERATE: {} + {} * {} = {}", self.xk(), step, direction, self.xk() + step * direction);
87
88 let next_iterate = self.xk() + step * direction;
89
90 *self.xk_mut() = next_iterate;
91
92 Ok(())
93 }
94}
95
96#[cfg(test)]
97mod steepest_descent_l1_test {
98 use super::*;
99 use nalgebra::DVector;
100
101 #[test]
102 pub fn coordinate_descent_morethuente() {
103 std::env::set_var("RUST_LOG", "info");
104
105 let _ = Tracer::default()
106 .with_stdout_layer(Some(LogFormat::Normal))
107 .build();
108 let gamma = 90.0;
109 let f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
110 let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
111 let g = DVector::from(vec![x[0], gamma * x[1]]);
112 (f, g).into()
113 };
114 let mut ls = MoreThuente::default();
117
118 let tol = 1e-12;
120
121 let x_0 = DVector::from(vec![180.0, 152.0]);
122 let mut sdl1 = CoordinateDescent::new(tol, x_0);
123
124 let max_iter_solver = 1000;
126 let max_iter_line_search = 100;
127
128 sdl1.minimize(
129 &mut ls,
130 f_and_g,
131 max_iter_solver,
132 max_iter_line_search,
133 None,
134 )
135 .unwrap();
136
137 println!("Iterate: {:?}", sdl1.xk());
138
139 let eval = f_and_g(sdl1.xk());
140 println!("Function eval: {:?}", eval);
141 println!("Gradient norm: {:?}", eval.g().norm());
142 println!("tol: {:?}", tol);
143
144 let convergence = sdl1.has_converged(&eval);
145 println!("Convergence: {:?}", convergence);
146
147 assert!((eval.f() - 0.0).abs() < 1e-6);
148 }
149
150 #[test]
151 pub fn coordinate_descent_backtracking() {
152 std::env::set_var("RUST_LOG", "info");
153
154 let _ = Tracer::default()
155 .with_stdout_layer(Some(LogFormat::Normal))
156 .build();
157 let gamma = 90.0;
158 let f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
159 let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
160 let g = DVector::from(vec![x[0], gamma * x[1]]);
161 (f, g).into()
162 };
163 let alpha = 1e-4;
165 let beta = 0.5;
166 let mut ls = BackTracking::new(alpha, beta);
167
168 let tol = 1e-12;
170
171 let x_0 = DVector::from(vec![180.0, 152.0]);
172 let mut sdl1 = CoordinateDescent::new(tol, x_0);
173
174 let max_iter_solver = 1000;
176 let max_iter_line_search = 100;
177
178 sdl1.minimize(
179 &mut ls,
180 f_and_g,
181 max_iter_solver,
182 max_iter_line_search,
183 None,
184 )
185 .unwrap();
186
187 println!("Iterate: {:?}", sdl1.xk());
188
189 let eval = f_and_g(sdl1.xk());
190 println!("Function eval: {:?}", eval);
191 println!("Gradient norm: {:?}", eval.g().norm());
192 println!("tol: {:?}", tol);
193
194 let convergence = sdl1.has_converged(&eval);
195 println!("Convergence: {:?}", convergence);
196
197 assert!((eval.f() - 0.0).abs() < 1e-6);
198 }
199}