mod builder;
mod cancellation;
mod checkpoint;
mod context;
mod event;
mod extensions;
mod policy;
mod result;
mod termination;
pub use policy::{
AbsoluteTolerancePolicy, CheckpointPolicy, MaxIterationPolicy, NoProgressPolicy,
RelativeTolerancePolicy, StagnationPolicy, TargetValuePolicy, TimeoutPolicy,
};
pub use builder::{GenerateBuilder, GenerateBuilderFallible};
pub use cancellation::CancellationGuard;
use context::EngineContext;
pub(crate) use event::{EngineAction, EngineSignal, EventBatch};
use extensions::Extensions;
use policy::EnginePolicy;
pub use result::{EngineFailure, EngineResult, EngineResultWithSnapshot};
use result::{InternalEngineFailure, InternalEngineResult};
pub use termination::Termination;
pub use checkpoint::InMemoryCheckpointStore;
#[cfg(feature = "writing")]
pub use checkpoint::JsonCheckpointStore;
use num_traits::float::FloatCore;
use std::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use crate::{
result::EngineOutput,
state::{Snapshotable, State, StateView},
FallibleProcedure,
};
use crate::{watchers::Observers, UserState};
pub struct Engine<Proc, P, Q>
where
Proc: FallibleProcedure<P>,
Proc::State: UserState,
Q: EnginePolicy<<Proc::State as UserState>::Float>,
{
procedure: Proc,
problem: P,
policy: Q,
state: State<Proc::State>,
time: bool,
start_time: Option<std::time::Instant>,
cancellation: CancellationToken,
observers: Observers<Proc::State>,
extensions: Extensions<Proc::State>,
}
impl<Proc, P, Q> Engine<Proc, P, Q>
where
Proc: FallibleProcedure<P>,
Proc::State: UserState,
<Proc::State as UserState>::Float: FloatCore,
Q: EnginePolicy<<Proc::State as UserState>::Float>,
{
pub fn run_with_snapshot(
mut self,
) -> EngineResultWithSnapshot<Proc::Output, Proc::State, Proc::Error>
where
Proc::State: Snapshotable,
{
let result = self._run();
let snapshot = self.state.user.snapshot();
result
.map(|output| output.with_snapshot(snapshot))
.map_err(|internal| EngineFailure::from_internal(internal, self.state))
}
pub fn run(mut self) -> EngineResult<Proc::Output, Proc::State, Proc::Error> {
self._run()
.map_err(|internal| EngineFailure::from_internal(internal, self.state))
}
fn _run(&mut self) -> InternalEngineResult<Proc::Output, Proc::State, Proc::Error> {
self.initialise_state()?;
loop {
let result = self.policy_step()?;
match result {
EngineAction::Continue => continue,
EngineAction::Stop(reason) => {
self.emit_event(EngineSignal::Termination(reason));
return self.finalise(reason);
}
EngineAction::EmitCheckpoint(reason) => {
self.emit_event(EngineSignal::CheckpointRequested(reason));
}
}
}
}
fn policy_step(&mut self) -> Result<EngineAction, InternalEngineFailure<Proc::Error>> {
let batch = self.step_once()?;
let ctx = EngineContext {
iter: self.state.runtime.iteration(),
elapsed: self.start_time().elapsed(),
cancelled: self.cancellation.is_cancelled(),
checkpoint_due: false,
start_time: self.start_time(),
_marker: Default::default(),
};
let action = self.policy.decide(&batch, &ctx);
Ok(action)
}
fn start_time(&self) -> Instant {
self.start_time
.expect("start time should always be set in the initialisation phase")
}
#[instrument(name = "initialising runner", fields(ident = Proc::NAME), skip_all)]
fn initialise_state(&mut self) -> Result<(), InternalEngineFailure<Proc::Error>> {
self.start_time = Some(Instant::now());
self.state
.runtime
.record_duration(Instant::now() - self.start_time.unwrap());
self.procedure
.initialise_fallible(&mut self.problem, &mut self.state.user)
.map_err(InternalEngineFailure::new)?;
self.emit_event(EngineSignal::Initialised);
Ok(())
}
#[instrument(name = "wrapping up runner", fields(ident = Proc::NAME), skip_all)]
fn finalise(
&mut self,
reason: Termination,
) -> InternalEngineResult<Proc::Output, Proc::State, Proc::Error> {
match self
.procedure
.finalise_fallible(&mut self.problem, &self.state.user)
{
Err(e) => Err(InternalEngineFailure::new(e)),
Ok(result) => {
self.state
.runtime
.record_duration(Instant::now() - self.start_time.unwrap());
Ok(EngineOutput::new(
result,
StateView::new(&self.state),
reason,
))
}
}
}
fn step_once(
&mut self,
) -> Result<EventBatch<<Proc::State as UserState>::Float>, InternalEngineFailure<Proc::Error>>
{
self.state.runtime.increment_iteration();
self.state
.runtime
.record_duration(Instant::now() - self.start_time.unwrap());
self.procedure
.step_fallible(
&mut self.problem,
&mut self.state.user,
CancellationGuard {
token: &self.cancellation,
},
)
.map_err(InternalEngineFailure::new)?;
let progress = self.state.user.progress();
self.state
.convergence
.observe(&progress, self.state.runtime.iteration());
self.emit_event(EngineSignal::Progress(progress.clone()));
let events = EventBatch::new().add(progress);
Ok(events)
}
fn emit_event(&mut self, signal: EngineSignal<<Proc::State as UserState>::Float>) {
let state_view = StateView::new(&self.state);
self.extensions.dispatch(state_view, &signal);
self.observers.dispatch(Proc::NAME, state_view, &signal);
}
}