use crate::graph::{
error::GraphError,
persistence::{checkpointer::CheckpointerBox, snapshot::StateSnapshot},
state::State,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum DurabilityMode {
Exit,
Async,
Sync,
}
impl Default for DurabilityMode {
fn default() -> Self {
Self::Sync
}
}
impl DurabilityMode {
pub fn from_str(s: &str) -> Result<Self, GraphError> {
match s.to_lowercase().as_str() {
"exit" => Ok(Self::Exit),
"async" => Ok(Self::Async),
"sync" => Ok(Self::Sync),
_ => Err(GraphError::ExecutionError(format!(
"Invalid durability mode: {}. Must be one of: exit, async, sync",
s
))),
}
}
}
pub async fn save_checkpoint<S: State + 'static>(
checkpointer: Option<&CheckpointerBox<S>>,
snapshot: &StateSnapshot<S>,
mode: DurabilityMode,
) -> Result<(), GraphError> {
if let Some(checkpointer) = checkpointer {
match mode {
DurabilityMode::Exit => {
Ok(())
}
DurabilityMode::Async => {
let checkpointer = checkpointer.clone();
let snapshot = snapshot.clone();
let thread_id = snapshot.thread_id().to_string();
tokio::spawn(async move {
if let Err(e) = checkpointer.put(&thread_id, &snapshot).await {
log::error!("Failed to save checkpoint asynchronously: {}", e);
}
});
Ok(())
}
DurabilityMode::Sync => {
checkpointer
.put(snapshot.thread_id(), snapshot)
.await
.map_err(|e| {
GraphError::ExecutionError(format!("Failed to save checkpoint: {}", e))
})?;
Ok(())
}
}
} else {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_durability_mode_from_str() {
assert_eq!(
DurabilityMode::from_str("exit").unwrap(),
DurabilityMode::Exit
);
assert_eq!(
DurabilityMode::from_str("async").unwrap(),
DurabilityMode::Async
);
assert_eq!(
DurabilityMode::from_str("sync").unwrap(),
DurabilityMode::Sync
);
assert!(DurabilityMode::from_str("invalid").is_err());
}
}