use crate::observers::Observer;
use crate::problem::Problem;
use crate::types::OQNLPParams;
use ndarray::Array1;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::sync::Mutex;
use thiserror::Error;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
#[cfg(feature = "progress_bar")]
use kdam::{Bar, BarExt};
#[derive(Debug, Clone)]
pub struct VariableBounds {
pub lower: Array1<f64>,
pub upper: Array1<f64>,
}
#[derive(Debug, Error)]
pub enum ScatterSearchError {
#[error("Scatter Search Error: No candidates left")]
NoCandidates,
#[error(
"Scatter Search Error: No feasible candidates found after {attempts} attempts (problem dimension: {dimension})"
)]
NoFeasibleCandidates { attempts: usize, dimension: usize },
#[error("Scatter Search Error: Evaluation failed during {phase}: {source}")]
EvaluationError {
phase: String,
#[source]
source: crate::types::EvaluationError,
},
#[error(
"Scatter Search Error: Invalid bounds for variable {dimension}: lower={lower}, upper={upper}. Lower bound must be < upper bound."
)]
InvalidBounds { dimension: usize, lower: f64, upper: f64 },
}
impl From<crate::types::EvaluationError> for ScatterSearchError {
fn from(err: crate::types::EvaluationError) -> Self {
ScatterSearchError::EvaluationError { phase: "evaluation".to_string(), source: err }
}
}
type ScatterSearchResult = (Vec<(Array1<f64>, f64)>, Array1<f64>);
pub struct ScatterSearch<'a, P: Problem> {
problem: P,
params: OQNLPParams,
reference_set: Vec<Array1<f64>>,
reference_set_objectives: Vec<f64>,
bounds: VariableBounds,
rng: Mutex<StdRng>,
#[cfg(feature = "progress_bar")]
progress_bar: Option<Bar>,
#[cfg(feature = "rayon")]
enable_parallel: bool,
observer: Option<&'a mut Observer>,
custom_points: Option<Vec<Array1<f64>>>,
}
impl<'a, P: Problem + Sync + Send> ScatterSearch<'a, P> {
pub fn new(problem: P, params: OQNLPParams) -> Result<Self, ScatterSearchError> {
let var_bounds = problem.variable_bounds();
let bounds = VariableBounds {
lower: var_bounds.column(0).to_owned(),
upper: var_bounds.column(1).to_owned(),
};
for i in 0..bounds.lower.len() {
if bounds.lower[i] >= bounds.upper[i] {
return Err(ScatterSearchError::InvalidBounds {
dimension: i,
lower: bounds.lower[i],
upper: bounds.upper[i],
});
}
}
let seed: u64 = params.seed;
let ss: ScatterSearch<P> = Self {
problem,
params: params.clone(),
reference_set: Vec::new(),
reference_set_objectives: Vec::new(),
bounds,
rng: Mutex::new(StdRng::seed_from_u64(seed)),
#[cfg(feature = "progress_bar")]
progress_bar: None,
#[cfg(feature = "rayon")]
enable_parallel: true,
observer: None,
custom_points: None,
};
Ok(ss)
}
#[cfg(feature = "rayon")]
pub fn parallel(mut self, enable: bool) -> Self {
self.enable_parallel = enable;
self
}
pub fn with_observer(mut self, observer: &'a mut Observer) -> Self {
self.observer = Some(observer);
self
}
pub fn with_custom_points(mut self, points: Vec<Array1<f64>>) -> Self {
self.custom_points = Some(points);
self
}
pub fn run(mut self) -> Result<ScatterSearchResult, ScatterSearchError> {
#[cfg(feature = "progress_bar")]
{
self.progress_bar = Some(
Bar::builder()
.total(3)
.desc("Stage 1")
.unit("steps")
.build()
.expect("Failed to create progress bar"),
);
}
self.initialize_reference_set()?;
if let Some(ref mut obs) = self.observer {
if obs.should_observe_stage1() {
if let Some(stage1) = obs.stage1_mut() {
stage1.enter_substage("initialization_complete");
stage1.set_reference_set_size(3); }
}
obs.invoke_callback();
if obs.should_observe_stage1() {
if let Some(stage1) = obs.stage1_mut() {
stage1.enter_substage("diversification_complete");
stage1.set_reference_set_size(self.reference_set.len());
stage1.add_function_evaluations(self.reference_set.len());
}
}
obs.invoke_callback();
}
#[cfg(feature = "progress_bar")]
if let Some(pb) = &mut self.progress_bar {
pb.set_description("Stage 1, initialized and diversified");
pb.update(1).expect("Failed to update progress bar");
}
let trial_points = self.generate_trial_points()?;
#[cfg(feature = "progress_bar")]
if let Some(pb) = &mut self.progress_bar {
pb.set_description("Stage 1, generated trial points");
pb.update(1).expect("Failed to update progress bar");
}
self.update_reference_set(&trial_points);
if let Some(ref mut obs) = self.observer {
if obs.should_observe_stage1() {
if let Some(stage1) = obs.stage1_mut() {
stage1.enter_substage("intensification_complete");
stage1.add_trial_points(trial_points.len());
stage1.set_reference_set_size(self.reference_set.len());
}
}
obs.invoke_callback();
}
let best = self.best_solution()?;
#[cfg(feature = "progress_bar")]
if let Some(pb) = &mut self.progress_bar {
pb.set_description("Stage 1, found best solution");
pb.update(1).expect("Failed to update progress bar");
}
let reference_set_with_objectives: Vec<(Array1<f64>, f64)> =
self.reference_set.into_iter().zip(self.reference_set_objectives).collect();
Ok((reference_set_with_objectives, best))
}
pub fn initialize_reference_set(&mut self) -> Result<(), ScatterSearchError> {
let mut ref_set: Vec<Array1<f64>> = Vec::with_capacity(self.params.population_size);
let constraints = self.problem.constraints();
if constraints.is_empty() {
ref_set.push(self.bounds.lower.to_owned());
ref_set.push(self.bounds.upper.to_owned());
ref_set.push((&self.bounds.lower + &self.bounds.upper) / 2.0);
} else {
let seed_points = vec![
self.bounds.lower.to_owned(),
self.bounds.upper.to_owned(),
(&self.bounds.lower + &self.bounds.upper) / 2.0,
];
for point in seed_points {
if is_feasible(&point, &constraints) {
ref_set.push(point);
}
}
}
if let Some(ref custom_points) = self.custom_points {
if constraints.is_empty() {
ref_set.extend(custom_points.iter().cloned());
} else {
#[cfg(feature = "rayon")]
let feasible_custom: Vec<Array1<f64>> =
if self.enable_parallel && custom_points.len() >= 100 {
custom_points
.par_iter()
.filter(|point| is_feasible(point, &constraints))
.cloned()
.collect()
} else {
custom_points
.iter()
.filter(|point| is_feasible(point, &constraints))
.cloned()
.collect()
};
#[cfg(not(feature = "rayon"))]
let feasible_custom: Vec<Array1<f64>> = custom_points
.iter()
.filter(|point| is_feasible(point, &constraints))
.cloned()
.collect();
ref_set.extend(feasible_custom);
}
}
#[cfg(feature = "progress_bar")]
if let Some(pb) = &mut self.progress_bar {
pb.set_description("Stage 1, initialized reference set");
pb.update(1).expect("Failed to update progress bar");
}
self.diversify_reference_set(&mut ref_set, &constraints)?;
#[cfg(feature = "rayon")]
let objectives: Vec<f64> = if self.enable_parallel && ref_set.len() >= 20 {
ref_set
.par_iter()
.map(|point| self.problem.objective(point))
.collect::<Result<Vec<f64>, _>>()?
} else {
ref_set
.iter()
.map(|point| self.problem.objective(point))
.collect::<Result<Vec<f64>, _>>()?
};
#[cfg(not(feature = "rayon"))]
let objectives: Vec<f64> = ref_set
.iter()
.map(|point| self.problem.objective(point))
.collect::<Result<Vec<f64>, _>>()?;
let mut points_with_objectives: Vec<(Array1<f64>, f64)> =
ref_set.into_iter().zip(objectives).collect();
points_with_objectives.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
let (sorted_points, sorted_objectives): (Vec<Array1<f64>>, Vec<f64>) =
points_with_objectives.into_iter().unzip();
self.reference_set = sorted_points;
self.reference_set_objectives = sorted_objectives;
#[cfg(feature = "progress_bar")]
if let Some(pb) = &mut self.progress_bar {
pb.set_description("Stage 1, diversified reference set");
pb.update(1).expect("Failed to update progress bar");
}
Ok(())
}
pub fn diversify_reference_set(
&mut self,
ref_set: &mut Vec<Array1<f64>>,
constraints: &[fn(&[f64], &mut ()) -> f64],
) -> Result<(), ScatterSearchError> {
let mut candidates = self.generate_stratified_samples(self.params.population_size)?;
if !constraints.is_empty() {
candidates.retain(|point| is_feasible(point, constraints));
let mut attempts = 0;
while candidates.len() < self.params.population_size && attempts < 10 {
let new_batch =
self.generate_stratified_samples(self.params.population_size * 2)?;
let feasible_batch: Vec<Array1<f64>> =
new_batch.into_iter().filter(|point| is_feasible(point, constraints)).collect();
candidates.extend(feasible_batch);
attempts += 1;
}
if candidates.is_empty() {
return Err(ScatterSearchError::NoFeasibleCandidates {
attempts,
dimension: self.bounds.lower.len(),
});
}
if ref_set.len() + candidates.len() < self.params.population_size {
return Err(ScatterSearchError::NoFeasibleCandidates {
attempts,
dimension: self.bounds.lower.len(),
});
}
}
#[cfg(feature = "rayon")]
let mut min_dists: Vec<f64> = if self.enable_parallel {
candidates.par_iter().map(|c| self.min_distance(c, ref_set)).collect()
} else {
candidates.iter().map(|c| self.min_distance(c, ref_set)).collect()
};
#[cfg(not(feature = "rayon"))]
let mut min_dists: Vec<f64> =
candidates.iter().map(|c| self.min_distance(c, ref_set)).collect();
while ref_set.len() < self.params.population_size {
#[cfg(feature = "rayon")]
let (max_idx, _) = if self.enable_parallel {
(0..min_dists.len())
.into_par_iter()
.map(|i| (i, min_dists[i]))
.max_by(|a, b| a.1.total_cmp(&b.1))
.ok_or(ScatterSearchError::NoCandidates)?
} else {
min_dists
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.map(|(i, &v)| (i, v))
.ok_or(ScatterSearchError::NoCandidates)?
};
#[cfg(not(feature = "rayon"))]
let (max_idx, _) = min_dists
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.map(|(i, &v)| (i, v))
.ok_or(ScatterSearchError::NoCandidates)?;
let farthest = candidates.swap_remove(max_idx);
min_dists.swap_remove(max_idx);
ref_set.push(farthest);
#[cfg(feature = "rayon")]
{
if self.enable_parallel {
let updater_iter = candidates.par_iter().zip(min_dists.par_iter_mut());
updater_iter.for_each(|(candidate, min_dist)| {
if let Some(last) = ref_set.last() {
let dist = euclidean_distance_squared(candidate, last);
if dist < *min_dist {
*min_dist = dist;
}
}
});
} else {
let updater_iter = candidates.iter().zip(min_dists.iter_mut());
updater_iter.for_each(|(candidate, min_dist)| {
if let Some(last) = ref_set.last() {
let dist = euclidean_distance_squared(candidate, last);
if dist < *min_dist {
*min_dist = dist;
}
}
});
}
}
#[cfg(not(feature = "rayon"))]
{
let updater_iter = candidates.iter().zip(min_dists.iter_mut());
updater_iter.for_each(|(candidate, min_dist)| {
if let Some(last) = ref_set.last() {
let dist = euclidean_distance_squared(candidate, last);
if dist < *min_dist {
*min_dist = dist;
}
}
});
}
}
Ok(())
}
pub fn generate_stratified_samples(
&self,
n: usize,
) -> Result<Vec<Array1<f64>>, ScatterSearchError> {
let dim: usize = self.bounds.lower.len();
let seeds: Vec<u64> = {
let mut rng = self.rng.lock().expect("RNG mutex poisoned");
(0..n).map(|_| rng.random::<u64>()).collect::<Vec<_>>()
};
#[cfg(feature = "rayon")]
let samples = if self.enable_parallel {
seeds
.into_par_iter()
.map(|seed| {
let mut rng = StdRng::seed_from_u64(seed);
Ok(Array1::from_shape_fn(dim, |i| {
rng.random_range(self.bounds.lower[i]..=self.bounds.upper[i])
}))
})
.collect::<Result<Vec<_>, ScatterSearchError>>()
} else {
seeds
.into_iter()
.map(|seed: u64| {
let mut rng = StdRng::seed_from_u64(seed);
Ok(Array1::from_shape_fn(dim, |i| {
rng.random_range(self.bounds.lower[i]..=self.bounds.upper[i])
}))
})
.collect::<Result<Vec<_>, ScatterSearchError>>()
}?;
#[cfg(not(feature = "rayon"))]
let samples = seeds
.into_iter()
.map(|seed: u64| {
let mut rng = StdRng::seed_from_u64(seed);
Ok(Array1::from_shape_fn(dim, |i| {
rng.random_range(self.bounds.lower[i]..=self.bounds.upper[i])
}))
})
.collect::<Result<Vec<_>, ScatterSearchError>>()?;
Ok(samples)
}
pub fn min_distance(&self, point: &Array1<f64>, ref_set: &[Array1<f64>]) -> f64 {
#[cfg(feature = "rayon")]
{
if self.enable_parallel {
ref_set
.par_iter()
.map(|p| euclidean_distance_squared(point, p))
.reduce(|| f64::INFINITY, f64::min)
} else {
let mut min_dist = f64::INFINITY;
for p in ref_set {
let dist = euclidean_distance_squared(point, p);
if dist < min_dist {
min_dist = dist;
if dist < 1e-14 {
return 0.0;
}
}
}
min_dist
}
}
#[cfg(not(feature = "rayon"))]
{
let mut min_dist = f64::INFINITY;
for p in ref_set {
let dist = euclidean_distance_squared(point, p);
if dist < min_dist {
min_dist = dist;
if dist < 1e-14 {
return 0.0;
}
}
}
min_dist
}
}
pub fn generate_trial_points(&mut self) -> Result<Vec<Array1<f64>>, ScatterSearchError> {
let k = (self.reference_set.len() as f64).sqrt() as usize;
let k = k.max(2).min(self.reference_set.len());
let indices: Vec<(usize, usize)> =
(0..k).flat_map(|i| ((i + 1)..k).map(move |j| (i, j))).collect();
let seeds: Vec<u64> = {
let mut rng = self.rng.lock().expect("RNG mutex poisoned");
(0..indices.len()).map(|_| rng.random::<u64>()).collect::<Vec<_>>()
};
#[cfg(feature = "rayon")]
let trial_points: Vec<Array1<f64>> = if self.enable_parallel {
indices
.par_iter()
.zip(seeds.par_iter())
.flat_map(|(&(i, j), &seed)| {
self.combine_points(&self.reference_set[i], &self.reference_set[j], seed)
.into_par_iter()
})
.collect()
} else {
indices
.iter()
.zip(seeds.iter())
.flat_map(|(&(i, j), &seed)| {
self.combine_points(&self.reference_set[i], &self.reference_set[j], seed)
})
.collect()
};
#[cfg(not(feature = "rayon"))]
let trial_points: Vec<Array1<f64>> = indices
.iter()
.zip(seeds.iter())
.flat_map(|(&(i, j), &seed)| {
self.combine_points(&self.reference_set[i], &self.reference_set[j], seed)
})
.collect();
Ok(trial_points)
}
pub fn combine_points(&self, a: &Array1<f64>, b: &Array1<f64>, seed: u64) -> Vec<Array1<f64>> {
let mut points = Vec::with_capacity(6);
const DIRECTIONS: [f64; 4] = [0.25, 0.5, 0.75, 1.25];
for &alpha in &DIRECTIONS {
let mut point = a * alpha + b * (1.0 - alpha);
self.apply_bounds(&mut point);
points.push(point);
}
let mut rng: StdRng = StdRng::seed_from_u64(seed);
for _ in 0..2 {
let mut point = (a + b) / 2.0;
point.iter_mut().enumerate().for_each(|(i, x)| {
*x += rng.random_range(-0.1..0.1) * (self.bounds.upper[i] - self.bounds.lower[i]);
});
self.apply_bounds(&mut point);
points.push(point);
}
points
}
pub fn apply_bounds(&self, point: &mut Array1<f64>) {
for i in 0..point.len() {
point[i] = point[i].clamp(self.bounds.lower[i], self.bounds.upper[i]);
}
}
pub fn update_reference_set(&mut self, trials: &[Array1<f64>]) {
if trials.is_empty() {
return;
}
let worst_obj = self.reference_set_objectives.last().copied().unwrap_or(f64::INFINITY);
let min_dist_threshold = {
let ref_set = &self.reference_set;
if ref_set.len() < 2 {
0.0
} else {
let k = ((ref_set.len() as f64).sqrt() as usize).max(2).min(ref_set.len());
let sample_size = k;
#[cfg(feature = "rayon")]
let sum_dist = if self.enable_parallel {
(0..sample_size)
.into_par_iter()
.map(|i| {
euclidean_distance_squared(
&ref_set[i],
&ref_set[(i + 1) % ref_set.len()],
)
})
.sum::<f64>()
} else {
(0..sample_size)
.map(|i| {
euclidean_distance_squared(
&ref_set[i],
&ref_set[(i + 1) % ref_set.len()],
)
})
.sum::<f64>()
};
#[cfg(not(feature = "rayon"))]
let sum_dist = (0..sample_size)
.map(|i| {
euclidean_distance_squared(&ref_set[i], &ref_set[(i + 1) % ref_set.len()])
})
.sum::<f64>();
(sum_dist / sample_size as f64) * 0.10 }
};
let constraints = self.problem.constraints();
let evaluate_trial = |point: &Array1<f64>| -> Option<(Array1<f64>, f64)> {
if !constraints.is_empty() && !is_feasible(point, &constraints) {
return None;
}
let is_diverse =
self.reference_set.iter().take(5).all(|ref_point| {
euclidean_distance_squared(point, ref_point) > min_dist_threshold
});
if !is_diverse {
return None;
}
let obj = self.problem.objective(point).ok()?;
if obj < worst_obj { Some((point.clone(), obj)) } else { None }
};
#[cfg(feature = "rayon")]
let trial_evaluated: Vec<(Array1<f64>, f64)> = if self.enable_parallel {
trials.par_iter().filter_map(evaluate_trial).collect()
} else {
trials.iter().filter_map(evaluate_trial).collect()
};
#[cfg(not(feature = "rayon"))]
let trial_evaluated: Vec<(Array1<f64>, f64)> =
trials.iter().filter_map(evaluate_trial).collect();
if trial_evaluated.is_empty() {
return;
}
let ref_evaluated: Vec<(Array1<f64>, f64)> = std::mem::take(&mut self.reference_set)
.into_iter()
.zip(std::mem::take(&mut self.reference_set_objectives))
.collect();
let mut all_points = ref_evaluated;
all_points.extend(trial_evaluated);
let pop_size = self.params.population_size;
let k = ((pop_size as f64).sqrt() as usize).max(2).min(pop_size);
all_points.select_nth_unstable_by(k - 1, |a, b| a.1.total_cmp(&b.1));
all_points[..k].sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
all_points.select_nth_unstable_by(pop_size - 1, |a, b| a.1.total_cmp(&b.1));
all_points.truncate(pop_size);
let (points, objectives): (Vec<Array1<f64>>, Vec<f64>) = all_points.into_iter().unzip();
self.reference_set = points;
self.reference_set_objectives = objectives;
}
pub fn best_solution(&self) -> Result<Array1<f64>, ScatterSearchError> {
#[cfg(feature = "rayon")]
let best_idx = if self.enable_parallel {
self.reference_set_objectives
.par_iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.total_cmp(b))
.map(|(idx, _)| idx)
.ok_or(ScatterSearchError::NoCandidates)?
} else {
self.reference_set_objectives
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.total_cmp(b))
.map(|(idx, _)| idx)
.ok_or(ScatterSearchError::NoCandidates)?
};
#[cfg(not(feature = "rayon"))]
let best_idx = self
.reference_set_objectives
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.total_cmp(b))
.map(|(idx, _)| idx)
.ok_or(ScatterSearchError::NoCandidates)?;
Ok(self.reference_set[best_idx].clone())
}
pub fn store_trial(&mut self, trial: Array1<f64>) {
self.reference_set.push(trial);
}
}
#[inline]
fn euclidean_distance_squared(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let diff = a - b;
diff.dot(&diff)
}
#[inline]
fn is_feasible(point: &Array1<f64>, constraints: &[fn(&[f64], &mut ()) -> f64]) -> bool {
if constraints.is_empty() {
return true;
}
let x_slice = point.as_slice().expect("Failed to convert point to slice");
for constraint_fn in constraints {
let value = constraint_fn(x_slice, &mut ());
if value < 0.0 {
return false;
}
}
true
}
#[cfg(test)]
mod tests_scatter_search {
use super::*;
use crate::types::EvaluationError;
use crate::types::OQNLPParams;
use ndarray::{Array2, array};
#[derive(Debug, Clone)]
pub struct SixHumpCamel;
impl Problem for SixHumpCamel {
fn objective(&self, x: &Array1<f64>) -> Result<f64, EvaluationError> {
Ok((4.0 - 2.1 * x[0].powi(2) + x[0].powi(4) / 3.0) * x[0].powi(2)
+ x[0] * x[1]
+ (-4.0 + 4.0 * x[1].powi(2)) * x[1].powi(2))
}
fn variable_bounds(&self) -> Array2<f64> {
array![[-3.0, 3.0], [-2.0, 2.0]]
}
}
#[test]
fn test_population_size() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 50,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 100,
seed: 0,
..OQNLPParams::default()
};
let ss: ScatterSearch<SixHumpCamel> = ScatterSearch::new(problem, params).unwrap();
let (ref_set, _) = ss.run().unwrap();
assert_eq!(ref_set.len(), 100);
}
#[test]
fn test_bounds_in_reference_set() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 50,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 100,
seed: 0,
..OQNLPParams::default()
};
let ss: ScatterSearch<SixHumpCamel> = ScatterSearch::new(problem, params).unwrap();
let bounds: VariableBounds = ss.bounds.clone();
let (ref_set, _) = ss.run().unwrap();
assert_eq!(ref_set.len(), 100);
for (point, _obj) in ref_set {
for i in 0..point.len() {
assert!(point[i] >= bounds.lower[i]);
assert!(point[i] <= bounds.upper[i]);
}
}
}
#[test]
fn test_same_reference_set() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 50,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 100,
seed: 0,
..OQNLPParams::default()
};
let ss1: ScatterSearch<SixHumpCamel> =
ScatterSearch::new(problem.clone(), params.clone()).unwrap();
let ss2: ScatterSearch<SixHumpCamel> = ScatterSearch::new(problem, params).unwrap();
let (ref_set1, _) = ss1.run().unwrap();
let (ref_set2, _) = ss2.run().unwrap();
assert_eq!(ref_set1.len(), 100);
assert_eq!(ref_set1.len(), ref_set2.len());
for i in 0..ref_set1.len() {
assert_eq!(ref_set1[i], ref_set2[i]);
}
}
#[test]
fn test_generate_trial_points() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 10,
seed: 0,
..OQNLPParams::default()
};
let mut ss: ScatterSearch<SixHumpCamel> = ScatterSearch::new(problem, params).unwrap();
ss.initialize_reference_set().unwrap();
let trial_points: Vec<Array1<f64>> = ss.generate_trial_points().unwrap();
let n = ss.reference_set.len();
let k = (n as f64).sqrt() as usize;
let expected = k * (k - 1) / 2 * 6;
assert_eq!(trial_points.len(), expected);
}
#[test]
fn test_combine_points() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 10,
seed: 0,
..OQNLPParams::default()
};
let ss: ScatterSearch<SixHumpCamel> = ScatterSearch::new(problem, params).unwrap();
let a: Array1<f64> = array![1.0, 1.0];
let b: Array1<f64> = array![2.0, 2.0];
let trial_points: Vec<Array1<f64>> = ss.combine_points(&a, &b, 0);
assert_eq!(trial_points.len(), 6);
}
#[test]
fn test_store_trials() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 4,
seed: 0,
..OQNLPParams::default()
};
let mut ss: ScatterSearch<SixHumpCamel> = ScatterSearch::new(problem, params).unwrap();
assert_eq!(ss.reference_set.len(), 0);
let trial: Array1<f64> = array![1.0, 1.0];
ss.store_trial(trial.clone());
assert_eq!(ss.reference_set.len(), 1);
assert_eq!(ss.reference_set[0], trial);
}
#[test]
fn test_update_reference_set() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 4,
seed: 0,
..OQNLPParams::default()
};
let mut ss: ScatterSearch<SixHumpCamel> = ScatterSearch::new(problem, params).unwrap();
ss.initialize_reference_set().unwrap();
let trials: Vec<Array1<f64>> = vec![array![1.0, 1.0], array![2.0, 2.0]];
ss.update_reference_set(&trials);
assert_eq!(ss.reference_set.len(), 4);
}
#[test]
fn test_min_distance() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 4,
seed: 0,
..OQNLPParams::default()
};
let mut ss: ScatterSearch<SixHumpCamel> = ScatterSearch::new(problem, params).unwrap();
ss.initialize_reference_set().unwrap();
let point: Array1<f64> = array![-3.0, -2.0];
let min_dist: f64 = ss.min_distance(&point, &ss.reference_set);
assert_eq!(min_dist, 0.0);
}
#[test]
fn test_euclidean_distance_squared() {
let a: Array1<f64> = array![1.0, 2.0];
let b: Array1<f64> = array![3.0, 4.0];
let dist: f64 = euclidean_distance_squared(&a, &b);
assert_eq!(dist, 8.0);
}
#[cfg(feature = "rayon")]
#[test]
fn test_generate_trial_points_rayon() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 10,
seed: 0,
..OQNLPParams::default()
};
let mut ss: ScatterSearch<SixHumpCamel> = ScatterSearch::new(problem, params).unwrap();
ss.initialize_reference_set().unwrap();
let trial_points: Vec<Array1<f64>> = ss.generate_trial_points().unwrap();
let n = ss.reference_set.len();
let k = (n as f64).sqrt() as usize;
let expected = k * (k - 1) / 2 * 6;
assert_eq!(trial_points.len(), expected);
}
#[cfg(feature = "rayon")]
#[test]
fn test_update_reference_set_rayon() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 4,
seed: 0,
..OQNLPParams::default()
};
let mut ss: ScatterSearch<SixHumpCamel> = ScatterSearch::new(problem, params).unwrap();
ss.initialize_reference_set().unwrap();
let trials: Vec<Array1<f64>> = vec![array![1.0, 1.0], array![2.0, 2.0]];
ss.update_reference_set(&trials);
assert_eq!(ss.reference_set.len(), 4);
}
#[cfg(feature = "rayon")]
#[test]
fn test_min_distance_rayon() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 4,
seed: 0,
..OQNLPParams::default()
};
let mut ss: ScatterSearch<SixHumpCamel> = ScatterSearch::new(problem, params).unwrap();
ss.initialize_reference_set().unwrap();
let point: Array1<f64> = array![-3.0, -2.0];
let min_dist: f64 = ss.min_distance(&point, &ss.reference_set);
assert_eq!(min_dist, 0.0);
}
#[test]
fn test_with_custom_points() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 10,
seed: 0,
..OQNLPParams::default()
};
let custom_points = vec![array![0.0, 0.0], array![1.0, 1.0], array![-1.0, -1.0]];
let ss = ScatterSearch::new(problem, params).unwrap().with_custom_points(custom_points);
assert!(ss.custom_points.is_some(), "Custom points should be set");
assert_eq!(ss.custom_points.as_ref().unwrap().len(), 3, "Should have 3 custom points");
}
#[test]
fn test_custom_points_in_reference_set() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 15,
seed: 0,
..OQNLPParams::default()
};
let custom_point = array![0.5, 0.5];
let custom_points = vec![custom_point.clone()];
let mut ss = ScatterSearch::new(problem, params).unwrap().with_custom_points(custom_points);
ss.initialize_reference_set().unwrap();
assert_eq!(ss.reference_set.len(), 15, "Reference set should have population_size points");
assert!(
ss.reference_set.len() >= 4,
"Reference set should include at least the 3 seed points + 1 custom point"
);
}
#[test]
fn test_custom_points_empty() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 10,
seed: 0,
..OQNLPParams::default()
};
let custom_points: Vec<Array1<f64>> = vec![];
let mut ss = ScatterSearch::new(problem, params).unwrap().with_custom_points(custom_points);
ss.initialize_reference_set().unwrap();
assert_eq!(ss.reference_set.len(), 10, "Reference set should have population_size points");
}
#[test]
fn test_custom_points_full_run() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 20,
seed: 0,
..OQNLPParams::default()
};
let custom_points = vec![
array![0.0, 0.0], array![0.5, -0.5], array![-0.5, 0.5], ];
let ss = ScatterSearch::new(problem, params).unwrap().with_custom_points(custom_points);
let result = ss.run();
assert!(result.is_ok(), "Scatter search should complete successfully with custom points");
let (ref_set, _best) = result.unwrap();
assert_eq!(ref_set.len(), 20, "Reference set should have population_size points after run");
}
#[test]
fn test_custom_points_many() {
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams = OQNLPParams {
iterations: 1,
wait_cycle: 30,
threshold_factor: 0.2,
distance_factor: 0.75,
population_size: 20,
seed: 0,
..OQNLPParams::default()
};
let mut custom_points = Vec::new();
for i in 0..10 {
custom_points.push(array![i as f64 * 0.1, i as f64 * 0.1]);
}
let mut ss = ScatterSearch::new(problem, params).unwrap().with_custom_points(custom_points);
ss.initialize_reference_set().unwrap();
assert_eq!(ss.reference_set.len(), 20, "Reference set should be capped at population_size");
}
#[test]
fn test_constraints_in_reference_set() {
#[derive(Debug, Clone)]
struct ConstrainedProblem;
impl Problem for ConstrainedProblem {
fn objective(&self, x: &Array1<f64>) -> Result<f64, EvaluationError> {
Ok((x[0] - 1.0).powi(2) + (x[1] - 1.0).powi(2))
}
fn variable_bounds(&self) -> Array2<f64> {
array![[0.0, 2.0], [0.0, 2.0]]
}
fn constraints(&self) -> Vec<fn(&[f64], &mut ()) -> f64> {
vec![
|x: &[f64], _: &mut ()| 1.5 - x[0] - x[1], ]
}
}
let problem = ConstrainedProblem;
let params = OQNLPParams { population_size: 50, seed: 42, ..OQNLPParams::default() };
let ss = ScatterSearch::new(problem.clone(), params).unwrap();
let (ref_set, _) = ss.run().unwrap();
let constraints = problem.constraints();
for (point, _obj) in &ref_set {
let x = point.as_slice().expect("Failed to convert point to slice");
for constraint_fn in &constraints {
let value = constraint_fn(x, &mut ());
assert!(
value >= -1e-10,
"Constraint violated: point = {:?}, constraint value = {}",
point,
value
);
}
}
for (point, _obj) in &ref_set {
let sum = point[0] + point[1];
assert!(
sum <= 1.5 + 1e-10,
"Direct constraint check failed: x + y = {} > 1.5 for point {:?}",
sum,
point
);
}
}
#[test]
fn test_invalid_bounds() {
#[derive(Debug, Clone)]
struct InvalidBoundsProblem;
impl Problem for InvalidBoundsProblem {
fn objective(&self, x: &Array1<f64>) -> Result<f64, EvaluationError> {
Ok(x[0].powi(2) + x[1].powi(2))
}
fn variable_bounds(&self) -> Array2<f64> {
array![[2.0, 1.0], [-1.0, 1.0]]
}
}
let problem = InvalidBoundsProblem;
let params = OQNLPParams::default();
let result = ScatterSearch::new(problem, params);
assert!(result.is_err(), "ScatterSearch::new should fail with invalid bounds");
match result {
Err(ScatterSearchError::InvalidBounds { dimension, lower, upper }) => {
assert_eq!(dimension, 0, "Should report error for first dimension");
assert_eq!(lower, 2.0, "Lower bound should be 2.0");
assert_eq!(upper, 1.0, "Upper bound should be 1.0");
}
_ => panic!("Expected InvalidBounds error"),
}
}
}