use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use blazen_events::{AnyEvent, DynamicEvent, Event, EventEnvelope, StartEvent};
use serde::Serialize;
use tokio::sync::{broadcast, mpsc, oneshot};
use uuid::Uuid;
use crate::builder::InputHandlerFn;
use crate::context::Context;
use crate::error::WorkflowError;
#[cfg(feature = "persist")]
use crate::event_loop::CheckpointConfig;
use crate::event_loop::event_loop;
#[cfg(feature = "distributed")]
use crate::handler::WorkflowResult;
use crate::handler::{WorkflowControl, WorkflowHandler};
#[cfg(feature = "distributed")]
use crate::session_ref::RemoteRefDescriptor;
use crate::session_ref::{
RegistryKey, SERIALIZED_SESSION_REFS_META_KEY, SessionRefError, SessionRefRegistry,
SessionRefSerializable,
};
use crate::snapshot::WorkflowSnapshot;
use crate::step::StepRegistration;
pub type SessionRefDeserializerFn =
fn(&[u8]) -> Result<Arc<dyn SessionRefSerializable>, SessionRefError>;
pub struct Workflow {
pub(crate) name: String,
pub(crate) step_registry: HashMap<String, Vec<StepRegistration>>,
pub(crate) timeout: Option<Duration>,
pub(crate) input_handler: Option<InputHandlerFn>,
pub(crate) auto_publish_events: bool,
pub(crate) session_pause_policy: crate::session_ref::SessionPausePolicy,
#[cfg(feature = "persist")]
pub(crate) checkpoint_store: Option<Arc<dyn blazen_persist::CheckpointStore>>,
#[cfg(feature = "persist")]
pub(crate) checkpoint_after_step: bool,
#[cfg(feature = "telemetry")]
pub(crate) collect_history: bool,
}
impl std::fmt::Debug for Workflow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Workflow")
.field("name", &self.name)
.field("step_count", &self.step_registry.len())
.field("timeout", &self.timeout)
.finish_non_exhaustive()
}
}
impl Workflow {
pub async fn run(&self, input: serde_json::Value) -> crate::error::Result<WorkflowHandler> {
let start_event = StartEvent { data: input };
self.run_with_event(start_event).await
}
pub async fn run_with_event<E: Event + Serialize>(
&self,
start_event: E,
) -> crate::error::Result<WorkflowHandler> {
self.run_with_event_and_session_refs(start_event, None)
.await
}
pub async fn run_with_registry(
&self,
input: serde_json::Value,
session_refs: Arc<SessionRefRegistry>,
) -> crate::error::Result<WorkflowHandler> {
let start_event = StartEvent { data: input };
self.run_with_event_and_session_refs(start_event, Some(session_refs))
.await
}
pub(crate) async fn run_with_event_and_session_refs<E: Event + Serialize>(
&self,
start_event: E,
session_refs: Option<Arc<SessionRefRegistry>>,
) -> crate::error::Result<WorkflowHandler> {
let (event_tx, event_rx) = mpsc::unbounded_channel::<EventEnvelope>();
let (stream_tx, _stream_rx) = broadcast::channel::<Box<dyn AnyEvent>>(256);
let (result_tx, result_rx) = oneshot::channel();
let (control_tx, control_rx) = mpsc::unbounded_channel::<WorkflowControl>();
let ctx = match session_refs {
Some(refs) => Context::new_with_session_refs(event_tx.clone(), stream_tx.clone(), refs),
None => Context::new(event_tx.clone(), stream_tx.clone()),
};
ctx.set_session_pause_policy(self.session_pause_policy)
.await;
let session_refs = ctx.session_refs_arc().await;
let run_id = Uuid::new_v4();
ctx.set_metadata("run_id", serde_json::Value::String(run_id.to_string()))
.await;
ctx.set_metadata(
"workflow_name",
serde_json::Value::String(self.name.clone()),
)
.await;
let envelope = EventEnvelope::new(Box::new(start_event), None);
event_tx
.send(envelope)
.map_err(|_| WorkflowError::ChannelClosed)?;
#[cfg(feature = "telemetry")]
let (history_tx, history_rx) = if self.collect_history {
let (tx, rx) = mpsc::unbounded_channel();
(Some(tx), Some(rx))
} else {
(None, None)
};
let registry = self.step_registry.clone();
let timeout = self.timeout;
let workflow_name = self.name.clone();
let input_handler = self.input_handler.clone();
let auto_publish = self.auto_publish_events;
#[cfg(feature = "persist")]
let checkpoint_config = CheckpointConfig {
store: self.checkpoint_store.clone(),
after_step: self.checkpoint_after_step,
};
let event_loop_handle = tokio::spawn(event_loop(
event_rx,
event_tx,
registry,
ctx,
result_tx,
timeout,
control_rx,
workflow_name,
run_id,
input_handler,
auto_publish,
#[cfg(feature = "persist")]
checkpoint_config,
#[cfg(feature = "telemetry")]
history_tx,
));
Ok(WorkflowHandler::new(
result_rx,
stream_tx,
control_tx,
event_loop_handle,
session_refs,
#[cfg(feature = "telemetry")]
history_rx,
))
}
pub async fn resume(
snapshot: WorkflowSnapshot,
steps: Vec<StepRegistration>,
timeout: Option<Duration>,
) -> crate::error::Result<WorkflowHandler> {
Self::resume_inner(snapshot, steps, HashMap::new(), timeout).await
}
pub async fn resume_with_deserializers(
snapshot: WorkflowSnapshot,
steps: Vec<StepRegistration>,
deserializers: HashMap<&'static str, SessionRefDeserializerFn>,
timeout: Option<Duration>,
) -> crate::error::Result<WorkflowHandler> {
Self::resume_inner(snapshot, steps, deserializers, timeout).await
}
async fn resume_inner(
snapshot: WorkflowSnapshot,
steps: Vec<StepRegistration>,
deserializers: HashMap<&'static str, SessionRefDeserializerFn>,
timeout: Option<Duration>,
) -> crate::error::Result<WorkflowHandler> {
let mut registry: HashMap<String, Vec<StepRegistration>> = HashMap::new();
for step in steps {
for &event_type in &step.accepts {
registry
.entry(event_type.to_owned())
.or_default()
.push(step.clone());
}
}
let (event_tx, event_rx) = mpsc::unbounded_channel::<EventEnvelope>();
let (stream_tx, _stream_rx) = broadcast::channel::<Box<dyn AnyEvent>>(256);
let (result_tx, result_rx) = oneshot::channel();
let (control_tx, control_rx) = mpsc::unbounded_channel::<WorkflowControl>();
let ctx = Context::new(event_tx.clone(), stream_tx.clone());
ctx.restore_state(snapshot.context_state).await;
ctx.restore_collected(snapshot.collected_events).await;
ctx.restore_metadata(snapshot.metadata).await;
let session_refs = ctx.session_refs_arc().await;
if let Some(meta) = ctx
.snapshot_metadata()
.await
.get(SERIALIZED_SESSION_REFS_META_KEY)
&& !deserializers.is_empty()
{
rehydrate_serialized_session_refs(&session_refs, meta, &deserializers).await?;
}
for serialized in &snapshot.pending_events {
let event: Box<dyn AnyEvent> =
blazen_events::try_deserialize_event(&serialized.event_type, &serialized.data)
.unwrap_or_else(|| {
Box::new(DynamicEvent {
event_type: serialized.event_type.clone(),
data: serialized.data.clone(),
})
});
let envelope = EventEnvelope::new(event, serialized.source_step.clone());
event_tx
.send(envelope)
.map_err(|_| WorkflowError::ChannelClosed)?;
}
let workflow_name = snapshot.workflow_name;
let run_id = snapshot.run_id;
#[cfg(feature = "telemetry")]
let history_tx: Option<mpsc::UnboundedSender<blazen_telemetry::HistoryEvent>> = None;
#[cfg(feature = "persist")]
let checkpoint_config = CheckpointConfig {
store: None,
after_step: false,
};
let event_loop_handle = tokio::spawn(event_loop(
event_rx,
event_tx,
registry,
ctx,
result_tx,
timeout,
control_rx,
workflow_name,
run_id,
None, false, #[cfg(feature = "persist")]
checkpoint_config,
#[cfg(feature = "telemetry")]
history_tx,
));
Ok(WorkflowHandler::new(
result_rx,
stream_tx,
control_tx,
event_loop_handle,
session_refs,
#[cfg(feature = "telemetry")]
None, ))
}
#[cfg(feature = "persist")]
pub async fn resume_from(
store: Arc<dyn blazen_persist::CheckpointStore>,
run_id: &Uuid,
steps: Vec<StepRegistration>,
) -> crate::error::Result<WorkflowHandler> {
let checkpoint = store
.load(run_id)
.await
.map_err(|e| WorkflowError::Context(format!("checkpoint load failed: {e}")))?
.ok_or_else(|| {
WorkflowError::Context(format!("no checkpoint found for run_id {run_id}"))
})?;
let snapshot: WorkflowSnapshot = checkpoint.into();
Self::resume(snapshot, steps, Some(Duration::from_secs(300))).await
}
#[must_use]
pub fn step_names(&self) -> Vec<String> {
let mut seen = std::collections::HashSet::new();
let mut names = Vec::new();
for registrations in self.step_registry.values() {
for reg in registrations {
if seen.insert(®.name) {
names.push(reg.name.clone());
}
}
}
names
}
#[cfg(feature = "distributed")]
pub async fn run_remote(
&self,
input: serde_json::Value,
peer: &dyn crate::distributed::PeerClient,
) -> crate::error::Result<WorkflowResult> {
use crate::distributed::{RemoteWorkflowRequest, RemoteWorkflowResponse};
let step_ids = self.step_names();
let request = RemoteWorkflowRequest {
workflow_name: self.name.clone(),
step_ids,
input,
timeout_secs: self.timeout.map(|d| d.as_secs()),
};
let response: RemoteWorkflowResponse = peer
.invoke_sub_workflow(request)
.await
.map_err(|e| WorkflowError::Context(format!("peer invocation failed: {e}")))?;
if let Some(err) = &response.error
&& !err.is_empty()
{
return Err(WorkflowError::Context(format!(
"remote workflow failed: {err}"
)));
}
let registry = Arc::new(SessionRefRegistry::new());
for (key_uuid, descriptor) in &response.remote_refs {
let key = RegistryKey(*key_uuid);
let remote_desc = RemoteRefDescriptor {
origin_node_id: descriptor.origin_node_id.clone(),
type_tag: descriptor.type_tag.clone(),
created_at_epoch_ms: descriptor.created_at_epoch_ms,
};
let _ = registry.insert_remote(key, remote_desc).await;
}
let result_json = response.result.unwrap_or(serde_json::Value::Null);
let stop_event = blazen_events::StopEvent {
result: result_json,
};
Ok(WorkflowResult {
event: Box::new(stop_event),
session_refs: registry,
})
}
pub fn new_from_registered_steps(
name: impl Into<String>,
step_ids: Vec<&str>,
) -> crate::error::Result<Self> {
use crate::builder::WorkflowBuilder;
use crate::step_registry::lookup_step_builder;
let mut builder = WorkflowBuilder::new(name);
for step_id in step_ids {
let registration =
lookup_step_builder(step_id).ok_or_else(|| WorkflowError::UnknownStep {
step_id: step_id.to_string(),
})?;
builder = builder.step(registration);
}
builder.no_timeout().build()
}
}
async fn rehydrate_serialized_session_refs(
registry: &Arc<SessionRefRegistry>,
meta: &serde_json::Value,
deserializers: &HashMap<&'static str, SessionRefDeserializerFn>,
) -> crate::error::Result<()> {
let Some(entries) = meta.as_object() else {
return Err(WorkflowError::SessionRefsNotSerializable {
keys: vec!["__blazen_serialized_session_refs metadata is not a JSON object".to_owned()],
});
};
let mut failures: Vec<String> = Vec::new();
for (key_str, record) in entries {
let Ok(key) = RegistryKey::parse(key_str) else {
failures.push(format!("invalid RegistryKey '{key_str}'"));
continue;
};
let Some(record_obj) = record.as_object() else {
failures.push(format!("record for key {key_str} is not an object"));
continue;
};
let Some(type_tag) = record_obj.get("type_tag").and_then(|v| v.as_str()) else {
failures.push(format!("record for key {key_str} missing type_tag"));
continue;
};
let Some(data_value) = record_obj.get("data") else {
failures.push(format!("record for key {key_str} missing data"));
continue;
};
let bytes: crate::value::BytesWrapper = match serde_json::from_value(data_value.clone()) {
Ok(b) => b,
Err(e) => {
failures.push(format!("failed to decode data bytes for {key_str}: {e}"));
continue;
}
};
let Some(&deserializer) = deserializers.get(type_tag) else {
failures.push(format!(
"no registered deserializer for type_tag '{type_tag}' (key {key_str})"
));
continue;
};
let value = match deserializer(&bytes.0) {
Ok(v) => v,
Err(e) => {
failures.push(format!(
"deserializer for type_tag '{type_tag}' failed on key {key_str}: {e}"
));
continue;
}
};
if let Err(e) = registry.insert_serializable_with_key(key, value).await {
failures.push(format!("registry insert failed for {key_str}: {e}"));
}
}
if failures.is_empty() {
Ok(())
} else {
Err(WorkflowError::SessionRefsNotSerializable { keys: failures })
}
}
#[cfg(test)]
mod tests {
use blazen_events::{Event, StartEvent, StopEvent};
use std::sync::Arc;
use std::time::Duration;
use crate::Workflow;
use crate::builder::WorkflowBuilder;
use crate::error::WorkflowError;
use crate::step::{StepFn, StepOutput, StepRegistration};
fn echo_step() -> StepRegistration {
let handler: StepFn = Arc::new(|event, _ctx| {
Box::pin(async move {
let start = event
.as_any()
.downcast_ref::<StartEvent>()
.expect("expected StartEvent");
let stop = StopEvent {
result: start.data.clone(),
};
Ok(StepOutput::Single(Box::new(stop)))
})
});
StepRegistration {
name: "echo".into(),
accepts: vec![StartEvent::event_type()],
emits: vec![StopEvent::event_type()],
handler,
max_concurrency: 0,
}
}
#[tokio::test]
async fn simple_start_to_stop() {
let workflow = WorkflowBuilder::new("test")
.step(echo_step())
.build()
.unwrap();
let handler = workflow
.run(serde_json::json!({"hello": "world"}))
.await
.unwrap();
let result = handler.result().await.unwrap().event;
assert_eq!(result.event_type_id(), StopEvent::event_type());
let stop = result.downcast_ref::<StopEvent>().unwrap();
assert_eq!(stop.result, serde_json::json!({"hello": "world"}));
}
#[tokio::test]
async fn empty_workflow_fails_validation() {
let result = WorkflowBuilder::new("empty").build();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, WorkflowError::ValidationFailed(_)));
}
#[test]
fn step_names_returns_unique_names() {
let handler_b: StepFn =
Arc::new(|_event, _ctx| Box::pin(async move { Ok(StepOutput::None) }));
let step_b = StepRegistration {
name: "side_effect".into(),
accepts: vec![StartEvent::event_type()],
emits: vec![],
handler: handler_b,
max_concurrency: 0,
};
let workflow = WorkflowBuilder::new("step-names-test")
.step(echo_step())
.step(step_b)
.build()
.unwrap();
let mut names = workflow.step_names();
names.sort();
assert_eq!(names, vec!["echo", "side_effect"]);
}
#[tokio::test]
async fn timeout_triggers() {
let handler: StepFn = Arc::new(|_event, _ctx| {
Box::pin(async move {
tokio::time::sleep(Duration::from_secs(3600)).await;
Ok(StepOutput::None)
})
});
let step = StepRegistration {
name: "slow".into(),
accepts: vec![StartEvent::event_type()],
emits: vec![],
handler,
max_concurrency: 0,
};
let workflow = WorkflowBuilder::new("timeout-test")
.step(step)
.timeout(Duration::from_millis(50))
.build()
.unwrap();
let wf_handler = workflow.run(serde_json::json!(null)).await.unwrap();
let result = wf_handler.result().await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), WorkflowError::Timeout { .. }));
}
#[tokio::test]
async fn step_error_propagates() {
let handler: StepFn = Arc::new(|_event, _ctx| {
Box::pin(async move { Err(WorkflowError::Context("test error".into())) })
});
let step = StepRegistration {
name: "failing".into(),
accepts: vec![StartEvent::event_type()],
emits: vec![],
handler,
max_concurrency: 0,
};
let workflow = WorkflowBuilder::new("error-test")
.step(step)
.build()
.unwrap();
let wf_handler = workflow.run(serde_json::json!(null)).await.unwrap();
let result = wf_handler.result().await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
WorkflowError::StepFailed { .. }
));
}
struct TestSerializable {
value: i32,
}
impl crate::session_ref::SessionRefSerializable for TestSerializable {
fn blazen_serialize(&self) -> Result<Vec<u8>, crate::session_ref::SessionRefError> {
Ok(self.value.to_be_bytes().to_vec())
}
fn blazen_type_tag(&self) -> &'static str {
"test::TestSerializable"
}
}
fn test_deserialize(
bytes: &[u8],
) -> Result<
Arc<dyn crate::session_ref::SessionRefSerializable>,
crate::session_ref::SessionRefError,
> {
if bytes.len() != 4 {
return Err(crate::session_ref::SessionRefError::SerializationFailed {
type_tag: "test::TestSerializable".to_owned(),
source: "expected 4 bytes".into(),
});
}
let mut buf = [0u8; 4];
buf.copy_from_slice(bytes);
let value = i32::from_be_bytes(buf);
Ok(Arc::new(TestSerializable { value }))
}
fn park_step() -> StepRegistration {
let handler: StepFn = Arc::new(|_event, ctx| {
Box::pin(async move {
let registry = ctx.session_refs_arc().await;
let _ = registry
.insert_serializable(Arc::new(TestSerializable { value: 1234 }))
.await
.unwrap();
Ok(StepOutput::None)
})
});
StepRegistration {
name: "park".into(),
accepts: vec![StartEvent::event_type()],
emits: vec![],
handler,
max_concurrency: 0,
}
}
#[tokio::test]
async fn pickle_or_serialize_round_trip_through_snapshot() {
use crate::session_ref::{SERIALIZED_SESSION_REFS_META_KEY, SessionPausePolicy};
use std::collections::HashMap;
let workflow = WorkflowBuilder::new("serialize-roundtrip")
.step(park_step())
.session_pause_policy(SessionPausePolicy::PickleOrSerialize)
.build()
.unwrap();
let wf_handler = workflow.run(serde_json::json!(null)).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
wf_handler.pause().unwrap();
let snapshot = wf_handler.snapshot().await.unwrap();
wf_handler.abort().unwrap();
let raw = snapshot
.metadata
.get(SERIALIZED_SESSION_REFS_META_KEY)
.expect("metadata must contain serialized session refs");
let entries = raw
.as_object()
.expect("serialized session refs metadata must be a JSON object");
assert_eq!(entries.len(), 1);
let (_key_str, record) = entries.iter().next().unwrap();
let record_obj = record.as_object().unwrap();
assert_eq!(
record_obj.get("type_tag").and_then(|v| v.as_str()).unwrap(),
"test::TestSerializable"
);
let bytes: crate::value::BytesWrapper =
serde_json::from_value(record_obj.get("data").unwrap().clone()).unwrap();
assert_eq!(bytes.0, vec![0, 0, 4, 210]);
let mut deserializers: HashMap<&'static str, crate::workflow::SessionRefDeserializerFn> =
HashMap::new();
deserializers.insert("test::TestSerializable", test_deserialize);
let resumed_handler = Workflow::resume_with_deserializers(
snapshot,
vec![park_step()],
deserializers,
Some(Duration::from_millis(200)),
)
.await
.unwrap();
let resumed_refs = resumed_handler.session_refs();
assert_eq!(resumed_refs.len().await, 1);
let entries = resumed_refs.serializable_entries().await;
assert_eq!(entries.len(), 1);
let ser = &entries[0].1;
let round_trip = ser.blazen_serialize().unwrap();
assert_eq!(round_trip, vec![0, 0, 4, 210]);
assert_eq!(ser.blazen_type_tag(), "test::TestSerializable");
resumed_handler.abort().unwrap();
}
#[tokio::test]
async fn resume_with_missing_deserializer_errors() {
use crate::session_ref::SessionPausePolicy;
use std::collections::HashMap;
let workflow = WorkflowBuilder::new("serialize-missing-deser")
.step(park_step())
.session_pause_policy(SessionPausePolicy::PickleOrSerialize)
.build()
.unwrap();
let wf_handler = workflow.run(serde_json::json!(null)).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
wf_handler.pause().unwrap();
let snapshot = wf_handler.snapshot().await.unwrap();
wf_handler.abort().unwrap();
let deserializers: HashMap<&'static str, crate::workflow::SessionRefDeserializerFn> =
HashMap::new();
let resumed = Workflow::resume_with_deserializers(
snapshot,
vec![park_step()],
deserializers,
Some(Duration::from_millis(200)),
)
.await;
let h = resumed.unwrap();
assert_eq!(h.session_refs().len().await, 0);
h.abort().unwrap();
}
}