use web_time::{Duration, Instant};
use crate::core::constraint::BoxConstraints;
use crate::core::math::{ClampInPlace, NormInfinity, NormSquared, Scalar, ScaledAdd, VectorLen};
use crate::core::state::{CmaEsState, GradientState, SimplexState, State};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum TerminationReason {
MaxIter,
MaxCostEvals,
MaxGradientEvals,
GradientTolerance,
RelativeGradientTolerance,
ProjectedGradientTolerance,
ParamTolerance,
RelativeParamTolerance,
CostTolerance,
RelativeCostTolerance,
TargetCost,
NoImprovement,
SimplexTolerance,
CmaEsTolerance,
MaxTime,
SolverConverged,
SolverFailed,
}
impl TerminationReason {
pub fn is_failure(&self) -> bool {
matches!(self, Self::SolverFailed)
}
}
pub trait TerminationCriterion<S> {
fn check(&mut self, state: &S) -> Option<TerminationReason>;
fn reset(&mut self) {}
}
pub struct MaxIter(pub u64);
impl<S: State> TerminationCriterion<S> for MaxIter {
fn check(&mut self, state: &S) -> Option<TerminationReason> {
if state.iter() >= self.0 {
Some(TerminationReason::MaxIter)
} else {
None
}
}
}
pub struct MaxCostEvals(pub u64);
impl<S: State> TerminationCriterion<S> for MaxCostEvals {
fn check(&mut self, state: &S) -> Option<TerminationReason> {
if state.cost_evals() >= self.0 {
Some(TerminationReason::MaxCostEvals)
} else {
None
}
}
}
pub struct MaxGradientEvals(pub u64);
impl<S: GradientState> TerminationCriterion<S> for MaxGradientEvals {
fn check(&mut self, state: &S) -> Option<TerminationReason> {
if state.gradient_evals() >= self.0 {
Some(TerminationReason::MaxGradientEvals)
} else {
None
}
}
}
pub struct GradientTolerance<F = f64>(pub F);
impl<S, F> TerminationCriterion<S> for GradientTolerance<F>
where
F: Scalar,
S: GradientState,
S::Param: NormSquared<F>,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let g = state.gradient()?;
if g.norm_squared() <= self.0 * self.0 {
Some(TerminationReason::GradientTolerance)
} else {
None
}
}
}
pub struct RelativeGradientTolerance<F = f64> {
tol: F,
initial_norm_squared: Option<F>,
}
impl<F: Scalar> RelativeGradientTolerance<F> {
pub fn new(tol: F) -> Self {
Self {
tol,
initial_norm_squared: None,
}
}
}
impl<S, F> TerminationCriterion<S> for RelativeGradientTolerance<F>
where
F: Scalar,
S: GradientState,
S::Param: NormSquared<F>,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let g = state.gradient()?;
let norm_squared = g.norm_squared();
let initial = *self.initial_norm_squared.get_or_insert(norm_squared);
if norm_squared <= self.tol * self.tol * initial {
Some(TerminationReason::RelativeGradientTolerance)
} else {
None
}
}
fn reset(&mut self) {
self.initial_norm_squared = None;
}
}
pub struct ProjectedGradientTolerance<P, F = f64> {
lower: P,
upper: P,
tol: F,
}
impl<P, F> ProjectedGradientTolerance<P, F> {
pub fn new(lower: P, upper: P, tol: F) -> Self {
Self { lower, upper, tol }
}
pub fn from_problem<Pr>(problem: &Pr, tol: F) -> Self
where
Pr: BoxConstraints<Param = P>,
P: Clone,
{
Self {
lower: problem.lower().clone(),
upper: problem.upper().clone(),
tol,
}
}
}
impl<S, P, F> TerminationCriterion<S> for ProjectedGradientTolerance<P, F>
where
F: Scalar,
S: GradientState + State<Param = P>,
P: ScaledAdd<F> + ClampInPlace + NormInfinity<F> + Clone,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let g = state.gradient()?;
let mut probe = state.param().clone(); probe.scaled_add(-F::one(), g); probe.clamp_in_place(&self.lower, &self.upper); probe.scaled_add(-F::one(), state.param()); if probe.norm_infinity() <= self.tol {
Some(TerminationReason::ProjectedGradientTolerance)
} else {
None
}
}
}
pub struct ParamTolerance<P, F = f64> {
tol_squared: F,
last: Option<P>,
}
impl<P, F: Scalar> ParamTolerance<P, F> {
pub fn new(tol: F) -> Self {
Self {
tol_squared: tol * tol,
last: None,
}
}
}
impl<S, P, F> TerminationCriterion<S> for ParamTolerance<P, F>
where
F: Scalar,
S: State<Param = P>,
P: ScaledAdd<F> + NormSquared<F> + Clone,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let curr = state.param();
let triggered = if let Some(last) = &self.last {
let mut diff = curr.clone();
diff.scaled_add(-F::one(), last);
diff.norm_squared() <= self.tol_squared
} else {
false
};
self.last = Some(curr.clone());
triggered.then_some(TerminationReason::ParamTolerance)
}
fn reset(&mut self) {
self.last = None;
}
}
pub struct RelativeParamTolerance<P, F = f64> {
tol: F,
last: Option<P>,
}
impl<P, F> RelativeParamTolerance<P, F> {
pub fn new(tol: F) -> Self {
Self { tol, last: None }
}
}
impl<S, P, F> TerminationCriterion<S> for RelativeParamTolerance<P, F>
where
F: Scalar,
S: State<Param = P>,
P: ScaledAdd<F> + NormSquared<F> + Clone,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let curr = state.param();
let triggered = if let Some(last) = &self.last {
let mut diff = curr.clone();
diff.scaled_add(-F::one(), last);
diff.norm_squared() <= self.tol * self.tol * curr.norm_squared()
} else {
false
};
self.last = Some(curr.clone());
triggered.then_some(TerminationReason::RelativeParamTolerance)
}
fn reset(&mut self) {
self.last = None;
}
}
pub struct CostTolerance<F = f64> {
tol: F,
last: Option<F>,
}
impl<F> CostTolerance<F> {
pub fn new(tol: F) -> Self {
Self { tol, last: None }
}
}
impl<S, F> TerminationCriterion<S> for CostTolerance<F>
where
F: Scalar,
S: State<Float = F>,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let curr = state.cost();
let triggered = self
.last
.is_some_and(|l| (l - curr).abs() <= self.tol && curr.is_finite());
self.last = Some(curr);
triggered.then_some(TerminationReason::CostTolerance)
}
fn reset(&mut self) {
self.last = None;
}
}
pub struct RelativeCostTolerance<F = f64> {
tol: F,
last: Option<F>,
}
impl<F> RelativeCostTolerance<F> {
pub fn new(tol: F) -> Self {
Self { tol, last: None }
}
}
impl<S, F> TerminationCriterion<S> for RelativeCostTolerance<F>
where
F: Scalar,
S: State<Float = F>,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let curr = state.cost();
let triggered = self
.last
.is_some_and(|l| curr.is_finite() && (l - curr).abs() <= self.tol * l.abs());
self.last = Some(curr);
triggered.then_some(TerminationReason::RelativeCostTolerance)
}
fn reset(&mut self) {
self.last = None;
}
}
pub struct TargetCost<F = f64>(pub F);
impl<S, F> TerminationCriterion<S> for TargetCost<F>
where
F: Scalar,
S: State<Float = F>,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
(state.best_cost() <= self.0).then_some(TerminationReason::TargetCost)
}
}
pub struct NoImprovement<F = f64> {
patience: u64,
tol: F,
anchor: Option<F>,
stalled: u64,
}
impl<F> NoImprovement<F> {
pub fn new(patience: u64, tol: F) -> Self {
Self {
patience,
tol,
anchor: None,
stalled: 0,
}
}
}
impl<S, F> TerminationCriterion<S> for NoImprovement<F>
where
F: Scalar,
S: State<Float = F>,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let curr = state.best_cost();
let improved = match self.anchor {
None => curr.is_finite(),
Some(anchor) => curr.is_finite() && curr < anchor - self.tol,
};
if improved {
self.anchor = Some(curr);
self.stalled = 0;
None
} else {
self.stalled += 1;
(self.stalled >= self.patience).then_some(TerminationReason::NoImprovement)
}
}
fn reset(&mut self) {
self.anchor = None;
self.stalled = 0;
}
}
pub struct SimplexTolerance<F = f64> {
tol_x: F,
tol_f: F,
}
impl<F> SimplexTolerance<F> {
pub fn new(tol_x: F, tol_f: F) -> Self {
Self { tol_x, tol_f }
}
}
impl<S, F> TerminationCriterion<S> for SimplexTolerance<F>
where
F: Scalar,
S: SimplexState<Float = F>,
S::Param: Clone + ScaledAdd<F> + NormInfinity<F>,
{
fn check(&mut self, state: &S) -> Option<TerminationReason> {
let vertices = state.vertices();
let costs = state.costs();
let best = &vertices[0];
let best_cost = costs[0];
for x_i in &vertices[1..] {
let mut diff = x_i.clone();
diff.scaled_add(-F::one(), best);
if diff.norm_infinity() > self.tol_x {
return None;
}
}
for &f_i in &costs[1..] {
if (f_i - best_cost).abs() > self.tol_f {
return None;
}
}
Some(TerminationReason::SimplexTolerance)
}
}
pub struct CmaEsTolerance<F = f64> {
tol_x: F,
}
impl<F> CmaEsTolerance<F> {
pub fn new(tol_x: F) -> Self {
Self { tol_x }
}
}
impl<V, M, F> TerminationCriterion<CmaEsState<V, M, F>> for CmaEsTolerance<F>
where
F: Scalar,
V: VectorLen + std::ops::Index<usize, Output = F>,
{
fn check(&mut self, state: &CmaEsState<V, M, F>) -> Option<TerminationReason> {
(state.sigma() * state.max_axis_std() < self.tol_x)
.then_some(TerminationReason::CmaEsTolerance)
}
}
pub struct MaxTime {
limit: Duration,
start: Option<Instant>,
}
impl MaxTime {
pub fn new(limit: Duration) -> Self {
Self { limit, start: None }
}
}
impl<S> TerminationCriterion<S> for MaxTime {
fn check(&mut self, _state: &S) -> Option<TerminationReason> {
let start = *self.start.get_or_insert_with(Instant::now);
if start.elapsed() >= self.limit {
Some(TerminationReason::MaxTime)
} else {
None
}
}
fn reset(&mut self) {
self.start = None;
}
}
#[cfg(test)]
mod reset_tests {
use super::*;
use crate::core::state::BasicState;
type S = BasicState<Vec<f64>>;
#[test]
fn relative_gradient_tolerance_reanchors_after_reset() {
let mut c = RelativeGradientTolerance::new(0.1_f64);
let mut state: S = BasicState::new(vec![0.0]);
state.gradient = Some(vec![10.0]);
assert!(TerminationCriterion::<S>::check(&mut c, &state).is_none());
TerminationCriterion::<S>::reset(&mut c);
state.gradient = Some(vec![0.5]);
assert!(
TerminationCriterion::<S>::check(&mut c, &state).is_none(),
"reset should re-anchor ‖∇f_0‖ to this run's initial gradient"
);
}
#[test]
fn no_improvement_clears_stall_counter_after_reset() {
let mut c = NoImprovement::new(2, 0.0_f64);
let mut state: S = BasicState::new(vec![0.0]);
state.best_cost = 10.0;
assert!(TerminationCriterion::<S>::check(&mut c, &state).is_none());
assert!(TerminationCriterion::<S>::check(&mut c, &state).is_none());
TerminationCriterion::<S>::reset(&mut c);
assert!(
TerminationCriterion::<S>::check(&mut c, &state).is_none(),
"reset should clear the anchor and stall counter"
);
}
#[test]
fn max_time_restarts_clock_after_reset() {
let mut c = MaxTime::new(Duration::from_millis(20));
assert!(TerminationCriterion::<()>::check(&mut c, &()).is_none());
std::thread::sleep(Duration::from_millis(40));
assert_eq!(
TerminationCriterion::<()>::check(&mut c, &()),
Some(TerminationReason::MaxTime)
);
TerminationCriterion::<()>::reset(&mut c);
assert!(
TerminationCriterion::<()>::check(&mut c, &()).is_none(),
"reset should restart the wall-clock from the next check"
);
}
}