use web_time::{Duration, Instant};
use crate::core::constraint::BoxConstrained;
use crate::core::math::{ClampInPlace, NormInfinity, NormSquared, ScaledAdd};
use crate::core::state::{GradientState, SimplexState, State};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TerminationReason {
MaxIter,
MaxCostEvals,
MaxGradientEvals,
GradientTolerance,
RelativeGradientTolerance,
ProjectedGradientTolerance,
ParamTolerance,
RelativeParamTolerance,
CostTolerance,
RelativeCostTolerance,
SimplexTolerance,
MaxTime,
SolverConverged,
SolverFailed,
}
impl TerminationReason {
pub fn is_failure(&self) -> bool {
matches!(self, TerminationReason::SolverFailed)
}
}
pub trait TerminationCriterion<S> {
fn check(&mut self, state: &S) -> Option<TerminationReason>;
}
pub struct MaxIter(pub u64);
impl<S: State> TerminationCriterion<S> for MaxIter {
fn check(&mut self, state: &S) -> Option<TerminationReason> {
if state.iter() >= self.0 {
Some(TerminationReason::MaxIter)
} else {
None
}
}
}
pub struct MaxCostEvals(pub u64);
impl<S: State> TerminationCriterion<S> for MaxCostEvals {
fn check(&mut self, state: &S) -> Option<TerminationReason> {
if state.cost_evals() >= self.0 {
Some(TerminationReason::MaxCostEvals)
} else {
None
}
}
}
pub struct MaxGradientEvals(pub u64);
impl<S: GradientState> TerminationCriterion<S> for MaxGradientEvals {
fn check(&mut self, state: &S) -> Option<TerminationReason> {
if state.gradient_evals() >= self.0 {
Some(TerminationReason::MaxGradientEvals)
} else {
None
}
}
}
pub struct GradientTolerance(pub f64);
impl<S> TerminationCriterion<S> for GradientTolerance
where
S: GradientState,
S::Param: NormSquared,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let g = state.gradient()?;
if g.norm_squared() <= self.0 * self.0 {
Some(TerminationReason::GradientTolerance)
} else {
None
}
}
}
pub struct RelativeGradientTolerance {
tol: f64,
initial_norm_squared: Option<f64>,
}
impl RelativeGradientTolerance {
pub fn new(tol: f64) -> Self {
Self {
tol,
initial_norm_squared: None,
}
}
}
impl<S> TerminationCriterion<S> for RelativeGradientTolerance
where
S: GradientState,
S::Param: NormSquared,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let g = state.gradient()?;
let norm_squared = g.norm_squared();
let initial = *self.initial_norm_squared.get_or_insert(norm_squared);
if norm_squared <= self.tol * self.tol * initial {
Some(TerminationReason::RelativeGradientTolerance)
} else {
None
}
}
}
pub struct ProjectedGradientTolerance<P> {
lower: P,
upper: P,
tol: f64,
}
impl<P> ProjectedGradientTolerance<P> {
pub fn new(lower: P, upper: P, tol: f64) -> Self {
Self { lower, upper, tol }
}
pub fn from_problem<Pr>(problem: &Pr, tol: f64) -> Self
where
Pr: BoxConstrained<Param = P>,
P: Clone,
{
Self {
lower: problem.lower().clone(),
upper: problem.upper().clone(),
tol,
}
}
}
impl<S, P> TerminationCriterion<S> for ProjectedGradientTolerance<P>
where
S: GradientState + State<Param = P>,
P: ScaledAdd<f64> + ClampInPlace + NormInfinity + Clone,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let g = state.gradient()?;
let mut probe = state.param().clone(); probe.scaled_add(-1.0, g); probe.clamp_in_place(&self.lower, &self.upper); probe.scaled_add(-1.0, state.param()); if probe.norm_infinity() <= self.tol {
Some(TerminationReason::ProjectedGradientTolerance)
} else {
None
}
}
}
pub struct ParamTolerance<P> {
tol_squared: f64,
last: Option<P>,
}
impl<P> ParamTolerance<P> {
pub fn new(tol: f64) -> Self {
Self {
tol_squared: tol * tol,
last: None,
}
}
}
impl<S, P> TerminationCriterion<S> for ParamTolerance<P>
where
S: State<Param = P>,
P: ScaledAdd<f64> + NormSquared + Clone,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let curr = state.param();
let triggered = if let Some(last) = &self.last {
let mut diff = curr.clone();
diff.scaled_add(-1.0, last);
diff.norm_squared() <= self.tol_squared
} else {
false
};
self.last = Some(curr.clone());
triggered.then_some(TerminationReason::ParamTolerance)
}
}
pub struct RelativeParamTolerance<P> {
tol: f64,
last: Option<P>,
}
impl<P> RelativeParamTolerance<P> {
pub fn new(tol: f64) -> Self {
Self { tol, last: None }
}
}
impl<S, P> TerminationCriterion<S> for RelativeParamTolerance<P>
where
S: State<Param = P>,
P: ScaledAdd<f64> + NormSquared + Clone,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let curr = state.param();
let triggered = if let Some(last) = &self.last {
let mut diff = curr.clone();
diff.scaled_add(-1.0, last);
diff.norm_squared() <= self.tol * self.tol * curr.norm_squared()
} else {
false
};
self.last = Some(curr.clone());
triggered.then_some(TerminationReason::RelativeParamTolerance)
}
}
pub struct CostTolerance {
tol: f64,
last: Option<f64>,
}
impl CostTolerance {
pub fn new(tol: f64) -> Self {
Self { tol, last: None }
}
}
impl<S> TerminationCriterion<S> for CostTolerance
where
S: State<Float = f64>,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let curr = state.cost();
let triggered = self
.last
.is_some_and(|l| (l - curr).abs() <= self.tol && curr.is_finite());
self.last = Some(curr);
triggered.then_some(TerminationReason::CostTolerance)
}
}
pub struct RelativeCostTolerance {
tol: f64,
last: Option<f64>,
}
impl RelativeCostTolerance {
pub fn new(tol: f64) -> Self {
Self { tol, last: None }
}
}
impl<S> TerminationCriterion<S> for RelativeCostTolerance
where
S: State<Float = f64>,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let curr = state.cost();
let triggered = self
.last
.is_some_and(|l| curr.is_finite() && (l - curr).abs() <= self.tol * l.abs());
self.last = Some(curr);
triggered.then_some(TerminationReason::RelativeCostTolerance)
}
}
pub struct SimplexTolerance {
tol_x: f64,
tol_f: f64,
}
impl SimplexTolerance {
pub fn new(tol_x: f64, tol_f: f64) -> Self {
Self { tol_x, tol_f }
}
}
impl<S> TerminationCriterion<S> for SimplexTolerance
where
S: SimplexState<Float = f64>,
S::Param: Clone + ScaledAdd<f64> + NormInfinity,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let vertices = state.vertices();
let costs = state.costs();
let best = &vertices[0];
let best_cost = costs[0];
for x_i in &vertices[1..] {
let mut diff = x_i.clone();
diff.scaled_add(-1.0, best);
if diff.norm_infinity() > self.tol_x {
return None;
}
}
for &f_i in &costs[1..] {
if (f_i - best_cost).abs() > self.tol_f {
return None;
}
}
Some(TerminationReason::SimplexTolerance)
}
}
pub struct MaxTime {
limit: Duration,
start: Option<Instant>,
}
impl MaxTime {
pub fn new(limit: Duration) -> Self {
Self { limit, start: None }
}
}
impl<S> TerminationCriterion<S> for MaxTime {
fn check(&mut self, _state: &S) -> Option<TerminationReason> {
let start = *self.start.get_or_insert_with(Instant::now);
if start.elapsed() >= self.limit {
Some(TerminationReason::MaxTime)
} else {
None
}
}
}