optimization_solvers/steepest_descent/
gradient_descent.rs1use super::*;
2
3#[derive(derive_getters::Getters)]
8pub struct GradientDescent {
9 pub grad_tol: Floating,
10 pub x: DVector<Floating>,
11 pub k: usize,
12}
13
14impl GradientDescent {
15 pub fn new(grad_tol: Floating, x0: DVector<Floating>) -> Self {
16 Self {
17 grad_tol,
18 x: x0,
19 k: 0,
20 }
21 }
22}
23
24impl ComputeDirection for GradientDescent {
25 fn compute_direction(
26 &mut self,
27 eval: &FuncEvalMultivariate,
28 ) -> Result<DVector<Floating>, SolverError> {
29 Ok(-eval.g())
30 }
31}
32
33impl LineSearchSolver for GradientDescent {
34 fn xk(&self) -> &DVector<Floating> {
35 &self.x
36 }
37 fn xk_mut(&mut self) -> &mut DVector<Floating> {
38 &mut self.x
39 }
40 fn k(&self) -> &usize {
41 &self.k
42 }
43 fn k_mut(&mut self) -> &mut usize {
44 &mut self.k
45 }
46 fn has_converged(&self, eval: &FuncEvalMultivariate) -> bool {
47 let grad = eval.g();
49 grad.iter()
51 .fold(Floating::NEG_INFINITY, |acc, x| x.abs().max(acc))
52 < self.grad_tol
53 }
54
55 fn update_next_iterate<LS: LineSearch>(
56 &mut self,
57 line_search: &mut LS,
58 eval_x_k: &FuncEvalMultivariate, oracle: &mut impl FnMut(&DVector<Floating>) -> FuncEvalMultivariate,
60 direction: &DVector<Floating>,
61 max_iter_line_search: usize,
62 ) -> Result<(), SolverError> {
63 let step = line_search.compute_step_len(
64 self.xk(),
65 eval_x_k,
66 direction,
67 oracle,
68 max_iter_line_search,
69 );
70
71 debug!(target: "gradient_descent", "ITERATE: {} + {} * {} = {}", self.xk(), step, direction, self.xk() + step * direction);
72
73 let next_iterate = self.xk() + step * direction;
74
75 *self.xk_mut() = next_iterate;
76
77 Ok(())
78 }
79}
80
81#[cfg(test)]
82mod gradient_descent_test {
83 use super::*;
84
85 #[test]
86 pub fn grad_descent_more_thuente() {
87 std::env::set_var("RUST_LOG", "info");
88
89 let _ = Tracer::default()
90 .with_stdout_layer(Some(LogFormat::Normal))
91 .build();
92 let gamma = 90.0;
93 let f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
94 let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
95 let g = DVector::from(vec![x[0], gamma * x[1]]);
96 (f, g).into()
97 };
98 let mut ls = MoreThuente::default();
100
101 let tol = 1e-12;
103 let x_0 = DVector::from(vec![180.0, 152.0]);
104 let mut gd = GradientDescent::new(tol, x_0);
105
106 let max_iter_solver = 1000;
108 let max_iter_line_search = 100;
109
110 gd.minimize(
111 &mut ls,
112 f_and_g,
113 max_iter_solver,
114 max_iter_line_search,
115 None,
116 )
117 .unwrap();
118
119 println!("Iterate: {:?}", gd.xk());
120
121 let eval = f_and_g(gd.xk());
122 println!("Function eval: {:?}", eval);
123 println!("Gradient norm: {:?}", eval.g().norm());
124 println!("tol: {:?}", tol);
125
126 let convergence = gd.has_converged(&eval);
127 println!("Convergence: {:?}", convergence);
128
129 assert!((eval.f() - 0.0).abs() < 1e-6);
130 }
131
132 #[test]
133 pub fn grad_desc_backtracking() {
134 std::env::set_var("RUST_LOG", "debug");
135
136 let _ = Tracer::default()
137 .with_stdout_layer(Some(LogFormat::Normal))
138 .build();
139 let gamma = 90.0;
140 let f_and_g = |x: &DVector<Floating>| -> FuncEvalMultivariate {
141 let f = 0.5 * (x[0].powi(2) + gamma * x[1].powi(2));
142 let g = DVector::from(vec![x[0], gamma * x[1]]);
143 (f, g).into()
144 };
145 let alpha = 1e-4;
147 let beta = 0.5;
148 let mut ls = BackTracking::new(alpha, beta);
149
150 let tol = 1e-12;
152 let x_0 = DVector::from(vec![180.0, 152.0]);
153 let mut gd = GradientDescent::new(tol, x_0);
154
155 let max_iter_solver = 1000;
157 let max_iter_line_search = 100;
158
159 gd.minimize(
160 &mut ls,
161 f_and_g,
162 max_iter_solver,
163 max_iter_line_search,
164 None,
165 )
166 .unwrap();
167
168 println!("Iterate: {:?}", gd.xk());
169
170 let eval = f_and_g(gd.xk());
171 println!("Function eval: {:?}", eval);
172 println!("Gradient norm: {:?}", eval.g().norm());
173 println!("tol: {:?}", tol);
174
175 let convergence = gd.has_converged(&eval);
176 println!("Convergence: {:?}", convergence);
177
178 assert!((eval.f() - 0.0).abs() < 1e-6);
179 }
180}