use num_traits::float::FloatCore;
use std::sync::{Arc, Mutex};
use tokio_util::sync::CancellationToken;
use crate::engine::policy::{CancellationPolicy, CompletionPolicy, EnginePolicy, PolicyStack};
use crate::{
engine::{
checkpoint::{CheckpointBackend, CheckpointExtension},
extensions::Extensions,
Engine,
},
state::{Snapshotable, State, StateRestorer},
watchers::{Frequency, Observe, Observers},
FallibleProcedure, Infallible, Procedure, UserState,
};
pub trait GenerateBuilderFallible: Sized {
fn build_for<P>(self, problem: P) -> Builder<Self, P, Uninitialised>
where
Self: FallibleProcedure<P>,
Self::State: UserState;
}
impl<Proc> GenerateBuilderFallible for Proc {
fn build_for<P>(self, problem: P) -> Builder<Self, P, Uninitialised>
where
Proc: FallibleProcedure<P>,
Proc::State: UserState,
{
Builder {
procedure: self,
problem,
state: None,
time: true,
cancellation_token: None,
observers: Observers::new(),
policies: PolicyStack::new()
.add(CancellationPolicy)
.add(CompletionPolicy),
extensions: Extensions::new(),
_initialised: std::marker::PhantomData,
}
}
}
pub trait GenerateBuilder: Sized {
fn build_for<P>(self, problem: P) -> Builder<Infallible<Self>, P, Uninitialised>
where
Self: Procedure<P>,
Self::State: UserState;
}
impl<Proc> GenerateBuilder for Proc {
fn build_for<P>(self, problem: P) -> Builder<Infallible<Self>, P, Uninitialised>
where
Proc: Procedure<P>,
Proc::State: UserState,
{
Builder {
procedure: Infallible(self),
problem,
state: None,
time: true,
cancellation_token: None,
observers: Observers::new(),
policies: PolicyStack::new()
.add(CancellationPolicy)
.add(CompletionPolicy),
extensions: Extensions::new(),
_initialised: std::marker::PhantomData,
}
}
}
pub struct Uninitialised;
pub struct Initialised;
pub struct Builder<Proc, P, I>
where
Proc: FallibleProcedure<P>,
Proc::State: UserState,
<Proc::State as UserState>::Float: FloatCore,
{
procedure: Proc,
problem: P,
state: Option<Proc::State>,
time: bool,
cancellation_token: Option<CancellationToken>,
observers: Observers<Proc::State>,
policies: PolicyStack<<Proc::State as UserState>::Float>,
extensions: Extensions<Proc::State>,
_initialised: std::marker::PhantomData<I>,
}
impl<Proc, P, I> Builder<Proc, P, I>
where
Proc: FallibleProcedure<P>,
Proc::State: UserState,
<Proc::State as UserState>::Float: FloatCore + 'static,
{
#[must_use]
pub fn time(mut self, time: bool) -> Self {
self.time = time;
self
}
#[must_use]
pub fn attach_observer<OBS>(mut self, observer: OBS, frequency: Frequency) -> Self
where
OBS: Observe<Proc::State> + 'static,
{
self.observers
.attach(Arc::new(Mutex::new(observer)), frequency);
self
}
#[must_use]
pub fn and_policy<Q>(mut self, policy: Q) -> Self
where
Q: EnginePolicy<<Proc::State as UserState>::Float> + 'static,
{
self.policies = self.policies.add(policy);
self
}
#[must_use]
pub fn cancellation_token(mut self, token: CancellationToken) -> Self {
self.cancellation_token = Some(token);
self
}
#[must_use]
pub fn with_default_policies(
mut self,
max_iter: usize,
absolute_tolerance: <Proc::State as UserState>::Float,
window_size: usize,
) -> Self {
self.policies = self.policies.merge(PolicyStack::standard(
max_iter,
absolute_tolerance,
window_size,
));
self
}
#[must_use]
pub fn with_checkpoint_backend<C>(mut self, store: C) -> Self
where
C: CheckpointBackend<
<Proc::State as Snapshotable>::Snapshot,
<Proc::State as UserState>::Float,
> + 'static,
Proc::State: Snapshotable,
{
self.extensions = self.extensions.add(CheckpointExtension::new(store));
self
}
}
impl<Proc, P> Builder<Proc, P, Uninitialised>
where
Proc: FallibleProcedure<P>,
Proc::State: UserState,
<Proc::State as UserState>::Float: FloatCore + 'static,
{
#[must_use]
pub fn with_initial_state(self, user: Proc::State) -> Builder<Proc, P, Initialised> {
Builder {
procedure: self.procedure,
problem: self.problem,
state: Some(user),
time: self.time,
cancellation_token: self.cancellation_token,
observers: self.observers,
policies: self.policies,
extensions: self.extensions,
_initialised: std::marker::PhantomData,
}
}
#[must_use]
pub fn resume_from_checkpoint(
self,
snapshot: <Proc::State as Snapshotable>::Snapshot,
) -> Builder<Proc, P, Initialised>
where
Proc: FallibleProcedure<P>,
Proc::State: Snapshotable + StateRestorer<Proc::State>,
{
let user = Proc::State::restore(snapshot);
Builder {
procedure: self.procedure,
problem: self.problem,
state: Some(user),
time: self.time,
cancellation_token: self.cancellation_token,
observers: self.observers,
policies: self.policies,
extensions: self.extensions,
_initialised: std::marker::PhantomData,
}
}
}
impl<Proc, P> Builder<Proc, P, Initialised>
where
Proc: FallibleProcedure<P>,
Proc::State: UserState,
<Proc::State as UserState>::Float: FloatCore + 'static,
{
pub fn finalise(mut self) -> Engine<Proc, P, PolicyStack<<Proc::State as UserState>::Float>>
where
<Proc::State as UserState>::Float: num_traits::FromPrimitive,
{
let user = self.state.take().expect("builder invariant: user is set");
let cancellation = self.cancellation_token.unwrap_or_default();
#[cfg(feature = "ctrlc")]
{
let token = cancellation.clone();
ctrlc::set_handler(move || {
token.cancel();
})
.unwrap();
}
Engine {
procedure: self.procedure,
problem: self.problem,
state: State::new(user),
time: self.time,
start_time: None,
cancellation,
policy: self.policies,
observers: self.observers,
extensions: self.extensions,
}
}
pub fn finalise_with(
mut self,
policy: PolicyStack<<Proc::State as UserState>::Float>,
) -> Engine<Proc, P, PolicyStack<<Proc::State as UserState>::Float>> {
let user = self.state.take().expect("builder invariant: user is set");
let cancellation = self.cancellation_token.unwrap_or_default();
#[cfg(feature = "ctrlc")]
{
let token = cancellation.clone();
ctrlc::set_handler(move || {
token.cancel();
})
.unwrap();
}
Engine {
procedure: self.procedure,
problem: self.problem,
state: State::new(user),
time: self.time,
start_time: None,
cancellation,
policy,
observers: self.observers,
extensions: self.extensions,
}
}
}