argmin/solver/linesearch/
backtracking.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    ArgminFloat, CostFunction, Error, Gradient, IterState, LineSearch, Problem, Solver, State,
10    TerminationReason, TerminationStatus, KV,
11};
12use crate::solver::linesearch::condition::*;
13use argmin_math::ArgminScaledAdd;
14#[cfg(feature = "serde1")]
15use serde::{Deserialize, Serialize};
16
17/// # Backtracking line search
18///
19/// The Backtracking line search is a method which finds a step length from a given point along a
20/// given direction, such that this step length obeys the Armijo (sufficient decrease) condition.
21///
22/// ## Requirements on the optimization problem
23///
24/// The optimization problem is required to implement [`CostFunction`] and [`Gradient`].
25///
26/// ## References
27///
28/// Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
29/// Springer. ISBN 0-387-30303-0.
30///
31/// Wikipedia: <https://en.wikipedia.org/wiki/Backtracking_line_search>
32#[derive(Clone, Eq, PartialEq, Debug)]
33#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
34pub struct BacktrackingLineSearch<P, G, L, F> {
35    /// initial parameter vector
36    init_param: Option<P>,
37    /// initial cost
38    init_cost: F,
39    /// initial gradient
40    init_grad: Option<G>,
41    /// Search direction
42    search_direction: Option<G>,
43    /// Contraction factor rho
44    rho: F,
45    /// Stopping condition
46    condition: L,
47    /// alpha
48    alpha: F,
49}
50
51impl<P, G, L, F> BacktrackingLineSearch<P, G, L, F>
52where
53    F: ArgminFloat,
54{
55    /// Construct a new instance of `BacktrackingLineSearch`
56    ///
57    /// # Example
58    ///
59    /// ```
60    /// # use argmin::solver::linesearch::BacktrackingLineSearch;
61    /// # use argmin::solver::linesearch::condition::ArmijoCondition;
62    ///
63    /// let backtracking: BacktrackingLineSearch<Vec<f64>, Vec<f64>, _, f64> =
64    ///     BacktrackingLineSearch::new(ArmijoCondition::new(0.0001f64));
65    /// ```
66    pub fn new(condition: L) -> Self {
67        BacktrackingLineSearch {
68            init_param: None,
69            init_cost: F::infinity(),
70            init_grad: None,
71            search_direction: None,
72            rho: float!(0.9),
73            condition,
74            alpha: float!(1.0),
75        }
76    }
77
78    /// Set contraction factor rho
79    ///
80    /// This factor must be in (0, 1).
81    ///
82    /// # Example
83    ///
84    /// ```
85    /// # use argmin::core::Error;
86    /// # use argmin::solver::linesearch::BacktrackingLineSearch;
87    /// # use argmin::solver::linesearch::condition::ArmijoCondition;
88    /// # fn main() -> Result<(), Error> {
89    /// # let backtracking: BacktrackingLineSearch<Vec<f64>, Vec<f64>, _, f64> =
90    /// #     BacktrackingLineSearch::new(ArmijoCondition::new(0.0001f64));
91    /// let backtracking = backtracking.rho(0.5)?;
92    /// # Ok(())
93    /// # }
94    /// ```
95    pub fn rho(mut self, rho: F) -> Result<Self, Error> {
96        if rho <= float!(0.0) || rho >= float!(1.0) {
97            return Err(argmin_error!(
98                InvalidParameter,
99                "BacktrackingLineSearch: Contraction factor rho must be in (0, 1)."
100            ));
101        }
102        self.rho = rho;
103        Ok(self)
104    }
105}
106
107impl<P, G, L, F> LineSearch<G, F> for BacktrackingLineSearch<P, G, L, F>
108where
109    F: ArgminFloat,
110{
111    /// Set search direction
112    fn search_direction(&mut self, search_direction: G) {
113        self.search_direction = Some(search_direction);
114    }
115
116    /// Set initial step length
117    fn initial_step_length(&mut self, alpha: F) -> Result<(), Error> {
118        if alpha <= float!(0.0) {
119            return Err(argmin_error!(
120                InvalidParameter,
121                "LineSearch: Initial alpha must be > 0."
122            ));
123        }
124        self.alpha = alpha;
125        Ok(())
126    }
127}
128
129impl<P, G, L, F> BacktrackingLineSearch<P, G, L, F>
130where
131    P: ArgminScaledAdd<G, F, P>,
132    L: LineSearchCondition<G, G, F>,
133    IterState<P, G, (), (), (), F>: State<Float = F>,
134    F: ArgminFloat,
135{
136    /// Perform a single backtracking step
137    fn backtracking_step<O>(
138        &self,
139        problem: &mut Problem<O>,
140        state: IterState<P, G, (), (), (), F>,
141    ) -> Result<IterState<P, G, (), (), (), F>, Error>
142    where
143        O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
144        IterState<P, G, (), (), (), F>: State<Float = F>,
145    {
146        let new_param = self
147            .init_param
148            .as_ref()
149            .ok_or_else(argmin_error_closure!(
150                PotentialBug,
151                "`BacktrackingLineSearch`: Initial parameter vector not set."
152            ))?
153            .scaled_add(
154                &self.alpha,
155                self.search_direction
156                    .as_ref()
157                    .ok_or_else(argmin_error_closure!(
158                        PotentialBug,
159                        "`BacktrackingLineSearch`: Search direction not set."
160                    ))?,
161            );
162
163        let cur_cost = problem.cost(&new_param)?;
164
165        let out = if self.condition.requires_current_gradient() {
166            state
167                .gradient(problem.gradient(&new_param)?)
168                .param(new_param)
169                .cost(cur_cost)
170        } else {
171            state.param(new_param).cost(cur_cost)
172        };
173
174        Ok(out)
175    }
176}
177
178impl<O, P, G, L, F> Solver<O, IterState<P, G, (), (), (), F>> for BacktrackingLineSearch<P, G, L, F>
179where
180    P: Clone + ArgminScaledAdd<G, F, P>,
181    G: ArgminScaledAdd<G, F, G>,
182    O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
183    L: LineSearchCondition<G, G, F>,
184    F: ArgminFloat,
185{
186    const NAME: &'static str = "Backtracking line search";
187
188    fn init(
189        &mut self,
190        problem: &mut Problem<O>,
191        mut state: IterState<P, G, (), (), (), F>,
192    ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
193        if self.search_direction.is_none() {
194            return Err(argmin_error!(
195                NotInitialized,
196                "BacktrackingLineSearch: search_direction must be set."
197            ));
198        }
199
200        let init_param = state.take_param().ok_or_else(argmin_error_closure!(
201            NotInitialized,
202            concat!(
203                "`BacktrackingLineSearch` requires an initial parameter vector. ",
204                "Please provide an initial guess via `Executor`s `configure` method."
205            )
206        ))?;
207
208        let cost = state.get_cost();
209
210        self.init_cost = if cost.is_infinite() {
211            problem.cost(&init_param)?
212        } else {
213            cost
214        };
215
216        let init_grad = state
217            .take_gradient()
218            .map(Result::Ok)
219            .unwrap_or_else(|| problem.gradient(&init_param))?;
220
221        self.init_param = Some(init_param);
222        self.init_grad = Some(init_grad);
223        let state = self.backtracking_step(problem, state)?;
224        Ok((state, None))
225    }
226
227    fn next_iter(
228        &mut self,
229        problem: &mut Problem<O>,
230        state: IterState<P, G, (), (), (), F>,
231    ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
232        self.alpha = self.alpha * self.rho;
233        let state = self.backtracking_step(problem, state)?;
234        Ok((state, None))
235    }
236
237    fn terminate(&mut self, state: &IterState<P, G, (), (), (), F>) -> TerminationStatus {
238        if self.condition.evaluate_condition(
239            state.cost,
240            state.get_gradient(),
241            self.init_cost,
242            self.init_grad.as_ref().unwrap(),
243            self.search_direction.as_ref().unwrap(),
244            self.alpha,
245        ) {
246            TerminationStatus::Terminated(TerminationReason::SolverConverged)
247        } else {
248            TerminationStatus::NotTerminated
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::assert_error;
257    use crate::core::{test_utils::TestProblem, ArgminError, Executor, State};
258    use crate::test_trait_impl;
259    use approx::assert_relative_eq;
260    use num_traits::Float;
261
262    #[derive(Debug, Clone)]
263    struct BTTestProblem {}
264
265    impl CostFunction for BTTestProblem {
266        type Param = Vec<f64>;
267        type Output = f64;
268
269        fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
270            Ok(p[0].powi(2) + p[1].powi(2))
271        }
272    }
273
274    impl Gradient for BTTestProblem {
275        type Param = Vec<f64>;
276        type Gradient = Vec<f64>;
277
278        fn gradient(&self, p: &Self::Param) -> Result<Self::Gradient, Error> {
279            Ok(vec![2.0 * p[0], 2.0 * p[1]])
280        }
281    }
282
283    test_trait_impl!(backtrackinglinesearch,
284                    BacktrackingLineSearch<TestProblem, Vec<f64>, ArmijoCondition<f64>, f64>);
285
286    #[test]
287    fn test_new() {
288        let c: f64 = 0.01;
289        let armijo = ArmijoCondition::new(c).unwrap();
290        let ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
291            BacktrackingLineSearch::new(armijo);
292
293        assert_eq!(ls.init_param, None);
294        assert!(ls.init_cost.is_infinite());
295        assert!(ls.init_cost.is_sign_positive());
296        assert_eq!(ls.init_grad, None);
297        assert_eq!(ls.search_direction, None);
298        assert_eq!(ls.rho.to_ne_bytes(), 0.9f64.to_ne_bytes());
299        assert_eq!(ls.alpha.to_ne_bytes(), 1.0f64.to_ne_bytes());
300    }
301
302    #[test]
303    fn test_rho() {
304        let c: f64 = 0.01;
305        let armijo = ArmijoCondition::new(c).unwrap();
306        let ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
307            BacktrackingLineSearch::new(armijo);
308
309        assert_error!(
310            ls.rho(1.0f64),
311            ArgminError,
312            "Invalid parameter: \"BacktrackingLineSearch: Contraction factor rho must be in (0, 1).\""
313        );
314
315        let c: f64 = 0.01;
316        let armijo = ArmijoCondition::new(c).unwrap();
317        let ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
318            BacktrackingLineSearch::new(armijo);
319
320        assert_error!(
321            ls.rho(0.0f64),
322            ArgminError,
323            "Invalid parameter: \"BacktrackingLineSearch: Contraction factor rho must be in (0, 1).\""
324        );
325
326        let c: f64 = 0.01;
327        let armijo = ArmijoCondition::new(c).unwrap();
328        let ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
329            BacktrackingLineSearch::new(armijo);
330
331        assert!(ls.rho(0.0f64 + f64::EPSILON).is_ok());
332
333        let c: f64 = 0.01;
334        let armijo = ArmijoCondition::new(c).unwrap();
335        let ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
336            BacktrackingLineSearch::new(armijo);
337
338        assert!(ls.rho(1.0f64 - f64::EPSILON).is_ok());
339    }
340
341    #[test]
342    fn test_search_direction() {
343        let c: f64 = 0.01;
344        let armijo = ArmijoCondition::new(c).unwrap();
345        let mut ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
346            BacktrackingLineSearch::new(armijo);
347        ls.search_direction(vec![1.0f64, 1.0]);
348
349        assert_eq!(ls.search_direction, Some(vec![1.0f64, 1.0]));
350    }
351
352    #[test]
353    fn test_initial_step_length() {
354        let c: f64 = 0.01;
355        let armijo = ArmijoCondition::new(c).unwrap();
356        let mut ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
357            BacktrackingLineSearch::new(armijo);
358
359        assert!(ls.initial_step_length(f64::EPSILON).is_ok());
360
361        assert_error!(
362            ls.initial_step_length(0.0f64),
363            ArgminError,
364            "Invalid parameter: \"LineSearch: Initial alpha must be > 0.\""
365        );
366    }
367
368    #[test]
369    fn test_init_param_not_initialized() {
370        let mut linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
371            BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
372        linesearch.search_direction(vec![1.0f64, 1.0]);
373        let res = linesearch.init(&mut Problem::new(TestProblem::new()), IterState::new());
374        assert_error!(
375            res,
376            ArgminError,
377            concat!(
378                "Not initialized: \"`BacktrackingLineSearch` requires an initial parameter vector. ",
379                "Please provide an initial guess via `Executor`s `configure` method.\""
380            )
381        );
382    }
383
384    #[test]
385    fn test_step_armijo() {
386        use crate::core::Problem;
387
388        let prob = BTTestProblem {};
389
390        let c: f64 = 0.01;
391        let armijo = ArmijoCondition::new(c).unwrap();
392        let mut ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
393            BacktrackingLineSearch::new(armijo);
394
395        ls.init_param = Some(vec![-1.0, 0.0]);
396        ls.init_cost = f64::infinity();
397        ls.init_grad = Some(vec![-2.0, 0.0]);
398        ls.search_direction(vec![2.0f64, 0.0]);
399        ls.initial_step_length(0.8).unwrap();
400
401        let data = ls.backtracking_step(&mut Problem::new(prob), IterState::new());
402        assert!(data.is_ok());
403
404        let param = data.as_ref().unwrap().get_param().unwrap();
405        let cost = data.as_ref().unwrap().get_cost();
406        assert_relative_eq!(param[0], 0.6, epsilon = f64::EPSILON);
407        assert_relative_eq!(param[1], 0.0, epsilon = f64::EPSILON);
408        assert_relative_eq!(cost, 0.6f64.powi(2), epsilon = f64::EPSILON);
409
410        assert!(data.as_ref().unwrap().get_gradient().is_none());
411    }
412
413    #[test]
414    fn test_step_wolfe() {
415        // Wolfe, in contrast to Armijo, requires the current gradient. This test makes sure that
416        // the implementation of the backtracking linesearch properly considers this.
417        use crate::core::Problem;
418
419        let prob = BTTestProblem {};
420
421        let c1: f64 = 0.01;
422        let c2: f64 = 0.9;
423        let wolfe = WolfeCondition::new(c1, c2).unwrap();
424        let mut ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, WolfeCondition<f64>, f64> =
425            BacktrackingLineSearch::new(wolfe);
426
427        ls.init_param = Some(vec![-1.0, 0.0]);
428        ls.init_cost = f64::infinity();
429        ls.init_grad = Some(vec![-2.0, 0.0]);
430        ls.search_direction(vec![2.0f64, 0.0]);
431        ls.initial_step_length(0.8).unwrap();
432
433        let data = ls.backtracking_step(&mut Problem::new(prob), IterState::new());
434        assert!(data.is_ok());
435
436        let param = data.as_ref().unwrap().get_param().unwrap();
437        let cost = data.as_ref().unwrap().get_cost();
438        let gradient = data.as_ref().unwrap().get_gradient().unwrap();
439        assert_relative_eq!(param[0], 0.6, epsilon = f64::EPSILON);
440        assert_relative_eq!(param[1], 0.0, epsilon = f64::EPSILON);
441        assert_relative_eq!(cost, 0.6f64.powi(2), epsilon = f64::EPSILON);
442        assert_relative_eq!(gradient[0], 2.0 * 0.6, epsilon = f64::EPSILON);
443        assert_relative_eq!(gradient[1], 0.0, epsilon = f64::EPSILON);
444    }
445
446    #[test]
447    fn test_init_armijo() {
448        use crate::core::IterState;
449        use crate::core::Problem;
450
451        let prob = BTTestProblem {};
452
453        let c: f64 = 0.01;
454        let armijo = ArmijoCondition::new(c).unwrap();
455        let mut ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
456            BacktrackingLineSearch::new(armijo);
457
458        ls.init_param = Some(vec![-1.0, 0.0]);
459        ls.init_cost = f64::infinity();
460        // in contrast to the step tests above, it is not necessary to set the init_grad here
461        // because it will be computed in init if not present.
462        ls.initial_step_length(0.8).unwrap();
463
464        assert_error!(
465            ls.init(
466                &mut Problem::new(prob.clone()),
467                IterState::new().param(ls.init_param.clone().unwrap())
468            ),
469            ArgminError,
470            "Not initialized: \"BacktrackingLineSearch: search_direction must be set.\""
471        );
472
473        ls.search_direction(vec![2.0f64, 0.0]);
474
475        let data = ls.init(
476            &mut Problem::new(prob),
477            IterState::new().param(ls.init_param.clone().unwrap()),
478        );
479        assert!(data.is_ok());
480
481        let data = data.unwrap().0;
482
483        let param = data.get_param().unwrap();
484        let cost = data.get_cost();
485        assert_relative_eq!(param[0], 0.6, epsilon = f64::EPSILON);
486        assert_relative_eq!(param[1], 0.0, epsilon = f64::EPSILON);
487        assert_relative_eq!(cost, 0.6f64.powi(2), epsilon = f64::EPSILON);
488
489        assert!(data.get_gradient().is_none());
490    }
491
492    #[test]
493    fn test_init_wolfe() {
494        use crate::core::IterState;
495        use crate::core::Problem;
496
497        let prob = BTTestProblem {};
498
499        let c1: f64 = 0.01;
500        let c2: f64 = 0.9;
501        let wolfe = WolfeCondition::new(c1, c2).unwrap();
502        let mut ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, WolfeCondition<f64>, f64> =
503            BacktrackingLineSearch::new(wolfe);
504
505        ls.init_param = Some(vec![-1.0, 0.0]);
506        ls.init_cost = f64::infinity();
507        // in contrast to the step tests above, it is not necessary to set the init_grad here
508        // because it will be computed in init if not present.
509        ls.initial_step_length(0.8).unwrap();
510
511        assert_error!(
512            ls.init(
513                &mut Problem::new(prob.clone()),
514                IterState::new().param(ls.init_param.clone().unwrap())
515            ),
516            ArgminError,
517            "Not initialized: \"BacktrackingLineSearch: search_direction must be set.\""
518        );
519
520        ls.search_direction(vec![2.0f64, 0.0]);
521
522        let data = ls.init(
523            &mut Problem::new(prob),
524            IterState::new().param(ls.init_param.clone().unwrap()),
525        );
526        assert!(data.is_ok());
527
528        let data = data.unwrap().0;
529
530        let param = data.get_param().unwrap();
531        let cost = data.get_cost();
532        let gradient = data.get_gradient().unwrap();
533        assert_relative_eq!(param[0], 0.6, epsilon = f64::EPSILON);
534        assert_relative_eq!(param[1], 0.0, epsilon = f64::EPSILON);
535        assert_relative_eq!(cost, 0.6f64.powi(2), epsilon = f64::EPSILON);
536        assert_relative_eq!(gradient[0], 2.0 * 0.6, epsilon = f64::EPSILON);
537        assert_relative_eq!(gradient[1], 0.0, epsilon = f64::EPSILON);
538    }
539
540    #[test]
541    fn test_next_iter() {
542        // Similar to step test, but with the added check that self.alpha is reduced.
543        use crate::core::Problem;
544
545        let prob = BTTestProblem {};
546
547        let c: f64 = 0.01;
548        let armijo = ArmijoCondition::new(c).unwrap();
549        let mut ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
550            BacktrackingLineSearch::new(armijo);
551
552        let init_alpha = 0.8;
553        ls.init_param = Some(vec![-1.0, 0.0]);
554        ls.init_cost = f64::infinity();
555        ls.init_grad = Some(vec![-2.0, 0.0]);
556        ls.search_direction(vec![2.0f64, 0.0]);
557        ls.initial_step_length(init_alpha).unwrap();
558
559        let data = ls.next_iter(
560            &mut Problem::new(prob),
561            IterState::new().param(ls.init_param.clone().unwrap()),
562        );
563        assert!(data.is_ok());
564
565        let param = data.as_ref().unwrap().0.get_param().unwrap();
566        let cost = data.as_ref().unwrap().0.get_cost();
567        // step is smaller than compared to step test, because of the reduced alpha.
568        assert_relative_eq!(param[0], 0.44, epsilon = f64::EPSILON);
569        assert_relative_eq!(param[1], 0.0, epsilon = f64::EPSILON);
570        assert_relative_eq!(cost, 0.44f64.powi(2), epsilon = f64::EPSILON);
571
572        assert!(data.as_ref().unwrap().0.get_gradient().is_none());
573        assert_relative_eq!(ls.alpha, ls.rho * 0.8, epsilon = f64::EPSILON);
574    }
575
576    #[test]
577    fn test_termination() {
578        let c: f64 = 0.01;
579        let armijo = ArmijoCondition::new(c).unwrap();
580        let mut ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
581            BacktrackingLineSearch::new(armijo);
582
583        let init_alpha = 0.8;
584        ls.init_param = Some(vec![-1.0, 0.0]);
585        ls.init_cost = f64::infinity();
586        ls.init_grad = Some(vec![-2.0, 0.0]);
587        ls.search_direction(vec![2.0f64, 0.0]);
588        ls.initial_step_length(init_alpha).unwrap();
589
590        let init_param = ls.init_param.clone().unwrap();
591        assert_eq!(
592            <BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> as Solver<
593                TestProblem,
594                IterState<Vec<f64>, Vec<f64>, (), (), (), f64>,
595            >>::terminate(
596                &mut ls,
597                &IterState::<Vec<f64>, Vec<f64>, (), (), (), f64>::new().param(init_param)
598            ),
599            TerminationStatus::Terminated(TerminationReason::SolverConverged)
600        );
601
602        ls.init_cost = 0.0f64;
603
604        let init_param = ls.init_param.clone().unwrap();
605        assert_eq!(
606            <BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> as Solver<
607                TestProblem,
608                IterState<Vec<f64>, Vec<f64>, (), (), (), f64>,
609            >>::terminate(
610                &mut ls,
611                &IterState::<Vec<f64>, Vec<f64>, (), (), (), f64>::new().param(init_param)
612            ),
613            TerminationStatus::NotTerminated
614        );
615    }
616
617    #[test]
618    fn test_executor_1() {
619        let prob = BTTestProblem {};
620
621        let c: f64 = 0.01;
622        let armijo = ArmijoCondition::new(c).unwrap();
623        let mut ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
624            BacktrackingLineSearch::new(armijo);
625
626        ls.init_param = Some(vec![-1.0, 0.0]);
627        ls.init_cost = f64::infinity();
628        // in contrast to the step tests above, it is not necessary to set the init_grad here
629        // because it will be computed in init if not present.
630        ls.initial_step_length(0.8).unwrap();
631
632        assert_error!(
633            Executor::new(prob.clone(), ls.clone())
634                .configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
635                .run(),
636            ArgminError,
637            "Not initialized: \"BacktrackingLineSearch: search_direction must be set.\""
638        );
639
640        ls.search_direction(vec![2.0f64, 0.0]);
641
642        let data = Executor::new(prob, ls.clone())
643            .configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
644            .run();
645        assert!(data.is_ok());
646
647        let data = data.unwrap().state;
648
649        let param = data.get_param().unwrap();
650        assert_relative_eq!(param[0], 0.6, epsilon = f64::EPSILON);
651        assert_relative_eq!(param[1], 0.0, epsilon = f64::EPSILON);
652        assert_relative_eq!(data.get_cost(), 0.6.powi(2), epsilon = f64::EPSILON);
653        assert_eq!(data.iter, 0);
654        let func_counts = data.get_func_counts();
655        assert_eq!(func_counts["cost_count"], 2);
656        assert_eq!(func_counts["gradient_count"], 1);
657        assert_eq!(
658            data.termination_status,
659            TerminationStatus::Terminated(TerminationReason::SolverConverged)
660        );
661
662        assert!(data.get_gradient().is_none());
663    }
664
665    #[test]
666    fn test_executor_2() {
667        let prob = BTTestProblem {};
668
669        // difference compared to test_executor_1: c is larger to force another backtracking step
670        let c: f64 = 0.2;
671        let armijo = ArmijoCondition::new(c).unwrap();
672        let mut ls: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
673            BacktrackingLineSearch::new(armijo);
674
675        ls.init_param = Some(vec![-1.0, 0.0]);
676        ls.init_cost = f64::infinity();
677        // in contrast to the step tests above, it is not necessary to set the init_grad here
678        // because it will be computed in init if not present.
679        ls.initial_step_length(0.8).unwrap();
680
681        assert_error!(
682            Executor::new(prob.clone(), ls.clone())
683                .configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
684                .run(),
685            ArgminError,
686            "Not initialized: \"BacktrackingLineSearch: search_direction must be set.\""
687        );
688
689        ls.search_direction(vec![2.0f64, 0.0]);
690
691        let data = Executor::new(prob, ls.clone())
692            .configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
693            .run();
694        assert!(data.is_ok());
695
696        let data = data.unwrap().state;
697
698        let param = data.get_param().unwrap();
699        assert_relative_eq!(param[0], 0.44, epsilon = f64::EPSILON);
700        assert_relative_eq!(param[1], 0.0, epsilon = f64::EPSILON);
701        assert_relative_eq!(data.get_cost(), 0.44f64.powi(2), epsilon = f64::EPSILON);
702        assert_eq!(data.iter, 1);
703        let func_counts = data.get_func_counts();
704        assert_eq!(func_counts["cost_count"], 3);
705        assert_eq!(func_counts["gradient_count"], 1);
706        assert_eq!(
707            data.termination_status,
708            TerminationStatus::Terminated(TerminationReason::SolverConverged)
709        );
710        assert!(data.get_gradient().is_none());
711    }
712}