use basin::problems::Booth;
use basin::{
Backtracking, BasicState, CostFunction, CountsMirror, EvalCounts, Executor, Gradient,
GradientDescent, GradientTolerance, InnerExecutor, Problem, Solver, State,
TerminationCriterion, TerminationReason,
};
use std::cell::RefCell;
use std::rc::Rc;
struct MultiStartState {
iterates: Vec<Vec<f64>>,
costs: Vec<f64>,
iter: u64,
cost_evals: u64,
best_cost: f64,
best_iter: u64,
best_cost_evals: u64,
}
impl MultiStartState {
fn new(iterates: Vec<Vec<f64>>) -> Self {
let n = iterates.len();
Self {
iterates,
costs: vec![f64::INFINITY; n],
iter: 0,
cost_evals: 0,
best_cost: f64::INFINITY,
best_iter: 0,
best_cost_evals: 0,
}
}
}
impl State for MultiStartState {
type Param = Vec<f64>;
type Float = f64;
fn iter(&self) -> u64 {
self.iter
}
fn increment_iter(&mut self) {
self.iter += 1;
}
fn cost_evals(&self) -> u64 {
self.cost_evals
}
fn param(&self) -> &Vec<f64> {
&self.iterates[0]
}
fn cost(&self) -> f64 {
self.costs[0]
}
fn best_param(&self) -> &Vec<f64> {
&self.iterates[0]
}
fn best_cost(&self) -> f64 {
self.best_cost
}
fn best_iter(&self) -> u64 {
self.best_iter
}
fn best_cost_evals(&self) -> u64 {
self.best_cost_evals
}
fn update_best(&mut self) {
let curr = self.costs[0];
if curr < self.best_cost {
self.best_cost = curr;
self.best_iter = self.iter;
self.best_cost_evals = self.cost_evals;
}
}
fn reset_best(&mut self) {
self.best_cost = f64::INFINITY;
self.best_iter = 0;
self.best_cost_evals = 0;
}
}
impl CountsMirror for MultiStartState {
fn mirror(&mut self, delta: &EvalCounts) {
self.cost_evals = delta.total_work();
}
}
fn sort_by_cost(iterates: &mut [Vec<f64>], costs: &mut [f64]) {
let n = iterates.len();
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&i, &j| {
costs[i]
.partial_cmp(&costs[j])
.unwrap_or(std::cmp::Ordering::Equal)
});
let new_iterates: Vec<Vec<f64>> = idx.iter().map(|&i| iterates[i].clone()).collect();
let new_costs: Vec<f64> = idx.iter().map(|&i| costs[i]).collect();
iterates.clone_from_slice(&new_iterates);
costs.clone_from_slice(&new_costs);
}
struct PerVertexRefine<G> {
inner: InnerExecutor<BasicState<Vec<f64>>, G>,
}
impl<G> PerVertexRefine<G> {
fn new(inner: InnerExecutor<BasicState<Vec<f64>>, G>) -> Self {
Self { inner }
}
}
impl<P, G> Solver<P, MultiStartState> for PerVertexRefine<G>
where
P: CostFunction<Param = Vec<f64>, Output = f64>
+ Gradient<Param = Vec<f64>, Gradient = Vec<f64>>,
G: Solver<P, BasicState<Vec<f64>>, Error = P::Error>,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: MultiStartState,
) -> Result<MultiStartState, Self::Error> {
for (v, c) in state.iterates.iter().zip(state.costs.iter_mut()) {
*c = problem.cost(v)?;
}
sort_by_cost(&mut state.iterates, &mut state.costs);
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
mut state: MultiStartState,
) -> Result<(MultiStartState, Option<TerminationReason>), Self::Error> {
let mut new_iterates: Vec<Vec<f64>> = Vec::with_capacity(state.iterates.len());
let mut new_costs: Vec<f64> = Vec::with_capacity(state.iterates.len());
let prev_iterates = std::mem::take(&mut state.iterates);
for v in prev_iterates {
let result = self.inner.run(problem, BasicState::new(v))?;
if result.reason.is_failure() {
state.iterates = new_iterates;
state.costs = new_costs;
if state.iterates.is_empty() {
state.iterates.push(vec![0.0; 2]);
state.costs.push(f64::INFINITY);
}
return Ok((state, Some(result.reason)));
}
new_costs.push(result.cost());
new_iterates.push(result.param().clone());
}
state.iterates = new_iterates;
state.costs = new_costs;
sort_by_cost(&mut state.iterates, &mut state.costs);
Ok((state, None))
}
}
struct AlwaysFails;
impl<P, S: State> Solver<P, S> for AlwaysFails {
type Error = std::convert::Infallible;
fn next_iter(
&mut self,
_problem: &mut Problem<P>,
state: S,
) -> Result<(S, Option<TerminationReason>), Self::Error> {
Ok((state, Some(TerminationReason::SolverFailed)))
}
}
#[test]
fn inner_executor_polishes_starts_to_booth_optimum() {
let problem = Booth::<Vec<f64>>::default();
let starts = vec![vec![0.0, 0.0], vec![-1.0, 5.0], vec![3.0, 1.0]];
let outer_state = MultiStartState::new(starts);
let inner = InnerExecutor::new(GradientDescent::with_line_search(Backtracking::new()))
.max_iter(50)
.terminate_on(GradientTolerance(1e-8));
let outer = PerVertexRefine::new(inner);
let result = Executor::new(problem, outer, outer_state)
.max_iter(3)
.run()
.unwrap();
assert!(
result.cost() < 1e-6,
"expected near-zero cost at Booth optimum (1, 3), got {}",
result.cost()
);
let best = result.param();
assert!(
(best[0] - 1.0).abs() < 1e-3,
"x[0] = {} (expected ≈ 1)",
best[0]
);
assert!(
(best[1] - 3.0).abs() < 1e-3,
"x[1] = {} (expected ≈ 3)",
best[1]
);
}
#[test]
fn inner_executor_aggregates_cost_evals_into_outer() {
let problem = Booth::<Vec<f64>>::default();
let starts = vec![vec![0.0, 0.0], vec![-1.0, 5.0], vec![3.0, 1.0]];
let outer_state = MultiStartState::new(starts);
let inner = InnerExecutor::new(GradientDescent::with_line_search(Backtracking::new()))
.max_iter(50)
.terminate_on(GradientTolerance(1e-8));
let outer = PerVertexRefine::new(inner);
let result = Executor::new(problem, outer, outer_state)
.max_iter(2)
.run()
.unwrap();
let evals = result.state.cost_evals();
assert!(
evals >= 3 + 6,
"expected outer to aggregate inner work; got {} cost evals (≥ 9 minimum)",
evals
);
}
struct CountResets(Rc<RefCell<u32>>);
impl<S> TerminationCriterion<S> for CountResets {
fn check(&mut self, _state: &S) -> Option<TerminationReason> {
None
}
fn reset(&mut self) {
*self.0.borrow_mut() += 1;
}
}
#[test]
fn inner_executor_resets_criteria_once_per_run() {
let problem = Booth::<Vec<f64>>::default();
let starts = vec![vec![0.0, 0.0], vec![-1.0, 5.0], vec![3.0, 1.0]];
let outer_state = MultiStartState::new(starts);
let resets = Rc::new(RefCell::new(0u32));
let inner = InnerExecutor::new(GradientDescent::with_line_search(Backtracking::new()))
.max_iter(50)
.terminate_on(GradientTolerance(1e-8))
.terminate_on(CountResets(Rc::clone(&resets)));
let outer = PerVertexRefine::new(inner);
Executor::new(problem, outer, outer_state)
.max_iter(2)
.run()
.unwrap();
assert_eq!(
*resets.borrow(),
6,
"run_loop should reset the reused criteria vector once per inner run"
);
}
#[test]
fn inner_executor_bubbles_inner_solver_failed_via_outer() {
let problem = Booth::<Vec<f64>>::default();
let starts = vec![vec![0.0, 0.0]];
let outer_state = MultiStartState::new(starts);
let inner = InnerExecutor::new(AlwaysFails);
let outer = PerVertexRefine::new(inner);
let result = Executor::new(problem, outer, outer_state)
.max_iter(5)
.run()
.unwrap();
assert_eq!(
result.reason,
TerminationReason::SolverFailed,
"outer should bubble SolverFailed from the inner; got {:?}",
result.reason
);
assert_eq!(result.iter(), 0);
}