use async_trait::async_trait;
use chrono::{DateTime, Utc};
use rustc_hash::FxHashMap;
use std::sync::RwLock;
use crate::{
runtimes::session::SessionState, schedulers::SchedulerState, state::VersionedState,
types::NodeKind,
};
#[derive(Debug, Clone)]
pub struct Checkpoint {
pub session_id: String,
pub step: u64,
pub state: VersionedState,
pub frontier: Vec<NodeKind>,
pub versions_seen: FxHashMap<String, FxHashMap<String, u64>>, pub concurrency_limit: usize,
pub created_at: DateTime<Utc>,
pub ran_nodes: Vec<NodeKind>,
pub skipped_nodes: Vec<NodeKind>,
pub updated_channels: Vec<String>,
}
impl Checkpoint {
#[must_use]
pub fn from_session(session_id: &str, session: &SessionState) -> Self {
Self {
session_id: session_id.to_string(),
step: session.step,
state: session.state.clone(),
frontier: session.frontier.clone(),
versions_seen: session.scheduler_state.versions_seen.clone(),
concurrency_limit: session.scheduler.concurrency_limit,
created_at: Utc::now(),
ran_nodes: vec![], skipped_nodes: vec![],
updated_channels: vec![],
}
}
#[must_use]
pub fn from_step_report(
session_id: &str,
session_state: &SessionState,
step_report: &crate::runtimes::execution::StepReport,
) -> Self {
Self {
session_id: session_id.to_string(),
step: session_state.step,
state: session_state.state.clone(),
frontier: session_state.frontier.clone(),
versions_seen: session_state.scheduler_state.versions_seen.clone(),
concurrency_limit: session_state.scheduler.concurrency_limit,
created_at: Utc::now(),
ran_nodes: step_report.ran_nodes.clone(),
skipped_nodes: step_report.skipped_nodes.clone(),
updated_channels: step_report
.barrier_outcome
.updated_channels
.iter()
.map(|s| (*s).to_string())
.collect(),
}
}
}
#[derive(Debug, thiserror::Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
pub enum CheckpointerError {
#[error("session not found: {session_id}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::checkpointer::not_found),
help(
"Ensure the session ID `{session_id}` is correct and the session has been created."
)
)
)]
NotFound { session_id: String },
#[error("backend error: {message}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::checkpointer::backend),
help("Check backend connectivity and permissions; backend message: {message}.")
)
)]
Backend { message: String },
#[error("checkpointer error: {message}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(code(weavegraph::checkpointer::other))
)]
Other { message: String },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CheckpointerType {
InMemory,
#[cfg(feature = "sqlite")]
#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))]
SQLite,
#[cfg(feature = "postgres")]
#[cfg_attr(docsrs, doc(cfg(feature = "postgres")))]
Postgres,
}
pub type Result<T> = std::result::Result<T, CheckpointerError>;
#[async_trait]
pub trait Checkpointer: Send + Sync {
async fn save(&self, checkpoint: Checkpoint) -> Result<()>;
async fn load_latest(&self, session_id: &str) -> Result<Option<Checkpoint>>;
async fn list_sessions(&self) -> Result<Vec<String>>;
}
#[derive(Default)]
pub struct InMemoryCheckpointer {
inner: RwLock<FxHashMap<String, Checkpoint>>,
}
impl InMemoryCheckpointer {
#[must_use]
pub fn new() -> Self {
Self {
inner: RwLock::new(FxHashMap::default()),
}
}
}
#[async_trait]
impl Checkpointer for InMemoryCheckpointer {
#[tracing::instrument(skip(self), fields(session_id = %checkpoint.session_id, step = checkpoint.step))]
async fn save(&self, checkpoint: Checkpoint) -> Result<()> {
let mut map = self
.inner
.write()
.expect("InMemoryCheckpointer RwLock poisoned");
map.insert(checkpoint.session_id.clone(), checkpoint);
Ok(())
}
#[tracing::instrument(skip(self), fields(session_id = %session_id))]
async fn load_latest(&self, session_id: &str) -> Result<Option<Checkpoint>> {
let map = self
.inner
.read()
.expect("InMemoryCheckpointer RwLock poisoned");
Ok(map.get(session_id).cloned())
}
#[tracing::instrument(skip(self))]
async fn list_sessions(&self) -> Result<Vec<String>> {
let map = self
.inner
.read()
.expect("InMemoryCheckpointer RwLock poisoned");
Ok(map.keys().cloned().collect())
}
}
#[must_use = "restored session state should be used to continue execution"]
pub fn restore_session_state(cp: &Checkpoint) -> SessionState {
use crate::schedulers::Scheduler;
SessionState {
state: cp.state.clone(),
step: cp.step,
frontier: cp.frontier.clone(),
scheduler: Scheduler::new(cp.concurrency_limit),
scheduler_state: SchedulerState {
versions_seen: cp.versions_seen.clone(),
},
}
}