use basin::{
Backtracking, BasicSimplexState, BasicState, CostFunction, CostTolerance, Executor, Gradient,
GradientDescent, GradientState, GradientTolerance, MaxCostEvals, MaxGradientEvals, MaxIter,
MaxTime, NelderMead, NoImprovement, ParamTolerance, Problem, RelativeCostTolerance,
RelativeGradientTolerance, RelativeParamTolerance, Solver, State, TargetCost,
TerminationCriterion, TerminationReason,
};
use std::time::Duration;
struct Quadratic;
impl CostFunction for Quadratic {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, std::convert::Infallible> {
Ok(0.5 * x.iter().map(|v| v * v).sum::<f64>())
}
}
impl Gradient for Quadratic {
type Gradient = Vec<f64>;
fn gradient(&self, x: &Vec<f64>) -> Result<Vec<f64>, std::convert::Infallible> {
Ok(x.clone())
}
}
#[test]
fn gradient_tolerance_fires_at_iter_zero_when_starting_at_optimum() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.1),
BasicState::new(vec![0.0, 0.0]),
)
.terminate_on(GradientTolerance(1e-8))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::GradientTolerance);
assert_eq!(result.iter(), 0, "should not have done any iterations");
}
#[test]
fn gradient_tolerance_fires_after_convergence() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.5),
BasicState::new(vec![1.0, -1.0, 0.5]),
)
.max_iter(1_000)
.terminate_on(GradientTolerance(1e-6))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::GradientTolerance);
assert!(result.iter() > 0 && result.iter() < 1_000);
let g = result
.state
.gradient()
.expect("gradient should be populated");
let g_norm = g.iter().map(|v| v * v).sum::<f64>().sqrt();
assert!(g_norm <= 1e-6);
}
#[test]
fn relative_gradient_tolerance_fires_after_convergence() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.5),
BasicState::new(vec![1.0, 1.0]),
)
.max_iter(1_000)
.terminate_on(RelativeGradientTolerance::new(1e-3))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::RelativeGradientTolerance);
assert!(
result.iter() > 0 && result.iter() < 20,
"expected convergence near iter 10, got {}",
result.iter()
);
}
#[test]
fn relative_gradient_tolerance_fires_at_iter_zero_when_starting_at_optimum() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.1),
BasicState::new(vec![0.0, 0.0]),
)
.terminate_on(RelativeGradientTolerance::new(1e-6))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::RelativeGradientTolerance);
assert_eq!(result.iter(), 0);
}
#[test]
fn relative_gradient_tolerance_is_scale_invariant() {
let run_from = |x0: f64| {
Executor::new(
Quadratic,
GradientDescent::new(0.5),
BasicState::new(vec![x0, x0]),
)
.max_iter(1_000)
.terminate_on(RelativeGradientTolerance::new(1e-3))
.run()
.unwrap()
};
let small = run_from(1.0);
let large = run_from(1.0e6);
assert_eq!(small.reason, TerminationReason::RelativeGradientTolerance);
assert_eq!(large.reason, TerminationReason::RelativeGradientTolerance);
assert_eq!(
small.iter(),
large.iter(),
"relative gradient tolerance should be scale-invariant"
);
}
#[test]
fn max_iter_field_default_is_one_thousand() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.001), BasicState::new(vec![10.0, 10.0]),
)
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::MaxIter);
assert_eq!(result.iter(), 1_000);
}
#[test]
fn explicit_max_iter_criterion_works_alongside_default() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.001),
BasicState::new(vec![10.0, 10.0]),
)
.terminate_on(MaxIter(5))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::MaxIter);
assert_eq!(result.iter(), 5);
}
#[test]
fn param_tolerance_fires_when_steps_become_small() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.5),
BasicState::new(vec![1.0, 1.0]),
)
.max_iter(1_000)
.terminate_on(ParamTolerance::new(1e-8))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::ParamTolerance);
}
#[test]
fn cost_tolerance_fires_when_cost_stagnates() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.5),
BasicState::new(vec![1.0, 1.0]),
)
.max_iter(1_000)
.terminate_on(CostTolerance::new(1e-12))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::CostTolerance);
}
#[test]
fn relative_param_tolerance_fires_when_relative_step_small() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.001),
BasicState::new(vec![1.0, 1.0]),
)
.max_iter(1_000)
.terminate_on(RelativeParamTolerance::new(1e-2))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::RelativeParamTolerance);
assert!(result.iter() < 5, "fired late at iter {}", result.iter());
}
#[test]
fn relative_cost_tolerance_fires_when_relative_reduction_small() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.001),
BasicState::new(vec![1.0, 1.0]),
)
.max_iter(1_000)
.terminate_on(RelativeCostTolerance::new(1e-2))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::RelativeCostTolerance);
assert!(result.iter() < 5, "fired late at iter {}", result.iter());
}
#[test]
fn target_cost_fires_at_iter_zero_when_start_is_below_target() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.1),
BasicState::new(vec![0.5, 0.5]),
)
.terminate_on(TargetCost(1.0))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::TargetCost);
assert_eq!(result.iter(), 0);
}
#[test]
fn target_cost_fires_when_cost_drops_to_target() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.5),
BasicState::new(vec![1.0, 1.0]),
)
.max_iter(1_000)
.terminate_on(TargetCost(1e-3))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::TargetCost);
assert!(result.iter() > 0 && result.iter() < 1_000);
assert!(result.state.cost() <= 1e-3);
}
#[test]
fn target_cost_does_not_fire_when_target_unreachable() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.1),
BasicState::new(vec![1.0, 1.0]),
)
.terminate_on(MaxIter(10))
.terminate_on(TargetCost(-1.0))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::MaxIter);
assert_eq!(result.iter(), 10);
}
#[test]
fn no_improvement_fires_after_patience_stalled_iters() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.5),
BasicState::new(vec![1.0, 1.0]),
)
.max_iter(100)
.terminate_on(NoImprovement::new(3, 10.0))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::NoImprovement);
assert_eq!(result.iter(), 3);
}
#[test]
fn no_improvement_does_not_fire_under_monotone_decrease() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.5),
BasicState::new(vec![1.0, 1.0]),
)
.terminate_on(MaxIter(20))
.terminate_on(NoImprovement::new(5, 0.0))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::MaxIter);
assert_eq!(result.iter(), 20);
}
#[test]
fn no_improvement_resets_counter_on_real_improvement() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.5),
BasicState::new(vec![1.0, 1.0]),
)
.max_iter(100)
.terminate_on(NoImprovement::new(3, 0.1))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::NoImprovement);
assert_eq!(result.iter(), 5);
}
#[test]
fn relative_cost_tolerance_is_scale_invariant() {
let run_from = |x0: f64| {
Executor::new(
Quadratic,
GradientDescent::new(0.001),
BasicState::new(vec![x0, x0]),
)
.max_iter(1_000)
.terminate_on(RelativeCostTolerance::new(1e-2))
.run()
.unwrap()
};
let small = run_from(1.0);
let large = run_from(1.0e6);
assert_eq!(small.reason, TerminationReason::RelativeCostTolerance);
assert_eq!(large.reason, TerminationReason::RelativeCostTolerance);
assert_eq!(
small.iter(),
large.iter(),
"relative cost tolerance should be scale-invariant"
);
}
#[test]
fn first_criterion_to_fire_wins() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.1),
BasicState::new(vec![1.0, 1.0]),
)
.max_iter(1_000)
.terminate_on(ParamTolerance::new(100.0))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::ParamTolerance);
assert!(result.iter() < 5);
}
#[test]
fn max_time_eventually_fires() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.001),
BasicState::new(vec![1e6, 1e6, 1e6]),
)
.max_iter(u64::MAX)
.terminate_on(MaxTime::new(Duration::from_millis(50)))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::MaxTime);
}
struct AlwaysConverged;
impl Solver<Quadratic, BasicState<Vec<f64>>> for AlwaysConverged {
type Error = std::convert::Infallible;
fn next_iter(
&mut self,
_problem: &mut Problem<Quadratic>,
state: BasicState<Vec<f64>>,
) -> Result<(BasicState<Vec<f64>>, Option<TerminationReason>), Self::Error> {
Ok((state, None))
}
fn terminate(&self, _state: &BasicState<Vec<f64>>) -> Option<TerminationReason> {
Some(TerminationReason::SolverConverged)
}
}
#[test]
fn solver_terminate_hook_is_honored() {
let result = Executor::new(Quadratic, AlwaysConverged, BasicState::new(vec![1.0, 2.0]))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::SolverConverged);
assert_eq!(result.iter(), 0);
}
struct FailsOnSecondCall {
calls: u64,
}
impl Solver<Quadratic, BasicState<Vec<f64>>> for FailsOnSecondCall {
type Error = std::convert::Infallible;
fn next_iter(
&mut self,
_problem: &mut Problem<Quadratic>,
state: BasicState<Vec<f64>>,
) -> Result<(BasicState<Vec<f64>>, Option<TerminationReason>), Self::Error> {
Ok({
self.calls += 1;
if self.calls >= 2 {
(state, Some(TerminationReason::SolverFailed))
} else {
(state, None)
}
})
}
}
#[test]
fn solver_can_signal_termination_mid_iter() {
let result = Executor::new(
Quadratic,
FailsOnSecondCall { calls: 0 },
BasicState::new(vec![1.0, 2.0]),
)
.max_iter(100)
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::SolverFailed);
assert_eq!(result.iter(), 1);
}
struct StopAt(u64);
impl<S: State> TerminationCriterion<S> for StopAt {
fn check(&mut self, state: &S) -> Option<TerminationReason> {
(state.iter() == self.0).then_some(TerminationReason::SolverConverged)
}
}
#[test]
fn cost_evals_matches_iter_for_constant_step_gradient_descent() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.001),
BasicState::new(vec![10.0, 10.0]),
)
.terminate_on(MaxIter(20))
.run()
.unwrap();
assert_eq!(result.iter(), 20);
assert_eq!(result.state.cost_evals(), 21);
}
#[test]
fn cost_evals_exceeds_iter_with_backtracking() {
let result = Executor::new(
Quadratic,
GradientDescent::with_line_search(Backtracking::new().alpha_init(8.0).rho(0.5)),
BasicState::new(vec![1.0, 1.0]),
)
.terminate_on(MaxIter(10))
.run()
.unwrap();
assert_eq!(result.iter(), 10);
assert!(
result.state.cost_evals() > result.iter() + 1,
"expected line search to inflate cost_evals beyond iter+1: cost_evals={}, iter={}",
result.state.cost_evals(),
result.iter()
);
}
#[test]
fn cost_evals_exceeds_iter_for_nelder_mead_shrinks() {
let result = Executor::new(
Quadratic,
NelderMead::new(),
BasicSimplexState::new(vec![2.0, -3.0]),
)
.terminate_on(MaxIter(50))
.run()
.unwrap();
assert!(result.state.cost_evals() >= result.iter() + 3);
}
#[test]
fn max_gradient_evals_fires_before_max_iter() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.001),
BasicState::new(vec![10.0, 10.0]),
)
.max_iter(10_000)
.terminate_on(MaxGradientEvals(5))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::MaxGradientEvals);
assert!(result.state.gradient_evals() >= 5);
}
#[test]
fn max_cost_evals_fires_before_max_iter() {
let result = Executor::new(
Quadratic,
NelderMead::new(),
BasicSimplexState::new(vec![5.0, -2.0, 4.0]),
)
.max_iter(10_000)
.terminate_on(MaxCostEvals(25))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::MaxCostEvals);
assert!(
result.state.cost_evals() >= 25,
"cost_evals should have reached the budget: {}",
result.state.cost_evals()
);
}
#[test]
fn custom_termination_criterion() {
let result = Executor::new(
Quadratic,
GradientDescent::new(0.1),
BasicState::new(vec![5.0, 5.0]),
)
.max_iter(1_000)
.terminate_on(StopAt(7))
.run()
.unwrap();
assert_eq!(result.reason, TerminationReason::SolverConverged);
assert_eq!(result.iter(), 7);
}