use crate::{
error::{DiffsolError, NonLinearSolverError},
non_linear_solver_error,
nonlinear_solver::convergence::ConvergenceStatus,
Convergence, InitialConditionSolverOptions, Scalar, Vector,
};
use num_traits::{FromPrimitive, One, Pow};
pub trait LineSearch<V: Vector>: Default {
fn take_optimal_step(
&mut self,
x: &mut V,
delta: &mut V,
error_y: &V,
fun: &impl Fn(&V, &mut V),
linear_solver: &impl Fn(&mut V) -> Result<(), DiffsolError>,
convergence: &mut Convergence<V>,
) -> Result<ConvergenceStatus, DiffsolError>;
fn reset(&mut self);
}
#[derive(Default)]
pub struct NoLineSearch;
impl<V: Vector> LineSearch<V> for NoLineSearch {
fn take_optimal_step(
&mut self,
x: &mut V,
delta: &mut V,
error_y: &V,
fun: &impl Fn(&V, &mut V),
linear_solver: &impl Fn(&mut V) -> Result<(), DiffsolError>,
convergence: &mut Convergence<V>,
) -> Result<ConvergenceStatus, DiffsolError> {
fun(x, delta);
linear_solver(delta)?;
x.sub_assign(&*delta);
let norm = convergence.norm(delta, error_y);
Ok(convergence.check_new_iteration(norm))
}
fn reset(&mut self) {}
}
pub struct BacktrackingLineSearch<V: Vector> {
pub tau: V::T,
pub c: V::T,
pub steptol: V::T,
pub max_iter: usize,
pub n_iters: usize,
delta0: V,
x0: V,
norm: V::T,
}
impl<V: Vector> Default for BacktrackingLineSearch<V> {
fn default() -> Self {
let ic_options = InitialConditionSolverOptions::<V::T>::default();
Self {
tau: ic_options.step_reduction_factor,
c: ic_options.armijo_constant,
steptol: V::T::EPSILON.pow(V::T::from_f64(2.0 / 3.0).unwrap()),
max_iter: ic_options.max_linesearch_iterations,
n_iters: 0,
delta0: V::zeros(0, Default::default()),
x0: V::zeros(0, Default::default()),
norm: V::T::one(),
}
}
}
impl<V: Vector> LineSearch<V> for BacktrackingLineSearch<V> {
fn reset(&mut self) {
self.n_iters = 0;
}
fn take_optimal_step(
&mut self,
x: &mut V,
delta: &mut V,
error_y: &V,
fun: &impl Fn(&V, &mut V),
linear_solver: &impl Fn(&mut V) -> Result<(), DiffsolError>,
convergence: &mut Convergence<V>,
) -> Result<ConvergenceStatus, DiffsolError> {
if convergence.niter() == 0 {
fun(x, delta);
linear_solver(delta)?;
self.norm = convergence.norm(delta, error_y);
if let ConvergenceStatus::Converged = convergence.check_norm(self.norm) {
x.sub_assign(&*delta);
return Ok(ConvergenceStatus::Converged);
}
}
if self.x0.len() == 0 {
self.x0 = V::zeros(x.len(), x.context().clone());
self.delta0 = V::zeros(delta.len(), delta.context().clone());
}
self.x0.copy_from(x);
self.delta0.copy_from(delta);
let half = V::T::from_f64(0.5).unwrap();
let norm = self.norm;
let phi0 = norm * norm * half;
let two_phi0 = norm * norm;
let min_alpha = self.steptol / norm;
let mut alpha = V::T::one();
for i in 0..self.max_iter {
x.axpy(-alpha, &self.delta0, V::T::one());
fun(x, delta);
linear_solver(delta)?;
let new_norm = convergence.norm(delta, error_y);
self.n_iters = i;
let phi1 = new_norm * new_norm * half;
if phi1 <= phi0 - self.c * alpha * two_phi0 {
self.norm = new_norm;
return Ok(convergence.check_norm(new_norm));
}
alpha *= self.tau;
if alpha < min_alpha {
return Err(non_linear_solver_error!(LinesearchFailedMinStep));
}
x.copy_from(&self.x0);
}
Err(non_linear_solver_error!(LinesearchFailedMaxIterations))
}
}