use argmin::{argmin_error, argmin_error_closure, core::*, float};
#[derive(Clone, Copy)]
pub struct CurveClosestParameterNewton<F, P> {
gamma: F,
knot_domain: (P, P),
closed: bool,
}
impl<F, P> CurveClosestParameterNewton<F, P>
where
F: ArgminFloat,
P: Clone,
{
pub fn new(domain: (P, P), closed: bool) -> Self {
CurveClosestParameterNewton {
gamma: float!(1.0),
knot_domain: domain,
closed,
}
}
#[allow(unused)]
pub fn with_gamma(mut self, gamma: F) -> Result<Self, Error> {
if gamma <= float!(0.0) || gamma > float!(1.0) {
return Err(argmin_error!(
InvalidParameter,
"Newton: gamma must be in (0, 1]."
));
}
self.gamma = gamma;
Ok(self)
}
}
impl<O, F> Solver<O, IterState<F, F, (), F, (), F>> for CurveClosestParameterNewton<F, F>
where
O: Gradient<Param = F, Gradient = F> + Hessian<Param = F, Hessian = F>,
F: Clone + ArgminFloat,
{
fn name(&self) -> &str {
"Closest parameter newton method"
}
fn next_iter(
&mut self,
problem: &mut Problem<O>,
state: IterState<F, F, (), F, (), F>,
) -> Result<(IterState<F, F, (), F, (), F>, Option<KV>), Error> {
let param = state.get_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`Newton` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?;
let grad = problem.gradient(param)?;
let hessian = problem.hessian(param)?;
let inv = F::one() / hessian;
let new_param = *param - self.gamma * inv * grad;
let new_param = if new_param < self.knot_domain.0 {
if self.closed {
self.knot_domain.1 - (new_param - self.knot_domain.0)
} else {
self.knot_domain.0
}
} else if new_param > self.knot_domain.1 {
if self.closed {
self.knot_domain.0 + (new_param - self.knot_domain.1)
} else {
self.knot_domain.1
}
} else {
new_param
};
Ok((state.param(new_param), None))
}
fn terminate(&mut self, state: &IterState<F, F, (), F, (), F>) -> TerminationStatus {
if state.iter > state.max_iters {
return TerminationStatus::Terminated(TerminationReason::MaxItersReached);
}
match (state.get_param(), state.get_prev_param()) {
(Some(current_param), Some(prev_param)) => {
let delta = (*current_param - *prev_param).abs();
if delta < F::epsilon() {
TerminationStatus::Terminated(TerminationReason::SolverConverged)
} else {
TerminationStatus::NotTerminated
}
}
_ => TerminationStatus::NotTerminated,
}
}
}