use crate::{
core::{Bounds, Callbacks},
traits::{Bound, Status, Transform, TransformExt},
};
use std::convert::Infallible;
pub trait Algorithm<P, S: Status, U = (), E = Infallible>: Send + Sync {
type Summary;
type Config;
type Init;
fn initialize(
&mut self,
problem: &P,
status: &mut S,
args: &U,
init: &Self::Init,
config: &Self::Config,
) -> Result<(), E>;
fn step(
&mut self,
current_step: usize,
problem: &P,
status: &mut S,
args: &U,
config: &Self::Config,
) -> Result<(), E>;
#[allow(unused_variables)]
fn postprocessing(
&mut self,
problem: &P,
status: &mut S,
args: &U,
config: &Self::Config,
) -> Result<(), E> {
Ok(())
}
#[allow(unused_variables)]
fn summarize(
&self,
current_step: usize,
problem: &P,
status: &S,
args: &U,
init: &Self::Init,
config: &Self::Config,
) -> Result<Self::Summary, E>;
fn reset(&mut self) {}
fn process<C>(
&mut self,
problem: &P,
args: &U,
init: Self::Init,
config: Self::Config,
callbacks: C,
) -> Result<Self::Summary, E>
where
C: Into<Callbacks<Self, P, S, U, E, Self::Config>>,
Self: Sized,
{
let mut status = S::default();
let mut cbs: Callbacks<Self, P, S, U, E, Self::Config> = callbacks.into();
self.initialize(problem, &mut status, args, &init, &config)?;
let mut current_step = 0;
loop {
self.step(current_step, problem, &mut status, args, &config)?;
if cbs
.check_for_termination(current_step, self, problem, &mut status, args, &config)
.is_break()
{
break;
}
current_step += 1;
}
self.postprocessing(problem, &mut status, args, &config)?;
self.summarize(current_step, problem, &status, args, &init, &config)
}
fn process_with_default_callbacks(
&mut self,
problem: &P,
user_data: &U,
init: Self::Init,
config: Self::Config,
) -> Result<Self::Summary, E>
where
Self: Sized,
{
self.process(problem, user_data, init, config, Self::default_callbacks())
}
fn process_default(
&mut self,
problem: &P,
user_data: &U,
init: Self::Init,
) -> Result<Self::Summary, E>
where
Self: Sized,
Self::Config: Default,
{
self.process(
problem,
user_data,
init,
Self::Config::default(),
Self::default_callbacks(),
)
}
fn default_callbacks() -> Callbacks<Self, P, S, U, E, Self::Config>
where
Self: Sized,
{
Callbacks::empty()
}
}
pub trait SupportsBounds
where
Self: Sized,
{
fn get_bounds_mut(&mut self) -> &mut Option<Bounds>;
fn with_bounds<I: IntoIterator<Item = B>, B: Into<Bound>>(mut self, bounds: I) -> Self {
let bounds = bounds
.into_iter()
.map(Into::into)
.collect::<Vec<_>>()
.into();
*self.get_bounds_mut() = Some(bounds);
self
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum BoundsHandlingMode {
#[default]
Auto,
NativeBounds,
TransformBounds,
}
pub(crate) fn resolve_bounds_and_transform(
bounds: &Option<Bounds>,
transform: &Option<Box<dyn Transform>>,
mode: BoundsHandlingMode,
) -> (Option<Bounds>, Option<Box<dyn Transform>>) {
match mode {
BoundsHandlingMode::Auto | BoundsHandlingMode::NativeBounds => (
bounds.clone(),
transform
.as_ref()
.map(|transform| dyn_clone::clone_box(transform.as_ref())),
),
BoundsHandlingMode::TransformBounds => {
let resolved_transform = match (bounds, transform) {
(Some(bounds), Some(transform)) => Some(Box::new(
dyn_clone::clone_box(transform.as_ref()).compose(bounds.clone()),
) as Box<dyn Transform>),
(Some(bounds), None) => Some(Box::new(bounds.clone()) as Box<dyn Transform>),
(None, Some(transform)) => Some(dyn_clone::clone_box(transform.as_ref())),
(None, None) => None,
};
(None, resolved_transform)
}
}
}
pub trait SupportsTransform
where
Self: Sized,
{
fn get_transform_mut(&mut self) -> &mut Option<Box<dyn Transform>>;
fn with_transform<T: Transform + 'static>(mut self, transform: &T) -> Self {
*self.get_transform_mut() = Some(dyn_clone::clone_box(transform));
self
}
}
pub trait SupportsParameterNames
where
Self: Sized,
{
fn get_parameter_names_mut(&mut self) -> &mut Option<Vec<String>>;
fn with_parameter_names<I, S>(mut self, parameter_names: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
*self.get_parameter_names_mut() = Some(
parameter_names
.into_iter()
.map(|name| name.as_ref().to_string())
.collect(),
);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{DMatrix, DVector, Float};
use std::borrow::Cow;
#[derive(Clone)]
struct Scale(Float);
impl Transform for Scale {
fn to_external<'a>(&'a self, z: &'a DVector<Float>) -> Cow<'a, DVector<Float>> {
Cow::Owned(z.scale(self.0))
}
fn to_internal<'a>(&'a self, x: &'a DVector<Float>) -> Cow<'a, DVector<Float>> {
Cow::Owned(x.unscale(self.0))
}
fn to_external_jacobian(&self, z: &DVector<Float>) -> DMatrix<Float> {
DMatrix::identity(z.len(), z.len()).scale(self.0)
}
fn to_external_component_hessian(&self, _a: usize, z: &DVector<Float>) -> DMatrix<Float> {
DMatrix::zeros(z.len(), z.len())
}
}
#[test]
fn transform_bounds_mode_moves_bounds_into_transform() {
let bounds = Some(Bounds::from([(0.0, 1.0)]));
let transform: Option<Box<dyn Transform>> = Some(Box::new(Scale(2.0)));
let (resolved_bounds, resolved_transform) =
resolve_bounds_and_transform(&bounds, &transform, BoundsHandlingMode::TransformBounds);
assert!(resolved_bounds.is_none());
let Some(resolved_transform) = resolved_transform else {
panic!("transform should be composed");
};
let x = resolved_transform.to_owned_external(&DVector::from_row_slice(&[10.0]));
assert!(x[0] >= 0.0 && x[0] <= 1.0);
}
#[test]
fn native_bounds_mode_preserves_bounds_and_transform() {
let bounds = Some(Bounds::from([(0.0, 1.0)]));
let transform: Option<Box<dyn Transform>> = Some(Box::new(Scale(2.0)));
let (resolved_bounds, resolved_transform) =
resolve_bounds_and_transform(&bounds, &transform, BoundsHandlingMode::NativeBounds);
assert!(resolved_bounds.is_some());
assert!(resolved_transform.is_some());
}
}