use std::sync::Arc;
use std::time::Instant;
mod stage1;
mod stage2;
pub use stage1::Stage1State;
pub use stage2::Stage2State;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ObserverMode {
Stage1Only,
Stage2Only,
Both,
}
pub type ObserverCallback = Arc<dyn Fn(&mut Observer) + Send + Sync>;
#[derive(Debug, Clone, PartialEq)]
struct PreviousStage2State {
best_objective: f64,
solution_set_size: usize,
threshold_value: f64,
local_solver_calls: usize,
improved_local_calls: usize,
function_evaluations: usize,
unchanged_cycles: usize,
}
impl PreviousStage2State {
fn from_stage2(stage2: &Stage2State) -> Self {
Self {
best_objective: stage2.best_objective(),
solution_set_size: stage2.solution_set_size(),
threshold_value: stage2.threshold_value(),
local_solver_calls: stage2.local_solver_calls(),
improved_local_calls: stage2.improved_local_calls(),
function_evaluations: stage2.function_evaluations(),
unchanged_cycles: stage2.unchanged_cycles(),
}
}
fn has_changed(&self, stage2: &Stage2State) -> bool {
self.best_objective != stage2.best_objective()
|| self.solution_set_size != stage2.solution_set_size()
|| self.threshold_value != stage2.threshold_value()
|| self.local_solver_calls != stage2.local_solver_calls()
|| self.improved_local_calls != stage2.improved_local_calls()
|| self.function_evaluations != stage2.function_evaluations()
}
}
pub struct Observer {
mode: ObserverMode,
stage1: Option<Stage1State>,
stage2: Option<Stage2State>,
track_timing: bool,
start_time: Option<Instant>,
callback: Option<ObserverCallback>,
callback_frequency: usize,
stage1_completed: bool,
stage2_started: bool,
previous_stage2_state: Option<std::sync::RwLock<PreviousStage2State>>,
filter_stage2_changes: bool,
}
impl std::fmt::Debug for Observer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Observer")
.field("mode", &self.mode)
.field("stage1", &self.stage1)
.field("stage2", &self.stage2)
.field("track_timing", &self.track_timing)
.field("start_time", &self.start_time)
.field("callback", &self.callback.as_ref().map(|_| "Some(...)"))
.field("callback_frequency", &self.callback_frequency)
.field("stage1_completed", &self.stage1_completed)
.field("stage2_started", &self.stage2_started)
.finish()
}
}
impl Clone for Observer {
fn clone(&self) -> Self {
Self {
mode: self.mode,
stage1: self.stage1.clone(),
stage2: self.stage2.clone(),
track_timing: self.track_timing,
start_time: self.start_time,
callback: self.callback.clone(),
callback_frequency: self.callback_frequency,
stage1_completed: self.stage1_completed,
stage2_started: self.stage2_started,
previous_stage2_state: self
.previous_stage2_state
.as_ref()
.map(|cell| std::sync::RwLock::new(cell.read().unwrap().clone())),
filter_stage2_changes: self.filter_stage2_changes,
}
}
}
impl Observer {
pub fn new() -> Self {
Self {
mode: ObserverMode::Both,
stage1: None,
stage2: None,
track_timing: false,
start_time: None,
callback: None,
callback_frequency: 1,
stage1_completed: false,
stage2_started: false,
previous_stage2_state: None,
filter_stage2_changes: false,
}
}
pub fn with_stage1_tracking(mut self) -> Self {
self.stage1 = Some(Stage1State::new());
self
}
pub fn with_stage2_tracking(mut self) -> Self {
self.stage2 = Some(Stage2State::new());
self
}
pub fn with_timing(mut self) -> Self {
self.track_timing = true;
self
}
pub fn with_mode(mut self, mode: ObserverMode) -> Self {
self.mode = mode;
self
}
pub fn with_callback<F>(mut self, callback: F) -> Self
where
F: Fn(&mut Observer) + Send + Sync + 'static,
{
self.callback = Some(Arc::new(callback));
self
}
pub fn with_callback_frequency(mut self, frequency: usize) -> Self {
self.callback_frequency = frequency;
if self.callback.is_none() {
self = self.with_default_callback();
}
self
}
pub fn unique_updates(mut self) -> Self {
self.filter_stage2_changes = true;
self
}
pub fn with_default_callback(self) -> Self {
fn format_coords(arr: &ndarray::Array1<f64>) -> String {
let values: Vec<String> = arr.iter().map(|v| format!("{:.6}", v)).collect();
format!("[{}]", values.join(", "))
}
self.with_callback(|obs| {
if let Some(stage1) = obs.stage1() {
let substage = stage1.current_substage();
let message = if substage == "scatter_search_running" {
"[Stage 1] Starting Scatter Search...".to_string()
} else if substage == "initialization_complete" {
format!(
"[Stage 1] Initialization Complete | Initial Points: {}",
stage1.function_evaluations()
)
} else if substage == "diversification_complete" {
format!(
"[Stage 1] Diversification Complete | Ref. Set Size: {}",
stage1.reference_set_size()
)
} else if substage == "intensification_complete" {
format!(
"[Stage 1] Intensification Complete | Trial Points Generated: {} | Accepted: {}",
stage1.trial_points_generated(),
stage1.reference_set_size()
)
} else if substage == "scatter_search_complete" {
if let Some(point) = stage1.best_point() {
format!(
"[Stage 1] Scatter Search Complete | Best: {:.6} at {}",
stage1.best_objective(),
format_coords(point)
)
} else {
format!(
"[Stage 1] Scatter Search Complete | Best: {:.6}",
stage1.best_objective()
)
}
} else if substage == "local_optimization_complete" {
if let Some(point) = stage1.best_point() {
format!(
"[Stage 1] Local Optimization Complete | Best: {:.6} at {} | Total Fn Evals: {}",
stage1.best_objective(),
format_coords(point),
stage1.function_evaluations()
)
} else {
format!(
"[Stage 1] Local Optimization Complete | Best: {:.6} | Total Fn Evals: {}",
stage1.best_objective(),
stage1.function_evaluations()
)
}
} else {
return; };
eprintln!("{}", message);
}
if let Some(stage2) = obs.stage2() {
if stage2.current_iteration() > 0 {
let current_iter = stage2.current_iteration();
let best_obj = stage2.best_objective();
let last_added_coords = stage2.last_added_point().map(format_coords);
let sol_size = stage2.solution_set_size();
let threshold = stage2.threshold_value();
let local_calls = stage2.local_solver_calls();
let fn_evals = stage2.function_evaluations();
let should_print = if obs.filter_stage2_changes {
let prev_state = obs.previous_stage2_state.as_ref().map(|cell| cell.read().unwrap().clone());
let has_changed = prev_state.as_ref().is_none_or(|prev| prev.has_changed(stage2));
let current_state = PreviousStage2State::from_stage2(stage2);
obs.previous_stage2_state = Some(std::sync::RwLock::new(current_state));
has_changed
} else {
true };
if should_print {
let message = if let Some(coords) = last_added_coords {
format!(
"[Stage 2] Iter {} | Best: {:.6} at {} | Solutions: {} | Threshold: {:.6} | Local Calls: {} | Fn Evals: {}",
current_iter, best_obj, coords, sol_size, threshold, local_calls, fn_evals
)
} else {
format!(
"[Stage 2] Iter {} | Best: {:.6} | Solutions: {} | Threshold: {:.6} | Local Calls: {} | Fn Evals: {}",
current_iter, best_obj, sol_size, threshold, local_calls, fn_evals
)
};
eprintln!("{}", message);
}
}
}
})
}
pub fn with_stage1_callback(self) -> Self {
fn format_coords(arr: &ndarray::Array1<f64>) -> String {
let values: Vec<String> = arr.iter().map(|v| format!("{:.6}", v)).collect();
format!("[{}]", values.join(", "))
}
self.with_callback(|obs| {
if let Some(stage1) = obs.stage1() {
let substage = stage1.current_substage();
if substage == "scatter_search_running" {
eprintln!("[Stage 1] Starting Scatter Search...");
} else if substage == "initialization_complete" {
eprintln!(
"[Stage 1] Initialization Complete | Initial Points: {}",
stage1.function_evaluations()
);
} else if substage == "diversification_complete" {
eprintln!(
"[Stage 1] Diversification Complete | Ref. Set Size: {}",
stage1.reference_set_size()
);
} else if substage == "intensification_complete" {
eprintln!(
"[Stage 1] Intensification Complete | Trial Points Generated: {} | Accepted: {}",
stage1.trial_points_generated(),
stage1.reference_set_size()
);
} else if substage == "scatter_search_complete" {
if let Some(point) = stage1.best_point() {
eprintln!(
"[Stage 1] Scatter Search Complete | Best: {:.6} at {}",
stage1.best_objective(),
format_coords(point)
);
} else {
eprintln!(
"[Stage 1] Scatter Search Complete | Best: {:.6}",
stage1.best_objective()
);
}
} else if substage == "local_optimization_complete" {
if let Some(point) = stage1.best_point() {
eprintln!(
"[Stage 1] Local Optimization Complete | Best: {:.6} at {} | TotalFnEvals: {}",
stage1.best_objective(),
format_coords(point),
stage1.function_evaluations()
);
} else {
eprintln!(
"[Stage 1] Local Optimization Complete | Best: {:.6} | TotalFnEvals: {}",
stage1.best_objective(),
stage1.function_evaluations()
);
}
}
}
})
}
pub fn with_stage2_callback(self) -> Self {
fn format_coords(arr: &ndarray::Array1<f64>) -> String {
let values: Vec<String> = arr.iter().map(|v| format!("{:.6}", v)).collect();
format!("[{}]", values.join(", "))
}
self.with_callback(|obs| {
if let Some(stage2) = obs.stage2() {
if stage2.current_iteration() > 0 {
if let Some(point) = stage2.last_added_point() {
eprintln!(
"[Stage 2] Iter {} | Best: {:.6} at {} | Solutions: {} | Threshold: {:.6} | Local Calls: {} | Fn Evals: {}",
stage2.current_iteration(),
stage2.best_objective(),
format_coords(point),
stage2.solution_set_size(),
stage2.threshold_value(),
stage2.local_solver_calls(),
stage2.function_evaluations()
);
} else {
eprintln!(
"[Stage 2] Iter {} | Best: {:.6} | Solutions: {} | Threshold: {:.6} | Local Calls: {} | Fn Evals: {}",
stage2.current_iteration(),
stage2.best_objective(),
stage2.solution_set_size(),
stage2.threshold_value(),
stage2.local_solver_calls(),
stage2.function_evaluations()
);
}
}
}
})
}
pub(crate) fn start_timer(&mut self) {
if self.track_timing {
self.start_time = Some(Instant::now());
}
}
pub fn elapsed_time(&self) -> Option<f64> {
self.start_time.map(|start| start.elapsed().as_secs_f64())
}
pub fn should_observe_stage1(&self) -> bool {
matches!(self.mode, ObserverMode::Stage1Only | ObserverMode::Both) && self.stage1.is_some()
}
pub fn should_observe_stage2(&self) -> bool {
matches!(self.mode, ObserverMode::Stage2Only | ObserverMode::Both) && self.stage2.is_some()
}
pub fn stage1(&self) -> Option<&Stage1State> {
if self.stage1_completed { None } else { self.stage1.as_ref() }
}
pub fn stage1_final(&self) -> Option<&Stage1State> {
self.stage1.as_ref()
}
pub(crate) fn stage1_mut(&mut self) -> Option<&mut Stage1State> {
self.stage1.as_mut()
}
pub(crate) fn mark_stage1_complete(&mut self) {
self.stage1_completed = true;
}
pub fn stage2(&self) -> Option<&Stage2State> {
if self.stage2_started { self.stage2.as_ref() } else { None }
}
pub(crate) fn stage2_mut(&mut self) -> Option<&mut Stage2State> {
self.stage2.as_mut()
}
pub(crate) fn mark_stage2_started(&mut self) {
self.stage2_started = true;
}
pub fn is_timing_enabled(&self) -> bool {
self.track_timing
}
pub(crate) fn invoke_callback(&mut self) {
if let Some(callback) = &self.callback {
let callback = Arc::clone(callback);
callback(self);
}
}
pub(crate) fn should_invoke_callback(&self, iteration: usize) -> bool {
self.callback.is_some() && (iteration % self.callback_frequency == 0)
}
}
impl Default for Observer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests_observers {
use super::*;
use std::sync::{Arc, Mutex};
#[test]
fn test_observer_creation() {
let observer = Observer::new();
assert!(!observer.should_observe_stage1());
assert!(!observer.should_observe_stage2());
assert!(!observer.is_timing_enabled());
}
#[test]
fn test_observer_with_stage1() {
let observer = Observer::new().with_stage1_tracking();
assert!(observer.should_observe_stage1());
assert!(!observer.should_observe_stage2());
}
#[test]
fn test_observer_with_stage2() {
let observer = Observer::new().with_stage2_tracking();
assert!(!observer.should_observe_stage1());
assert!(observer.should_observe_stage2());
}
#[test]
fn test_observer_with_both_stages() {
let observer = Observer::new().with_stage1_tracking().with_stage2_tracking();
assert!(observer.should_observe_stage1());
assert!(observer.should_observe_stage2());
}
#[test]
fn test_observer_with_timing() {
let observer = Observer::new().with_timing();
assert!(observer.is_timing_enabled());
}
#[test]
fn test_observer_modes() {
let observer = Observer::new()
.with_mode(ObserverMode::Stage1Only)
.with_stage1_tracking()
.with_stage2_tracking();
assert!(observer.should_observe_stage1());
assert!(!observer.should_observe_stage2());
let observer = Observer::new()
.with_mode(ObserverMode::Stage2Only)
.with_stage1_tracking()
.with_stage2_tracking();
assert!(!observer.should_observe_stage1());
assert!(observer.should_observe_stage2());
let observer = Observer::new()
.with_mode(ObserverMode::Both)
.with_stage1_tracking()
.with_stage2_tracking();
assert!(observer.should_observe_stage1());
assert!(observer.should_observe_stage2());
}
#[test]
fn test_observer_stage1_state_access() {
let mut observer = Observer::new().with_stage1_tracking();
assert!(observer.stage1().is_some());
observer.mark_stage1_complete();
assert!(observer.stage1().is_none());
assert!(observer.stage1_final().is_some());
}
#[test]
fn test_observer_stage2_state_access() {
let mut observer = Observer::new().with_stage2_tracking();
assert!(observer.stage2().is_none());
observer.mark_stage2_started();
assert!(observer.stage2().is_some());
}
#[test]
fn test_observer_timing() {
let mut observer = Observer::new().with_timing();
assert!(observer.elapsed_time().is_none());
observer.start_timer();
std::thread::sleep(std::time::Duration::from_millis(10));
let elapsed = observer.elapsed_time();
assert!(elapsed.is_some());
assert!(elapsed.unwrap() > 0.0);
}
#[test]
fn test_observer_callbacks() {
let callback_count = Arc::new(Mutex::new(0));
let callback_count_clone = Arc::clone(&callback_count);
let mut observer = Observer::new().with_callback(move |_| {
let mut count = callback_count_clone.lock().unwrap();
*count += 1;
});
observer.invoke_callback();
assert_eq!(*callback_count.lock().unwrap(), 1);
}
#[test]
fn test_observer_callback_frequency() {
let callback_count = Arc::new(Mutex::new(0));
let callback_count_clone = Arc::clone(&callback_count);
let observer = Observer::new().with_callback_frequency(3).with_callback(move |_| {
let mut count = callback_count_clone.lock().unwrap();
*count += 1;
});
assert!(!observer.should_invoke_callback(1));
assert!(!observer.should_invoke_callback(2));
assert!(observer.should_invoke_callback(3));
assert!(!observer.should_invoke_callback(4));
assert!(!observer.should_invoke_callback(5));
assert!(observer.should_invoke_callback(6));
}
#[test]
fn test_observer_default_callbacks() {
let _observer1 = Observer::new().with_default_callback();
let _observer2 = Observer::new().with_stage1_callback();
let _observer3 = Observer::new().with_stage2_callback();
}
#[test]
fn test_observer_clone_behavior() {
let observer = Observer::new()
.with_stage1_tracking()
.with_stage2_tracking()
.with_timing()
.with_mode(ObserverMode::Stage1Only);
let cloned = observer.clone();
assert!(cloned.should_observe_stage1());
assert!(!cloned.should_observe_stage2()); assert!(cloned.is_timing_enabled());
}
#[test]
fn test_observer_default_implementation() {
let observer = Observer::default();
assert!(!observer.should_observe_stage1());
assert!(!observer.should_observe_stage2());
assert!(!observer.is_timing_enabled());
}
#[test]
fn test_observer_stage1_mut_access() {
let mut observer = Observer::new().with_stage1_tracking();
{
let stage1 = observer.stage1_mut().unwrap();
stage1.set_reference_set_size(10);
stage1.set_best_objective(5.0);
}
let stage1 = observer.stage1().unwrap();
assert_eq!(stage1.reference_set_size(), 10);
assert_eq!(stage1.best_objective(), 5.0);
}
#[test]
fn test_observer_stage2_mut_access() {
let mut observer = Observer::new().with_stage2_tracking();
observer.mark_stage2_started();
{
let stage2 = observer.stage2_mut().unwrap();
stage2.set_iteration(5);
stage2.set_best_objective(3.0);
}
let stage2 = observer.stage2().unwrap();
assert_eq!(stage2.current_iteration(), 5);
assert_eq!(stage2.best_objective(), 3.0);
}
#[test]
fn test_observer_mode_restrictions() {
let observer = Observer::new()
.with_mode(ObserverMode::Stage1Only)
.with_stage1_tracking()
.with_stage2_tracking();
assert!(observer.should_observe_stage1());
assert!(!observer.should_observe_stage2());
let observer = Observer::new()
.with_mode(ObserverMode::Stage2Only)
.with_stage1_tracking()
.with_stage2_tracking();
assert!(!observer.should_observe_stage1());
assert!(observer.should_observe_stage2());
let observer = Observer::new()
.with_mode(ObserverMode::Both)
.with_stage1_tracking()
.with_stage2_tracking();
assert!(observer.should_observe_stage1());
assert!(observer.should_observe_stage2());
}
#[test]
fn test_observer_callback_with_frequency() {
let observer = Observer::new().with_callback_frequency(5);
assert!(observer.callback.is_some());
}
#[test]
fn test_observer_stage_transitions() {
let mut observer = Observer::new().with_stage1_tracking().with_stage2_tracking();
assert!(observer.stage1().is_some());
assert!(observer.stage2().is_none());
observer.mark_stage1_complete();
assert!(observer.stage1().is_none()); assert!(observer.stage1_final().is_some());
observer.mark_stage2_started();
assert!(observer.stage2().is_some()); }
#[test]
fn test_observer_without_tracking() {
let observer = Observer::new();
assert!(!observer.should_observe_stage1());
assert!(!observer.should_observe_stage2());
assert!(observer.stage1().is_none());
assert!(observer.stage1_final().is_none());
assert!(observer.stage2().is_none());
let mut observer = observer;
assert!(observer.stage1_mut().is_none());
assert!(observer.stage2_mut().is_none());
}
#[test]
fn test_observer_with_simple_optimization_problem() {
use crate::local_solver::builders::COBYLABuilder;
use crate::oqnlp::OQNLP;
use crate::problem::Problem;
use crate::types::{EvaluationError, LocalSolverType, OQNLPParams};
use ndarray::{Array1, Array2};
#[derive(Debug, Clone)]
struct QuadraticSum {
dimension: usize,
}
impl QuadraticSum {
fn new(dimension: usize) -> Self {
Self { dimension }
}
}
impl Problem for QuadraticSum {
fn objective(&self, x: &Array1<f64>) -> Result<f64, EvaluationError> {
Ok(x.iter().map(|xi| xi * xi).sum())
}
fn variable_bounds(&self) -> Array2<f64> {
let mut bounds = Array2::zeros((self.dimension, 2));
for i in 0..self.dimension {
bounds[[i, 0]] = -4.0; bounds[[i, 1]] = 4.0; }
bounds
}
}
let problem = QuadraticSum::new(2);
let params = OQNLPParams {
iterations: 100,
wait_cycle: 5,
threshold_factor: 0.5,
distance_factor: 0.5,
population_size: 150,
local_solver_type: LocalSolverType::COBYLA,
local_solver_config: COBYLABuilder::default().max_iter(25).build(),
seed: 0,
};
let observer = Observer::new().with_stage1_tracking().with_stage2_tracking().with_timing();
let mut oqnlp = OQNLP::new(problem, params).unwrap().add_observer(observer);
let solution_set = oqnlp.run().unwrap();
let observer = oqnlp.observer().unwrap();
if let Some(stage1) = observer.stage1_final() {
assert!(stage1.function_evaluations() > 0);
assert!(stage1.reference_set_size() > 0);
assert!(!stage1.best_objective().is_nan());
assert!(stage1.best_objective() >= 0.0); assert!(stage1.trial_points_generated() > 0);
if let Some(time) = stage1.total_time() {
assert!(time > 0.0);
}
}
if let Some(stage2) = observer.stage2() {
println!("Stage 2 ran with {} function evaluations", stage2.function_evaluations());
assert!(!stage2.best_objective().is_nan());
assert!(stage2.best_objective() >= 0.0); assert!(stage2.threshold_value() >= 0.0);
if let Some(time) = stage2.total_time() {
assert!(time >= 0.0);
}
} else {
println!("Stage 2 did not run");
}
let best_solution = solution_set.best_solution().unwrap();
let best_objective = best_solution.objective;
assert!(best_objective >= 0.0);
assert!(best_objective < 1e-3);
println!("Optimization test completed successfully!");
println!("Best objective found: {:.6}", best_objective);
println!("Solution: {:?}", best_solution.point);
}
#[test]
fn test_observer_stage2_unique_updates() {
use std::sync::{Arc, Mutex};
let messages = Arc::new(Mutex::new(Vec::new()));
let messages_clone = Arc::clone(&messages);
let mut observer = Observer::new()
.with_stage2_tracking()
.unique_updates()
.with_callback_frequency(1) .with_callback(move |obs| {
if let Some(stage2) = obs.stage2() {
if stage2.current_iteration() > 0 {
let current_iter = stage2.current_iteration();
let best_obj = stage2.best_objective();
let sol_size = stage2.solution_set_size();
let threshold = stage2.threshold_value();
let should_print = if obs.filter_stage2_changes {
let prev_state = obs.previous_stage2_state.as_ref().map(|cell| cell.read().unwrap().clone());
let has_changed = prev_state.as_ref().is_none_or(|prev| prev.has_changed(stage2));
let current_state = PreviousStage2State::from_stage2(stage2);
obs.previous_stage2_state = Some(std::sync::RwLock::new(current_state));
has_changed
} else {
true };
if should_print {
let message = format!(
"[Stage 2] Iter {} | Best: {:.6} | Solutions: {} | Threshold: {:.6}",
current_iter, best_obj, sol_size, threshold
);
messages_clone.lock().unwrap().push(message);
}
}
}
});
observer.mark_stage2_started();
{
let stage2 = observer.stage2_mut().unwrap();
stage2.set_iteration(1);
stage2.set_best_objective(10.0);
stage2.set_solution_set_size(5);
stage2.set_threshold_value(1.0);
}
if observer.should_invoke_callback(1) {
observer.invoke_callback(); }
{
let stage2 = observer.stage2_mut().unwrap();
stage2.set_iteration(2);
}
if observer.should_invoke_callback(2) {
observer.invoke_callback(); }
{
let stage2 = observer.stage2_mut().unwrap();
stage2.set_iteration(3);
stage2.set_best_objective(8.0); }
if observer.should_invoke_callback(3) {
observer.invoke_callback(); }
{
let stage2 = observer.stage2_mut().unwrap();
stage2.set_iteration(4);
}
if observer.should_invoke_callback(4) {
observer.invoke_callback(); }
{
let stage2 = observer.stage2_mut().unwrap();
stage2.set_iteration(5);
stage2.set_solution_set_size(6); }
if observer.should_invoke_callback(5) {
observer.invoke_callback(); }
let captured_messages = messages.lock().unwrap();
println!("Captured {} messages:", captured_messages.len());
for msg in captured_messages.iter() {
println!(" {}", msg);
}
assert_eq!(captured_messages.len(), 3, "Should have 3 messages (iterations 1, 3, and 5)");
assert!(captured_messages[0].contains("Iter 1"));
assert!(captured_messages[1].contains("Iter 3"));
assert!(captured_messages[2].contains("Iter 5"));
}
}