1use crate::core::{
9 ArgminFloat, Error, Executor, Gradient, Hessian, IterState, LineSearch, Operator,
10 OptimizationResult, Problem, Solver, State, TerminationReason, TerminationStatus, KV,
11};
12use crate::solver::conjugategradient::ConjugateGradient;
13use argmin_math::{
14 ArgminConj, ArgminDot, ArgminL2Norm, ArgminMul, ArgminScaledAdd, ArgminSub, ArgminZeroLike,
15};
16#[cfg(feature = "serde1")]
17use serde::{Deserialize, Serialize};
18
19#[derive(Clone)]
33#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
34pub struct NewtonCG<L, F> {
35 linesearch: L,
37 curvature_threshold: F,
39 tol: F,
41}
42
43impl<L, F> NewtonCG<L, F>
44where
45 F: ArgminFloat,
46{
47 pub fn new(linesearch: L) -> Self {
57 NewtonCG {
58 linesearch,
59 curvature_threshold: float!(0.0),
60 tol: F::epsilon(),
61 }
62 }
63
64 #[must_use]
76 pub fn with_curvature_threshold(mut self, threshold: F) -> Self {
77 self.curvature_threshold = threshold;
78 self
79 }
80
81 pub fn with_tolerance(mut self, tol: F) -> Result<Self, Error> {
97 if tol <= float!(0.0) {
98 return Err(argmin_error!(
99 InvalidParameter,
100 "`NewtonCG`: tol must be > 0."
101 ));
102 }
103 self.tol = tol;
104 Ok(self)
105 }
106}
107
108impl<O, L, P, G, H, F> Solver<O, IterState<P, G, (), H, (), F>> for NewtonCG<L, F>
109where
110 O: Gradient<Param = P, Gradient = G> + Hessian<Param = P, Hessian = H>,
111 P: Clone
112 + ArgminSub<P, P>
113 + ArgminDot<P, F>
114 + ArgminScaledAdd<P, F, P>
115 + ArgminMul<F, P>
116 + ArgminConj
117 + ArgminZeroLike,
118 G: ArgminL2Norm<F> + ArgminMul<F, P>,
119 H: Clone + ArgminDot<P, P>,
120 L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>,
121 F: ArgminFloat + ArgminL2Norm<F>,
122{
123 fn name(&self) -> &str {
124 "Newton-CG"
125 }
126
127 fn next_iter(
128 &mut self,
129 problem: &mut Problem<O>,
130 mut state: IterState<P, G, (), H, (), F>,
131 ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
132 let param = state.take_param().ok_or_else(argmin_error_closure!(
133 NotInitialized,
134 concat!(
135 "`NewtonCG` requires an initial parameter vector. ",
136 "Please provide an initial guess via `Executor`s `configure` method."
137 )
138 ))?;
139 let grad = state
140 .take_gradient()
141 .map(Result::Ok)
142 .unwrap_or_else(|| problem.gradient(¶m))?;
143 let hessian = state
144 .take_hessian()
145 .map(Result::Ok)
146 .unwrap_or_else(|| problem.hessian(¶m))?;
147
148 let mut cg_problem = Problem::new(CGSubProblem::new(&hessian));
150
151 let mut x_p = param.zero_like();
152 let mut x = param.zero_like();
153 let mut cg = ConjugateGradient::new(grad.mul(&(float!(-1.0))));
154
155 let (mut cg_state, _): (IterState<_, _, _, _, _, _>, _) =
156 cg.init(&mut cg_problem, IterState::new().param(x_p.clone()))?;
157
158 let grad_norm_factor = float!(0.5).min(grad.l2_norm().sqrt()) * grad.l2_norm();
159
160 for iter in 0.. {
161 (cg_state, _) = cg.next_iter(&mut cg_problem, cg_state)?;
162
163 let cost = cg_state.get_cost();
164
165 x = cg_state.take_param().unwrap();
166 let p = cg.get_prev_p()?;
167
168 let curvature = p.dot(&hessian.dot(p));
169 if curvature <= self.curvature_threshold {
170 if iter == 0 {
171 x = grad.mul(&(float!(-1.0)));
172 } else {
173 x = x_p;
174 }
175 break;
176 }
177
178 if cost <= grad_norm_factor {
179 break;
180 }
181
182 cg_state = cg_state.param(x.clone()).cost(cost);
183 x_p = x.clone();
184 }
185
186 self.linesearch.search_direction(x);
189
190 let line_cost = state.get_cost();
191
192 let OptimizationResult {
194 problem: line_problem,
195 state: mut linesearch_state,
196 ..
197 } = Executor::new(problem.take_problem().unwrap(), self.linesearch.clone())
198 .configure(|state| state.param(param).gradient(grad).cost(line_cost))
199 .ctrlc(false)
200 .run()?;
201
202 problem.consume_problem(line_problem);
203
204 Ok((
205 state
206 .param(linesearch_state.take_param().unwrap())
207 .cost(linesearch_state.get_cost()),
208 None,
209 ))
210 }
211
212 fn terminate(&mut self, state: &IterState<P, G, (), H, (), F>) -> TerminationStatus {
213 if (state.get_cost() - state.get_prev_cost()).abs() < self.tol {
214 TerminationStatus::Terminated(TerminationReason::SolverConverged)
215 } else {
216 TerminationStatus::NotTerminated
217 }
218 }
219}
220
221#[derive(Clone)]
222struct CGSubProblem<'a, P, H> {
223 hessian: &'a H,
224 phantom: std::marker::PhantomData<P>,
225}
226
227impl<'a, P, H> CGSubProblem<'a, P, H> {
228 fn new(hessian: &'a H) -> Self {
230 CGSubProblem {
231 hessian,
232 phantom: std::marker::PhantomData,
233 }
234 }
235}
236
237impl<P, H> Operator for CGSubProblem<'_, P, H>
238where
239 H: ArgminDot<P, P>,
240{
241 type Param = P;
242 type Output = P;
243
244 fn apply(&self, p: &P) -> Result<P, Error> {
245 Ok(self.hessian.dot(p))
246 }
247}
248
249#[cfg(test)]
250#[allow(clippy::let_unit_value)]
251mod tests {
252 use super::*;
253 use crate::core::{test_utils::TestProblem, ArgminError};
254 use crate::solver::linesearch::MoreThuenteLineSearch;
255
256 test_trait_impl!(
257 newton_cg,
258 NewtonCG<MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>, f64>
259 );
260
261 test_trait_impl!(cg_subproblem, CGSubProblem<Vec<f64>, Vec<Vec<f64>>>);
262
263 #[test]
264 fn test_tolerance() {
265 let tol1: f64 = 1e-4;
266
267 let linesearch: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> =
268 MoreThuenteLineSearch::new();
269
270 let NewtonCG { tol: t, .. }: NewtonCG<MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>, f64> =
271 NewtonCG::new(linesearch).with_tolerance(tol1).unwrap();
272
273 assert!((t - tol1).abs() < f64::EPSILON);
274 }
275
276 #[test]
277 fn test_new() {
278 #[derive(Eq, PartialEq, Debug, Copy, Clone)]
279 struct LineSearch {}
280 let ls = LineSearch {};
281 let ncg: NewtonCG<_, f64> = NewtonCG::new(ls);
282 let NewtonCG {
283 linesearch,
284 curvature_threshold,
285 tol,
286 } = ncg;
287 assert_eq!(linesearch, ls);
288 assert_eq!(curvature_threshold.to_ne_bytes(), 0.0f64.to_ne_bytes());
289 assert_eq!(tol.to_ne_bytes(), f64::EPSILON.to_ne_bytes());
290 }
291
292 #[test]
293 fn test_with_curvature_threshold() {
294 #[derive(Eq, PartialEq, Debug, Copy, Clone)]
295 struct LineSearch {}
296 let ls = LineSearch {};
297 let ncg: NewtonCG<_, f64> = NewtonCG::new(ls).with_curvature_threshold(1e-6);
298 let NewtonCG {
299 linesearch,
300 curvature_threshold,
301 tol,
302 } = ncg;
303 assert_eq!(linesearch, ls);
304 assert_eq!(curvature_threshold.to_ne_bytes(), 1e-6f64.to_ne_bytes());
305 assert_eq!(tol.to_ne_bytes(), f64::EPSILON.to_ne_bytes());
306 }
307
308 #[test]
309 fn test_with_tolerance() {
310 let ls = ();
311 for tolerance in [f64::EPSILON, 1.0, 10.0, 100.0] {
312 let ncg: NewtonCG<_, f64> = NewtonCG::new(ls).with_tolerance(tolerance).unwrap();
313 assert_eq!(ncg.tol.to_ne_bytes(), tolerance.to_ne_bytes());
314 }
315
316 for tolerance in [-f64::EPSILON, 0.0, -1.0] {
317 let res = NewtonCG::new(ls).with_tolerance(tolerance);
318 assert_error!(
319 res,
320 ArgminError,
321 "Invalid parameter: \"`NewtonCG`: tol must be > 0.\""
322 );
323 }
324 }
325
326 #[test]
327 fn test_next_iter_param_not_initialized() {
328 use crate::solver::linesearch::{condition::ArmijoCondition, BacktrackingLineSearch};
329 let ls = BacktrackingLineSearch::new(ArmijoCondition::new(0.9f64).unwrap());
330 let mut ncg: NewtonCG<_, f64> = NewtonCG::new(ls);
331 let res = ncg.next_iter(&mut Problem::new(TestProblem::new()), IterState::new());
332 assert_error!(
333 res,
334 ArgminError,
335 concat!(
336 "Not initialized: \"`NewtonCG` requires an initial parameter vector. ",
337 "Please provide an initial guess via `Executor`s `configure` method.\""
338 )
339 );
340 }
341
342 }