use crate::core::math::Scalar;
use crate::core::problem::EvalCounts;
use crate::core::state::{CountsMirror, State};
pub struct NllsState<P, F = f64> {
pub(crate) param: P,
pub(crate) cost: Option<F>,
pub(crate) iter: u64,
pub(crate) cost_evals: u64,
pub(crate) residual_evals: u64,
pub(crate) jacobian_evals: u64,
pub(crate) best_param: Option<P>,
pub(crate) best_cost: F,
pub(crate) best_iter: u64,
pub(crate) best_cost_evals: u64,
}
impl<P, F: Scalar> NllsState<P, F> {
pub fn new(param: P) -> Self {
Self {
param,
cost: None,
iter: 0,
cost_evals: 0,
residual_evals: 0,
jacobian_evals: 0,
best_param: None,
best_cost: F::infinity(),
best_iter: 0,
best_cost_evals: 0,
}
}
pub fn residual_evals(&self) -> u64 {
self.residual_evals
}
pub fn jacobian_evals(&self) -> u64 {
self.jacobian_evals
}
}
impl<P: Clone, F: Scalar> State for NllsState<P, F> {
type Param = P;
type Float = F;
fn iter(&self) -> u64 {
self.iter
}
fn increment_iter(&mut self) {
self.iter += 1;
}
fn cost_evals(&self) -> u64 {
self.cost_evals
}
fn param(&self) -> &P {
&self.param
}
fn cost(&self) -> F {
self.cost
.expect("NllsState::cost read before Solver::init populated it")
}
fn best_param(&self) -> &P {
self.best_param
.as_ref()
.expect("NllsState::best_param read before Solver::init populated it")
}
fn best_cost(&self) -> F {
self.best_cost
}
fn best_iter(&self) -> u64 {
self.best_iter
}
fn best_cost_evals(&self) -> u64 {
self.best_cost_evals
}
fn update_best(&mut self) {
if let Some(curr) = self.cost {
if self.best_param.is_none() || curr < self.best_cost {
self.best_param = Some(self.param.clone());
self.best_cost = curr;
self.best_iter = self.iter;
self.best_cost_evals = self.cost_evals;
}
}
}
fn reset_best(&mut self) {
self.best_param = None;
self.best_cost = F::infinity();
self.best_iter = 0;
self.best_cost_evals = 0;
}
}
impl<P, F> CountsMirror for NllsState<P, F>
where
NllsState<P, F>: State,
{
fn mirror(&mut self, delta: &EvalCounts) {
self.cost_evals = delta.cost_evals + delta.residual_evals;
self.residual_evals = delta.residual_evals;
self.jacobian_evals = delta.jacobian_evals + delta.hessian_evals;
}
}