1use crate::calculus::{DifferentiableVectorFunction, VectorFunction};
2use fenris_traits::Real;
3use itertools::iterate;
4use log::debug;
5use nalgebra::{DVector, DVectorView, DVectorViewMut, Scalar};
6use numeric_literals::replace_float_literals;
7use std::error::Error;
8use std::fmt;
9use std::fmt::Display;
10
11#[derive(Debug, Clone)]
12pub struct NewtonResult<T>
13where
14 T: Scalar,
15{
16 pub solution: DVector<T>,
17 pub iterations: usize,
18}
19
20#[derive(Debug, Copy, Clone, PartialEq, Eq)]
21pub struct NewtonSettings<T> {
22 pub max_iterations: Option<usize>,
23 pub tolerance: T,
24}
25
26#[derive(Debug)]
27pub enum NewtonError {
28 MaximumIterationsReached(usize),
30 JacobianError(Box<dyn Error>),
32 LineSearchError(Box<dyn Error>),
34}
35
36impl Display for NewtonError {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
38 match self {
39 &NewtonError::MaximumIterationsReached(maxit) => {
40 write!(f, "Failed to converge within maximum number of iterations ({}).", maxit)
41 }
42 &NewtonError::JacobianError(ref err) => {
43 write!(f, "Failed to solve Jacobian system. Error: {}", err)
44 }
45 &NewtonError::LineSearchError(ref err) => {
46 write!(f, "Line search failed to produce valid step direction. Error: {}", err)
47 }
48 }
49 }
50}
51
52impl Error for NewtonError {}
53
54#[replace_float_literals(T::from_f64(literal).unwrap())]
61pub fn newton<'a, T, F>(
62 function: F,
63 x: impl Into<DVectorViewMut<'a, T>>,
64 f: impl Into<DVectorViewMut<'a, T>>,
65 dx: impl Into<DVectorViewMut<'a, T>>,
66 settings: NewtonSettings<T>,
67) -> Result<usize, NewtonError>
68where
69 T: Real,
70 F: DifferentiableVectorFunction<T>,
71{
72 newton_line_search(function, x, f, dx, settings, &mut NoLineSearch {})
73}
74
75#[replace_float_literals(T::from_f64(literal).unwrap())]
77pub fn newton_line_search<'a, T, F>(
78 mut function: F,
79 x: impl Into<DVectorViewMut<'a, T>>,
80 f: impl Into<DVectorViewMut<'a, T>>,
81 dx: impl Into<DVectorViewMut<'a, T>>,
82 settings: NewtonSettings<T>,
83 line_search: &mut impl LineSearch<T, F>,
84) -> Result<usize, NewtonError>
85where
86 T: Real,
87 F: DifferentiableVectorFunction<T>,
88{
89 let mut x = x.into();
90 let mut f = f.into();
91 let mut minus_dx = dx.into();
92
93 assert_eq!(x.nrows(), f.nrows());
94 assert_eq!(minus_dx.nrows(), f.nrows());
95
96 function.eval_into(&mut f, &DVectorView::from(&x));
97
98 let mut iter = 0;
99
100 while f.norm() > settings.tolerance {
101 if settings
102 .max_iterations
103 .map(|max_iter| iter == max_iter)
104 .unwrap_or(false)
105 {
106 return Err(NewtonError::MaximumIterationsReached(iter));
107 }
108
109 let j_result = function.solve_jacobian_system(&mut minus_dx, &DVectorView::from(&x), &DVectorView::from(&f));
111 if let Err(err) = j_result {
112 return Err(NewtonError::JacobianError(err));
113 }
114
115 minus_dx *= -1.0;
117 let dx = &minus_dx;
118
119 let step_length = line_search
120 .step(
121 &mut function,
122 DVectorViewMut::from(&mut f),
123 DVectorViewMut::from(&mut x),
124 DVectorView::from(dx),
125 )
126 .map_err(|err| NewtonError::LineSearchError(err))?;
127 debug!("Newton step length at iter {}: {}", iter, step_length);
128 iter += 1;
129 }
130
131 Ok(iter)
132}
133
134pub trait LineSearch<T: Scalar, F: VectorFunction<T>> {
135 fn step(
136 &mut self,
137 function: &mut F,
138 f: DVectorViewMut<T>,
139 x: DVectorViewMut<T>,
140 direction: DVectorView<T>,
141 ) -> Result<T, Box<dyn Error>>;
142}
143
144#[derive(Clone, Debug)]
146pub struct NoLineSearch;
147
148impl<T, F> LineSearch<T, F> for NoLineSearch
149where
150 T: Real,
151 F: VectorFunction<T>,
152{
153 #[replace_float_literals(T::from_f64(literal).unwrap())]
154 fn step(
155 &mut self,
156 function: &mut F,
157 mut f: DVectorViewMut<T>,
158 mut x: DVectorViewMut<T>,
159 direction: DVectorView<T>,
160 ) -> Result<T, Box<dyn Error>> {
161 let p = direction;
162 x.axpy(T::one(), &p, T::one());
163 function.eval_into(&mut f, &DVectorView::from(&x));
164 Ok(T::one())
165 }
166}
167
168pub struct BacktrackingLineSearch;
173
174impl<T, F> LineSearch<T, F> for BacktrackingLineSearch
175where
176 T: Real,
177 F: VectorFunction<T>,
178{
179 #[replace_float_literals(T::from_f64(literal).unwrap())]
180 fn step(
181 &mut self,
182 function: &mut F,
183 mut f: DVectorViewMut<T>,
184 mut x: DVectorViewMut<T>,
185 direction: DVectorView<T>,
186 ) -> Result<T, Box<dyn Error>> {
187 let c = 1e-4;
204 let alpha_min = 1e-6;
205
206 let p = direction;
207 let g_initial = 0.5 * f.magnitude_squared();
208
209 let initial_alphas = [0.0, 1.0, 0.75, 0.5];
213 let mut alpha_iter = initial_alphas
214 .iter()
215 .copied()
216 .chain(iterate(0.25, |alpha_i| 0.25 * *alpha_i));
217
218 let mut alpha_prev = alpha_iter.next().unwrap();
219 let mut alpha = alpha_iter.next().unwrap();
220
221 loop {
222 let delta_alpha = alpha - alpha_prev;
223
224 x.axpy(delta_alpha, &p, T::one());
230 function.eval_into(&mut f, &DVectorView::from(&x));
231
232 let g = 0.5 * f.magnitude_squared();
233 if g <= (1.0 - c * alpha) * g_initial {
234 break;
235 } else if alpha < alpha_min {
236 return Err(Box::from(format!(
237 "Failed to produce valid step direction.\
238 Alpha {} is smaller than minimum allowed alpha {}.",
239 alpha, alpha_min
240 )));
241 } else {
242 alpha_prev = alpha;
243 alpha = alpha_iter.next().unwrap();
244 }
245 }
246
247 Ok(alpha)
248 }
249}