use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use blazen_events::InputRequestEvent;
use crate::error::WorkflowError;
use crate::value::StateValue;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedEvent {
pub event_type: String,
pub data: serde_json::Value,
pub source_step: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowSnapshot {
pub workflow_name: String,
pub run_id: Uuid,
pub timestamp: DateTime<Utc>,
pub context_state: HashMap<String, StateValue>,
pub collected_events: HashMap<String, Vec<serde_json::Value>>,
pub pending_events: Vec<SerializedEvent>,
pub metadata: HashMap<String, serde_json::Value>,
#[cfg(feature = "telemetry")]
#[serde(default)]
pub history: Vec<blazen_telemetry::HistoryEvent>,
}
impl WorkflowSnapshot {
pub fn to_json(&self) -> Result<String, WorkflowError> {
serde_json::to_string(self).map_err(WorkflowError::Serialization)
}
pub fn to_json_pretty(&self) -> Result<String, WorkflowError> {
serde_json::to_string_pretty(self).map_err(WorkflowError::Serialization)
}
pub fn from_json(json: &str) -> Result<Self, WorkflowError> {
serde_json::from_str(json).map_err(WorkflowError::Serialization)
}
pub fn to_msgpack(&self) -> Result<Vec<u8>, WorkflowError> {
rmp_serde::to_vec(self).map_err(|e| WorkflowError::BinarySerialization(e.to_string()))
}
pub fn from_msgpack(bytes: &[u8]) -> Result<Self, WorkflowError> {
rmp_serde::from_slice(bytes).map_err(|e| WorkflowError::BinarySerialization(e.to_string()))
}
#[must_use]
pub fn input_request(&self) -> Option<InputRequestEvent> {
self.metadata
.get("__input_request")
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
}
#[cfg(feature = "persist")]
const COLLECTED_EVENTS_META_KEY: &str = "__blazen_collected_events";
#[cfg(feature = "persist")]
const SOURCE_STEPS_META_KEY: &str = "__blazen_pending_source_steps";
#[cfg(feature = "persist")]
impl From<WorkflowSnapshot> for blazen_persist::WorkflowCheckpoint {
fn from(snap: WorkflowSnapshot) -> Self {
let mut metadata = snap.metadata;
if !snap.collected_events.is_empty()
&& let Ok(val) = serde_json::to_value(&snap.collected_events)
{
metadata.insert(COLLECTED_EVENTS_META_KEY.to_owned(), val);
}
let source_steps: Vec<Option<String>> = snap
.pending_events
.iter()
.map(|e| e.source_step.clone())
.collect();
if source_steps.iter().any(Option::is_some)
&& let Ok(val) = serde_json::to_value(&source_steps)
{
metadata.insert(SOURCE_STEPS_META_KEY.to_owned(), val);
}
let pending_events = snap
.pending_events
.into_iter()
.map(|e| blazen_persist::SerializedEvent {
event_type: e.event_type,
data: e.data,
})
.collect();
let state = snap
.context_state
.into_iter()
.map(|(k, v)| {
let json = serde_json::to_value(&v).unwrap_or(serde_json::Value::Null);
(k, json)
})
.collect();
blazen_persist::WorkflowCheckpoint {
workflow_name: snap.workflow_name,
run_id: snap.run_id,
timestamp: snap.timestamp,
state,
pending_events,
metadata,
}
}
}
#[cfg(feature = "persist")]
impl From<blazen_persist::WorkflowCheckpoint> for WorkflowSnapshot {
fn from(cp: blazen_persist::WorkflowCheckpoint) -> Self {
let mut metadata = cp.metadata;
let collected_events = metadata
.remove(COLLECTED_EVENTS_META_KEY)
.and_then(|val| {
serde_json::from_value::<HashMap<String, Vec<serde_json::Value>>>(val).ok()
})
.unwrap_or_default();
let source_steps: Vec<Option<String>> = metadata
.remove(SOURCE_STEPS_META_KEY)
.and_then(|val| serde_json::from_value(val).ok())
.unwrap_or_default();
let pending_events = cp
.pending_events
.into_iter()
.enumerate()
.map(|(i, e)| SerializedEvent {
event_type: e.event_type,
data: e.data,
source_step: source_steps.get(i).and_then(Clone::clone),
})
.collect();
let context_state = cp
.state
.into_iter()
.map(|(k, v)| {
let sv =
serde_json::from_value::<StateValue>(v.clone()).unwrap_or(StateValue::Json(v));
(k, sv)
})
.collect();
WorkflowSnapshot {
workflow_name: cp.workflow_name,
run_id: cp.run_id,
timestamp: cp.timestamp,
context_state,
collected_events,
pending_events,
metadata,
#[cfg(feature = "telemetry")]
history: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_snapshot() -> WorkflowSnapshot {
let mut state = HashMap::new();
state.insert(
"counter".to_owned(),
StateValue::Json(serde_json::json!(42)),
);
state.insert(
"name".to_owned(),
StateValue::Json(serde_json::json!("alice")),
);
let mut collected = HashMap::new();
collected.insert(
"blazen::StartEvent".to_owned(),
vec![serde_json::json!({"data": 1})],
);
let mut metadata = HashMap::new();
let run_id = Uuid::new_v4();
metadata.insert(
"run_id".to_owned(),
serde_json::Value::String(run_id.to_string()),
);
metadata.insert(
"workflow_name".to_owned(),
serde_json::Value::String("test_wf".to_owned()),
);
WorkflowSnapshot {
workflow_name: "test_wf".to_owned(),
run_id,
timestamp: Utc::now(),
context_state: state,
collected_events: collected,
pending_events: vec![SerializedEvent {
event_type: "blazen::StartEvent".to_owned(),
data: serde_json::json!({"data": "hello"}),
source_step: Some("step_a".to_owned()),
}],
metadata,
#[cfg(feature = "telemetry")]
history: Vec::new(),
}
}
#[test]
fn json_roundtrip() {
let snap = sample_snapshot();
let json = snap.to_json().unwrap();
let restored = WorkflowSnapshot::from_json(&json).unwrap();
assert_eq!(restored.workflow_name, snap.workflow_name);
assert_eq!(restored.run_id, snap.run_id);
assert_eq!(restored.context_state, snap.context_state);
assert_eq!(restored.collected_events, snap.collected_events);
assert_eq!(restored.pending_events.len(), snap.pending_events.len());
assert_eq!(
restored.pending_events[0].event_type,
snap.pending_events[0].event_type
);
}
#[test]
fn pretty_json_roundtrip() {
let snap = sample_snapshot();
let json = snap.to_json_pretty().unwrap();
let restored = WorkflowSnapshot::from_json(&json).unwrap();
assert_eq!(restored.workflow_name, snap.workflow_name);
}
#[test]
fn from_invalid_json_fails() {
let result = WorkflowSnapshot::from_json("not valid json");
assert!(result.is_err());
}
#[test]
fn msgpack_roundtrip() {
let snap = sample_snapshot();
let bytes = snap.to_msgpack().unwrap();
let restored = WorkflowSnapshot::from_msgpack(&bytes).unwrap();
assert_eq!(restored.workflow_name, snap.workflow_name);
assert_eq!(restored.run_id, snap.run_id);
assert_eq!(restored.context_state, snap.context_state);
assert_eq!(restored.collected_events, snap.collected_events);
assert_eq!(restored.pending_events.len(), snap.pending_events.len());
}
#[test]
fn msgpack_with_bytes_roundtrip() {
use crate::value::BytesWrapper;
let mut state = HashMap::new();
state.insert(
"data".to_owned(),
StateValue::Bytes(BytesWrapper(vec![0xDE, 0xAD, 0xBE, 0xEF])),
);
state.insert("count".to_owned(), StateValue::Json(serde_json::json!(42)));
let snap = WorkflowSnapshot {
workflow_name: "bytes_test".to_owned(),
run_id: Uuid::new_v4(),
timestamp: Utc::now(),
context_state: state,
collected_events: HashMap::new(),
pending_events: Vec::new(),
metadata: HashMap::new(),
#[cfg(feature = "telemetry")]
history: Vec::new(),
};
let bytes = snap.to_msgpack().unwrap();
let restored = WorkflowSnapshot::from_msgpack(&bytes).unwrap();
assert_eq!(restored.context_state, snap.context_state);
assert_eq!(
restored
.context_state
.get("data")
.unwrap()
.as_bytes()
.unwrap(),
&[0xDE, 0xAD, 0xBE, 0xEF]
);
}
#[test]
fn from_invalid_msgpack_fails() {
let result = WorkflowSnapshot::from_msgpack(&[0xFF, 0xFF]);
assert!(result.is_err());
}
}