use chrono::Utc;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::{
channels::{Channel, ExtrasChannel, MessagesChannel},
message::Message,
runtimes::checkpointer::Checkpoint,
state::VersionedState,
types::NodeKind,
utils::json_ext::JsonSerializable,
};
impl<T> JsonSerializable<PersistenceError> for T
where
T: serde::Serialize + for<'de> serde::de::DeserializeOwned,
{
fn to_json_string(&self) -> std::result::Result<String, PersistenceError> {
serde_json::to_string(self).map_err(|e| PersistenceError::Serde { source: e })
}
fn from_json_str(s: &str) -> std::result::Result<Self, PersistenceError> {
serde_json::from_str(s).map_err(|e| PersistenceError::Serde { source: e })
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PersistedVecChannel<T> {
pub version: u32,
#[serde(default)]
pub items: Vec<T>,
}
impl<T> Default for PersistedVecChannel<T> {
fn default() -> Self {
Self {
version: 1,
items: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PersistedMapChannel<V> {
pub version: u32,
#[serde(default)]
pub map: FxHashMap<String, V>,
}
impl<V> Default for PersistedMapChannel<V> {
fn default() -> Self {
Self {
version: 1,
map: FxHashMap::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PersistedState {
pub messages: PersistedVecChannel<Message>,
pub extra: PersistedMapChannel<Value>,
#[serde(default)]
pub errors: PersistedVecChannel<crate::channels::errors::ErrorEvent>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct PersistedVersionsSeen(pub FxHashMap<String, FxHashMap<String, u64>>);
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PersistedCheckpoint {
pub session_id: String,
pub step: u64,
pub state: PersistedState,
pub frontier: Vec<String>,
pub versions_seen: PersistedVersionsSeen,
pub concurrency_limit: usize,
pub created_at: String,
#[serde(default)]
pub ran_nodes: Vec<String>,
#[serde(default)]
pub skipped_nodes: Vec<String>,
#[serde(default)]
pub updated_channels: Vec<String>,
}
use thiserror::Error;
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
pub enum PersistenceError {
#[error("missing field: {0}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::persistence::missing_field),
help("Populate the field in the persisted JSON before conversion.")
)
)]
MissingField(&'static str),
#[error("JSON serialization/deserialization failed: {source}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::persistence::serde),
help("Ensure the JSON structure matches Persisted* types; serde error: {source}.")
)
)]
Serde {
#[source]
source: serde_json::Error,
},
#[error("persistence error: {0}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(code(weavegraph::persistence::other))
)]
Other(String),
}
pub type Result<T> = std::result::Result<T, PersistenceError>;
impl From<&VersionedState> for PersistedState {
fn from(s: &VersionedState) -> Self {
PersistedState {
messages: PersistedVecChannel {
version: s.messages.version(),
items: s.messages.snapshot(),
},
extra: PersistedMapChannel {
version: s.extra.version(),
map: s.extra.snapshot(),
},
errors: PersistedVecChannel {
version: s.errors.version(),
items: s.errors.snapshot(),
},
}
}
}
impl TryFrom<PersistedState> for VersionedState {
type Error = PersistenceError;
fn try_from(p: PersistedState) -> Result<Self> {
Ok(VersionedState {
messages: MessagesChannel::new(p.messages.items, p.messages.version),
extra: ExtrasChannel::new(p.extra.map, p.extra.version),
errors: crate::channels::ErrorsChannel::new(p.errors.items, p.errors.version),
})
}
}
impl From<&FxHashMap<String, FxHashMap<String, u64>>> for PersistedVersionsSeen {
fn from(v: &FxHashMap<String, FxHashMap<String, u64>>) -> Self {
PersistedVersionsSeen(v.clone())
}
}
impl From<PersistedVersionsSeen> for FxHashMap<String, FxHashMap<String, u64>> {
fn from(p: PersistedVersionsSeen) -> Self {
p.0
}
}
impl From<&Checkpoint> for PersistedCheckpoint {
fn from(cp: &Checkpoint) -> Self {
PersistedCheckpoint {
session_id: cp.session_id.clone(),
step: cp.step,
state: PersistedState::from(&cp.state),
frontier: cp.frontier.iter().map(|k| k.encode()).collect(),
versions_seen: PersistedVersionsSeen(cp.versions_seen.clone()),
concurrency_limit: cp.concurrency_limit,
created_at: cp.created_at.to_rfc3339(),
ran_nodes: cp.ran_nodes.iter().map(|k| k.encode()).collect(),
skipped_nodes: cp.skipped_nodes.iter().map(|k| k.encode()).collect(),
updated_channels: cp.updated_channels.clone(),
}
}
}
impl TryFrom<PersistedCheckpoint> for Checkpoint {
type Error = PersistenceError;
fn try_from(p: PersistedCheckpoint) -> Result<Self> {
let state = VersionedState::try_from(p.state)?;
let frontier: Vec<NodeKind> = p.frontier.iter().map(|s| NodeKind::decode(s)).collect();
let ran_nodes: Vec<NodeKind> = p.ran_nodes.iter().map(|s| NodeKind::decode(s)).collect();
let skipped_nodes: Vec<NodeKind> = p
.skipped_nodes
.iter()
.map(|s| NodeKind::decode(s))
.collect();
let parsed_dt = chrono::DateTime::parse_from_rfc3339(&p.created_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now());
Ok(Checkpoint {
session_id: p.session_id,
step: p.step,
state,
frontier,
versions_seen: p.versions_seen.0,
concurrency_limit: p.concurrency_limit,
created_at: parsed_dt,
ran_nodes,
skipped_nodes,
updated_channels: p.updated_channels,
})
}
}