use crate::core::math::Scalar;
use crate::core::problem::EvalCounts;
use crate::core::state::{CountsMirror, GradientState, State};
pub struct ScalarGradientState<F = f64> {
pub(crate) param: F,
pub(crate) cost: Option<F>,
pub(crate) gradient: Option<F>,
pub(crate) iter: u64,
pub(crate) cost_evals: u64,
pub(crate) gradient_evals: u64,
pub(crate) best_param: Option<F>,
pub(crate) best_cost: F,
pub(crate) best_iter: u64,
pub(crate) best_cost_evals: u64,
pub(crate) best_gradient_evals: u64,
}
impl<F: Scalar> ScalarGradientState<F> {
pub fn new(param: F) -> Self {
Self {
param,
cost: None,
gradient: None,
iter: 0,
cost_evals: 0,
gradient_evals: 0,
best_param: None,
best_cost: F::infinity(),
best_iter: 0,
best_cost_evals: 0,
best_gradient_evals: 0,
}
}
}
impl<F: Scalar> State for ScalarGradientState<F> {
type Param = F;
type Float = F;
fn iter(&self) -> u64 {
self.iter
}
fn increment_iter(&mut self) {
self.iter += 1;
}
fn cost_evals(&self) -> u64 {
self.cost_evals
}
fn param(&self) -> &F {
&self.param
}
fn cost(&self) -> F {
self.cost
.expect("ScalarGradientState::cost read before Solver::init populated it")
}
fn best_param(&self) -> &F {
self.best_param
.as_ref()
.expect("ScalarGradientState::best_param read before Solver::init populated it")
}
fn best_cost(&self) -> F {
self.best_cost
}
fn best_iter(&self) -> u64 {
self.best_iter
}
fn best_cost_evals(&self) -> u64 {
self.best_cost_evals
}
fn update_best(&mut self) {
if let Some(curr) = self.cost {
if self.best_param.is_none() || curr < self.best_cost {
self.best_param = Some(self.param);
self.best_cost = curr;
self.best_iter = self.iter;
self.best_cost_evals = self.cost_evals;
self.best_gradient_evals = self.gradient_evals;
}
}
}
fn reset_best(&mut self) {
self.best_param = None;
self.best_cost = F::infinity();
self.best_iter = 0;
self.best_cost_evals = 0;
self.best_gradient_evals = 0;
}
}
impl<F: Scalar> GradientState for ScalarGradientState<F> {
fn gradient(&self) -> Option<&F> {
self.gradient.as_ref()
}
fn gradient_evals(&self) -> u64 {
self.gradient_evals
}
fn best_gradient_evals(&self) -> u64 {
self.best_gradient_evals
}
}
impl<F> CountsMirror for ScalarGradientState<F>
where
ScalarGradientState<F>: State,
{
fn mirror(&mut self, delta: &EvalCounts) {
self.cost_evals = delta.cost_evals + delta.residual_evals;
self.gradient_evals = delta.gradient_evals + delta.jacobian_evals + delta.hessian_evals;
}
}