use crate::{
core::{Graph, Values},
dtype,
};
#[derive(Debug)]
pub enum OptError {
MaxIterations(Values),
InvalidSystem,
FailedToStep,
}
pub type OptResult<T> = Result<T, OptError>;
pub trait OptParams: Default + Clone {
fn base_params(&self) -> &BaseOptParams;
}
#[derive(Debug, Clone)]
pub struct BaseOptParams {
pub max_iterations: usize,
pub error_tol_relative: dtype,
pub error_tol_absolute: dtype,
pub error_tol: dtype,
}
impl Default for BaseOptParams {
fn default() -> Self {
Self {
max_iterations: 100,
error_tol_relative: 1e-6,
error_tol_absolute: 1e-6,
error_tol: 0.0,
}
}
}
impl OptParams for BaseOptParams {
fn base_params(&self) -> &BaseOptParams {
self
}
}
pub trait OptObserver {
fn on_step(&self, values: &Values, time: i64);
}
#[derive(Default)]
pub struct OptObserverVec {
observers: Vec<Box<dyn OptObserver>>,
}
impl OptObserverVec {
pub fn add(&mut self, callback: impl OptObserver + 'static) {
let boxed = Box::new(callback);
self.observers.push(boxed);
}
pub fn notify(&self, values: &Values, idx: usize) {
for callback in &self.observers {
callback.on_step(values, idx as i64);
}
}
}
pub trait Optimizer {
type Params: OptParams
where
Self: Sized;
fn new(params: Self::Params, graph: Graph) -> Self
where
Self: Sized;
fn observers(&self) -> &OptObserverVec;
fn observers_mut(&mut self) -> &mut OptObserverVec;
fn graph(&self) -> &Graph;
fn graph_mut(&mut self) -> &mut Graph;
fn params(&self) -> &BaseOptParams;
fn step(&mut self, values: Values, idx: usize) -> OptResult<(Values, String)>;
fn error(&self, values: &Values) -> dtype;
fn init(&mut self, _values: &Values) -> Vec<&'static str> {
Vec::new()
}
fn optimize(&mut self, mut values: Values) -> OptResult<Values> {
let append = self.init(&values);
let mut error_old = self.error(&values);
if error_old <= self.params().error_tol {
log::info!("Error is already below tolerance, skipping optimization");
return Ok(values);
}
let extra = if append.is_empty() { "" } else { " |" };
log::info!(
"{:^5} | {:^12} | {:^12} | {:^12} | {}",
"Iter",
"Error",
"ErrorAbs",
"ErrorRel",
append.join(" | ") + extra,
);
log::info!(
"{:^5} | {:^12} | {:^12} | {:^12} | {}",
"-----",
"------------",
"------------",
"------------",
append
.iter()
.map(|s| "-".repeat(s.len()))
.collect::<Vec<_>>()
.join(" | ")
+ extra
);
log::info!(
"{:^5} | {:^12.4e} | {:^12} | {:^12} | {}",
0,
error_old,
"-",
"-",
append
.iter()
.map(|s| format!("{:^width$}", "-", width = s.len()))
.collect::<Vec<_>>()
.join(" | ")
+ extra
);
let mut error_new = error_old;
for i in 1..self.params().max_iterations + 1 {
error_old = error_new;
let (temp, info) = self.step(values, i)?;
values = temp;
self.observers().notify(&values, i);
error_new = self.error(&values);
let error_decrease_abs = error_old - error_new;
let error_decrease_rel = error_decrease_abs / error_old;
log::info!(
"{i:^5} | {error_new:^12.4e} | {error_decrease_abs:^12.4e} | {error_decrease_rel:^12.4e} | {info}"
);
if error_new <= self.params().error_tol {
log::info!("Error is below tolerance, stopping optimization");
return Ok(values);
}
if error_decrease_abs >= 0.0 && error_decrease_abs <= self.params().error_tol_absolute {
log::info!("Error decrease is below absolute tolerance, stopping optimization");
return Ok(values);
}
if error_decrease_rel >= 0.0 && error_decrease_rel <= self.params().error_tol_relative {
log::info!("Error decrease is below relative tolerance, stopping optimization");
return Ok(values);
}
}
Err(OptError::MaxIterations(values))
}
fn add_observer(&mut self, observer: impl OptObserver + 'static)
where
Self: Sized,
{
self.observers_mut().add(observer);
}
fn new_default(graph: Graph) -> Self
where
Self: Sized,
{
Self::new(Self::Params::default(), graph)
}
}