use crate::filters::{DistanceFilter, MeritFilter};
use crate::local_solver::runner::LocalSolver;
use crate::observers::Observer;
use crate::problem::Problem;
use crate::scatter_search::ScatterSearch;
use crate::types::{FilterParams, LocalSolution, OQNLPParams, SolutionSet};
#[cfg(feature = "checkpointing")]
use crate::{
checkpoint::{CheckpointError, CheckpointManager},
types::{CheckpointConfig, OQNLPCheckpoint},
};
#[cfg(feature = "checkpointing")]
use chrono;
#[cfg(feature = "progress_bar")]
use kdam::{Bar, BarExt};
use ndarray::{Array1, Array2};
use rand::SeedableRng;
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use thiserror::Error;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
#[derive(Debug, Error)]
pub enum OQNLPError {
#[error("OQNLP Error: Local solver failed: {0}")]
LocalSolverError(#[from] crate::local_solver::runner::LocalSolverError),
#[error(
"OQNLP Error: No feasible solution found after evaluating {candidates_evaluated} candidates."
)]
NoFeasibleSolution { candidates_evaluated: usize },
#[error("OQNLP Error: Objective function evaluation failed during {stage}.")]
ObjectiveFunctionEvaluationFailed { stage: String },
#[error("OQNLP Error: Failed to create a new ScatterSearch instance: {0}")]
ScatterSearchError(#[from] crate::scatter_search::ScatterSearchError),
#[error("OQNLP Error: Failed to run the ScatterSearch instance: {0}")]
ScatterSearchRunError(crate::scatter_search::ScatterSearchError),
#[error(
"OQNLP Error: Population size should be at least 3, got {0}. Reference Set size should be at least 3, since it pushes the bounds and the midpoint."
)]
InvalidPopulationSize(usize),
#[error(
"OQNLP Error: Iterations should be less than or equal to population size. OQNLP received `iterations`: {0}, `population size`: {1}."
)]
InvalidIterations(usize, usize),
#[error("OQNLP Error: Failed to create distance filter: {0}")]
DistanceFilterError(#[from] crate::filters::FiltersErrors),
#[error("OQNLP Error: Threshold factor must be positive, got {0}.")]
InvalidThresholdFactor(f64),
#[error(
"OQNLP Error: Custom point at index {point_index} has invalid dimension. Expected {expected} dimensions, got {got}."
)]
InvalidCustomPointsDimension { point_index: usize, expected: usize, got: usize },
#[error("OQNLP Error: Custom point at index {index} is outside variable bounds.")]
CustomPointOutOfBounds { index: usize },
#[cfg(feature = "checkpointing")]
#[error("OQNLP Error: Checkpointing error: {0}")]
CheckpointError(#[from] CheckpointError),
}
#[cfg_attr(feature = "rayon", doc = "")]
#[cfg_attr(feature = "rayon", doc = "### Parallel Processing")]
#[cfg_attr(feature = "rayon", doc = "")]
#[cfg_attr(
feature = "rayon",
doc = "- [`batch_iterations()`](OQNLP::batch_iterations): Set batch size for parallel processing of stage two iterations"
)]
#[cfg_attr(feature = "rayon", doc = "")]
#[cfg_attr(feature = "checkpointing", doc = "")]
#[cfg_attr(feature = "checkpointing", doc = "### State Persistence")]
#[cfg_attr(feature = "checkpointing", doc = "")]
#[cfg_attr(
feature = "checkpointing",
doc = "- [`with_checkpointing()`](Self::with_checkpointing): Enable automatic state saving"
)]
#[cfg_attr(
feature = "checkpointing",
doc = "- [`resume_with_modified_params()`](Self::resume_with_modified_params): Continue with new parameters"
)]
#[cfg_attr(feature = "checkpointing", doc = "")]
pub struct OQNLP<P: Problem + Clone> {
problem: P,
params: OQNLPParams,
merit_filter: MeritFilter,
distance_filter: DistanceFilter,
local_solver: LocalSolver<P>,
solution_set: Option<SolutionSet>,
max_time: Option<f64>,
#[cfg(feature = "rayon")]
batch_iterations: Option<usize>,
verbose: bool,
#[cfg(feature = "rayon")]
enable_parallel: bool,
target_objective: Option<f64>,
exclude_out_of_bounds: bool,
#[cfg(feature = "checkpointing")]
checkpoint_manager: Option<CheckpointManager>,
#[cfg(feature = "checkpointing")]
current_iteration: usize,
#[cfg(feature = "checkpointing")]
current_reference_set: Option<Vec<Array1<f64>>>,
#[cfg(feature = "checkpointing")]
unchanged_cycles: usize,
#[cfg(feature = "checkpointing")]
start_time: Option<std::time::Instant>,
#[cfg(feature = "checkpointing")]
current_seed: u64,
observer: Option<Observer>,
custom_points: Option<Vec<Array1<f64>>>,
abs_tol: f64,
rel_tol: f64,
}
impl<P: Problem + Clone + Send + Sync> OQNLP<P> {
pub fn new(problem: P, params: OQNLPParams) -> Result<Self, OQNLPError> {
if params.population_size <= 3 {
return Err(OQNLPError::InvalidPopulationSize(params.population_size));
}
if params.iterations > params.population_size {
return Err(OQNLPError::InvalidIterations(params.iterations, params.population_size));
}
if params.threshold_factor <= 0.0 {
return Err(OQNLPError::InvalidThresholdFactor(params.threshold_factor));
}
if params.wait_cycle >= params.iterations {
eprintln!(
"Warning: `wait_cycle` is greater than or equal to `iterations`. This may lead to suboptimal results."
);
}
let filter_params: FilterParams = FilterParams {
distance_factor: params.distance_factor,
wait_cycle: params.wait_cycle,
threshold_factor: params.threshold_factor,
};
Ok(Self {
problem: problem.clone(),
params: params.clone(),
merit_filter: MeritFilter::new(),
distance_filter: DistanceFilter::new(filter_params)?,
local_solver: LocalSolver::new(
problem,
params.local_solver_type.clone(),
params.local_solver_config.clone(),
),
solution_set: None,
max_time: None,
#[cfg(feature = "rayon")]
batch_iterations: None,
verbose: false,
#[cfg(feature = "rayon")]
enable_parallel: true, target_objective: None,
exclude_out_of_bounds: false,
#[cfg(feature = "checkpointing")]
checkpoint_manager: None,
#[cfg(feature = "checkpointing")]
current_iteration: 0,
#[cfg(feature = "checkpointing")]
current_reference_set: None,
#[cfg(feature = "checkpointing")]
unchanged_cycles: 0,
#[cfg(feature = "checkpointing")]
start_time: None,
#[cfg(feature = "checkpointing")]
current_seed: params.seed,
observer: None,
custom_points: None,
abs_tol: 1e-8,
rel_tol: 1e-6,
})
}
pub fn add_observer(mut self, observer: Observer) -> Self {
self.observer = Some(observer);
self
}
pub fn observer(&self) -> Option<&Observer> {
self.observer.as_ref()
}
pub fn with_points(mut self, points: Array2<f64>) -> Result<Self, OQNLPError> {
let bounds = self.problem.variable_bounds();
let n_dims = bounds.nrows();
let mut custom_points = Vec::new();
for (i, point_row) in points.outer_iter().enumerate() {
let point = point_row.to_owned();
if point.len() != n_dims {
return Err(OQNLPError::InvalidCustomPointsDimension {
point_index: i,
expected: n_dims,
got: point.len(),
});
}
for (j, &value) in point.iter().enumerate() {
let lower = bounds[[j, 0]];
let upper = bounds[[j, 1]];
if value < lower || value > upper {
return Err(OQNLPError::CustomPointOutOfBounds { index: i });
}
}
custom_points.push(point);
}
self.custom_points = Some(custom_points);
Ok(self)
}
pub fn run(&mut self) -> Result<SolutionSet, OQNLPError> {
if let Some(ref mut observer) = self.observer {
observer.start_timer();
if observer.should_observe_stage1() {
if let Some(stage1) = observer.stage1_mut() {
stage1.start();
}
}
}
#[cfg(feature = "checkpointing")]
let resumed_from_checkpoint = self.try_resume_from_checkpoint()?;
#[cfg(not(feature = "checkpointing"))]
let resumed_from_checkpoint = false;
#[cfg(feature = "checkpointing")]
if self.start_time.is_none() {
self.start_time = Some(std::time::Instant::now());
}
let (mut ref_set, mut unchanged_cycles, ref_objectives) = if resumed_from_checkpoint {
#[cfg(feature = "checkpointing")]
{
if self.verbose {
println!("Resuming from checkpoint at iteration {}", self.current_iteration);
}
let ref_set = self.current_reference_set.clone().unwrap_or_default();
let unchanged_cycles = self.unchanged_cycles;
(ref_set, unchanged_cycles, None)
}
#[cfg(not(feature = "checkpointing"))]
unreachable!()
} else {
if self.verbose {
println!("Starting Stage 1");
}
#[cfg(feature = "rayon")]
let mut ss = ScatterSearch::new(self.problem.clone(), self.params.clone())?
.parallel(self.enable_parallel);
#[cfg(not(feature = "rayon"))]
let mut ss = ScatterSearch::new(self.problem.clone(), self.params.clone())?;
if let Some(ref custom_points) = self.custom_points {
ss = ss.with_custom_points(custom_points.clone());
}
let (ref_set_with_objectives, scatter_candidate) = if let Some(ref mut observer) =
self.observer
{
if observer.should_observe_stage1() {
ss.with_observer(observer).run().map_err(OQNLPError::ScatterSearchRunError)?
} else {
ss.run().map_err(OQNLPError::ScatterSearchRunError)?
}
} else {
ss.run().map_err(OQNLPError::ScatterSearchRunError)?
};
let (ref_set, ref_objectives): (Vec<Array1<f64>>, Vec<f64>) =
ref_set_with_objectives.into_iter().unzip();
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage1() {
if let Some(stage1) = observer.stage1_mut() {
stage1.enter_substage("scatter_search_complete");
stage1.set_reference_set_size(ref_set.len());
if let Some((best_idx, &best_obj)) = ref_objectives
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
{
stage1.set_best_solution(best_obj, &ref_set[best_idx]);
}
}
}
observer.invoke_callback();
}
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage1() {
if let Some(stage1) = observer.stage1_mut() {
stage1.enter_substage("local_optimization");
}
}
}
let (local_sol, local_fn_evals) =
if self.observer.as_ref().map(|obs| obs.should_observe_stage1()).unwrap_or(false) {
self.local_solver.solve_with_tracking(scatter_candidate, true)?
} else {
let sol = self.local_solver.solve(scatter_candidate)?;
(sol, 0)
};
self.merit_filter.update_threshold(local_sol.objective);
if self.verbose {
println!(
"Stage 1: Local solution found with objective = {:.8}",
local_sol.objective
);
}
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage1() {
if let Some(stage1) = observer.stage1_mut() {
stage1.enter_substage("local_optimization_complete");
stage1.set_best_solution(local_sol.objective, &local_sol.point);
stage1.add_function_evaluations(local_fn_evals as usize);
}
}
observer.invoke_callback();
}
self.process_local_solution(local_sol)?;
if self.target_objective_reached() {
if self.verbose {
println!(
"Stage 1: Target objective {:.8} reached. Stopping optimization.",
self.target_objective.unwrap()
);
}
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage1() {
if let Some(stage1) = observer.stage1_mut() {
stage1.enter_substage("stage1_complete");
stage1.end();
}
}
observer.invoke_callback();
observer.mark_stage1_complete();
}
return self
.solution_set
.clone()
.ok_or(OQNLPError::NoFeasibleSolution { candidates_evaluated: 0 });
}
#[cfg(feature = "checkpointing")]
{
self.current_reference_set = Some(ref_set.clone());
}
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage1() {
if let Some(stage1) = observer.stage1_mut() {
stage1.enter_substage("stage1_complete");
stage1.end();
}
}
observer.invoke_callback();
observer.mark_stage1_complete();
}
(ref_set, 0, Some(ref_objectives))
};
if self.verbose {
println!("Starting Stage 2");
}
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() {
if let Some(stage2) = observer.stage2_mut() {
stage2.start();
if let Some(ref sol_set) = self.solution_set {
if let Some(best) = sol_set.best_solution() {
stage2.set_best_solution(best.objective, &best.point);
}
stage2.set_solution_set_size(sol_set.len());
}
stage2.set_threshold_value(self.merit_filter.threshold);
}
}
observer.mark_stage2_started();
}
#[cfg(feature = "progress_bar")]
let mut stage2_bar = Bar::builder()
.total(self.params.iterations)
.desc("Stage 2")
.unit("it")
.postfix(format!(
"Objective function: {:.6}",
self.solution_set
.as_ref()
.and_then(|s| s.best_solution())
.map_or(f64::INFINITY, |s| s.objective)
))
.build()
.expect("Failed to create progress bar");
#[cfg(all(feature = "progress_bar", feature = "checkpointing"))]
if resumed_from_checkpoint {
for _ in 0..self.current_iteration {
stage2_bar.update(1).expect("Failed to update progress bar");
}
}
#[cfg(feature = "checkpointing")]
let mut rng: StdRng = if resumed_from_checkpoint {
StdRng::seed_from_u64(self.current_seed)
} else {
let seed = self.params.seed + self.current_iteration as u64;
self.current_seed = seed;
StdRng::seed_from_u64(seed)
};
#[cfg(not(feature = "checkpointing"))]
let mut rng: StdRng = StdRng::seed_from_u64(self.params.seed);
if !resumed_from_checkpoint {
ref_set.shuffle(&mut rng);
}
let start_timer: Option<std::time::Instant> =
self.max_time.map(|_| std::time::Instant::now());
#[cfg(feature = "checkpointing")]
let start_iter = if resumed_from_checkpoint { self.current_iteration } else { 0 };
#[cfg(not(feature = "checkpointing"))]
let start_iter = 0;
#[cfg(feature = "rayon")]
let effective_batch_size = if !self.enable_parallel {
1 } else {
self.batch_iterations.unwrap_or_else(|| {
let thread_count = rayon::current_num_threads();
let remaining_iterations = self.params.iterations.saturating_sub(start_iter);
if remaining_iterations < 4 || thread_count == 1 {
1 } else {
(remaining_iterations / (thread_count * 2)).clamp(2, 8)
}
})
};
#[cfg(not(feature = "rayon"))]
let effective_batch_size = 1;
let trials_to_process: Vec<_> =
ref_set.iter().take(self.params.iterations).enumerate().skip(start_iter).collect();
for batch in trials_to_process.chunks(effective_batch_size) {
if let (Some(max_secs), Some(start)) = (self.max_time, start_timer) {
if start.elapsed().as_secs_f64() > max_secs {
if self.verbose {
println!("Timeout reached after {} seconds", max_secs);
}
break;
}
}
#[cfg(feature = "rayon")]
if effective_batch_size > 1 {
self.process_batch_parallel(
batch,
&mut unchanged_cycles,
start_iter,
resumed_from_checkpoint,
ref_objectives.as_ref(),
)?;
} else {
self.process_batch_sequential(
batch,
&mut unchanged_cycles,
start_iter,
resumed_from_checkpoint,
ref_objectives.as_ref(),
)?;
}
#[cfg(not(feature = "rayon"))]
{
self.process_batch_sequential(
batch,
&mut unchanged_cycles,
start_iter,
resumed_from_checkpoint,
ref_objectives.as_ref(),
)?;
}
if self.target_objective_reached() {
if self.verbose {
println!(
"Stage 2: Target objective {:.8} reached. Stopping optimization.",
self.target_objective.unwrap()
);
}
break;
}
}
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() {
if let Some(stage2) = observer.stage2_mut() {
stage2.end();
}
}
}
#[cfg(feature = "checkpointing")]
self.maybe_save_final_checkpoint()?;
self.solution_set.clone().ok_or(OQNLPError::NoFeasibleSolution { candidates_evaluated: 0 })
}
fn should_start_local(&self, point: &Array1<f64>, obj: f64) -> Result<bool, OQNLPError> {
let passes_merit: bool = obj <= self.merit_filter.threshold;
let passes_distance: bool = self.distance_filter.check(point);
Ok(passes_merit && passes_distance)
}
fn target_objective_reached(&self) -> bool {
if let (Some(target), Some(solution_set)) = (self.target_objective, &self.solution_set) {
if let Some(best) = solution_set.best_solution() {
let target_reached = best.objective <= target;
if self.exclude_out_of_bounds {
return target_reached && self.is_within_bounds(&best.point);
}
return target_reached;
}
}
false
}
fn is_within_bounds(&self, point: &Array1<f64>) -> bool {
let bounds = self.problem.variable_bounds();
for (i, &value) in point.iter().enumerate() {
let lower_bound = bounds[[i, 0]];
let upper_bound = bounds[[i, 1]];
if value < lower_bound || value > upper_bound {
return false;
}
}
true
}
fn process_local_solution(&mut self, solution: LocalSolution) -> Result<bool, OQNLPError> {
let abs_tol = self.abs_tol;
let rel_tol = self.rel_tol;
if self.exclude_out_of_bounds && !self.is_within_bounds(&solution.point) {
if self.verbose {
println!(
"Solution with objective {:.8} rejected: out of bounds",
solution.objective
);
}
self.distance_filter.add_solution(solution);
return Ok(false);
}
let solutions = if let Some(existing) = &self.solution_set {
existing.solutions.clone()
} else {
if self.verbose {
println!(
"New solution added to solution set: objective = {:.8}, point = {}",
solution.objective, solution.point
);
println!("Solution set size: 1");
}
self.solution_set =
Some(SolutionSet { solutions: Array1::from(vec![solution.clone()]) });
self.merit_filter.update_threshold(solution.objective);
self.distance_filter.add_solution(solution);
return Ok(true);
};
let current_best: &LocalSolution = &solutions[0];
let obj1 = solution.objective;
let obj2 = current_best.objective;
let obj_diff = (obj1 - obj2).abs();
let tol = abs_tol.max(rel_tol * obj1.abs().max(obj2.abs()));
let added: bool = if obj1 < obj2 - tol {
self.solution_set =
Some(SolutionSet { solutions: Array1::from(vec![solution.clone()]) });
self.merit_filter.update_threshold(solution.objective);
if let Some(ref mut observer) = self.observer {
if let Some(stage2) = observer.stage2_mut() {
stage2.set_last_added_solution(&solution.point);
}
}
if self.verbose {
println!(
"New best solution found (replacing solution set): objective = {:.8}, point = {}",
solution.objective, solution.point
);
println!("Solution set size: 1");
}
false
} else if obj_diff <= tol && !self.is_duplicate_in_set(&solution, &solutions) {
let mut new_solutions: Vec<LocalSolution> = solutions.to_vec();
new_solutions.push(solution.clone());
self.solution_set = Some(SolutionSet { solutions: Array1::from(new_solutions) });
if let Some(ref mut observer) = self.observer {
if let Some(stage2) = observer.stage2_mut() {
stage2.set_last_added_solution(&solution.point);
}
}
if self.verbose {
println!(
"New solution added to solution set: objective = {:.8}, point = {}",
solution.objective, solution.point
);
println!("Solution set size: {}", self.solution_set.as_ref().unwrap().len());
}
true
} else {
false
};
self.distance_filter.add_solution(solution);
Ok(added)
}
fn is_duplicate_in_set(&self, candidate: &LocalSolution, set: &Array1<LocalSolution>) -> bool {
let distance_threshold = self.params.distance_factor;
set.iter().any(|s| {
let diff = &candidate.point - &s.point;
let dist_squared = diff.dot(&diff);
dist_squared < distance_threshold * distance_threshold
})
}
pub fn max_time(mut self, max_time: f64) -> Self {
self.max_time = Some(max_time);
self
}
#[cfg(feature = "rayon")]
pub fn batch_iterations(mut self, batch_size: usize) -> Self {
if batch_size > self.params.iterations {
eprintln!(
"Warning: batch_iterations ({}) is larger than total iterations ({}). \
This may lead to suboptimal resource usage. Consider setting batch_iterations <= iterations.",
batch_size, self.params.iterations
);
}
self.batch_iterations = Some(batch_size);
self
}
pub fn verbose(mut self) -> Self {
self.verbose = true;
self
}
#[cfg(feature = "rayon")]
pub fn parallel(mut self, enable: bool) -> Self {
self.enable_parallel = enable;
self
}
pub fn target_objective(mut self, target: f64) -> Self {
self.target_objective = Some(target);
self
}
pub fn set_exclude_out_of_bounds(mut self, enable: bool) -> Self {
self.exclude_out_of_bounds = enable;
self
}
pub fn with_tolerance(mut self, abs_tol: f64, rel_tol: f64) -> Self {
self.abs_tol = abs_tol;
self.rel_tol = rel_tol;
self
}
pub fn exclude_out_of_bounds(self) -> Self {
self.set_exclude_out_of_bounds(true)
}
#[cfg(feature = "checkpointing")]
pub fn with_checkpointing(mut self, config: CheckpointConfig) -> Result<Self, OQNLPError> {
self.checkpoint_manager = Some(CheckpointManager::new(config)?);
Ok(self)
}
#[cfg(feature = "checkpointing")]
pub fn try_resume_from_checkpoint(&mut self) -> Result<bool, OQNLPError> {
if let Some(ref manager) = self.checkpoint_manager {
if manager.config().auto_resume && manager.checkpoint_exists() {
let checkpoint = manager.load_latest_checkpoint()?;
self.restore_from_checkpoint(checkpoint)?;
return Ok(true);
}
}
Ok(false)
}
#[cfg(feature = "checkpointing")]
pub fn resume_with_modified_params(
&mut self,
new_params: OQNLPParams,
) -> Result<bool, OQNLPError> {
if let Some(ref manager) = self.checkpoint_manager {
if manager.checkpoint_exists() {
let mut checkpoint = manager.load_latest_checkpoint()?;
let old_iterations = checkpoint.params.iterations;
let old_population_size = checkpoint.params.population_size;
if new_params.population_size != old_population_size {
if new_params.population_size > old_population_size {
self.expand_reference_set(
&mut checkpoint.reference_set,
old_population_size,
new_params.population_size,
)?;
if self.verbose {
println!(
"Expanded reference set from {} to {} points",
old_population_size, new_params.population_size
);
}
} else {
eprintln!(
"Warning: New population size ({}) is smaller than original ({}). Using original reference set with {} points.",
new_params.population_size, old_population_size, old_population_size
);
if self.verbose {
println!(
"Keeping original reference set size of {} points despite smaller population_size parameter",
old_population_size
);
}
}
}
checkpoint.params = new_params.clone();
self.restore_from_checkpoint(checkpoint)?;
if self.verbose {
println!(
"Resumed with modified parameters. Iterations changed from {} to {}",
old_iterations, new_params.iterations
);
}
return Ok(true);
}
}
Ok(false)
}
#[cfg(feature = "checkpointing")]
fn expand_reference_set(
&self,
ref_set: &mut Vec<Array1<f64>>,
old_size: usize,
new_size: usize,
) -> Result<(), OQNLPError> {
if new_size <= old_size {
return Ok(());
}
let temp_params = OQNLPParams {
population_size: new_size,
seed: self.current_seed + ref_set.len() as u64,
..self.params.clone()
};
#[cfg(feature = "rayon")]
let mut scatter_search =
ScatterSearch::new(self.problem.clone(), temp_params)?.parallel(self.enable_parallel);
#[cfg(not(feature = "rayon"))]
let mut scatter_search = ScatterSearch::new(self.problem.clone(), temp_params)?;
let constraints = self.problem.constraints();
scatter_search.diversify_reference_set(ref_set, &constraints)?;
Ok(())
}
#[cfg(feature = "checkpointing")]
pub fn resume_from_checkpoint_with_params(
&mut self,
checkpoint_path: &std::path::Path,
new_params: OQNLPParams,
) -> Result<(), OQNLPError> {
if let Some(ref manager) = self.checkpoint_manager {
let mut checkpoint = manager.load_checkpoint_from_path(checkpoint_path)?;
let old_iterations = checkpoint.params.iterations;
checkpoint.params = new_params.clone();
self.restore_from_checkpoint(checkpoint)?;
if self.verbose {
println!(
"Resumed from {} with modified parameters. Iterations changed from {} to {}",
checkpoint_path.display(),
old_iterations,
new_params.iterations
);
}
}
Ok(())
}
#[cfg(feature = "checkpointing")]
fn restore_from_checkpoint(&mut self, checkpoint: OQNLPCheckpoint) -> Result<(), OQNLPError> {
let solution_count = checkpoint.solution_set.as_ref().map_or(0, |s| s.len());
self.params = checkpoint.params;
self.current_iteration = checkpoint.current_iteration;
self.merit_filter.update_threshold(checkpoint.merit_threshold);
self.solution_set = checkpoint.solution_set;
self.current_reference_set = Some(checkpoint.reference_set);
self.unchanged_cycles = checkpoint.unchanged_cycles;
self.current_seed = checkpoint.current_seed;
self.target_objective = checkpoint.target_objective;
self.exclude_out_of_bounds = checkpoint.exclude_out_of_bounds;
#[cfg(feature = "rayon")]
{
self.batch_iterations = checkpoint.batch_iterations;
}
#[cfg(feature = "rayon")]
{
self.enable_parallel = checkpoint.enable_parallel;
}
self.abs_tol = checkpoint.abs_tol;
self.rel_tol = checkpoint.rel_tol;
self.distance_filter.set_solutions(checkpoint.distance_filter_solutions);
if self.verbose {
println!(
"Resumed from checkpoint at iteration {} with {} solutions",
checkpoint.current_iteration, solution_count
);
}
Ok(())
}
#[cfg(feature = "checkpointing")]
fn create_checkpoint(&self) -> OQNLPCheckpoint {
let elapsed_time =
self.start_time.map(|start| start.elapsed().as_secs_f64()).unwrap_or(0.0);
OQNLPCheckpoint {
params: self.params.clone(),
current_iteration: self.current_iteration,
merit_threshold: self.merit_filter.threshold,
solution_set: self.solution_set.clone(),
reference_set: self.current_reference_set.clone().unwrap_or_default(),
unchanged_cycles: self.unchanged_cycles,
elapsed_time,
distance_filter_solutions: self.distance_filter.get_solutions().clone(),
current_seed: self.current_seed,
target_objective: self.target_objective,
exclude_out_of_bounds: self.exclude_out_of_bounds,
#[cfg(feature = "rayon")]
batch_iterations: self.batch_iterations,
#[cfg(feature = "rayon")]
enable_parallel: self.enable_parallel,
abs_tol: self.abs_tol,
rel_tol: self.rel_tol,
timestamp: chrono::Utc::now().to_rfc3339(),
}
}
#[cfg(feature = "checkpointing")]
fn maybe_save_checkpoint(&self) -> Result<(), OQNLPError> {
if let Some(ref manager) = self.checkpoint_manager {
if self.current_iteration % manager.config().save_frequency == 0 {
let checkpoint = self.create_checkpoint();
let saved_path = manager.save_checkpoint(&checkpoint, self.current_iteration)?;
if self.verbose {
println!("Checkpoint saved to: {}", saved_path.display());
}
}
}
Ok(())
}
#[cfg(feature = "checkpointing")]
fn maybe_save_final_checkpoint(&self) -> Result<(), OQNLPError> {
if let Some(manager) = &self.checkpoint_manager {
let checkpoint = self.create_checkpoint();
let saved_path = manager.save_checkpoint(&checkpoint, self.current_iteration)?;
if self.verbose {
println!("Final checkpoint saved to: {}", saved_path.display());
}
}
Ok(())
}
fn adjust_threshold(&mut self, current_threshold: f64) {
let new_threshold: f64 =
current_threshold + self.params.threshold_factor * (1.0 + current_threshold.abs());
self.merit_filter.update_threshold(new_threshold);
}
fn process_batch_sequential(
&mut self,
batch: &[(usize, &Array1<f64>)],
unchanged_cycles: &mut usize,
start_iter: usize,
resumed_from_checkpoint: bool,
ref_objectives: Option<&Vec<f64>>,
) -> Result<(), OQNLPError> {
for &(local_iter, trial) in batch {
if resumed_from_checkpoint && local_iter < start_iter {
continue;
}
#[cfg(feature = "checkpointing")]
{
self.current_iteration = local_iter;
self.unchanged_cycles = *unchanged_cycles;
}
#[cfg(feature = "checkpointing")]
{
self.current_seed = self.params.seed + self.current_iteration as u64;
}
let trial = trial.clone();
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() {
if let Some(stage2) = observer.stage2_mut() {
stage2.set_iteration(local_iter);
stage2.set_unchanged_cycles(*unchanged_cycles);
stage2.set_threshold_value(self.merit_filter.threshold);
}
}
}
let obj: f64 = if let Some(objectives) = ref_objectives {
let ref_set_index = local_iter % objectives.len();
objectives[ref_set_index]
} else {
self.problem.objective(&trial).map_err(|_| {
OQNLPError::ObjectiveFunctionEvaluationFailed { stage: "unknown".to_string() }
})?
};
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() && ref_objectives.is_none() {
if let Some(stage2) = observer.stage2_mut() {
stage2.add_function_evaluations(1);
}
}
}
if self.should_start_local(&trial, obj)? {
self.merit_filter.update_threshold(obj);
let (local_trial, eval_count) = if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() {
self.local_solver.solve_with_tracking(trial, true)?
} else {
let sol = self.local_solver.solve(trial)?;
(sol, 0)
}
} else {
let sol = self.local_solver.solve(trial)?;
(sol, 0)
};
let added: bool = self.process_local_solution(local_trial.clone())?;
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() {
if let Some(stage2) = observer.stage2_mut() {
stage2.add_local_solver_call(added);
stage2.add_function_evaluations(eval_count as usize);
if let Some(ref sol_set) = self.solution_set {
if let Some(best) = sol_set.best_solution() {
stage2.set_best_solution(best.objective, &best.point);
}
stage2.set_solution_set_size(sol_set.len());
}
}
}
}
if self.verbose && added {
println!(
"Stage 2, iteration {}: Added local solution found with objective = {:.8}",
local_iter, local_trial.objective
);
println!("x0 = {}", local_trial.point);
}
if self.target_objective_reached() {
if self.verbose {
println!(
"Stage 2, iteration {}: Target objective {:.8} reached. Stopping optimization.",
local_iter,
self.target_objective.unwrap()
);
}
return Ok(());
}
} else {
*unchanged_cycles += 1;
if *unchanged_cycles >= self.params.wait_cycle {
if self.verbose {
println!(
"Stage 2, iteration {}: Adjusting threshold from {:.8} to {:.8}",
local_iter,
self.merit_filter.threshold,
self.merit_filter.threshold + 0.1 * self.merit_filter.threshold.abs()
);
}
self.adjust_threshold(self.merit_filter.threshold);
*unchanged_cycles = 0;
}
}
if let Some(ref mut observer) = self.observer {
if observer.should_invoke_callback(local_iter) {
observer.invoke_callback();
}
}
#[cfg(feature = "checkpointing")]
self.maybe_save_checkpoint()?;
}
Ok(())
}
#[cfg(feature = "rayon")]
fn process_batch_parallel(
&mut self,
batch: &[(usize, &Array1<f64>)],
unchanged_cycles: &mut usize,
start_iter: usize,
resumed_from_checkpoint: bool,
ref_objectives: Option<&Vec<f64>>,
) -> Result<(), OQNLPError> {
let filtered_batch: Vec<_> = if resumed_from_checkpoint {
batch.iter().filter(|&&(local_iter, _)| local_iter >= start_iter).copied().collect()
} else {
batch.to_vec()
};
if filtered_batch.is_empty() {
return Ok(());
}
let batch_results: Result<Vec<(usize, Array1<f64>, f64)>, OQNLPError> = filtered_batch
.par_iter()
.map(|&(local_iter, trial)| {
let obj = if let Some(objectives) = ref_objectives {
let ref_set_index = local_iter % objectives.len();
objectives[ref_set_index]
} else {
self.problem.objective(trial).map_err(|_| {
OQNLPError::ObjectiveFunctionEvaluationFailed {
stage: "unknown".to_string(),
}
})?
};
Ok((local_iter, trial.clone(), obj))
})
.collect();
let batch_results = batch_results?;
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() && ref_objectives.is_none() {
if let Some(stage2) = observer.stage2_mut() {
stage2.add_function_evaluations(batch_results.len());
}
}
}
let (local_candidates, quick_candidates): (Vec<_>, Vec<_>) =
batch_results.into_iter().partition(|(_, trial, obj)| {
let passes_merit = *obj <= self.merit_filter.threshold;
let passes_distance = self.distance_filter.check(trial);
passes_merit && passes_distance
});
for (local_iter, _trial, _obj) in quick_candidates {
#[cfg(feature = "checkpointing")]
{
self.current_iteration = local_iter;
self.unchanged_cycles = *unchanged_cycles;
self.current_seed = self.params.seed + self.current_iteration as u64;
}
*unchanged_cycles += 1;
if *unchanged_cycles >= self.params.wait_cycle {
if self.verbose {
println!(
"Stage 2, iteration {}: Adjusting threshold from {:.8} to {:.8}",
local_iter,
self.merit_filter.threshold,
self.merit_filter.threshold + 0.1 * self.merit_filter.threshold.abs()
);
}
self.adjust_threshold(self.merit_filter.threshold);
*unchanged_cycles = 0;
}
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() {
if let Some(stage2) = observer.stage2_mut() {
stage2.set_iteration(local_iter);
stage2.set_unchanged_cycles(*unchanged_cycles);
stage2.set_threshold_value(self.merit_filter.threshold);
}
}
}
if let Some(ref mut observer) = self.observer {
if observer.should_invoke_callback(local_iter) {
observer.invoke_callback();
}
}
#[cfg(feature = "checkpointing")]
self.maybe_save_checkpoint()?;
}
if !local_candidates.is_empty() {
#[cfg(feature = "rayon")]
if local_candidates.len() >= 2 && self.enable_parallel {
let problem = self.problem.clone();
let solver_type = self.params.local_solver_type.clone();
let solver_config = self.params.local_solver_config.clone();
let track_evals =
self.observer.as_ref().map(|obs| obs.should_observe_stage2()).unwrap_or(false);
let local_results: Result<Vec<_>, OQNLPError> = local_candidates
.par_iter()
.map(|(local_iter, trial, obj)| {
let local_solver = LocalSolver::new(
problem.clone(),
solver_type.clone(),
solver_config.clone(),
);
let (local_solution, eval_count) = if track_evals {
local_solver.solve_with_tracking(trial.clone(), true)?
} else {
let sol = local_solver.solve(trial.clone())?;
(sol, 0)
};
Ok((*local_iter, trial.clone(), *obj, local_solution, eval_count))
})
.collect();
let local_results = local_results?;
for (local_iter, _trial, obj, local_solution, eval_count) in local_results {
#[cfg(feature = "checkpointing")]
{
self.current_iteration = local_iter;
self.unchanged_cycles = *unchanged_cycles;
self.current_seed = self.params.seed + self.current_iteration as u64;
}
self.merit_filter.update_threshold(obj);
let added = self.process_local_solution(local_solution.clone())?;
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() {
if let Some(stage2) = observer.stage2_mut() {
stage2.add_local_solver_call(added);
stage2.add_function_evaluations(eval_count as usize);
if let Some(ref sol_set) = self.solution_set {
if let Some(best) = sol_set.best_solution() {
stage2.set_best_solution(best.objective, &best.point);
}
stage2.set_solution_set_size(sol_set.len());
}
}
}
}
if self.verbose && added {
println!(
"Stage 2, iteration {}: Added local solution found with objective = {:.8}",
local_iter, local_solution.objective
);
println!("x0 = {}", local_solution.point);
}
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() {
if let Some(stage2) = observer.stage2_mut() {
stage2.set_iteration(local_iter);
stage2.set_unchanged_cycles(*unchanged_cycles);
stage2.set_threshold_value(self.merit_filter.threshold);
}
}
}
if let Some(ref mut observer) = self.observer {
if observer.should_invoke_callback(local_iter) {
observer.invoke_callback();
}
}
if self.target_objective_reached() {
if self.verbose {
println!(
"Stage 2, iteration {}: Target objective {:.8} reached. Stopping optimization.",
local_iter,
self.target_objective.unwrap()
);
}
return Ok(());
}
#[cfg(feature = "checkpointing")]
self.maybe_save_checkpoint()?;
}
} else {
for (local_iter, trial, obj) in local_candidates {
#[cfg(feature = "checkpointing")]
{
self.current_iteration = local_iter;
self.unchanged_cycles = *unchanged_cycles;
self.current_seed = self.params.seed + self.current_iteration as u64;
}
self.merit_filter.update_threshold(obj);
let (local_trial, eval_count) = if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() {
self.local_solver.solve_with_tracking(trial, true)?
} else {
let sol = self.local_solver.solve(trial)?;
(sol, 0)
}
} else {
let sol = self.local_solver.solve(trial)?;
(sol, 0)
};
let added = self.process_local_solution(local_trial.clone())?;
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() {
if let Some(stage2) = observer.stage2_mut() {
stage2.add_local_solver_call(added);
stage2.add_function_evaluations(eval_count as usize);
if let Some(ref sol_set) = self.solution_set {
if let Some(best) = sol_set.best_solution() {
stage2.set_best_solution(best.objective, &best.point);
}
stage2.set_solution_set_size(sol_set.len());
}
}
}
}
if let Some(ref mut observer) = self.observer {
if observer.should_observe_stage2() {
if let Some(stage2) = observer.stage2_mut() {
stage2.set_iteration(local_iter);
stage2.set_unchanged_cycles(*unchanged_cycles);
stage2.set_threshold_value(self.merit_filter.threshold);
}
}
}
if let Some(ref mut observer) = self.observer {
observer.invoke_callback();
}
if self.verbose && added {
println!(
"Stage 2, iteration {}: Added local solution found with objective = {:.8}",
local_iter, local_trial.objective
);
println!("x0 = {}", local_trial.point);
}
if self.target_objective_reached() {
if self.verbose {
println!(
"Stage 2, iteration {}: Target objective {:.8} reached. Stopping optimization.",
local_iter,
self.target_objective.unwrap()
);
}
return Ok(());
}
#[cfg(feature = "checkpointing")]
self.maybe_save_checkpoint()?;
}
}
#[cfg(not(feature = "rayon"))]
{
for (local_iter, trial, obj) in local_candidates {
#[cfg(feature = "checkpointing")]
{
self.current_iteration = local_iter;
self.unchanged_cycles = *unchanged_cycles;
self.current_seed = self.params.seed + self.current_iteration as u64;
}
self.merit_filter.update_threshold(obj);
let local_trial = self
.local_solver
.solve(trial)
.map_err(|e| OQNLPError::LocalSolverError(e.to_string()))?;
let added = self.process_local_solution(local_trial.clone())?;
if self.verbose && added {
println!(
"Stage 2, iteration {}: Added local solution found with objective = {:.8}",
local_iter, local_trial.objective
);
println!("x0 = {}", local_trial.point);
}
if self.target_objective_reached() {
if self.verbose {
println!(
"Stage 2, iteration {}: Target objective {:.8} reached. Stopping optimization.",
local_iter,
self.target_objective.unwrap()
);
}
return Ok(());
}
#[cfg(feature = "checkpointing")]
self.maybe_save_checkpoint()?;
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests_oqnlp {
use super::*;
use crate::types::EvaluationError;
use ndarray::{Array1, Array2, array};
#[derive(Clone)]
struct DummyProblem;
impl Problem for DummyProblem {
fn objective(&self, trial: &Array1<f64>) -> Result<f64, EvaluationError> {
Ok(trial.sum())
}
fn gradient(&self, trial: &Array1<f64>) -> Result<Array1<f64>, EvaluationError> {
Ok(Array1::ones(trial.len()))
}
fn variable_bounds(&self) -> Array2<f64> {
array![[-5.0, 5.0], [-5.0, 5.0], [-5.0, 5.0]]
}
}
#[derive(Clone)]
struct SixHumpCamel;
impl Problem for SixHumpCamel {
fn objective(&self, x: &Array1<f64>) -> Result<f64, EvaluationError> {
Ok((4.0 - 2.1 * x[0].powi(2) + x[0].powi(4) / 3.0) * x[0].powi(2)
+ x[0] * x[1]
+ (-4.0 + 4.0 * x[1].powi(2)) * x[1].powi(2))
}
fn gradient(&self, x: &Array1<f64>) -> Result<Array1<f64>, EvaluationError> {
Ok(array![
(8.0 - 8.4 * x[0].powi(2) + 2.0 * x[0].powi(4)) * x[0] + x[1],
x[0] + (-8.0 + 16.0 * x[1].powi(2)) * x[1]
])
}
fn variable_bounds(&self) -> Array2<f64> {
array![[-3.0, 3.0], [-2.0, 2.0]]
}
}
#[test]
fn test_process_local_solution_new() {
let problem: DummyProblem = DummyProblem;
let params: OQNLPParams = OQNLPParams::default();
let mut oqnlp: OQNLP<DummyProblem> = OQNLP::new(problem, params).unwrap();
let trial = Array1::from(vec![1.0, 2.0, 3.0]);
let ls: LocalSolution = LocalSolution { objective: trial.sum(), point: trial.clone() };
let added: bool = oqnlp.process_local_solution(ls.clone()).unwrap();
assert!(added);
let sol_set: SolutionSet = oqnlp.solution_set.unwrap();
assert_eq!(sol_set.len(), 1);
assert!((sol_set[0].objective - ls.objective).abs() < 1e-6);
}
#[test]
fn test_process_local_solution_duplicate() {
let problem: DummyProblem = DummyProblem;
let params: OQNLPParams = OQNLPParams::default();
let mut oqnlp: OQNLP<DummyProblem> = OQNLP::new(problem, params).unwrap();
let trial = Array1::from(vec![1.0, 2.0, 3.0]);
let ls: LocalSolution = LocalSolution { objective: trial.sum(), point: trial.clone() };
oqnlp.process_local_solution(ls.clone()).unwrap();
let added: bool = oqnlp.process_local_solution(ls.clone()).unwrap();
let sol_set = oqnlp.solution_set.unwrap();
assert_eq!(sol_set.len(), 1);
assert!(!added);
}
#[test]
fn test_process_local_solution_better() {
let problem: DummyProblem = DummyProblem;
let params: OQNLPParams = OQNLPParams::default();
let mut oqnlp: OQNLP<DummyProblem> = OQNLP::new(problem, params).unwrap();
let trial1 = Array1::from(vec![2.0, 2.0, 2.0]);
let ls1: LocalSolution = LocalSolution { objective: trial1.sum(), point: trial1.clone() };
oqnlp.process_local_solution(ls1).unwrap();
let trial2 = Array1::from(vec![1.0, 1.0, 1.0]);
let ls2: LocalSolution = LocalSolution { objective: trial2.sum(), point: trial2.clone() };
let added: bool = oqnlp.process_local_solution(ls2.clone()).unwrap();
let sol_set: SolutionSet = oqnlp.solution_set.unwrap();
assert_eq!(sol_set.len(), 1);
assert!((sol_set[0].objective - ls2.objective).abs() < 1e-6);
assert!(!added);
}
#[test]
fn test_should_start_local() {
let problem: DummyProblem = DummyProblem;
let params: OQNLPParams = OQNLPParams::default();
let mut oqnlp: OQNLP<DummyProblem> = OQNLP::new(problem, params).unwrap();
oqnlp.merit_filter.update_threshold(10.0);
let trial = Array1::from(vec![10.0, 10.0, 10.0]);
let obj: f64 = trial.sum(); let start: bool = oqnlp.should_start_local(&trial, obj).unwrap();
assert!(!start);
let trial2 = Array1::from(vec![1.0, 1.0, 1.0]);
let obj2: f64 = trial2.sum(); let start2: bool = oqnlp.should_start_local(&trial2, obj2).unwrap();
assert!(start2);
}
#[test]
fn test_adjust_threshold() {
let problem: DummyProblem = DummyProblem;
let params: OQNLPParams = OQNLPParams::default();
let mut oqnlp: OQNLP<DummyProblem> = OQNLP::new(problem, params).unwrap();
oqnlp.adjust_threshold(10.0);
assert!((oqnlp.merit_filter.threshold - 12.2).abs() < f64::EPSILON);
}
#[test]
fn test_max_time() {
let problem: DummyProblem = DummyProblem;
let params: OQNLPParams = OQNLPParams::default();
let oqnlp: OQNLP<DummyProblem> = OQNLP::new(problem, params).unwrap();
let oqnlp: OQNLP<DummyProblem> = oqnlp.max_time(10.0);
assert_eq!(oqnlp.max_time, Some(10.0));
let problem: SixHumpCamel = SixHumpCamel;
let params: OQNLPParams =
OQNLPParams { iterations: 250, population_size: 1500, ..Default::default() };
let mut oqnlp: OQNLP<SixHumpCamel> =
OQNLP::new(problem.clone(), params.clone()).unwrap().verbose().max_time(60.0);
let sol_set: SolutionSet = oqnlp.run().unwrap();
assert_eq!(sol_set.len(), 2);
let mut oqnlp: OQNLP<SixHumpCamel> =
OQNLP::new(problem, params).unwrap().verbose().max_time(0.00000000001);
let sol_set: SolutionSet = oqnlp.run().unwrap();
assert_eq!(sol_set.len(), 1);
}
#[test]
fn test_oqnlp_params_invalid_population_size() {
let problem: DummyProblem = DummyProblem {};
let params: OQNLPParams = OQNLPParams {
population_size: 1, ..OQNLPParams::default()
};
let oqnlp = OQNLP::new(problem, params);
assert!(matches!(oqnlp, Err(OQNLPError::InvalidPopulationSize(1))));
}
#[test]
fn test_oqnlp_params_invalid_threshold_factor_zero() {
let problem: DummyProblem = DummyProblem {};
let params: OQNLPParams = OQNLPParams {
threshold_factor: 0.0, ..OQNLPParams::default()
};
let oqnlp = OQNLP::new(problem, params);
assert!(matches!(oqnlp, Err(OQNLPError::InvalidThresholdFactor(0.0))));
}
#[test]
fn test_oqnlp_params_invalid_threshold_factor_negative() {
let problem: DummyProblem = DummyProblem {};
let params: OQNLPParams = OQNLPParams {
threshold_factor: -0.5, ..OQNLPParams::default()
};
let oqnlp = OQNLP::new(problem, params);
assert!(matches!(oqnlp, Err(OQNLPError::InvalidThresholdFactor(_))));
}
#[test]
#[cfg(feature = "progress_bar")]
fn test_progress_bar() {
use kdam::term;
use std::io::{IsTerminal, stderr};
term::init(stderr().is_terminal());
let problem = SixHumpCamel;
let params = OQNLPParams {
iterations: 5, population_size: 10, ..Default::default()
};
let mut oqnlp = OQNLP::new(problem, params).unwrap().verbose();
let result = oqnlp.run();
assert!(result.is_ok(), "OQNLP should run successfully with progress bar");
let sol_set = result.unwrap();
assert!(!sol_set.is_empty(), "Should find at least one solution");
}
#[test]
fn test_target_objective() {
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 50, population_size: 100, ..Default::default() };
let mut oqnlp = OQNLP::new(problem.clone(), params.clone()).unwrap().target_objective(-0.5);
let result = oqnlp.run();
assert!(result.is_ok(), "OQNLP should run successfully");
let sol_set = result.unwrap();
let best = sol_set.best_solution().unwrap();
assert!(
best.objective <= -0.5,
"Best objective {} should be <= target -0.5",
best.objective
);
let mut oqnlp2 = OQNLP::new(problem, params).unwrap().target_objective(-10.0);
let result2 = oqnlp2.run();
assert!(result2.is_ok(), "OQNLP should run successfully even if target not reached");
let sol_set2 = result2.unwrap();
let best2 = sol_set2.best_solution().unwrap();
assert!(
best2.objective > -10.0,
"Best objective {} should be > impossible target -10.0",
best2.objective
);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_resume_with_modified_params() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_resume");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let initial_params =
OQNLPParams { iterations: 10, population_size: 20, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_resume".to_string(),
save_frequency: 2,
keep_all: false,
auto_resume: true,
};
let mut oqnlp = OQNLP::new(problem.clone(), initial_params.clone())
.unwrap()
.with_checkpointing(checkpoint_config.clone())
.unwrap()
.verbose();
let _result = oqnlp.run();
let modified_params = OQNLPParams {
iterations: 25, population_size: 30, ..initial_params
};
let mut oqnlp2 = OQNLP::new(problem.clone(), modified_params.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap()
.verbose();
let resumed = oqnlp2.resume_with_modified_params(modified_params.clone());
assert!(resumed.is_ok(), "Resume with modified params should succeed");
assert!(resumed.unwrap(), "Should have resumed from checkpoint");
assert_eq!(oqnlp2.params.iterations, 25);
assert_eq!(oqnlp2.params.population_size, 30);
let empty_checkpoint_dir = env::temp_dir().join("globalsearch_test_empty");
std::fs::create_dir_all(&empty_checkpoint_dir).expect("Failed to create test directory");
let empty_checkpoint_config = CheckpointConfig {
checkpoint_dir: empty_checkpoint_dir.clone(),
checkpoint_name: "nonexistent".to_string(),
save_frequency: 2,
keep_all: false,
auto_resume: true,
};
let mut oqnlp3 = OQNLP::new(problem, modified_params.clone())
.unwrap()
.with_checkpointing(empty_checkpoint_config)
.unwrap();
let not_resumed = oqnlp3.resume_with_modified_params(modified_params);
assert!(not_resumed.is_ok(), "Should handle no checkpoint gracefully");
assert!(!not_resumed.unwrap(), "Should return false when no checkpoint exists");
let _ = std::fs::remove_dir_all(&checkpoint_dir);
let _ = std::fs::remove_dir_all(&empty_checkpoint_dir);
}
#[test]
fn test_exclude_out_of_bounds() {
let problem = DummyProblem;
let params = OQNLPParams { iterations: 5, population_size: 10, ..Default::default() };
let mut oqnlp =
OQNLP::new(problem.clone(), params.clone()).unwrap().exclude_out_of_bounds().verbose();
assert!(oqnlp.exclude_out_of_bounds);
let out_of_bounds_solution = LocalSolution {
point: Array1::from(vec![10.0, 10.0, 10.0]), objective: 30.0,
};
let within_bounds_solution = LocalSolution {
point: Array1::from(vec![1.0, 2.0, 3.0]), objective: 6.0,
};
let added_out_of_bounds = oqnlp.process_local_solution(out_of_bounds_solution).unwrap();
assert!(!added_out_of_bounds);
assert!(oqnlp.solution_set.is_none());
let added_within_bounds =
oqnlp.process_local_solution(within_bounds_solution.clone()).unwrap();
assert!(added_within_bounds);
assert!(oqnlp.solution_set.is_some());
let sol_set = oqnlp.solution_set.unwrap();
assert_eq!(sol_set.len(), 1);
assert!((sol_set[0].objective - within_bounds_solution.objective).abs() < 1e-6);
let mut oqnlp2 = OQNLP::new(problem, params).unwrap();
assert!(!oqnlp2.exclude_out_of_bounds);
let out_of_bounds_solution2 =
LocalSolution { point: Array1::from(vec![15.0, 20.0, 25.0]), objective: 60.0 };
let added_out_of_bounds2 =
oqnlp2.process_local_solution(out_of_bounds_solution2.clone()).unwrap();
assert!(added_out_of_bounds2);
assert!(oqnlp2.solution_set.is_some());
let sol_set2 = oqnlp2.solution_set.unwrap();
assert_eq!(sol_set2.len(), 1);
assert!((sol_set2[0].objective - out_of_bounds_solution2.objective).abs() < 1e-6);
}
#[test]
fn test_exclude_out_of_bounds_with_target_objective() {
let problem = DummyProblem;
let params = OQNLPParams { iterations: 5, population_size: 10, ..Default::default() };
let mut oqnlp = OQNLP::new(problem, params)
.unwrap()
.exclude_out_of_bounds()
.target_objective(50.0)
.verbose();
let out_of_bounds_good_obj = LocalSolution {
point: Array1::from(vec![10.0, 10.0, 10.0]), objective: 30.0, };
let within_bounds_good_obj = LocalSolution {
point: Array1::from(vec![1.0, 2.0, 3.0]), objective: 40.0, };
let within_bounds_bad_obj = LocalSolution {
point: Array1::from(vec![0.0, 0.0, 0.0]), objective: 60.0, };
oqnlp.process_local_solution(out_of_bounds_good_obj).unwrap();
assert!(!oqnlp.target_objective_reached());
oqnlp.process_local_solution(within_bounds_good_obj).unwrap();
assert!(oqnlp.target_objective_reached());
oqnlp.solution_set = None;
oqnlp.process_local_solution(within_bounds_bad_obj).unwrap();
assert!(!oqnlp.target_objective_reached()); }
#[test]
fn test_is_within_bounds() {
let problem = DummyProblem; let params = OQNLPParams::default();
let oqnlp = OQNLP::new(problem, params).unwrap();
let within_bounds = Array1::from(vec![1.0, 2.0, 3.0]);
assert!(oqnlp.is_within_bounds(&within_bounds));
let at_lower_bound = Array1::from(vec![-5.0, -5.0, -5.0]);
assert!(oqnlp.is_within_bounds(&at_lower_bound));
let at_upper_bound = Array1::from(vec![5.0, 5.0, 5.0]);
assert!(oqnlp.is_within_bounds(&at_upper_bound));
let below_lower_bound = Array1::from(vec![-6.0, 0.0, 0.0]);
assert!(!oqnlp.is_within_bounds(&below_lower_bound));
let above_upper_bound = Array1::from(vec![0.0, 6.0, 0.0]);
assert!(!oqnlp.is_within_bounds(&above_upper_bound));
let empty_point = Array1::from(vec![]);
assert!(oqnlp.is_within_bounds(&empty_point)); }
#[test]
#[cfg(feature = "checkpointing")]
fn test_exclude_out_of_bounds_checkpointing() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_exclude_bounds");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = DummyProblem;
let params = OQNLPParams { iterations: 5, population_size: 10, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_exclude_bounds".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: false,
};
let oqnlp = OQNLP::new(problem.clone(), params.clone())
.unwrap()
.with_checkpointing(checkpoint_config.clone())
.unwrap()
.exclude_out_of_bounds()
.verbose();
let checkpoint = oqnlp.create_checkpoint();
assert!(checkpoint.exclude_out_of_bounds);
let mut oqnlp2 =
OQNLP::new(problem, params).unwrap().with_checkpointing(checkpoint_config).unwrap();
assert!(!oqnlp2.exclude_out_of_bounds);
oqnlp2.restore_from_checkpoint(checkpoint).unwrap();
assert!(oqnlp2.exclude_out_of_bounds);
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_resume_with_modified_params_decreased_population() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_decrease");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let initial_params =
OQNLPParams { iterations: 10, population_size: 30, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_decrease".to_string(),
save_frequency: 2,
keep_all: false,
auto_resume: true,
};
let mut oqnlp = OQNLP::new(problem.clone(), initial_params.clone())
.unwrap()
.with_checkpointing(checkpoint_config.clone())
.unwrap()
.verbose();
let _result = oqnlp.run();
let modified_params = OQNLPParams {
iterations: 15,
population_size: 20, ..initial_params
};
let mut oqnlp2 = OQNLP::new(problem, modified_params.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap()
.verbose();
let resumed = oqnlp2.resume_with_modified_params(modified_params);
assert!(resumed.is_ok(), "Resume with decreased population should succeed");
assert!(resumed.unwrap(), "Should have resumed from checkpoint");
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_expand_reference_set() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_expand");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 5, population_size: 10, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_expand".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: true,
};
let oqnlp = OQNLP::new(problem.clone(), params.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap();
let mut ref_set = vec![
Array1::from(vec![-1.0, -1.0]),
Array1::from(vec![0.0, 0.0]),
Array1::from(vec![1.0, 1.0]),
];
let old_size = ref_set.len();
let new_size = 8;
let result = oqnlp.expand_reference_set(&mut ref_set, old_size, new_size);
assert!(result.is_ok(), "expand_reference_set should succeed");
assert_eq!(ref_set.len(), new_size, "Reference set should be expanded to new size");
assert!(ref_set.contains(&Array1::from(vec![-1.0, -1.0])));
assert!(ref_set.contains(&Array1::from(vec![0.0, 0.0])));
assert!(ref_set.contains(&Array1::from(vec![1.0, 1.0])));
let bounds = problem.variable_bounds();
for point in &ref_set {
assert!(point.len() == bounds.nrows(), "Point should have correct dimensions");
for (i, &val) in point.iter().enumerate() {
assert!(
val >= bounds[[i, 0]] && val <= bounds[[i, 1]],
"Point value {} should be within bounds [{}, {}]",
val,
bounds[[i, 0]],
bounds[[i, 1]]
);
}
}
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_expand_reference_set_edge_cases() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_expand_edge");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 5, population_size: 10, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_expand_edge".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: true,
};
let oqnlp =
OQNLP::new(problem, params).unwrap().with_checkpointing(checkpoint_config).unwrap();
let mut ref_set1 = vec![
Array1::from(vec![-1.0, -1.0]),
Array1::from(vec![0.0, 0.0]),
Array1::from(vec![1.0, 1.0]),
];
let original_len = ref_set1.len();
let result1 = oqnlp.expand_reference_set(&mut ref_set1, 5, 3);
assert!(result1.is_ok(), "expand_reference_set should handle new_size <= old_size");
assert_eq!(
ref_set1.len(),
original_len,
"Reference set should not change when new_size <= old_size"
);
let mut ref_set2 = vec![Array1::from(vec![-1.0, -1.0]), Array1::from(vec![0.0, 0.0])];
let original_len2 = ref_set2.len();
let result2 = oqnlp.expand_reference_set(&mut ref_set2, original_len2, original_len2);
assert!(result2.is_ok(), "expand_reference_set should handle new_size == old_size");
assert_eq!(
ref_set2.len(),
original_len2,
"Reference set should not change when new_size == old_size"
);
let mut ref_set3: Vec<Array1<f64>> = vec![];
let result3 = oqnlp.expand_reference_set(&mut ref_set3, 0, 5);
assert!(result3.is_ok(), "expand_reference_set should handle empty reference set");
assert_eq!(ref_set3.len(), 5, "Reference set should be expanded from empty to new size");
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_expand_reference_set_integration() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_expand_integration");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let initial_params =
OQNLPParams { iterations: 5, population_size: 10, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_expand_integration".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: true,
};
let mut oqnlp = OQNLP::new(problem.clone(), initial_params.clone())
.unwrap()
.with_checkpointing(checkpoint_config.clone())
.unwrap()
.verbose();
let _result = oqnlp.run();
let modified_params = OQNLPParams {
iterations: 8,
population_size: 25, ..initial_params
};
let mut oqnlp2 = OQNLP::new(problem, modified_params.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap()
.verbose();
let resumed = oqnlp2.resume_with_modified_params(modified_params);
assert!(resumed.is_ok(), "Resume with larger population should succeed");
assert!(resumed.unwrap(), "Should have resumed from checkpoint");
assert_eq!(oqnlp2.params.population_size, 25);
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_resume_from_checkpoint_with_params() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_resume_specific");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let initial_params =
OQNLPParams { iterations: 8, population_size: 15, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_resume_specific".to_string(),
save_frequency: 2,
keep_all: true, auto_resume: false, };
let mut oqnlp = OQNLP::new(problem.clone(), initial_params.clone())
.unwrap()
.with_checkpointing(checkpoint_config.clone())
.unwrap()
.verbose();
let _result = oqnlp.run();
let checkpoint_files: Vec<_> = std::fs::read_dir(&checkpoint_dir)
.expect("Failed to read checkpoint directory")
.filter_map(|entry| {
let entry = entry.ok()?;
let path = entry.path();
if path.extension()? == "bin"
&& path.file_name()?.to_str()?.contains("test_resume_specific")
{
Some(path)
} else {
None
}
})
.collect();
assert!(!checkpoint_files.is_empty(), "Should have created at least one checkpoint file");
let checkpoint_path = &checkpoint_files[0];
let modified_params = OQNLPParams {
iterations: 20, population_size: 25, ..initial_params
};
let mut oqnlp2 = OQNLP::new(problem.clone(), modified_params.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap()
.verbose();
let result =
oqnlp2.resume_from_checkpoint_with_params(checkpoint_path, modified_params.clone());
assert!(result.is_ok(), "Resume from specific checkpoint should succeed");
assert_eq!(oqnlp2.params.iterations, 20);
assert_eq!(oqnlp2.params.population_size, 25);
let continued_result = oqnlp2.run();
assert!(continued_result.is_ok(), "Should be able to continue optimization after resume");
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_resume_from_checkpoint_with_params_nonexistent_file() {
use crate::types::CheckpointConfig;
use std::env;
use std::path::PathBuf;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_resume_nonexistent");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 5, population_size: 10, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_nonexistent".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: false,
};
let mut oqnlp = OQNLP::new(problem, params.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap();
let nonexistent_path = PathBuf::from("nonexistent_checkpoint.bin");
let result = oqnlp.resume_from_checkpoint_with_params(&nonexistent_path, params);
assert!(result.is_err(), "Should return error for nonexistent checkpoint file");
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_resume_from_checkpoint_with_params_no_manager() {
use std::path::PathBuf;
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 5, population_size: 10, ..Default::default() };
let mut oqnlp = OQNLP::new(problem, params.clone()).unwrap();
let dummy_path = PathBuf::from("dummy_checkpoint.bin");
let result = oqnlp.resume_from_checkpoint_with_params(&dummy_path, params);
assert!(result.is_ok(), "Should handle absence of checkpoint manager gracefully");
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_resume_from_checkpoint_with_params_various_params() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_resume_various");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let initial_params = OQNLPParams {
iterations: 6,
population_size: 12,
distance_factor: 0.1,
threshold_factor: 0.3,
..Default::default()
};
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_resume_various".to_string(),
save_frequency: 1,
keep_all: true,
auto_resume: false,
};
let mut oqnlp = OQNLP::new(problem.clone(), initial_params.clone())
.unwrap()
.with_checkpointing(checkpoint_config.clone())
.unwrap()
.verbose();
let _result = oqnlp.run();
let checkpoint_files: Vec<_> = std::fs::read_dir(&checkpoint_dir)
.expect("Failed to read checkpoint directory")
.filter_map(|entry| {
let entry = entry.ok()?;
let path = entry.path();
if path.extension()? == "bin" { Some(path) } else { None }
})
.collect();
assert!(!checkpoint_files.is_empty(), "Should have created checkpoint files");
let checkpoint_path = &checkpoint_files[0];
let modified_params1 = OQNLPParams {
iterations: 10, ..initial_params.clone()
};
let mut oqnlp1 = OQNLP::new(problem.clone(), modified_params1.clone())
.unwrap()
.with_checkpointing(checkpoint_config.clone())
.unwrap();
let result1 = oqnlp1.resume_from_checkpoint_with_params(checkpoint_path, modified_params1);
assert!(result1.is_ok(), "Should handle iterations-only modification");
assert_eq!(oqnlp1.params.iterations, 10);
let modified_params2 = OQNLPParams {
iterations: 25,
population_size: 30,
distance_factor: 0.05,
threshold_factor: 0.4,
..initial_params
};
let mut oqnlp2 = OQNLP::new(problem, modified_params2.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap();
let result2 = oqnlp2.resume_from_checkpoint_with_params(checkpoint_path, modified_params2);
assert!(result2.is_ok(), "Should handle multiple parameter modifications");
assert_eq!(oqnlp2.params.iterations, 25);
assert_eq!(oqnlp2.params.population_size, 30);
assert!((oqnlp2.params.distance_factor - 0.05).abs() < 1e-10);
assert!((oqnlp2.params.threshold_factor - 0.4).abs() < 1e-10);
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_restore_from_checkpoint() {
use crate::types::{CheckpointConfig, OQNLPCheckpoint};
use chrono;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_restore");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let initial_params = OQNLPParams {
iterations: 15,
population_size: 20,
distance_factor: 0.15,
threshold_factor: 0.25,
..Default::default()
};
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_restore".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: false,
};
let mut oqnlp = OQNLP::new(problem.clone(), initial_params.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap()
.verbose()
.target_objective(-0.8);
let solution1 = LocalSolution { objective: -0.5, point: Array1::from(vec![0.1, -0.7]) };
let solution2 = LocalSolution { objective: -0.3, point: Array1::from(vec![-0.9, 0.2]) };
let solution_set =
SolutionSet { solutions: Array1::from(vec![solution1.clone(), solution2.clone()]) };
let checkpoint = OQNLPCheckpoint {
params: OQNLPParams {
iterations: 25,
population_size: 30,
distance_factor: 0.08,
threshold_factor: 0.35,
..initial_params.clone()
},
current_iteration: 12,
merit_threshold: -0.2,
solution_set: Some(solution_set.clone()),
reference_set: vec![
Array1::from(vec![-1.5, 1.0]),
Array1::from(vec![0.5, -1.2]),
Array1::from(vec![2.0, 0.8]),
],
unchanged_cycles: 3,
elapsed_time: 45.67,
distance_filter_solutions: vec![solution1.clone(), solution2.clone()],
current_seed: 98765,
target_objective: Some(-0.9),
exclude_out_of_bounds: true,
#[cfg(feature = "rayon")]
batch_iterations: Some(2),
#[cfg(feature = "rayon")]
enable_parallel: false,
abs_tol: 1e-8,
rel_tol: 1e-6,
timestamp: chrono::Utc::now().to_rfc3339(),
};
let result = oqnlp.restore_from_checkpoint(checkpoint.clone());
assert!(result.is_ok(), "restore_from_checkpoint should succeed");
assert_eq!(oqnlp.params.iterations, 25, "Iterations should be restored");
assert_eq!(oqnlp.params.population_size, 30, "Population size should be restored");
assert!(
(oqnlp.params.distance_factor - 0.08).abs() < 1e-10,
"Distance factor should be restored"
);
assert!(
(oqnlp.params.threshold_factor - 0.35).abs() < 1e-10,
"Threshold factor should be restored"
);
assert_eq!(oqnlp.current_iteration, 12, "Current iteration should be restored");
assert!(
(oqnlp.merit_filter.threshold - (-0.2)).abs() < 1e-10,
"Merit threshold should be restored"
);
assert_eq!(oqnlp.unchanged_cycles, 3, "Unchanged cycles should be restored");
assert_eq!(oqnlp.current_seed, 98765, "Current seed should be restored");
assert_eq!(oqnlp.target_objective, Some(-0.9), "Target objective should be restored");
let restored_solution_set =
oqnlp.solution_set.as_ref().expect("Solution set should be restored");
assert_eq!(restored_solution_set.len(), 2, "Solution set should have 2 solutions");
assert!(
(restored_solution_set[0].objective - (-0.5)).abs() < 1e-10,
"First solution objective should match"
);
assert!(
(restored_solution_set[1].objective - (-0.3)).abs() < 1e-10,
"Second solution objective should match"
);
let restored_ref_set =
oqnlp.current_reference_set.as_ref().expect("Reference set should be restored");
assert_eq!(restored_ref_set.len(), 3, "Reference set should have 3 points");
assert_eq!(
restored_ref_set[0],
Array1::from(vec![-1.5, 1.0]),
"First reference point should match"
);
assert_eq!(
restored_ref_set[1],
Array1::from(vec![0.5, -1.2]),
"Second reference point should match"
);
assert_eq!(
restored_ref_set[2],
Array1::from(vec![2.0, 0.8]),
"Third reference point should match"
);
let distance_filter_solutions = oqnlp.distance_filter.get_solutions();
assert_eq!(distance_filter_solutions.len(), 2, "Distance filter should have 2 solutions");
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_restore_from_checkpoint_empty_solution_set() {
use crate::types::{CheckpointConfig, OQNLPCheckpoint};
use chrono;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_restore_empty");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 10, population_size: 15, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_restore_empty".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: false,
};
let mut oqnlp = OQNLP::new(problem, params.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap()
.verbose();
let checkpoint = OQNLPCheckpoint {
params: params.clone(),
current_iteration: 5,
merit_threshold: 100.0,
solution_set: None, reference_set: vec![Array1::from(vec![0.0, 0.0])],
unchanged_cycles: 0,
elapsed_time: 10.0,
distance_filter_solutions: vec![],
current_seed: 12345,
target_objective: None,
exclude_out_of_bounds: false,
#[cfg(feature = "rayon")]
batch_iterations: None,
#[cfg(feature = "rayon")]
enable_parallel: true,
abs_tol: 1e-8,
rel_tol: 1e-6,
timestamp: chrono::Utc::now().to_rfc3339(),
};
let result = oqnlp.restore_from_checkpoint(checkpoint);
assert!(result.is_ok(), "Should handle empty solution set gracefully");
assert!(oqnlp.solution_set.is_none(), "Solution set should remain None");
assert_eq!(oqnlp.current_iteration, 5, "Current iteration should be restored");
assert!(
(oqnlp.merit_filter.threshold - 100.0).abs() < 1e-10,
"Merit threshold should be restored"
);
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_restore_from_checkpoint_verbose_output() {
use crate::types::{CheckpointConfig, OQNLPCheckpoint};
use chrono;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_restore_verbose");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 8, population_size: 12, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_restore_verbose".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: false,
};
let mut oqnlp_verbose = OQNLP::new(problem.clone(), params.clone())
.unwrap()
.with_checkpointing(checkpoint_config.clone())
.unwrap()
.verbose();
let solution = LocalSolution { objective: -1.2, point: Array1::from(vec![0.5, -0.3]) };
let solution_set = SolutionSet { solutions: Array1::from(vec![solution]) };
let checkpoint_with_solutions = OQNLPCheckpoint {
params: params.clone(),
current_iteration: 7,
merit_threshold: -1.0,
solution_set: Some(solution_set),
reference_set: vec![Array1::from(vec![1.0, 1.0])],
unchanged_cycles: 2,
elapsed_time: 25.0,
distance_filter_solutions: vec![],
current_seed: 54321,
target_objective: None,
exclude_out_of_bounds: false,
#[cfg(feature = "rayon")]
batch_iterations: Some(6),
#[cfg(feature = "rayon")]
enable_parallel: false,
abs_tol: 1e-8,
rel_tol: 1e-6,
timestamp: chrono::Utc::now().to_rfc3339(),
};
let result_verbose = oqnlp_verbose.restore_from_checkpoint(checkpoint_with_solutions);
assert!(result_verbose.is_ok(), "Verbose restore should succeed");
let mut oqnlp_quiet = OQNLP::new(problem, params.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap();
let checkpoint_no_solutions = OQNLPCheckpoint {
params,
current_iteration: 3,
merit_threshold: 50.0,
solution_set: None,
reference_set: vec![Array1::from(vec![-1.0, -1.0])],
unchanged_cycles: 1,
elapsed_time: 15.0,
distance_filter_solutions: vec![],
current_seed: 11111,
target_objective: None,
exclude_out_of_bounds: false,
#[cfg(feature = "rayon")]
batch_iterations: None,
#[cfg(feature = "rayon")]
enable_parallel: true,
abs_tol: 1e-8,
rel_tol: 1e-6,
timestamp: chrono::Utc::now().to_rfc3339(),
};
let result_quiet = oqnlp_quiet.restore_from_checkpoint(checkpoint_no_solutions);
assert!(result_quiet.is_ok(), "Quiet restore should succeed");
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_restore_from_checkpoint_edge_cases() {
use crate::types::{CheckpointConfig, OQNLPCheckpoint};
use chrono;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_restore_edge");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 5, population_size: 8, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_restore_edge".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: false,
};
let mut oqnlp = OQNLP::new(problem, params.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap();
let checkpoint_zero = OQNLPCheckpoint {
params: params.clone(),
current_iteration: 0,
merit_threshold: 1000.0,
solution_set: None,
reference_set: vec![],
unchanged_cycles: 0,
elapsed_time: 0.0,
distance_filter_solutions: vec![],
current_seed: 1,
target_objective: None,
exclude_out_of_bounds: false,
#[cfg(feature = "rayon")]
batch_iterations: None,
#[cfg(feature = "rayon")]
enable_parallel: false,
abs_tol: 1e-8,
rel_tol: 1e-6,
timestamp: chrono::Utc::now().to_rfc3339(),
};
let result1 = oqnlp.restore_from_checkpoint(checkpoint_zero);
assert!(result1.is_ok(), "Should handle iteration 0 checkpoint");
assert_eq!(oqnlp.current_iteration, 0);
assert_eq!(oqnlp.unchanged_cycles, 0);
let checkpoint_large_threshold = OQNLPCheckpoint {
params: params.clone(),
current_iteration: 2,
merit_threshold: f64::MAX / 2.0,
solution_set: None,
reference_set: vec![Array1::from(vec![0.0, 0.0])],
unchanged_cycles: 1,
elapsed_time: 5.0,
distance_filter_solutions: vec![],
current_seed: 999,
target_objective: Some(f64::MIN / 2.0),
exclude_out_of_bounds: true,
#[cfg(feature = "rayon")]
batch_iterations: Some(8),
#[cfg(feature = "rayon")]
enable_parallel: false,
abs_tol: 1e-8,
rel_tol: 1e-6,
timestamp: chrono::Utc::now().to_rfc3339(),
};
let result2 = oqnlp.restore_from_checkpoint(checkpoint_large_threshold);
assert!(result2.is_ok(), "Should handle large threshold values");
assert!((oqnlp.merit_filter.threshold - (f64::MAX / 2.0)).abs() < f64::MAX / 4.0);
assert_eq!(oqnlp.target_objective, Some(f64::MIN / 2.0));
let checkpoint_many_cycles = OQNLPCheckpoint {
params,
current_iteration: 4,
merit_threshold: 0.0,
solution_set: None,
reference_set: vec![Array1::from(vec![1.0, -1.0])],
unchanged_cycles: 1000,
elapsed_time: 100.0,
distance_filter_solutions: vec![],
current_seed: 777,
target_objective: Some(0.0),
exclude_out_of_bounds: false,
#[cfg(feature = "rayon")]
batch_iterations: Some(1),
#[cfg(feature = "rayon")]
enable_parallel: true,
abs_tol: 1e-8,
rel_tol: 1e-6,
timestamp: chrono::Utc::now().to_rfc3339(),
};
let result3 = oqnlp.restore_from_checkpoint(checkpoint_many_cycles);
assert!(result3.is_ok(), "Should handle many unchanged cycles");
assert_eq!(oqnlp.unchanged_cycles, 1000);
assert_eq!(oqnlp.current_seed, 777);
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_create_checkpoint() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_create_checkpoint");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let params = OQNLPParams {
iterations: 20,
population_size: 25,
distance_factor: 0.12,
threshold_factor: 0.28,
seed: 42,
..Default::default()
};
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_create".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: false,
};
let mut oqnlp = OQNLP::new(problem.clone(), params.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap()
.target_objective(-0.75);
oqnlp.current_iteration = 8;
oqnlp.unchanged_cycles = 2;
oqnlp.current_seed = 98765;
oqnlp.merit_filter.update_threshold(-0.3);
let solution1 = LocalSolution { objective: -0.6, point: Array1::from(vec![0.2, -0.8]) };
let solution2 = LocalSolution { objective: -0.4, point: Array1::from(vec![-0.7, 0.3]) };
let solution_set =
SolutionSet { solutions: Array1::from(vec![solution1.clone(), solution2.clone()]) };
oqnlp.solution_set = Some(solution_set);
oqnlp.current_reference_set = Some(vec![
Array1::from(vec![-2.0, 1.5]),
Array1::from(vec![1.0, -1.0]),
Array1::from(vec![0.0, 0.0]),
]);
oqnlp.distance_filter.add_solution(solution1.clone());
oqnlp.distance_filter.add_solution(solution2.clone());
oqnlp.start_time = Some(std::time::Instant::now() - std::time::Duration::from_secs(30));
let checkpoint = oqnlp.create_checkpoint();
assert_eq!(checkpoint.params.iterations, 20, "Parameters should be captured");
assert_eq!(checkpoint.params.population_size, 25, "Parameters should be captured");
assert!(
(checkpoint.params.distance_factor - 0.12).abs() < 1e-10,
"Parameters should be captured"
);
assert!(
(checkpoint.params.threshold_factor - 0.28).abs() < 1e-10,
"Parameters should be captured"
);
assert_eq!(checkpoint.params.seed, 42, "Parameters should be captured");
assert_eq!(checkpoint.current_iteration, 8, "Current iteration should be captured");
assert!(
(checkpoint.merit_threshold - (-0.3)).abs() < 1e-10,
"Merit threshold should be captured"
);
assert_eq!(checkpoint.unchanged_cycles, 2, "Unchanged cycles should be captured");
assert_eq!(checkpoint.current_seed, 98765, "Current seed should be captured");
assert_eq!(checkpoint.target_objective, Some(-0.75), "Target objective should be captured");
let checkpoint_solutions =
checkpoint.solution_set.as_ref().expect("Solution set should be captured");
assert_eq!(checkpoint_solutions.len(), 2, "Solution set should have 2 solutions");
assert!(
(checkpoint_solutions[0].objective - (-0.6)).abs() < 1e-10,
"First solution should match"
);
assert!(
(checkpoint_solutions[1].objective - (-0.4)).abs() < 1e-10,
"Second solution should match"
);
assert_eq!(checkpoint.reference_set.len(), 3, "Reference set should have 3 points");
assert_eq!(
checkpoint.reference_set[0],
Array1::from(vec![-2.0, 1.5]),
"Reference point should match"
);
assert_eq!(
checkpoint.reference_set[1],
Array1::from(vec![1.0, -1.0]),
"Reference point should match"
);
assert_eq!(
checkpoint.reference_set[2],
Array1::from(vec![0.0, 0.0]),
"Reference point should match"
);
assert_eq!(
checkpoint.distance_filter_solutions.len(),
2,
"Distance filter should have 2 solutions"
);
assert!(
checkpoint.elapsed_time >= 29.0 && checkpoint.elapsed_time <= 31.0,
"Elapsed time should be approximately 30 seconds, got {}",
checkpoint.elapsed_time
);
assert!(
chrono::DateTime::parse_from_rfc3339(&checkpoint.timestamp).is_ok(),
"Timestamp should be valid RFC3339 format"
);
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_create_checkpoint_empty_state() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_create_empty");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 5, population_size: 8, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_create_empty".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: false,
};
let oqnlp = OQNLP::new(problem, params.clone())
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap();
let checkpoint = oqnlp.create_checkpoint();
assert_eq!(checkpoint.params.iterations, 5, "Parameters should be captured");
assert_eq!(checkpoint.params.population_size, 8, "Parameters should be captured");
assert_eq!(checkpoint.current_iteration, 0, "Should start at iteration 0");
assert_eq!(checkpoint.unchanged_cycles, 0, "Should start with 0 unchanged cycles");
assert_eq!(checkpoint.current_seed, params.seed, "Should have initial seed");
assert_eq!(checkpoint.target_objective, None, "Should have no target objective");
assert!(checkpoint.solution_set.is_none(), "Solution set should be None initially");
assert!(checkpoint.reference_set.is_empty(), "Reference set should be empty initially");
assert!(
checkpoint.distance_filter_solutions.is_empty(),
"Distance filter should be empty initially"
);
assert_eq!(checkpoint.elapsed_time, 0.0, "Elapsed time should be 0 when no start time");
assert!(
chrono::DateTime::parse_from_rfc3339(&checkpoint.timestamp).is_ok(),
"Timestamp should be valid RFC3339 format"
);
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_create_checkpoint_edge_cases() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_create_edge");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 10, population_size: 15, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_create_edge".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: false,
};
let mut oqnlp = OQNLP::new(problem, params)
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap()
.target_objective(f64::NEG_INFINITY);
oqnlp.current_iteration = usize::MAX / 2;
oqnlp.unchanged_cycles = 999999;
oqnlp.current_seed = u64::MAX / 2;
oqnlp.merit_filter.update_threshold(f64::MAX / 2.0);
oqnlp.start_time = Some(
std::time::Instant::now()
.checked_sub(std::time::Duration::from_secs(60))
.unwrap_or(std::time::Instant::now()),
);
let checkpoint1 = oqnlp.create_checkpoint();
assert_eq!(
checkpoint1.current_iteration,
usize::MAX / 2,
"Should handle large iteration numbers"
);
assert_eq!(checkpoint1.unchanged_cycles, 999999, "Should handle large unchanged cycles");
assert_eq!(checkpoint1.current_seed, u64::MAX / 2, "Should handle large seed values");
assert!(
(checkpoint1.merit_threshold - (f64::MAX / 2.0)).abs() < f64::MAX / 4.0,
"Should handle large threshold"
);
assert_eq!(
checkpoint1.target_objective,
Some(f64::NEG_INFINITY),
"Should handle extreme target objective"
);
assert!(
checkpoint1.elapsed_time >= 55.0 && checkpoint1.elapsed_time <= 65.0,
"Should handle elapsed time, got {}",
checkpoint1.elapsed_time
);
oqnlp.current_iteration = 0;
oqnlp.unchanged_cycles = 0;
oqnlp.current_seed = 1;
oqnlp.merit_filter.update_threshold(0.0);
oqnlp.target_objective = Some(0.0);
oqnlp.start_time = Some(std::time::Instant::now());
let checkpoint2 = oqnlp.create_checkpoint();
assert_eq!(checkpoint2.current_iteration, 0, "Should handle zero iteration");
assert_eq!(checkpoint2.unchanged_cycles, 0, "Should handle zero unchanged cycles");
assert_eq!(checkpoint2.current_seed, 1, "Should handle minimal seed");
assert!((checkpoint2.merit_threshold - 0.0).abs() < 1e-10, "Should handle zero threshold");
assert_eq!(checkpoint2.target_objective, Some(0.0), "Should handle zero target objective");
assert!(
checkpoint2.elapsed_time >= 0.0 && checkpoint2.elapsed_time <= 1.0,
"Should handle minimal elapsed time, got {}",
checkpoint2.elapsed_time
);
assert_ne!(checkpoint1.timestamp, checkpoint2.timestamp, "Timestamps should be different");
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "rayon")]
fn test_batch_iterations() {
let problem: DummyProblem = DummyProblem;
let params: OQNLPParams = OQNLPParams::default();
let oqnlp: OQNLP<DummyProblem> = OQNLP::new(problem, params).unwrap();
let oqnlp = oqnlp.batch_iterations(4);
assert_eq!(oqnlp.batch_iterations, Some(4));
let problem2: DummyProblem = DummyProblem;
let params2: OQNLPParams = OQNLPParams::default();
let oqnlp2 =
OQNLP::new(problem2, params2).unwrap().batch_iterations(8).max_time(30.0).verbose();
assert_eq!(oqnlp2.batch_iterations, Some(8));
assert_eq!(oqnlp2.max_time, Some(30.0));
assert!(oqnlp2.verbose);
}
#[test]
#[cfg(feature = "rayon")]
fn test_batch_processing_sequential() {
let problem: DummyProblem = DummyProblem;
let params: OQNLPParams = OQNLPParams {
iterations: 6, population_size: 10, ..Default::default()
};
let mut oqnlp =
OQNLP::new(problem.clone(), params.clone()).unwrap().batch_iterations(1).verbose();
let result = oqnlp.run();
assert!(result.is_ok(), "OQNLP should run successfully with batch_iterations = 1");
let sol_set = result.unwrap();
assert!(!sol_set.is_empty(), "Should find at least one solution");
let mut oqnlp2 = OQNLP::new(problem, params).unwrap().verbose();
let result2 = oqnlp2.run();
assert!(result2.is_ok(), "OQNLP should run successfully without batch_iterations");
let sol_set2 = result2.unwrap();
assert!(!sol_set2.is_empty(), "Should find at least one solution");
}
#[test]
#[cfg(feature = "rayon")]
fn test_batch_processing_parallel() {
let problem = SixHumpCamel;
let params = OQNLPParams {
iterations: 8, population_size: 12, ..Default::default()
};
let mut oqnlp =
OQNLP::new(problem.clone(), params.clone()).unwrap().batch_iterations(4).verbose();
let result = oqnlp.run();
assert!(result.is_ok(), "OQNLP should run successfully with parallel batch processing");
let sol_set = result.unwrap();
assert!(!sol_set.is_empty(), "Should find at least one solution with parallel processing");
let mut oqnlp2 = OQNLP::new(problem, params).unwrap().verbose();
let result2 = oqnlp2.run();
assert!(result2.is_ok(), "OQNLP should run successfully with auto batch size");
let sol_set2 = result2.unwrap();
assert!(!sol_set2.is_empty(), "Should find solutions with auto batch size");
}
#[test]
#[cfg(feature = "rayon")]
fn test_batch_iterations_with_target_objective() {
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 20, population_size: 30, ..Default::default() };
let mut oqnlp = OQNLP::new(problem, params)
.unwrap()
.batch_iterations(3)
.target_objective(-0.5) .verbose();
let result = oqnlp.run();
assert!(
result.is_ok(),
"OQNLP should run successfully with batch_iterations and target_objective"
);
let sol_set = result.unwrap();
assert!(!sol_set.is_empty(), "Should find at least one solution");
let best = sol_set.best_solution().unwrap();
assert!(
best.objective <= -0.5,
"Best objective {} should be <= target -0.5",
best.objective
);
}
#[test]
#[cfg(feature = "rayon")]
fn test_batch_iterations_with_max_time() {
let problem = SixHumpCamel;
let params = OQNLPParams {
iterations: 100, population_size: 200,
..Default::default()
};
let mut oqnlp = OQNLP::new(problem, params)
.unwrap()
.batch_iterations(5)
.max_time(0.1) .verbose();
let start_time = std::time::Instant::now();
let result = oqnlp.run();
let elapsed = start_time.elapsed().as_secs_f64();
assert!(result.is_ok(), "OQNLP should run successfully with batch_iterations and max_time");
assert!(elapsed < 2.0, "Should stop within reasonable time due to max_time limit");
let sol_set = result.unwrap();
assert!(!sol_set.is_empty(), "Should find at least one solution before timeout");
}
#[test]
#[cfg(feature = "rayon")]
fn test_various_batch_iterations_values() {
let problem: DummyProblem = DummyProblem;
let params: OQNLPParams =
OQNLPParams { iterations: 4, population_size: 8, ..Default::default() };
let mut oqnlp1 = OQNLP::new(problem.clone(), params.clone()).unwrap().batch_iterations(1);
let result1 = oqnlp1.run();
assert!(result1.is_ok(), "Should work with batch_iterations = 1");
let mut oqnlp2 = OQNLP::new(problem.clone(), params.clone()).unwrap().batch_iterations(2);
let result2 = oqnlp2.run();
assert!(result2.is_ok(), "Should work with batch_iterations = 2");
let mut oqnlp3 = OQNLP::new(problem.clone(), params.clone()).unwrap().batch_iterations(10); let result3 = oqnlp3.run();
assert!(result3.is_ok(), "Should work with batch_iterations > iterations");
let mut oqnlp4 = OQNLP::new(problem, params).unwrap();
let result4 = oqnlp4.run();
assert!(result4.is_ok(), "Should work without setting batch_iterations");
}
#[test]
#[cfg(feature = "rayon")]
fn test_parallel_processing_thread_usage() {
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 16, population_size: 20, ..Default::default() };
let num_threads = rayon::current_num_threads();
let mut oqnlp =
OQNLP::new(problem, params).unwrap().batch_iterations(num_threads).verbose();
let result = oqnlp.run();
assert!(result.is_ok(), "OQNLP should run successfully with parallel processing");
let sol_set = result.unwrap();
assert!(!sol_set.is_empty(), "Should find solutions with parallel processing");
}
#[test]
#[cfg(feature = "rayon")]
fn test_batch_iterations_error_handling() {
let problem: DummyProblem = DummyProblem;
let params: OQNLPParams =
OQNLPParams { iterations: 5, population_size: 10, ..Default::default() };
let mut oqnlp = OQNLP::new(problem, params).unwrap().batch_iterations(3);
let result = oqnlp.run();
assert!(result.is_ok(), "OQNLP should handle batch processing correctly");
let sol_set = result.unwrap();
assert!(!sol_set.is_empty(), "Should find solutions even with batch processing");
}
#[test]
#[cfg(all(feature = "rayon", feature = "checkpointing"))]
fn test_batch_iterations_with_checkpointing() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_batch_checkpoint");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let params = OQNLPParams { iterations: 8, population_size: 12, ..Default::default() };
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_batch".to_string(),
save_frequency: 2,
keep_all: false,
auto_resume: false,
};
let mut oqnlp = OQNLP::new(problem, params)
.unwrap()
.with_checkpointing(checkpoint_config)
.unwrap()
.batch_iterations(4)
.verbose();
let result = oqnlp.run();
assert!(result.is_ok(), "OQNLP should work with batch_iterations and checkpointing");
let sol_set = result.unwrap();
assert!(
!sol_set.is_empty(),
"Should find solutions with batch processing and checkpointing"
);
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "checkpointing")]
fn test_create_checkpoint_consistency() {
use crate::types::CheckpointConfig;
use std::env;
let checkpoint_dir = env::temp_dir().join("globalsearch_test_checkpoint_consistency");
std::fs::create_dir_all(&checkpoint_dir).expect("Failed to create test directory");
let problem = SixHumpCamel;
let params = OQNLPParams {
iterations: 15,
population_size: 20,
distance_factor: 0.08,
threshold_factor: 0.32,
seed: 12345,
..Default::default()
};
let checkpoint_config = CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
checkpoint_name: "test_consistency".to_string(),
save_frequency: 1,
keep_all: false,
auto_resume: false,
};
let mut oqnlp1 = OQNLP::new(problem.clone(), params.clone())
.unwrap()
.with_checkpointing(checkpoint_config.clone())
.unwrap()
.target_objective(-0.6)
.exclude_out_of_bounds()
.verbose();
#[cfg(feature = "rayon")]
{
oqnlp1 = oqnlp1.batch_iterations(3);
}
oqnlp1.current_iteration = 7;
oqnlp1.unchanged_cycles = 2;
oqnlp1.current_seed = 98765;
oqnlp1.merit_filter.update_threshold(-0.4);
oqnlp1.start_time = Some(std::time::Instant::now() - std::time::Duration::from_secs(25));
let solution1 = LocalSolution { objective: -0.7, point: Array1::from(vec![0.3, -0.6]) };
let solution2 = LocalSolution { objective: -0.5, point: Array1::from(vec![-0.8, 0.4]) };
let solution_set =
SolutionSet { solutions: Array1::from(vec![solution1.clone(), solution2.clone()]) };
oqnlp1.solution_set = Some(solution_set.clone());
oqnlp1.current_reference_set = Some(vec![
Array1::from(vec![-1.5, 1.2]),
Array1::from(vec![0.8, -0.9]),
Array1::from(vec![2.1, 0.3]),
]);
oqnlp1.distance_filter.add_solution(solution1.clone());
oqnlp1.distance_filter.add_solution(solution2.clone());
let checkpoint = oqnlp1.create_checkpoint();
let mut oqnlp2 =
OQNLP::new(problem, params).unwrap().with_checkpointing(checkpoint_config).unwrap();
let restore_result = oqnlp2.restore_from_checkpoint(checkpoint);
assert!(restore_result.is_ok(), "Checkpoint restoration should succeed");
assert_eq!(oqnlp1.params.iterations, oqnlp2.params.iterations, "Iterations should match");
assert_eq!(
oqnlp1.params.population_size, oqnlp2.params.population_size,
"Population size should match"
);
assert!(
(oqnlp1.params.distance_factor - oqnlp2.params.distance_factor).abs() < 1e-10,
"Distance factor should match"
);
assert!(
(oqnlp1.params.threshold_factor - oqnlp2.params.threshold_factor).abs() < 1e-10,
"Threshold factor should match"
);
assert_eq!(oqnlp1.params.seed, oqnlp2.params.seed, "Seed should match");
assert_eq!(
oqnlp1.current_iteration, oqnlp2.current_iteration,
"Current iteration should match"
);
assert!(
(oqnlp1.merit_filter.threshold - oqnlp2.merit_filter.threshold).abs() < 1e-10,
"Merit threshold should match"
);
assert_eq!(
oqnlp1.unchanged_cycles, oqnlp2.unchanged_cycles,
"Unchanged cycles should match"
);
assert_eq!(oqnlp1.current_seed, oqnlp2.current_seed, "Current seed should match");
assert_eq!(
oqnlp1.target_objective, oqnlp2.target_objective,
"Target objective should match"
);
assert_eq!(
oqnlp1.exclude_out_of_bounds, oqnlp2.exclude_out_of_bounds,
"Exclude out of bounds should match"
);
#[cfg(feature = "rayon")]
{
assert_eq!(
oqnlp1.batch_iterations, oqnlp2.batch_iterations,
"Batch iterations should match"
);
}
let sol_set1 = oqnlp1.solution_set.as_ref().expect("Original should have solution set");
let sol_set2 = oqnlp2.solution_set.as_ref().expect("Restored should have solution set");
assert_eq!(sol_set1.len(), sol_set2.len(), "Solution set lengths should match");
for (s1, s2) in sol_set1.solutions.iter().zip(sol_set2.solutions.iter()) {
assert!(
(s1.objective - s2.objective).abs() < 1e-10,
"Solution objectives should match"
);
assert_eq!(s1.point.len(), s2.point.len(), "Solution point dimensions should match");
for (p1, p2) in s1.point.iter().zip(s2.point.iter()) {
assert!((p1 - p2).abs() < 1e-10, "Solution point values should match");
}
}
let ref_set1 =
oqnlp1.current_reference_set.as_ref().expect("Original should have reference set");
let ref_set2 =
oqnlp2.current_reference_set.as_ref().expect("Restored should have reference set");
assert_eq!(ref_set1.len(), ref_set2.len(), "Reference set lengths should match");
for (r1, r2) in ref_set1.iter().zip(ref_set2.iter()) {
assert_eq!(r1.len(), r2.len(), "Reference point dimensions should match");
for (p1, p2) in r1.iter().zip(r2.iter()) {
assert!((p1 - p2).abs() < 1e-10, "Reference point values should match");
}
}
let dist_filter1 = oqnlp1.distance_filter.get_solutions();
let dist_filter2 = oqnlp2.distance_filter.get_solutions();
assert_eq!(
dist_filter1.len(),
dist_filter2.len(),
"Distance filter solution counts should match"
);
for (d1, d2) in dist_filter1.iter().zip(dist_filter2.iter()) {
assert!(
(d1.objective - d2.objective).abs() < 1e-10,
"Distance filter objectives should match"
);
assert_eq!(
d1.point.len(),
d2.point.len(),
"Distance filter point dimensions should match"
);
for (p1, p2) in d1.point.iter().zip(d2.point.iter()) {
assert!((p1 - p2).abs() < 1e-10, "Distance filter point values should match");
}
}
let result1 = oqnlp1.run();
let result2 = oqnlp2.run();
assert!(result1.is_ok(), "Original instance should continue successfully");
assert!(result2.is_ok(), "Restored instance should continue successfully");
let _ = std::fs::remove_dir_all(&checkpoint_dir);
}
#[test]
#[cfg(feature = "rayon")]
fn test_batch_iterations_edge_case_small_remaining() {
let problem = DummyProblem;
let params = OQNLPParams {
iterations: 5, population_size: 10, ..Default::default()
};
let mut oqnlp = OQNLP::new(problem, params)
.unwrap()
.batch_iterations(10) .verbose();
let result = oqnlp.run();
assert!(result.is_ok(), "OQNLP should handle large batch size gracefully");
let sol_set = result.unwrap();
assert!(!sol_set.is_empty(), "Should find at least one solution");
}
#[test]
#[cfg(feature = "rayon")]
fn test_batch_iterations_exact_match() {
let problem = DummyProblem;
let params = OQNLPParams { iterations: 8, population_size: 12, ..Default::default() };
let mut oqnlp = OQNLP::new(problem, params)
.unwrap()
.batch_iterations(8) .verbose();
let result = oqnlp.run();
assert!(result.is_ok(), "OQNLP should handle exact match gracefully");
let sol_set = result.unwrap();
assert!(!sol_set.is_empty(), "Should find at least one solution");
}
#[test]
#[cfg(feature = "rayon")]
fn test_parallel_control() {
let problem = DummyProblem;
let params = OQNLPParams { iterations: 5, population_size: 8, ..Default::default() };
let oqnlp1 = OQNLP::new(problem, params.clone()).unwrap();
assert!(oqnlp1.enable_parallel, "Parallel should be enabled by default");
let mut oqnlp2 = OQNLP::new(DummyProblem, params.clone()).unwrap().parallel(false);
assert!(!oqnlp2.enable_parallel, "Parallel should be disabled");
let oqnlp3 = OQNLP::new(DummyProblem, params.clone()).unwrap().parallel(true);
assert!(oqnlp3.enable_parallel, "Parallel should be enabled");
let result = oqnlp2.run();
assert!(result.is_ok(), "OQNLP should work with parallel disabled");
let sol_set = result.unwrap();
assert!(!sol_set.is_empty(), "Should find solutions even with parallel disabled");
}
#[test]
fn test_with_points_valid() {
let problem = DummyProblem; let params = OQNLPParams { iterations: 5, population_size: 20, ..Default::default() };
let custom_points = array![[1.0, 2.0, 3.0], [-1.0, -2.0, -3.0], [0.0, 0.0, 0.0],];
let result = OQNLP::new(problem, params).unwrap().with_points(custom_points);
assert!(result.is_ok(), "with_points should succeed with valid points");
let oqnlp = result.unwrap();
assert!(oqnlp.custom_points.is_some(), "Custom points should be stored");
assert_eq!(oqnlp.custom_points.as_ref().unwrap().len(), 3, "Should have 3 custom points");
}
#[test]
fn test_with_points_out_of_bounds() {
let problem = DummyProblem; let params = OQNLPParams::default();
let custom_points = array![
[1.0, 2.0, 3.0], [10.0, 0.0, 0.0], ];
let result = OQNLP::new(problem, params).unwrap().with_points(custom_points);
assert!(result.is_err(), "with_points should fail with out-of-bounds points");
match result {
Err(OQNLPError::CustomPointOutOfBounds { index }) => {
assert_eq!(index, 1, "Should report the second point (index 1) as out of bounds");
}
_ => panic!("Expected CustomPointOutOfBounds error"),
}
}
#[test]
fn test_with_points_wrong_dimension() {
let problem = DummyProblem; let params = OQNLPParams::default();
let custom_points = array![
[1.0, 2.0], ];
let result = OQNLP::new(problem, params).unwrap().with_points(custom_points);
assert!(result.is_err(), "with_points should fail with wrong dimension");
match result {
Err(OQNLPError::InvalidCustomPointsDimension { point_index: _, expected, got }) => {
assert_eq!(expected, 3, "Expected 3 dimensions");
assert_eq!(got, 2, "Got 2 dimensions");
}
_ => panic!("Expected InvalidCustomPointsDimension error"),
}
}
#[test]
fn test_with_points_integration() {
let problem = DummyProblem; let params = OQNLPParams {
iterations: 10,
population_size: 20,
seed: 42, ..Default::default()
};
let custom_points = array![
[-4.5, -4.5, -4.5], [-4.8, -4.8, -4.8], ];
let mut oqnlp = OQNLP::new(problem.clone(), params.clone())
.unwrap()
.with_points(custom_points)
.unwrap()
.verbose();
let result = oqnlp.run();
assert!(result.is_ok(), "Optimization should succeed with custom points");
let sol_set = result.unwrap();
assert!(!sol_set.is_empty(), "Should find solutions");
let best = sol_set.best_solution().unwrap();
assert!(
best.objective < -10.0,
"Best objective should be good with custom points near optimum, got {}",
best.objective
);
}
#[test]
fn test_with_points_empty() {
let problem = DummyProblem;
let params = OQNLPParams::default();
let custom_points = Array2::<f64>::zeros((0, 3));
let result = OQNLP::new(problem, params).unwrap().with_points(custom_points);
assert!(result.is_ok(), "with_points should accept empty array");
let oqnlp = result.unwrap();
assert!(oqnlp.custom_points.is_some(), "Custom points should be set");
assert_eq!(oqnlp.custom_points.as_ref().unwrap().len(), 0, "Should have 0 custom points");
}
#[test]
fn test_with_points_at_bounds() {
let problem = DummyProblem; let params = OQNLPParams::default();
let custom_points = array![
[-5.0, -5.0, -5.0], [5.0, 5.0, 5.0], [-5.0, 0.0, 5.0], ];
let result = OQNLP::new(problem, params).unwrap().with_points(custom_points);
assert!(result.is_ok(), "with_points should accept points at bounds");
let oqnlp = result.unwrap();
assert_eq!(oqnlp.custom_points.as_ref().unwrap().len(), 3, "Should have 3 custom points");
}
#[test]
fn test_with_points_with_other_methods() {
let problem = DummyProblem;
let params = OQNLPParams { iterations: 5, population_size: 15, ..Default::default() };
let custom_points = array![[1.0, 1.0, 1.0], [-1.0, -1.0, -1.0],];
let result = OQNLP::new(problem, params)
.unwrap()
.with_points(custom_points)
.unwrap()
.verbose()
.max_time(60.0)
.target_objective(0.5)
.exclude_out_of_bounds();
assert!(result.custom_points.is_some(), "Custom points should be set");
assert!(result.verbose, "Verbose should be set");
assert_eq!(result.max_time, Some(60.0), "Max time should be set");
assert_eq!(result.target_objective, Some(0.5), "Target objective should be set");
assert!(result.exclude_out_of_bounds, "Exclude out of bounds should be set");
}
#[test]
fn test_with_points_affects_reference_set() {
let problem = SixHumpCamel;
let params = OQNLPParams {
iterations: 5,
population_size: 15, seed: 123,
..Default::default()
};
let custom_points = array![[0.0, 0.0], [1.0, 1.0], [-1.0, -1.0], [0.5, -0.5], [-0.5, 0.5],];
let mut oqnlp =
OQNLP::new(problem, params).unwrap().with_points(custom_points).unwrap().verbose();
let result = oqnlp.run();
assert!(result.is_ok(), "Optimization should succeed");
let sol_set = result.unwrap();
assert!(!sol_set.is_empty(), "Should find solutions");
}
}