use basin::problems::Booth;
use basin::{
Backtracking, BasicState, CostFunction, Executor, Gradient, GradientDescent, GradientTolerance,
InnerExecutor, Solver, State, TerminationReason,
};
struct MultiStartState {
iterates: Vec<Vec<f64>>,
costs: Vec<f64>,
iter: u64,
cost_evals: u64,
}
impl MultiStartState {
fn new(iterates: Vec<Vec<f64>>) -> Self {
let n = iterates.len();
Self {
iterates,
costs: vec![f64::INFINITY; n],
iter: 0,
cost_evals: 0,
}
}
}
impl State for MultiStartState {
type Param = Vec<f64>;
type Float = f64;
fn iter(&self) -> u64 {
self.iter
}
fn increment_iter(&mut self) {
self.iter += 1;
}
fn cost_evals(&self) -> u64 {
self.cost_evals
}
fn increment_cost_evals(&mut self, by: u64) {
self.cost_evals += by;
}
fn param(&self) -> &Vec<f64> {
&self.iterates[0]
}
fn cost(&self) -> f64 {
self.costs[0]
}
}
fn sort_by_cost(iterates: &mut [Vec<f64>], costs: &mut [f64]) {
let n = iterates.len();
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&i, &j| {
costs[i]
.partial_cmp(&costs[j])
.unwrap_or(std::cmp::Ordering::Equal)
});
let new_iterates: Vec<Vec<f64>> = idx.iter().map(|&i| iterates[i].clone()).collect();
let new_costs: Vec<f64> = idx.iter().map(|&i| costs[i]).collect();
iterates.clone_from_slice(&new_iterates);
costs.clone_from_slice(&new_costs);
}
struct PerVertexRefine<G> {
inner: InnerExecutor<BasicState<Vec<f64>>, G>,
}
impl<G> PerVertexRefine<G> {
fn new(inner: InnerExecutor<BasicState<Vec<f64>>, G>) -> Self {
Self { inner }
}
}
impl<P, G> Solver<P, MultiStartState> for PerVertexRefine<G>
where
P: CostFunction<Param = Vec<f64>, Output = f64>
+ Gradient<Param = Vec<f64>, Gradient = Vec<f64>>,
G: Solver<P, BasicState<Vec<f64>>>,
{
fn init(&mut self, problem: &P, mut state: MultiStartState) -> MultiStartState {
for (v, c) in state.iterates.iter().zip(state.costs.iter_mut()) {
*c = problem.cost(v);
}
state.cost_evals += state.iterates.len() as u64;
sort_by_cost(&mut state.iterates, &mut state.costs);
state
}
fn next_iter(
&mut self,
problem: &P,
mut state: MultiStartState,
) -> (MultiStartState, Option<TerminationReason>) {
let mut new_iterates: Vec<Vec<f64>> = Vec::with_capacity(state.iterates.len());
let mut new_costs: Vec<f64> = Vec::with_capacity(state.iterates.len());
let mut aggregated_cost_evals: u64 = 0;
let prev_iterates = std::mem::take(&mut state.iterates);
for v in prev_iterates {
let result = self.inner.run(problem, BasicState::new(v));
if result.reason.is_failure() {
state.cost_evals += aggregated_cost_evals + result.state.cost_evals();
state.iterates = new_iterates;
state.costs = new_costs;
if state.iterates.is_empty() {
state.iterates.push(vec![0.0; 2]);
state.costs.push(f64::INFINITY);
}
return (state, Some(result.reason));
}
aggregated_cost_evals += result.state.cost_evals();
new_costs.push(result.cost());
new_iterates.push(result.param().clone());
}
state.iterates = new_iterates;
state.costs = new_costs;
state.cost_evals += aggregated_cost_evals;
sort_by_cost(&mut state.iterates, &mut state.costs);
(state, None)
}
}
struct AlwaysFails;
impl<P, S: State> Solver<P, S> for AlwaysFails {
fn next_iter(&mut self, _problem: &P, state: S) -> (S, Option<TerminationReason>) {
(state, Some(TerminationReason::SolverFailed))
}
}
#[test]
fn inner_executor_polishes_starts_to_booth_optimum() {
let problem = Booth::<Vec<f64>>::default();
let starts = vec![vec![0.0, 0.0], vec![-1.0, 5.0], vec![3.0, 1.0]];
let outer_state = MultiStartState::new(starts);
let inner = InnerExecutor::new(GradientDescent::with_line_search(Backtracking::new()))
.max_iter(50)
.terminate_on(GradientTolerance(1e-8));
let outer = PerVertexRefine::new(inner);
let result = Executor::new(problem, outer, outer_state).max_iter(3).run();
assert!(
result.cost() < 1e-6,
"expected near-zero cost at Booth optimum (1, 3), got {}",
result.cost()
);
let best = result.param();
assert!(
(best[0] - 1.0).abs() < 1e-3,
"x[0] = {} (expected ≈ 1)",
best[0]
);
assert!(
(best[1] - 3.0).abs() < 1e-3,
"x[1] = {} (expected ≈ 3)",
best[1]
);
}
#[test]
fn inner_executor_aggregates_cost_evals_into_outer() {
let problem = Booth::<Vec<f64>>::default();
let starts = vec![vec![0.0, 0.0], vec![-1.0, 5.0], vec![3.0, 1.0]];
let outer_state = MultiStartState::new(starts);
let inner = InnerExecutor::new(GradientDescent::with_line_search(Backtracking::new()))
.max_iter(50)
.terminate_on(GradientTolerance(1e-8));
let outer = PerVertexRefine::new(inner);
let result = Executor::new(problem, outer, outer_state).max_iter(2).run();
let evals = result.state.cost_evals();
assert!(
evals >= 3 + 6,
"expected outer to aggregate inner work; got {} cost evals (≥ 9 minimum)",
evals
);
}
#[test]
fn inner_executor_bubbles_inner_solver_failed_via_outer() {
let problem = Booth::<Vec<f64>>::default();
let starts = vec![vec![0.0, 0.0]];
let outer_state = MultiStartState::new(starts);
let inner = InnerExecutor::new(AlwaysFails);
let outer = PerVertexRefine::new(inner);
let result = Executor::new(problem, outer, outer_state).max_iter(5).run();
assert_eq!(
result.reason,
TerminationReason::SolverFailed,
"outer should bubble SolverFailed from the inner; got {:?}",
result.reason
);
assert_eq!(result.iter(), 0);
}