use crate::error::OptimizeError;
use crate::unconstrained::{minimize, Bounds, Method, OptimizeResult, Options};
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::random::rngs::StdRng;
use scirs2_core::random::{Rng, RngExt, SeedableRng};
#[allow(dead_code)]
fn enforce_bounds_with_reflection<R: Rng>(rng: &mut R, val: f64, lb: f64, ub: f64) -> f64 {
if val >= lb && val <= ub {
val
} else if val < lb {
let excess = lb - val;
let range = ub - lb;
if excess <= range {
lb + excess
} else {
rng.random_range(lb..=ub)
}
} else {
let excess = val - ub;
let range = ub - lb;
if excess <= range {
ub - excess
} else {
rng.random_range(lb..=ub)
}
}
}
#[allow(dead_code)]
fn validate_bounds(bounds: &[(f64, f64)]) -> Result<(), OptimizeError> {
for (i, &(lb, ub)) in bounds.iter().enumerate() {
if !lb.is_finite() || !ub.is_finite() {
return Err(OptimizeError::InvalidInput(format!(
"Bounds must be finite values. Variable {}: _bounds = ({}, {})",
i, lb, ub
)));
}
if lb >= ub {
return Err(OptimizeError::InvalidInput(format!(
"Lower bound must be less than upper bound. Variable {}: lb = {}, ub = {}",
i, lb, ub
)));
}
if (ub - lb) < 1e-12 {
return Err(OptimizeError::InvalidInput(format!(
"Bounds range is too small. Variable {}: range = {}",
i,
ub - lb
)));
}
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct BasinHoppingOptions {
pub niter: usize,
pub temperature: f64,
pub stepsize: f64,
pub niter_success: Option<usize>,
pub seed: Option<u64>,
pub minimizer_method: Method,
pub bounds: Option<Vec<(f64, f64)>>,
}
impl Default for BasinHoppingOptions {
fn default() -> Self {
Self {
niter: 100,
temperature: 1.0,
stepsize: 0.5,
niter_success: None,
seed: None,
minimizer_method: Method::LBFGS,
bounds: None,
}
}
}
pub type AcceptTest = Box<dyn Fn(f64, f64, f64) -> bool>;
pub type TakeStep = Box<dyn FnMut(&Array1<f64>) -> Array1<f64>>;
pub struct BasinHopping<F>
where
F: Fn(&ArrayView1<f64>) -> f64 + Clone,
{
func: F,
x0: Array1<f64>,
options: BasinHoppingOptions,
ndim: usize,
rng: StdRng,
accept_test: AcceptTest,
take_step: TakeStep,
storage: Storage,
nfev: usize,
}
#[derive(Debug, Clone)]
struct Storage {
x: Array1<f64>,
fun: f64,
success: bool,
}
impl Storage {
fn new(x: Array1<f64>, fun: f64, success: bool) -> Self {
Self { x, fun, success }
}
fn update(&mut self, x: Array1<f64>, fun: f64, success: bool) -> bool {
if success && (fun < self.fun || !self.success) {
self.x = x;
self.fun = fun;
self.success = success;
true
} else {
false
}
}
}
impl<F> BasinHopping<F>
where
F: Fn(&ArrayView1<f64>) -> f64 + Clone,
{
pub fn new(
func: F,
x0: Array1<f64>,
options: BasinHoppingOptions,
accept_test: Option<AcceptTest>,
take_step: Option<TakeStep>,
) -> Self {
let ndim = x0.len();
let seed = options
.seed
.unwrap_or_else(|| scirs2_core::random::rng().random_range(0..u64::MAX));
let mut rng = StdRng::seed_from_u64(seed);
let accept_test = accept_test.unwrap_or_else(|| {
Box::new(move |f_new: f64, f_old: f64, temp: f64| {
if f_new < f_old {
true
} else {
let delta = (f_old - f_new) / temp;
delta > 0.0 && scirs2_core::random::rng().random_range(0.0..1.0) < delta.exp()
}
})
});
let take_step = take_step.unwrap_or_else(|| {
let stepsize = options.stepsize;
let bounds = options.bounds.clone();
let seed = options
.seed
.unwrap_or_else(|| scirs2_core::random::rng().random_range(0..u64::MAX));
Box::new(move |x: &Array1<f64>| {
let mut local_rng = StdRng::seed_from_u64(seed + x.len() as u64);
let mut x_new = x.clone();
for i in 0..x.len() {
x_new[i] += local_rng.random_range(-stepsize..stepsize);
if let Some(ref bounds) = bounds {
if i < bounds.len() {
let (lb, ub) = bounds[i];
x_new[i] =
enforce_bounds_with_reflection(&mut local_rng, x_new[i], lb, ub);
}
}
}
x_new
})
});
let mut x0_bounded = x0.clone();
if let Some(ref bounds) = options.bounds {
for (i, &(lb, ub)) in bounds.iter().enumerate() {
if i < x0_bounded.len() {
x0_bounded[i] = enforce_bounds_with_reflection(&mut rng, x0_bounded[i], lb, ub);
}
}
}
let initial_result = minimize(
func.clone(),
&x0_bounded.to_vec(),
options.minimizer_method,
Some(Options {
bounds: options.bounds.clone().map(|b| {
Bounds::from_vecs(
b.iter().map(|&(lb, _)| Some(lb)).collect(),
b.iter().map(|&(_, ub)| Some(ub)).collect(),
)
.expect("Operation failed")
}),
..Default::default()
}),
)
.expect("Operation failed");
let storage = Storage::new(
initial_result.x.clone(),
initial_result.fun,
initial_result.success,
);
Self {
func,
x0: initial_result.x,
options,
ndim,
rng,
accept_test,
take_step,
storage,
nfev: initial_result.nfev,
}
}
fn step(&mut self) -> (Array1<f64>, f64, bool) {
let x_new = (self.take_step)(&self.x0);
let result = minimize(
|x| (self.func)(x),
&x_new.to_vec(),
self.options.minimizer_method,
Some(Options {
bounds: self.options.bounds.clone().map(|b| {
Bounds::from_vecs(
b.iter().map(|&(lb, _)| Some(lb)).collect(),
b.iter().map(|&(_, ub)| Some(ub)).collect(),
)
.expect("Operation failed")
}),
..Default::default()
}),
)
.expect("Operation failed");
self.nfev += result.nfev;
let accept = (self.accept_test)(result.fun, self.storage.fun, self.temperature());
if accept {
self.x0 = result.x.clone();
}
(result.x, result.fun, result.success)
}
fn temperature(&self) -> f64 {
self.options.temperature
}
pub fn run(&mut self) -> OptimizeResult<f64> {
let mut nit = 0;
let mut success_counter = 0;
let mut message = "Maximum number of iterations reached".to_string();
for _ in 0..self.options.niter {
let (x, fun, success) = self.step();
nit += 1;
if self.storage.update(x.clone(), fun, success) {
success_counter = 0;
} else {
success_counter += 1;
}
if let Some(niter_success) = self.options.niter_success {
if success_counter >= niter_success {
message = format!("No improvement in {} iterations", niter_success);
break;
}
}
}
OptimizeResult {
x: self.storage.x.clone(),
fun: self.storage.fun,
nfev: self.nfev,
func_evals: self.nfev,
nit,
success: self.storage.success,
message,
..Default::default()
}
}
}
#[allow(dead_code)]
pub fn basinhopping<F>(
func: F,
x0: Array1<f64>,
options: Option<BasinHoppingOptions>,
accept_test: Option<AcceptTest>,
take_step: Option<TakeStep>,
) -> Result<OptimizeResult<f64>, OptimizeError>
where
F: Fn(&ArrayView1<f64>) -> f64 + Clone,
{
let options = options.unwrap_or_default();
if let Some(ref bounds) = options.bounds {
validate_bounds(bounds)?;
}
let mut solver = BasinHopping::new(func, x0, options, accept_test, take_step);
Ok(solver.run())
}