trellis-runner 0.3.1

Calculation runner
Documentation
use crate::{
    engine::{extensions::EngineSink, EngineSignal},
    state::{
        ConvergenceState, RuntimeState, Snapshotable, State, StateRestorer, StateView, UserState,
    },
};

use std::time::SystemTime;

mod in_memory;

#[cfg(feature = "writing")]
mod json;

pub use in_memory::InMemoryCheckpointStore;

#[cfg(feature = "writing")]
pub use json::JsonCheckpointStore;

#[derive(Debug, thiserror::Error)]
pub enum CheckpointError {
    #[error("filesystem error: {0}")]
    FileSystem(#[from] std::io::Error),
    #[error("serde json error: {0}")]
    SerdeJson(Box<dyn std::error::Error + 'static>),
}

pub struct CheckpointExtension<C> {
    store: C,
}

impl<C> CheckpointExtension<C> {
    pub(crate) fn new(store: C) -> Self {
        Self { store }
    }
}

impl<S, C> EngineSink<S> for CheckpointExtension<C>
where
    S: UserState + Snapshotable,
    C: CheckpointBackend<<S as Snapshotable>::Snapshot, <S as UserState>::Float> + 'static,
{
    fn handle(&mut self, state: StateView<'_, S>, signal: &EngineSignal<<S as UserState>::Float>) {
        if let EngineSignal::CheckpointRequested(_) = signal {
            let checkpoint = CheckpointView::from(state);

            if let Err(e) = self.store.save(checkpoint) {
                eprintln!("error saving checkpoint: {e:?}");
            }
        }
    }
}

#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct CheckpointView<'a, SN, F> {
    user: SN,
    runtime: &'a RuntimeState,
    convergence: &'a ConvergenceState<F>,
    version: u32,
    timestamp: SystemTime,
}

impl<'a, S> From<StateView<'a, S>> for CheckpointView<'a, S::Snapshot, <S as UserState>::Float>
where
    S: Snapshotable + UserState,
{
    fn from(state: StateView<'a, S>) -> CheckpointView<'a, S::Snapshot, <S as UserState>::Float> {
        Self {
            user: state.user().snapshot(),
            runtime: state.runtime(),
            convergence: state.convergence(),
            version: 1,
            timestamp: SystemTime::now(),
        }
    }
}

#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Checkpoint<SN, F> {
    user: SN,
    runtime: RuntimeState,
    convergence: ConvergenceState<F>,
    version: u32,
    timestamp: SystemTime,
}

impl<SN, F> Checkpoint<SN, F> {
    pub fn new(view: CheckpointView<'_, SN, F>) -> Self
    where
        F: Clone,
    {
        Self {
            user: view.user,
            runtime: view.runtime.clone(),
            convergence: view.convergence.clone(),
            version: view.version,
            timestamp: view.timestamp,
        }
    }
}

impl<SN, F> Checkpoint<SN, F> {
    pub fn into_state<S>(self) -> State<S>
    where
        S: StateRestorer<S> + Snapshotable<Snapshot = SN> + UserState<Float = F>,
    {
        State {
            user: S::restore(self.user),
            runtime: self.runtime,
            convergence: self.convergence,
        }
    }
}

pub trait CheckpointBackend<SN, F>: Send + Sync {
    fn save(&self, checkpoint: CheckpointView<'_, SN, F>) -> Result<(), CheckpointError>;

    fn load(&self) -> Result<Option<Checkpoint<SN, F>>, CheckpointError>;
}

pub trait EngineInitializer<S>
where
    S: UserState,
{
    fn try_load(&self) -> Result<Option<State<S>>, CheckpointError>;
}

impl<S, C> EngineInitializer<S> for C
where
    S: UserState + Snapshotable + StateRestorer<S>,
    C: CheckpointBackend<<S as Snapshotable>::Snapshot, <S as UserState>::Float>,
{
    fn try_load(&self) -> Result<Option<State<S>>, CheckpointError> {
        let maybe_checkpoint = self.load()?;

        Ok(maybe_checkpoint.map(|checkpoint| checkpoint.into_state()))
    }
}