use super::EngineContext;
use num_traits::float::FloatCore;
use std::time::Duration;
mod absolute_tolerance;
mod cancellation;
mod checkpoint;
mod complete;
mod max_iter;
mod no_progress;
mod relative_tolerance;
mod stagnation;
mod target_value;
mod timeout;
pub use absolute_tolerance::AbsoluteTolerancePolicy;
pub use cancellation::CancellationPolicy;
pub use checkpoint::CheckpointPolicy;
pub use complete::CompletionPolicy;
pub use max_iter::MaxIterationPolicy;
pub use no_progress::NoProgressPolicy;
pub use relative_tolerance::RelativeTolerancePolicy;
pub use stagnation::StagnationPolicy;
pub use target_value::TargetValuePolicy;
pub use timeout::TimeoutPolicy;
use crate::engine::{event::CheckpointReason, EngineAction, EventBatch};
pub trait EnginePolicy<F> {
fn decide(&mut self, batch: &EventBatch<F>, context: &EngineContext) -> EngineAction;
}
pub trait PolicyExt<F>: EnginePolicy<F> + Sized + 'static {
fn boxed(self) -> Box<dyn EnginePolicy<F>> {
Box::new(self)
}
}
impl<F, T> PolicyExt<F> for T where T: EnginePolicy<F> + Sized + 'static {}
pub struct PolicyStack<F> {
policies: Vec<Box<dyn EnginePolicy<F>>>,
}
impl<F> Default for PolicyStack<F> {
fn default() -> Self {
Self::new()
}
}
impl<F> PolicyStack<F> {
pub fn new() -> Self {
Self { policies: vec![] }
}
pub fn is_empty(&self) -> bool {
self.policies.is_empty()
}
pub fn add<P>(mut self, p: P) -> Self
where
P: EnginePolicy<F> + 'static,
{
self.policies.push(Box::new(p));
self
}
pub fn merge(mut self, other: Self) -> Self {
for each in other.policies.into_iter() {
self.policies.push(each);
}
self
}
}
impl<F> EnginePolicy<F> for PolicyStack<F> {
fn decide(&mut self, batch: &EventBatch<F>, ctx: &EngineContext) -> EngineAction {
let mut checkpoint = false;
for p in &mut self.policies {
match p.decide(batch, ctx) {
EngineAction::Stop(t) => return EngineAction::Stop(t),
EngineAction::Continue => {}
EngineAction::EmitCheckpoint(_) => {
checkpoint = true;
}
}
}
if checkpoint {
return EngineAction::EmitCheckpoint(CheckpointReason::Scheduled);
}
EngineAction::Continue
}
}
impl<F> PolicyStack<F> {
pub fn standard(max_iter: usize, atol: F, window_size: usize) -> PolicyStack<F>
where
F: FloatCore + 'static,
{
PolicyStack::new()
.add(CancellationPolicy)
.add(MaxIterationPolicy::new(max_iter))
.add(AbsoluteTolerancePolicy::new(atol, window_size))
}
pub fn optimisation(max_iter: usize, atol: F, window_size: usize) -> PolicyStack<F>
where
F: FloatCore + 'static + num_traits::FromPrimitive + std::iter::Sum<F>,
{
PolicyStack::new()
.add(CancellationPolicy)
.add(MaxIterationPolicy::new(max_iter))
.add(AbsoluteTolerancePolicy::new(atol, window_size))
.add(StagnationPolicy::new(window_size))
}
pub fn global_optimisation(
max_iter: usize,
target: F,
tol: F,
window_size: usize,
) -> PolicyStack<F>
where
F: FloatCore + 'static + num_traits::FromPrimitive + std::iter::Sum<F>,
{
PolicyStack::new()
.add(CancellationPolicy)
.add(MaxIterationPolicy::new(max_iter))
.add(TargetValuePolicy::new(target, tol, window_size))
.add(StagnationPolicy::new(window_size))
.add(NoProgressPolicy::new(F::epsilon(), window_size))
}
pub fn monte_carlo(max_iter: usize) -> PolicyStack<F>
where
F: 'static,
{
PolicyStack::new()
.add(CancellationPolicy)
.add(MaxIterationPolicy::new(max_iter))
}
pub fn timed(timeout: Duration) -> PolicyStack<F>
where
F: 'static,
{
PolicyStack::new()
.add(CancellationPolicy)
.add(TimeoutPolicy::new(timeout))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::progress::Progress;
#[test]
fn empty_stack_continues() {
let mut stack = PolicyStack::<f64>::new();
let batch: EventBatch<f64> = EventBatch::new();
let ctx = EngineContext::default();
assert!(matches!(stack.decide(&batch, &ctx), EngineAction::Continue));
}
#[test]
fn checkpoint_request_is_propagated() {
let mut stack = PolicyStack::<f64>::new()
.add(CheckpointPolicy::every(10))
.add(MaxIterationPolicy::new(500));
let batch: EventBatch<f64> = EventBatch::new().add(Progress::Complete);
let ctx = EngineContext {
iter: 10,
..Default::default()
};
assert!(matches!(
stack.decide(&batch, &ctx),
EngineAction::EmitCheckpoint(_)
));
}
#[test]
fn stop_takes_precedence_over_checkpoint() {
let mut stack = PolicyStack::<f64>::new()
.add(CheckpointPolicy::every(10))
.add(MaxIterationPolicy::new(0));
let batch: EventBatch<f64> = EventBatch::new().add(Progress::Complete);
let ctx = EngineContext {
iter: 10,
..Default::default()
};
assert!(matches!(stack.decide(&batch, &ctx), EngineAction::Stop(_)));
}
#[test]
fn policy_stack_stop_overrides_all() {
let mut stack = PolicyStack::new()
.add(NoProgressPolicy::new(0.1, 10))
.add(MaxIterationPolicy::new(100));
let batch = EventBatch::new().add(Progress::Complete);
let ctx = EngineContext {
iter: 101,
..Default::default()
};
let action = stack.decide(&batch, &ctx);
if let EngineAction::Stop(_) = action {
} else {
panic!("Stop must dominate all policies");
}
}
#[test]
fn integration_converges_via_tolerance() {
let mut stack = PolicyStack::<f64>::optimisation(100, 0.01, 5);
let ctx = EngineContext::default();
let batch = EventBatch::new().add(Progress::Report {
measure: 1.0,
diagnostics: crate::progress::ProgressDiagnostics {
absolute_error: Some(0.001),
..Default::default()
},
});
for _ in 0..10 {
stack.decide(&batch, &ctx);
}
let action = stack.decide(&batch, &ctx);
assert!(matches!(
action,
EngineAction::Stop(crate::Termination::Converged)
));
}
#[test]
fn integration_stagnation_overrides_no_progress() {
let mut stack = PolicyStack::<f64>::new()
.add(NoProgressPolicy::new(0.01, 3))
.add(StagnationPolicy::new(5));
let ctx = EngineContext::default();
for _ in 0..10 {
let batch = EventBatch::new().add(Progress::Measure(1.0));
let action = stack.decide(&batch, &ctx);
if let EngineAction::Stop(_) = action {
break;
}
}
}
#[test]
fn integration_timeout_trumps_all() {
let mut stack = PolicyStack::<f64>::timed(Duration::from_secs(1));
let batch = EventBatch::new().add(Progress::Complete);
let ctx = EngineContext {
elapsed: Duration::from_secs(10),
..Default::default()
};
assert!(matches!(
stack.decide(&batch, &ctx),
EngineAction::Stop(crate::Termination::Timeout)
));
}
}