use crate::error::OptimizeError;
use crate::unconstrained::convergence_diagnostics::{
ConvergenceDiagnostics, DiagnosticCollector, DiagnosticOptions, LineSearchDiagnostic,
};
use crate::unconstrained::OptimizeResult;
use scirs2_core::ndarray::{Array1, ArrayView1};
use std::cell::RefCell;
use std::rc::Rc;
pub type OptimizationCallback = Box<dyn FnMut(&CallbackInfo) -> CallbackResult>;
#[derive(Debug, Clone)]
pub struct CallbackInfo {
pub iteration: usize,
pub x: Array1<f64>,
pub f: f64,
pub grad: Array1<f64>,
pub step: Option<Array1<f64>>,
pub direction: Option<Array1<f64>>,
pub line_search: Option<LineSearchDiagnostic>,
pub elapsed_time: std::time::Duration,
}
#[derive(Debug, Clone, Copy)]
pub enum CallbackResult {
Continue,
Stop,
StopWithMessage(&'static str),
}
pub struct DiagnosticOptimizer {
collector: Rc<RefCell<DiagnosticCollector>>,
callbacks: Vec<OptimizationCallback>,
start_time: std::time::Instant,
}
impl DiagnosticOptimizer {
pub fn new(_diagnostic_options: DiagnosticOptions) -> Self {
Self {
collector: Rc::new(RefCell::new(DiagnosticCollector::new(_diagnostic_options))),
callbacks: Vec::new(),
start_time: std::time::Instant::now(),
}
}
pub fn add_callback(&mut self, callback: OptimizationCallback) {
self.callbacks.push(callback);
}
pub fn add_progress_callback(&mut self, every_n_nit: usize) {
let mut last_printed = 0;
self.add_callback(Box::new(move |info| {
if info.iteration >= last_printed + every_n_nit {
println!(
"Iteration {}: f = {:.6e}, |grad| = {:.6e}",
info.iteration,
info.f,
info.grad.mapv(|x| x.abs()).sum()
);
last_printed = info.iteration;
}
CallbackResult::Continue
}));
}
pub fn add_convergence_monitor(&mut self, patience: usize, min_improvement: f64) {
let mut best_f = f64::INFINITY;
let mut no_improvement_count = 0;
self.add_callback(Box::new(move |info| {
if info.f < best_f - min_improvement {
best_f = info.f;
no_improvement_count = 0;
} else {
no_improvement_count += 1;
}
if no_improvement_count >= patience {
CallbackResult::StopWithMessage("Early stopping: no _improvement")
} else {
CallbackResult::Continue
}
}));
}
pub fn add_time_limit(&mut self, max_duration: std::time::Duration) {
self.add_callback(Box::new(move |info| {
if info.elapsed_time > max_duration {
CallbackResult::StopWithMessage("Time limit exceeded")
} else {
CallbackResult::Continue
}
}));
}
pub fn process_iteration(&mut self, info: &CallbackInfo) -> CallbackResult {
if let (Some(step), Some(direction), Some(line_search)) =
(&info.step, &info.direction, &info.line_search)
{
self.collector.borrow_mut().record_iteration(
info.f,
&info.grad.view(),
&step.view(),
&direction.view(),
line_search.clone(),
);
}
for callback in &mut self.callbacks {
match callback(info) {
CallbackResult::Continue => continue,
result => return result,
}
}
CallbackResult::Continue
}
pub fn get_diagnostics(self) -> ConvergenceDiagnostics {
let collector = Rc::try_unwrap(self.collector)
.expect("Failed to unwrap Rc")
.into_inner();
collector.finalize()
}
}
#[allow(dead_code)]
pub fn optimize_with_diagnostics<F, O>(
optimizer_fn: O,
fun: F,
x0: Array1<f64>,
diagnostic_options: DiagnosticOptions,
callbacks: Vec<OptimizationCallback>,
) -> Result<(OptimizeResult<f64>, ConvergenceDiagnostics), OptimizeError>
where
F: FnMut(&ArrayView1<f64>) -> f64 + Clone,
O: FnOnce(
F,
Array1<f64>,
&mut DiagnosticOptimizer,
) -> Result<OptimizeResult<f64>, OptimizeError>,
{
let mut diagnostic_optimizer = DiagnosticOptimizer::new(diagnostic_options);
for callback in callbacks {
diagnostic_optimizer.add_callback(callback);
}
let result = optimizer_fn(fun, x0, &mut diagnostic_optimizer)?;
let diagnostics = diagnostic_optimizer.get_diagnostics();
Ok((result, diagnostics))
}
#[allow(dead_code)]
pub fn minimize_with_diagnostics<F>(
mut fun: F,
x0: Array1<f64>,
options: &crate::unconstrained::Options,
diagnostic_optimizer: &mut DiagnosticOptimizer,
) -> Result<OptimizeResult<f64>, OptimizeError>
where
F: FnMut(&ArrayView1<f64>) -> f64,
{
let mut x = x0.clone();
let mut f = fun(&x.view());
let mut iteration = 0;
loop {
let grad = finite_diff_gradient(&mut fun, &x.view(), 1e-8);
let direction = -&grad;
let alpha = 0.1;
let step = alpha * &direction;
let x_new = &x + &step;
let f_new = fun(&x_new.view());
let callback_info = CallbackInfo {
iteration,
x: x.clone(),
f,
grad: grad.clone(),
step: Some(step.clone()),
direction: Some(direction.clone()),
line_search: Some(LineSearchDiagnostic {
n_fev: 1,
n_gev: 1,
alpha,
alpha_init: 1.0,
success: f_new < f,
wolfe_satisfied: (true, true),
}),
elapsed_time: diagnostic_optimizer.start_time.elapsed(),
};
match diagnostic_optimizer.process_iteration(&callback_info) {
CallbackResult::Continue => {}
CallbackResult::Stop => break,
CallbackResult::StopWithMessage(msg) => {
return Ok(OptimizeResult {
x,
fun: f,
nit: iteration,
func_evals: iteration * 2,
nfev: iteration * 2,
success: false,
message: msg.to_string(),
jacobian: Some(grad),
hessian: None,
});
}
}
if grad.mapv(|x| x.abs()).sum() < options.gtol {
break;
}
x = x_new;
f = f_new;
iteration += 1;
if iteration >= options.max_iter {
break;
}
}
Ok(OptimizeResult {
x,
fun: f,
nit: iteration,
func_evals: iteration * 2,
nfev: iteration * 2,
success: iteration < options.max_iter,
message: if iteration < options.max_iter {
"Optimization converged".to_string()
} else {
"Maximum iterations reached".to_string()
},
jacobian: None,
hessian: None,
})
}
#[allow(dead_code)]
fn finite_diff_gradient<F>(fun: &mut F, x: &ArrayView1<f64>, eps: f64) -> Array1<f64>
where
F: FnMut(&ArrayView1<f64>) -> f64,
{
let n = x.len();
let mut grad = Array1::zeros(n);
let f0 = fun(x);
let mut x_pert = x.to_owned();
for i in 0..n {
let h = eps * (1.0 + x[i].abs());
x_pert[i] = x[i] + h;
let f_plus = fun(&x_pert.view());
grad[i] = (f_plus - f0) / h;
x_pert[i] = x[i];
}
grad
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diagnostic_optimizer() {
let options = DiagnosticOptions::default();
let mut optimizer = DiagnosticOptimizer::new(options);
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
let callback_called = Arc::new(AtomicBool::new(false));
let callback_called_clone = callback_called.clone();
optimizer.add_callback(Box::new(move |_info| {
callback_called_clone.store(true, Ordering::SeqCst);
CallbackResult::Continue
}));
let info = CallbackInfo {
iteration: 0,
x: Array1::zeros(2),
f: 1.0,
grad: Array1::ones(2),
step: Some(Array1::from_vec(vec![0.1, 0.1])),
direction: Some(Array1::from_vec(vec![-1.0, -1.0])),
line_search: Some(LineSearchDiagnostic {
n_fev: 1,
n_gev: 1,
alpha: 1.0,
alpha_init: 1.0,
success: true,
wolfe_satisfied: (true, true),
}),
elapsed_time: std::time::Duration::from_secs(0),
};
let result = optimizer.process_iteration(&info);
assert!(matches!(result, CallbackResult::Continue));
assert!(callback_called.load(Ordering::SeqCst));
}
#[test]
fn test_early_stopping_callback() {
let options = DiagnosticOptions::default();
let mut optimizer = DiagnosticOptimizer::new(options);
optimizer.add_convergence_monitor(2, 0.1);
let info = CallbackInfo {
iteration: 0,
x: Array1::zeros(2),
f: 1.0,
grad: Array1::ones(2),
step: None,
direction: None,
line_search: None,
elapsed_time: std::time::Duration::from_secs(0),
};
assert!(matches!(
optimizer.process_iteration(&info),
CallbackResult::Continue
));
assert!(matches!(
optimizer.process_iteration(&info),
CallbackResult::Continue
));
let result = optimizer.process_iteration(&info);
assert!(matches!(result, CallbackResult::StopWithMessage(_)));
}
}