argmin/solver/newton/
newton_cg.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    ArgminFloat, Error, Executor, Gradient, Hessian, IterState, LineSearch, Operator,
10    OptimizationResult, Problem, Solver, State, TerminationReason, TerminationStatus, KV,
11};
12use crate::solver::conjugategradient::ConjugateGradient;
13use argmin_math::{
14    ArgminConj, ArgminDot, ArgminL2Norm, ArgminMul, ArgminScaledAdd, ArgminSub, ArgminZeroLike,
15};
16#[cfg(feature = "serde1")]
17use serde::{Deserialize, Serialize};
18
19/// # Newton-Conjugate-Gradient (Newton-CG) method
20///
21/// The Newton-CG method (also called truncated Newton method) uses a modified CG to approximately
22/// solve the Newton equations. After a search direction is found, a line search is performed.
23///
24/// ## Requirements on the optimization problem
25///
26/// The optimization problem is required to implement [`Gradient`] and [`Hessian`].
27///
28/// ## Reference
29///
30/// Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
31/// Springer. ISBN 0-387-30303-0.
32#[derive(Clone)]
33#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
34pub struct NewtonCG<L, F> {
35    /// line search
36    linesearch: L,
37    /// curvature_threshold
38    curvature_threshold: F,
39    /// Tolerance for the stopping criterion based on cost difference
40    tol: F,
41}
42
43impl<L, F> NewtonCG<L, F>
44where
45    F: ArgminFloat,
46{
47    /// Construct a new instance of [`NewtonCG`]
48    ///
49    /// # Example
50    ///
51    /// ```
52    /// # use argmin::solver::newton::NewtonCG;
53    /// # let linesearch = ();
54    /// let ncg: NewtonCG<_, f64> = NewtonCG::new(linesearch);
55    /// ```
56    pub fn new(linesearch: L) -> Self {
57        NewtonCG {
58            linesearch,
59            curvature_threshold: float!(0.0),
60            tol: F::epsilon(),
61        }
62    }
63
64    /// Set curvature threshold
65    ///
66    /// Defaults to 0.
67    ///
68    /// # Example
69    ///
70    /// ```
71    /// # use argmin::solver::newton::NewtonCG;
72    /// # let linesearch = ();
73    /// let ncg: NewtonCG<_, f64> = NewtonCG::new(linesearch).with_curvature_threshold(1e-6);
74    /// ```
75    #[must_use]
76    pub fn with_curvature_threshold(mut self, threshold: F) -> Self {
77        self.curvature_threshold = threshold;
78        self
79    }
80
81    /// Set tolerance for the stopping criterion based on cost difference
82    ///
83    /// Must be larger than 0 and defaults to EPSILON.
84    ///
85    /// # Example
86    ///
87    /// ```
88    /// # use argmin::solver::newton::NewtonCG;
89    /// # use argmin::core::Error;
90    /// # fn main() -> Result<(), Error> {
91    /// # let linesearch = ();
92    /// let ncg: NewtonCG<_, f64> = NewtonCG::new(linesearch).with_tolerance(1e-6)?;
93    /// # Ok(())
94    /// # }
95    /// ```
96    pub fn with_tolerance(mut self, tol: F) -> Result<Self, Error> {
97        if tol <= float!(0.0) {
98            return Err(argmin_error!(
99                InvalidParameter,
100                "`NewtonCG`: tol must be > 0."
101            ));
102        }
103        self.tol = tol;
104        Ok(self)
105    }
106}
107
108impl<O, L, P, G, H, F> Solver<O, IterState<P, G, (), H, (), F>> for NewtonCG<L, F>
109where
110    O: Gradient<Param = P, Gradient = G> + Hessian<Param = P, Hessian = H>,
111    P: Clone
112        + ArgminSub<P, P>
113        + ArgminDot<P, F>
114        + ArgminScaledAdd<P, F, P>
115        + ArgminMul<F, P>
116        + ArgminConj
117        + ArgminZeroLike,
118    G: ArgminL2Norm<F> + ArgminMul<F, P>,
119    H: Clone + ArgminDot<P, P>,
120    L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>,
121    F: ArgminFloat + ArgminL2Norm<F>,
122{
123    fn name(&self) -> &str {
124        "Newton-CG"
125    }
126
127    fn next_iter(
128        &mut self,
129        problem: &mut Problem<O>,
130        mut state: IterState<P, G, (), H, (), F>,
131    ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
132        let param = state.take_param().ok_or_else(argmin_error_closure!(
133            NotInitialized,
134            concat!(
135                "`NewtonCG` requires an initial parameter vector. ",
136                "Please provide an initial guess via `Executor`s `configure` method."
137            )
138        ))?;
139        let grad = state
140            .take_gradient()
141            .map(Result::Ok)
142            .unwrap_or_else(|| problem.gradient(&param))?;
143        let hessian = state
144            .take_hessian()
145            .map(Result::Ok)
146            .unwrap_or_else(|| problem.hessian(&param))?;
147
148        // Solve CG subproblem
149        let mut cg_problem = Problem::new(CGSubProblem::new(&hessian));
150
151        let mut x_p = param.zero_like();
152        let mut x = param.zero_like();
153        let mut cg = ConjugateGradient::new(grad.mul(&(float!(-1.0))));
154
155        let (mut cg_state, _): (IterState<_, _, _, _, _, _>, _) =
156            cg.init(&mut cg_problem, IterState::new().param(x_p.clone()))?;
157
158        let grad_norm_factor = float!(0.5).min(grad.l2_norm().sqrt()) * grad.l2_norm();
159
160        for iter in 0.. {
161            (cg_state, _) = cg.next_iter(&mut cg_problem, cg_state)?;
162
163            let cost = cg_state.get_cost();
164
165            x = cg_state.take_param().unwrap();
166            let p = cg.get_prev_p()?;
167
168            let curvature = p.dot(&hessian.dot(p));
169            if curvature <= self.curvature_threshold {
170                if iter == 0 {
171                    x = grad.mul(&(float!(-1.0)));
172                } else {
173                    x = x_p;
174                }
175                break;
176            }
177
178            if cost <= grad_norm_factor {
179                break;
180            }
181
182            cg_state = cg_state.param(x.clone()).cost(cost);
183            x_p = x.clone();
184        }
185
186        // perform line search
187        // TODO: Should the algorithm stop when search direction is close to 0?
188        self.linesearch.search_direction(x);
189
190        let line_cost = state.get_cost();
191
192        // Run solver
193        let OptimizationResult {
194            problem: line_problem,
195            state: mut linesearch_state,
196            ..
197        } = Executor::new(problem.take_problem().unwrap(), self.linesearch.clone())
198            .configure(|state| state.param(param).gradient(grad).cost(line_cost))
199            .ctrlc(false)
200            .run()?;
201
202        problem.consume_problem(line_problem);
203
204        Ok((
205            state
206                .param(linesearch_state.take_param().unwrap())
207                .cost(linesearch_state.get_cost()),
208            None,
209        ))
210    }
211
212    fn terminate(&mut self, state: &IterState<P, G, (), H, (), F>) -> TerminationStatus {
213        if (state.get_cost() - state.get_prev_cost()).abs() < self.tol {
214            TerminationStatus::Terminated(TerminationReason::SolverConverged)
215        } else {
216            TerminationStatus::NotTerminated
217        }
218    }
219}
220
221#[derive(Clone)]
222struct CGSubProblem<'a, P, H> {
223    hessian: &'a H,
224    phantom: std::marker::PhantomData<P>,
225}
226
227impl<'a, P, H> CGSubProblem<'a, P, H> {
228    /// Constructor
229    fn new(hessian: &'a H) -> Self {
230        CGSubProblem {
231            hessian,
232            phantom: std::marker::PhantomData,
233        }
234    }
235}
236
237impl<P, H> Operator for CGSubProblem<'_, P, H>
238where
239    H: ArgminDot<P, P>,
240{
241    type Param = P;
242    type Output = P;
243
244    fn apply(&self, p: &P) -> Result<P, Error> {
245        Ok(self.hessian.dot(p))
246    }
247}
248
249#[cfg(test)]
250#[allow(clippy::let_unit_value)]
251mod tests {
252    use super::*;
253    use crate::core::{test_utils::TestProblem, ArgminError};
254    use crate::solver::linesearch::MoreThuenteLineSearch;
255
256    test_trait_impl!(
257        newton_cg,
258        NewtonCG<MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>, f64>
259    );
260
261    test_trait_impl!(cg_subproblem, CGSubProblem<Vec<f64>, Vec<Vec<f64>>>);
262
263    #[test]
264    fn test_tolerance() {
265        let tol1: f64 = 1e-4;
266
267        let linesearch: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> =
268            MoreThuenteLineSearch::new();
269
270        let NewtonCG { tol: t, .. }: NewtonCG<MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>, f64> =
271            NewtonCG::new(linesearch).with_tolerance(tol1).unwrap();
272
273        assert!((t - tol1).abs() < f64::EPSILON);
274    }
275
276    #[test]
277    fn test_new() {
278        #[derive(Eq, PartialEq, Debug, Copy, Clone)]
279        struct LineSearch {}
280        let ls = LineSearch {};
281        let ncg: NewtonCG<_, f64> = NewtonCG::new(ls);
282        let NewtonCG {
283            linesearch,
284            curvature_threshold,
285            tol,
286        } = ncg;
287        assert_eq!(linesearch, ls);
288        assert_eq!(curvature_threshold.to_ne_bytes(), 0.0f64.to_ne_bytes());
289        assert_eq!(tol.to_ne_bytes(), f64::EPSILON.to_ne_bytes());
290    }
291
292    #[test]
293    fn test_with_curvature_threshold() {
294        #[derive(Eq, PartialEq, Debug, Copy, Clone)]
295        struct LineSearch {}
296        let ls = LineSearch {};
297        let ncg: NewtonCG<_, f64> = NewtonCG::new(ls).with_curvature_threshold(1e-6);
298        let NewtonCG {
299            linesearch,
300            curvature_threshold,
301            tol,
302        } = ncg;
303        assert_eq!(linesearch, ls);
304        assert_eq!(curvature_threshold.to_ne_bytes(), 1e-6f64.to_ne_bytes());
305        assert_eq!(tol.to_ne_bytes(), f64::EPSILON.to_ne_bytes());
306    }
307
308    #[test]
309    fn test_with_tolerance() {
310        let ls = ();
311        for tolerance in [f64::EPSILON, 1.0, 10.0, 100.0] {
312            let ncg: NewtonCG<_, f64> = NewtonCG::new(ls).with_tolerance(tolerance).unwrap();
313            assert_eq!(ncg.tol.to_ne_bytes(), tolerance.to_ne_bytes());
314        }
315
316        for tolerance in [-f64::EPSILON, 0.0, -1.0] {
317            let res = NewtonCG::new(ls).with_tolerance(tolerance);
318            assert_error!(
319                res,
320                ArgminError,
321                "Invalid parameter: \"`NewtonCG`: tol must be > 0.\""
322            );
323        }
324    }
325
326    #[test]
327    fn test_next_iter_param_not_initialized() {
328        use crate::solver::linesearch::{condition::ArmijoCondition, BacktrackingLineSearch};
329        let ls = BacktrackingLineSearch::new(ArmijoCondition::new(0.9f64).unwrap());
330        let mut ncg: NewtonCG<_, f64> = NewtonCG::new(ls);
331        let res = ncg.next_iter(&mut Problem::new(TestProblem::new()), IterState::new());
332        assert_error!(
333            res,
334            ArgminError,
335            concat!(
336                "Not initialized: \"`NewtonCG` requires an initial parameter vector. ",
337                "Please provide an initial guess via `Executor`s `configure` method.\""
338            )
339        );
340    }
341
342    // TODO: Test next_iter.
343}