use crate::traits::{AbortSignal, CheckpointableAlgorithm, Status, Terminator};
use parking_lot::Mutex;
use std::{ops::ControlFlow, sync::Arc};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CheckpointAction {
Continue,
#[default]
Stop,
}
#[derive(Clone)]
pub struct CheckpointStore<T> {
checkpoint: Arc<Mutex<Option<T>>>,
}
impl<T> Default for CheckpointStore<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> CheckpointStore<T> {
pub fn new() -> Self {
Self {
checkpoint: Arc::new(Mutex::new(None)),
}
}
pub fn save(&self, checkpoint: T) {
*self.checkpoint.lock() = Some(checkpoint);
}
pub fn load(&self) -> Option<T>
where
T: Clone,
{
self.checkpoint.lock().clone()
}
pub fn has_checkpoint(&self) -> bool {
self.checkpoint.lock().is_some()
}
}
#[derive(Clone)]
pub struct CheckpointOnSignal<Sig, Sink> {
signal: Sig,
sink: Sink,
action: CheckpointAction,
}
impl<Sig, Sink> CheckpointOnSignal<Sig, Sink> {
pub const fn new(signal: Sig, sink: Sink) -> Self {
Self {
signal,
sink,
action: CheckpointAction::Stop,
}
}
pub const fn with_action(mut self, action: CheckpointAction) -> Self {
self.action = action;
self
}
}
impl<A, P, S, U, E, C, Sig, Sink> Terminator<A, P, S, U, E, C> for CheckpointOnSignal<Sig, Sink>
where
A: CheckpointableAlgorithm<P, S, U, E, Config = C>,
S: Status,
Sig: AbortSignal + Clone,
Sink: FnMut(A::Checkpoint) + Clone + Send + Sync + 'static,
{
fn check_for_termination(
&mut self,
current_step: usize,
algorithm: &mut A,
_problem: &P,
status: &mut S,
_args: &U,
_config: &C,
) -> ControlFlow<()> {
if self.signal.is_aborted() {
let checkpoint = algorithm.checkpoint(status, current_step.saturating_add(1));
(self.sink)(checkpoint);
self.signal.reset();
status.set_message().custom("Checkpoint requested");
if self.action == CheckpointAction::Stop {
return ControlFlow::Break(());
}
}
ControlFlow::Continue(())
}
}
pub type AtomicCheckpointSignal = crate::core::AtomicAbortSignal;
pub type CtrlCCheckpointSignal = crate::core::CtrlCAbortSignal;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn checkpoint_store_roundtrip() {
let store = CheckpointStore::new();
assert!(!store.has_checkpoint());
store.save(7usize);
assert!(store.has_checkpoint());
assert_eq!(store.load(), Some(7));
}
}