use crate::calculus::{DifferentiableVectorFunction, VectorFunction};
use fenris_traits::Real;
use itertools::iterate;
use log::debug;
use nalgebra::{DVector, DVectorView, DVectorViewMut, Scalar};
use numeric_literals::replace_float_literals;
use std::error::Error;
use std::fmt;
use std::fmt::Display;
#[derive(Debug, Clone)]
pub struct NewtonResult<T>
where
T: Scalar,
{
pub solution: DVector<T>,
pub iterations: usize,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct NewtonSettings<T> {
pub max_iterations: Option<usize>,
pub tolerance: T,
}
#[derive(Debug)]
pub enum NewtonError {
MaximumIterationsReached(usize),
JacobianError(Box<dyn Error>),
LineSearchError(Box<dyn Error>),
}
impl Display for NewtonError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
&NewtonError::MaximumIterationsReached(maxit) => {
write!(f, "Failed to converge within maximum number of iterations ({}).", maxit)
}
&NewtonError::JacobianError(ref err) => {
write!(f, "Failed to solve Jacobian system. Error: {}", err)
}
&NewtonError::LineSearchError(ref err) => {
write!(f, "Line search failed to produce valid step direction. Error: {}", err)
}
}
}
}
impl Error for NewtonError {}
#[replace_float_literals(T::from_f64(literal).unwrap())]
pub fn newton<'a, T, F>(
function: F,
x: impl Into<DVectorViewMut<'a, T>>,
f: impl Into<DVectorViewMut<'a, T>>,
dx: impl Into<DVectorViewMut<'a, T>>,
settings: NewtonSettings<T>,
) -> Result<usize, NewtonError>
where
T: Real,
F: DifferentiableVectorFunction<T>,
{
newton_line_search(function, x, f, dx, settings, &mut NoLineSearch {})
}
#[replace_float_literals(T::from_f64(literal).unwrap())]
pub fn newton_line_search<'a, T, F>(
mut function: F,
x: impl Into<DVectorViewMut<'a, T>>,
f: impl Into<DVectorViewMut<'a, T>>,
dx: impl Into<DVectorViewMut<'a, T>>,
settings: NewtonSettings<T>,
line_search: &mut impl LineSearch<T, F>,
) -> Result<usize, NewtonError>
where
T: Real,
F: DifferentiableVectorFunction<T>,
{
let mut x = x.into();
let mut f = f.into();
let mut minus_dx = dx.into();
assert_eq!(x.nrows(), f.nrows());
assert_eq!(minus_dx.nrows(), f.nrows());
function.eval_into(&mut f, &DVectorView::from(&x));
let mut iter = 0;
while f.norm() > settings.tolerance {
if settings
.max_iterations
.map(|max_iter| iter == max_iter)
.unwrap_or(false)
{
return Err(NewtonError::MaximumIterationsReached(iter));
}
let j_result = function.solve_jacobian_system(&mut minus_dx, &DVectorView::from(&x), &DVectorView::from(&f));
if let Err(err) = j_result {
return Err(NewtonError::JacobianError(err));
}
minus_dx *= -1.0;
let dx = &minus_dx;
let step_length = line_search
.step(
&mut function,
DVectorViewMut::from(&mut f),
DVectorViewMut::from(&mut x),
DVectorView::from(dx),
)
.map_err(|err| NewtonError::LineSearchError(err))?;
debug!("Newton step length at iter {}: {}", iter, step_length);
iter += 1;
}
Ok(iter)
}
pub trait LineSearch<T: Scalar, F: VectorFunction<T>> {
fn step(
&mut self,
function: &mut F,
f: DVectorViewMut<T>,
x: DVectorViewMut<T>,
direction: DVectorView<T>,
) -> Result<T, Box<dyn Error>>;
}
#[derive(Clone, Debug)]
pub struct NoLineSearch;
impl<T, F> LineSearch<T, F> for NoLineSearch
where
T: Real,
F: VectorFunction<T>,
{
#[replace_float_literals(T::from_f64(literal).unwrap())]
fn step(
&mut self,
function: &mut F,
mut f: DVectorViewMut<T>,
mut x: DVectorViewMut<T>,
direction: DVectorView<T>,
) -> Result<T, Box<dyn Error>> {
let p = direction;
x.axpy(T::one(), &p, T::one());
function.eval_into(&mut f, &DVectorView::from(&x));
Ok(T::one())
}
}
pub struct BacktrackingLineSearch;
impl<T, F> LineSearch<T, F> for BacktrackingLineSearch
where
T: Real,
F: VectorFunction<T>,
{
#[replace_float_literals(T::from_f64(literal).unwrap())]
fn step(
&mut self,
function: &mut F,
mut f: DVectorViewMut<T>,
mut x: DVectorViewMut<T>,
direction: DVectorView<T>,
) -> Result<T, Box<dyn Error>> {
let c = 1e-4;
let alpha_min = 1e-6;
let p = direction;
let g_initial = 0.5 * f.magnitude_squared();
let initial_alphas = [0.0, 1.0, 0.75, 0.5];
let mut alpha_iter = initial_alphas
.iter()
.copied()
.chain(iterate(0.25, |alpha_i| 0.25 * *alpha_i));
let mut alpha_prev = alpha_iter.next().unwrap();
let mut alpha = alpha_iter.next().unwrap();
loop {
let delta_alpha = alpha - alpha_prev;
x.axpy(delta_alpha, &p, T::one());
function.eval_into(&mut f, &DVectorView::from(&x));
let g = 0.5 * f.magnitude_squared();
if g <= (1.0 - c * alpha) * g_initial {
break;
} else if alpha < alpha_min {
return Err(Box::from(format!(
"Failed to produce valid step direction.\
Alpha {} is smaller than minimum allowed alpha {}.",
alpha, alpha_min
)));
} else {
alpha_prev = alpha;
alpha = alpha_iter.next().unwrap();
}
}
Ok(alpha)
}
}