use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DurabilityDecision {
Sync,
Async,
Skip,
}
pub trait DurabilityHook: Send + Sync {
fn decide(&self, step: u64, is_terminal: bool) -> DurabilityDecision;
}
impl<F> DurabilityHook for F
where
F: Fn(u64, bool) -> DurabilityDecision + Send + Sync,
{
fn decide(&self, step: u64, is_terminal: bool) -> DurabilityDecision {
(self)(step, is_terminal)
}
}
#[derive(Clone, Default)]
pub enum Durability {
#[default]
Sync,
Async,
Exit,
Every {
n: u64,
mode: Box<Durability>,
},
Custom(Arc<dyn DurabilityHook>),
}
impl std::fmt::Debug for Durability {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Sync => f.write_str("Sync"),
Self::Async => f.write_str("Async"),
Self::Exit => f.write_str("Exit"),
Self::Every { n, mode } => f
.debug_struct("Every")
.field("n", n)
.field("mode", mode)
.finish(),
Self::Custom(_) => f.write_str("Custom(<hook>)"),
}
}
}
impl PartialEq for Durability {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Sync, Self::Sync) | (Self::Async, Self::Async) | (Self::Exit, Self::Exit) => {
true
}
(Self::Every { n: a, mode: ma }, Self::Every { n: b, mode: mb }) => a == b && ma == mb,
_ => false,
}
}
}
impl Durability {
pub fn decide(&self, step: u64, is_terminal: bool) -> DurabilityDecision {
match self {
Self::Sync => DurabilityDecision::Sync,
Self::Async => DurabilityDecision::Async,
Self::Exit => {
if is_terminal {
DurabilityDecision::Sync
} else {
DurabilityDecision::Skip
}
}
Self::Every { n, mode } => {
if is_terminal {
return DurabilityDecision::Sync;
}
let stride = (*n).max(1);
if step.is_multiple_of(stride) {
mode.decide(step, false)
} else {
DurabilityDecision::Skip
}
}
Self::Custom(h) => h.decide(step, is_terminal),
}
}
pub fn save_per_step_sync(&self) -> bool {
matches!(self, Self::Sync)
}
pub fn save_per_step_async(&self) -> bool {
matches!(self, Self::Async)
}
pub fn save_on_exit(&self) -> bool {
matches!(self, Self::Exit)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_is_sync() {
assert_eq!(Durability::default(), Durability::Sync);
}
#[test]
fn predicates_match_variants() {
assert!(Durability::Sync.save_per_step_sync());
assert!(!Durability::Sync.save_per_step_async());
assert!(!Durability::Sync.save_on_exit());
assert!(!Durability::Async.save_per_step_sync());
assert!(Durability::Async.save_per_step_async());
assert!(!Durability::Async.save_on_exit());
assert!(!Durability::Exit.save_per_step_sync());
assert!(!Durability::Exit.save_per_step_async());
assert!(Durability::Exit.save_on_exit());
}
#[test]
fn every_stride_skips_intermediate() {
let d = Durability::Every {
n: 3,
mode: Box::new(Durability::Sync),
};
assert_eq!(d.decide(0, false), DurabilityDecision::Sync);
assert_eq!(d.decide(1, false), DurabilityDecision::Skip);
assert_eq!(d.decide(2, false), DurabilityDecision::Skip);
assert_eq!(d.decide(3, false), DurabilityDecision::Sync);
assert_eq!(d.decide(7, true), DurabilityDecision::Sync);
}
#[test]
fn custom_hook_is_invoked() {
let d = Durability::Custom(Arc::new(|step: u64, terminal: bool| {
if terminal || step.is_multiple_of(2) {
DurabilityDecision::Sync
} else {
DurabilityDecision::Skip
}
}));
assert_eq!(d.decide(0, false), DurabilityDecision::Sync);
assert_eq!(d.decide(1, false), DurabilityDecision::Skip);
assert_eq!(d.decide(99, true), DurabilityDecision::Sync);
}
}