use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use blazen_events::{
AnyEvent, DynamicEvent, Event, EventEnvelope, InputRequestEvent, InputResponseEvent,
StartEvent, StopEvent,
};
use chrono::Utc;
use serde::Serialize;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::task::JoinSet;
use uuid::Uuid;
use tracing::Instrument;
use crate::context::Context;
use crate::error::WorkflowError;
use crate::handler::WorkflowHandler;
use crate::snapshot::{SerializedEvent, WorkflowSnapshot};
use crate::step::{StepOutput, StepRegistration};
pub type InputHandlerFn = Arc<
dyn Fn(
InputRequestEvent,
)
-> Pin<Box<dyn Future<Output = Result<InputResponseEvent, WorkflowError>> + Send>>
+ Send
+ Sync,
>;
pub struct WorkflowBuilder {
name: String,
steps: Vec<StepRegistration>,
timeout: Option<Duration>,
input_handler: Option<InputHandlerFn>,
auto_publish_events: bool,
#[cfg(feature = "persist")]
checkpoint_store: Option<Arc<dyn blazen_persist::CheckpointStore>>,
#[cfg(feature = "persist")]
checkpoint_after_step: bool,
#[cfg(feature = "telemetry")]
collect_history: bool,
}
impl WorkflowBuilder {
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
steps: Vec::new(),
timeout: Some(Duration::from_secs(300)), input_handler: None,
auto_publish_events: false,
#[cfg(feature = "persist")]
checkpoint_store: None,
#[cfg(feature = "persist")]
checkpoint_after_step: false,
#[cfg(feature = "telemetry")]
collect_history: false,
}
}
#[must_use]
pub fn step(mut self, registration: StepRegistration) -> Self {
self.steps.push(registration);
self
}
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn no_timeout(mut self) -> Self {
self.timeout = None;
self
}
#[must_use]
pub fn input_handler(mut self, handler: InputHandlerFn) -> Self {
self.input_handler = Some(handler);
self
}
#[must_use]
pub fn auto_publish_events(mut self, enabled: bool) -> Self {
self.auto_publish_events = enabled;
self
}
#[cfg(feature = "telemetry")]
#[must_use]
pub fn with_history(mut self) -> Self {
self.collect_history = true;
self
}
#[cfg(feature = "persist")]
#[must_use]
pub fn checkpoint_store(mut self, store: Arc<dyn blazen_persist::CheckpointStore>) -> Self {
self.checkpoint_store = Some(store);
self
}
#[cfg(feature = "persist")]
#[must_use]
pub fn checkpoint_after_step(mut self, enabled: bool) -> Self {
self.checkpoint_after_step = enabled;
self
}
pub fn build(self) -> crate::error::Result<Workflow> {
if self.steps.is_empty() {
return Err(WorkflowError::ValidationFailed(
"workflow must have at least one step".into(),
));
}
let mut registry: HashMap<String, Vec<StepRegistration>> = HashMap::new();
for step in self.steps {
for &event_type in &step.accepts {
registry
.entry(event_type.to_owned())
.or_default()
.push(step.clone());
}
}
Ok(Workflow {
name: self.name,
step_registry: registry,
timeout: self.timeout,
input_handler: self.input_handler,
auto_publish_events: self.auto_publish_events,
#[cfg(feature = "persist")]
checkpoint_store: self.checkpoint_store,
#[cfg(feature = "persist")]
checkpoint_after_step: self.checkpoint_after_step,
#[cfg(feature = "telemetry")]
collect_history: self.collect_history,
})
}
}
pub struct Workflow {
name: String,
step_registry: HashMap<String, Vec<StepRegistration>>,
timeout: Option<Duration>,
input_handler: Option<InputHandlerFn>,
auto_publish_events: bool,
#[cfg(feature = "persist")]
checkpoint_store: Option<Arc<dyn blazen_persist::CheckpointStore>>,
#[cfg(feature = "persist")]
checkpoint_after_step: bool,
#[cfg(feature = "telemetry")]
collect_history: bool,
}
impl std::fmt::Debug for Workflow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Workflow")
.field("name", &self.name)
.field("step_count", &self.step_registry.len())
.field("timeout", &self.timeout)
.finish_non_exhaustive()
}
}
impl Workflow {
pub async fn run(&self, input: serde_json::Value) -> crate::error::Result<WorkflowHandler> {
let start_event = StartEvent { data: input };
self.run_with_event(start_event).await
}
pub async fn run_with_event<E: Event + Serialize>(
&self,
start_event: E,
) -> crate::error::Result<WorkflowHandler> {
let (event_tx, event_rx) = mpsc::unbounded_channel::<EventEnvelope>();
let (stream_tx, _stream_rx) = broadcast::channel::<Box<dyn AnyEvent>>(256);
let (result_tx, result_rx) = oneshot::channel();
let (pause_tx, pause_rx) = oneshot::channel::<()>();
let (snapshot_tx, snapshot_rx) = oneshot::channel::<WorkflowSnapshot>();
let ctx = Context::new(event_tx.clone(), stream_tx.clone());
let run_id = Uuid::new_v4();
ctx.set_metadata("run_id", serde_json::Value::String(run_id.to_string()))
.await;
ctx.set_metadata(
"workflow_name",
serde_json::Value::String(self.name.clone()),
)
.await;
let envelope = EventEnvelope::new(Box::new(start_event), None);
event_tx
.send(envelope)
.map_err(|_| WorkflowError::ChannelClosed)?;
#[cfg(feature = "telemetry")]
let (history_tx, history_rx) = if self.collect_history {
let (tx, rx) = mpsc::unbounded_channel();
(Some(tx), Some(rx))
} else {
(None, None)
};
let registry = self.step_registry.clone();
let timeout = self.timeout;
let workflow_name = self.name.clone();
let input_handler = self.input_handler.clone();
let auto_publish = self.auto_publish_events;
#[cfg(feature = "persist")]
let checkpoint_config = CheckpointConfig {
store: self.checkpoint_store.clone(),
after_step: self.checkpoint_after_step,
};
let event_loop_handle = tokio::spawn(event_loop(
event_rx,
event_tx,
registry,
ctx,
result_tx,
timeout,
pause_rx,
snapshot_tx,
workflow_name,
run_id,
input_handler,
auto_publish,
#[cfg(feature = "persist")]
checkpoint_config,
#[cfg(feature = "telemetry")]
history_tx,
));
Ok(WorkflowHandler::new(
result_rx,
stream_tx,
Some(pause_tx),
Some(snapshot_rx),
event_loop_handle,
#[cfg(feature = "telemetry")]
history_rx,
))
}
pub async fn resume(
snapshot: WorkflowSnapshot,
steps: Vec<StepRegistration>,
timeout: Option<Duration>,
) -> crate::error::Result<WorkflowHandler> {
let mut registry: HashMap<String, Vec<StepRegistration>> = HashMap::new();
for step in steps {
for &event_type in &step.accepts {
registry
.entry(event_type.to_owned())
.or_default()
.push(step.clone());
}
}
let (event_tx, event_rx) = mpsc::unbounded_channel::<EventEnvelope>();
let (stream_tx, _stream_rx) = broadcast::channel::<Box<dyn AnyEvent>>(256);
let (result_tx, result_rx) = oneshot::channel();
let (pause_tx, pause_rx) = oneshot::channel::<()>();
let (snapshot_tx, snapshot_rx) = oneshot::channel::<WorkflowSnapshot>();
let ctx = Context::new(event_tx.clone(), stream_tx.clone());
ctx.restore_state(snapshot.context_state).await;
ctx.restore_collected(snapshot.collected_events).await;
ctx.restore_metadata(snapshot.metadata).await;
for serialized in &snapshot.pending_events {
let event: Box<dyn AnyEvent> =
blazen_events::try_deserialize_event(&serialized.event_type, &serialized.data)
.unwrap_or_else(|| {
Box::new(DynamicEvent {
event_type: serialized.event_type.clone(),
data: serialized.data.clone(),
})
});
let envelope = EventEnvelope::new(event, serialized.source_step.clone());
event_tx
.send(envelope)
.map_err(|_| WorkflowError::ChannelClosed)?;
}
let workflow_name = snapshot.workflow_name;
let run_id = snapshot.run_id;
#[cfg(feature = "telemetry")]
let history_tx: Option<mpsc::UnboundedSender<blazen_telemetry::HistoryEvent>> = None;
#[cfg(feature = "persist")]
let checkpoint_config = CheckpointConfig {
store: None,
after_step: false,
};
let event_loop_handle = tokio::spawn(event_loop(
event_rx,
event_tx,
registry,
ctx,
result_tx,
timeout,
pause_rx,
snapshot_tx,
workflow_name,
run_id,
None, false, #[cfg(feature = "persist")]
checkpoint_config,
#[cfg(feature = "telemetry")]
history_tx,
));
Ok(WorkflowHandler::new(
result_rx,
stream_tx,
Some(pause_tx),
Some(snapshot_rx),
event_loop_handle,
#[cfg(feature = "telemetry")]
None, ))
}
#[cfg(feature = "persist")]
pub async fn resume_from(
store: Arc<dyn blazen_persist::CheckpointStore>,
run_id: &Uuid,
steps: Vec<StepRegistration>,
) -> crate::error::Result<WorkflowHandler> {
let checkpoint = store
.load(run_id)
.await
.map_err(|e| WorkflowError::Context(format!("checkpoint load failed: {e}")))?
.ok_or_else(|| {
WorkflowError::Context(format!("no checkpoint found for run_id {run_id}"))
})?;
let snapshot: WorkflowSnapshot = checkpoint.into();
Self::resume(snapshot, steps, Some(Duration::from_secs(300))).await
}
pub async fn resume_with_input(
snapshot: WorkflowSnapshot,
response: InputResponseEvent,
steps: Vec<StepRegistration>,
timeout: Option<Duration>,
) -> crate::error::Result<WorkflowHandler> {
let mut snapshot = snapshot;
snapshot.pending_events.push(SerializedEvent {
event_type: "blazen::InputResponseEvent".to_owned(),
data: serde_json::to_value(&response)
.expect("InputResponseEvent serialization should never fail"),
source_step: Some("__human_input".to_owned()),
});
snapshot.metadata.remove("__input_request");
Self::resume(snapshot, steps, timeout).await
}
}
#[cfg(feature = "persist")]
struct CheckpointConfig {
store: Option<Arc<dyn blazen_persist::CheckpointStore>>,
after_step: bool,
}
#[cfg(feature = "persist")]
async fn save_checkpoint(
store: &dyn blazen_persist::CheckpointStore,
ctx: &Context,
workflow_name: &str,
run_id: Uuid,
) {
let context_state = ctx.snapshot_state().await;
let collected_events = ctx.snapshot_collected().await;
let metadata = ctx.snapshot_metadata().await;
let snapshot = WorkflowSnapshot {
workflow_name: workflow_name.to_owned(),
run_id,
timestamp: Utc::now(),
context_state,
collected_events,
pending_events: Vec::new(), metadata,
#[cfg(feature = "telemetry")]
history: Vec::new(),
};
let checkpoint: blazen_persist::WorkflowCheckpoint = snapshot.into();
if let Err(e) = store.save(&checkpoint).await {
tracing::warn!(
run_id = %run_id,
error = %e,
"auto-checkpoint failed (best-effort)"
);
} else {
tracing::debug!(run_id = %run_id, "auto-checkpoint saved");
}
}
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
async fn event_loop(
event_rx: mpsc::UnboundedReceiver<EventEnvelope>,
event_tx: mpsc::UnboundedSender<EventEnvelope>,
registry: HashMap<String, Vec<StepRegistration>>,
ctx: Context,
result_tx: oneshot::Sender<Result<Box<dyn AnyEvent>, WorkflowError>>,
timeout: Option<Duration>,
pause_rx: oneshot::Receiver<()>,
snapshot_tx: oneshot::Sender<WorkflowSnapshot>,
workflow_name: String,
run_id: Uuid,
input_handler: Option<InputHandlerFn>,
auto_publish_events: bool,
#[cfg(feature = "persist")] checkpoint_config: CheckpointConfig,
#[cfg(feature = "telemetry")] history_tx: Option<
mpsc::UnboundedSender<blazen_telemetry::HistoryEvent>,
>,
) {
let stream_ctx = ctx.clone();
let span = tracing::info_span!(
"workflow.run",
workflow_name = %workflow_name,
run_id = %run_id,
);
event_loop_inner(
event_rx,
event_tx,
registry,
ctx,
result_tx,
timeout,
pause_rx,
snapshot_tx,
workflow_name,
run_id,
input_handler,
auto_publish_events,
#[cfg(feature = "persist")]
checkpoint_config,
#[cfg(feature = "telemetry")]
history_tx,
)
.instrument(span)
.await;
stream_ctx.signal_stream_end().await;
}
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
async fn event_loop_inner(
mut event_rx: mpsc::UnboundedReceiver<EventEnvelope>,
event_tx: mpsc::UnboundedSender<EventEnvelope>,
registry: HashMap<String, Vec<StepRegistration>>,
ctx: Context,
result_tx: oneshot::Sender<Result<Box<dyn AnyEvent>, WorkflowError>>,
timeout: Option<Duration>,
mut pause_rx: oneshot::Receiver<()>,
snapshot_tx: oneshot::Sender<WorkflowSnapshot>,
workflow_name: String,
run_id: Uuid,
input_handler: Option<InputHandlerFn>,
auto_publish_events: bool,
#[cfg(feature = "persist")] checkpoint_config: CheckpointConfig,
#[cfg(feature = "telemetry")] history_tx: Option<
mpsc::UnboundedSender<blazen_telemetry::HistoryEvent>,
>,
) {
let start = Instant::now();
#[cfg(feature = "telemetry")]
if let Some(ref tx) = history_tx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::WorkflowStarted {
input: serde_json::json!({}),
},
});
}
let (error_tx, mut error_rx) = mpsc::unbounded_channel::<WorkflowError>();
let mut in_flight: JoinSet<()> = JoinSet::new();
let in_flight_count = Arc::new(AtomicUsize::new(0));
let publish_lifecycle = |ctx: &Context,
kind: &str,
step_name: Option<&str>,
event_type_str: Option<&str>,
duration_ms: Option<u64>,
error: Option<&str>| {
let ctx = ctx.clone();
let kind = kind.to_owned();
let step_name = step_name.map(ToOwned::to_owned);
let event_type_str = event_type_str.map(ToOwned::to_owned);
let error = error.map(ToOwned::to_owned);
async move {
let mut data = serde_json::Map::new();
data.insert("kind".into(), serde_json::Value::String(kind));
if let Some(s) = step_name {
data.insert("step_name".into(), serde_json::Value::String(s));
}
if let Some(e) = event_type_str {
data.insert("event_type".into(), serde_json::Value::String(e));
}
if let Some(d) = duration_ms {
data.insert("duration_ms".into(), serde_json::Value::Number(d.into()));
}
if let Some(e) = error {
data.insert("error".into(), serde_json::Value::String(e));
}
ctx.write_event_to_stream(DynamicEvent {
event_type: "blazen::lifecycle".to_owned(),
data: serde_json::Value::Object(data),
})
.await;
}
};
loop {
let recv_result = if let Some(timeout_dur) = timeout {
let remaining = timeout_dur.saturating_sub(start.elapsed());
if remaining.is_zero() {
#[cfg(feature = "telemetry")]
if let Some(ref tx) = history_tx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::WorkflowTimedOut {
elapsed_ms: u64::try_from(start.elapsed().as_millis())
.unwrap_or(u64::MAX),
},
});
}
let _ = result_tx.send(Err(WorkflowError::Timeout {
elapsed: start.elapsed(),
}));
return;
}
tokio::select! {
biased;
err = error_rx.recv() => {
if let Some(workflow_err) = err {
#[cfg(feature = "telemetry")]
if let Some(ref tx) = history_tx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::WorkflowFailed {
error: workflow_err.to_string(),
duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
},
});
}
let _ = result_tx.send(Err(workflow_err));
return;
}
continue;
}
maybe_envelope = event_rx.recv() => {
maybe_envelope.ok_or(())
}
() = tokio::time::sleep(remaining) => {
#[cfg(feature = "telemetry")]
if let Some(ref tx) = history_tx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::WorkflowTimedOut {
elapsed_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
},
});
}
let _ = result_tx.send(Err(WorkflowError::Timeout {
elapsed: start.elapsed(),
}));
return;
}
_ = &mut pause_rx => {
#[cfg(feature = "telemetry")]
if let Some(ref tx) = history_tx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::WorkflowPaused {
reason: blazen_telemetry::PauseReason::Manual,
pending_count: 0,
},
});
}
handle_pause(
&mut in_flight,
&mut event_rx,
&ctx,
result_tx,
snapshot_tx,
&workflow_name,
run_id,
)
.instrument(tracing::info_span!("workflow.pause", pause_type = "manual"))
.await;
return;
}
}
} else {
tokio::select! {
biased;
err = error_rx.recv() => {
if let Some(workflow_err) = err {
#[cfg(feature = "telemetry")]
if let Some(ref tx) = history_tx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::WorkflowFailed {
error: workflow_err.to_string(),
duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
},
});
}
let _ = result_tx.send(Err(workflow_err));
return;
}
continue;
}
maybe_envelope = event_rx.recv() => {
maybe_envelope.ok_or(())
}
_ = &mut pause_rx => {
#[cfg(feature = "telemetry")]
if let Some(ref tx) = history_tx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::WorkflowPaused {
reason: blazen_telemetry::PauseReason::Manual,
pending_count: 0,
},
});
}
handle_pause(
&mut in_flight,
&mut event_rx,
&ctx,
result_tx,
snapshot_tx,
&workflow_name,
run_id,
)
.instrument(tracing::info_span!("workflow.pause", pause_type = "manual"))
.await;
return;
}
}
};
let Ok(envelope) = recv_result else {
let _ = result_tx.send(Err(WorkflowError::ChannelClosed));
return;
};
let event = envelope.event;
let event_type = event.event_type_id();
#[cfg(feature = "telemetry")]
if let Some(ref tx) = history_tx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::EventReceived {
event_type: event_type.to_string(),
source_step: envelope.source_step.clone(),
},
});
}
if auto_publish_events {
publish_lifecycle(&ctx, "event_routed", None, Some(event_type), None, None).await;
}
{
let _event_span = tracing::debug_span!(
"workflow.event",
event_type = %event_type,
source_step = ?envelope.source_step,
)
.entered();
tracing::debug!(
event_type,
source_step = ?envelope.source_step,
"event loop received event"
);
}
if event_type == StopEvent::event_type() {
tracing::info!("workflow completed via StopEvent");
#[cfg(feature = "telemetry")]
if let Some(ref tx) = history_tx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::WorkflowCompleted {
duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
},
});
}
let final_event: Box<dyn AnyEvent> =
if event.as_any().downcast_ref::<StopEvent>().is_some() {
event
} else if let Some(dynamic) = event.as_any().downcast_ref::<DynamicEvent>() {
match serde_json::from_value::<StopEvent>(dynamic.data.clone()) {
Ok(stop) => Box::new(stop),
Err(_) => {
Box::new(StopEvent {
result: dynamic.data.clone(),
})
}
}
} else {
let json = event.to_json();
Box::new(StopEvent {
result: json.get("result").cloned().unwrap_or(json),
})
};
let _ = result_tx.send(Ok(final_event));
return;
}
if event_type == InputRequestEvent::event_type() {
let request = if let Some(req) = event.as_any().downcast_ref::<InputRequestEvent>() {
req.clone()
} else if let Some(dynamic) = event.as_any().downcast_ref::<DynamicEvent>() {
if let Ok(req) = serde_json::from_value::<InputRequestEvent>(dynamic.data.clone()) {
req
} else {
let _ = result_tx.send(Err(WorkflowError::Context(
"failed to deserialize InputRequestEvent from DynamicEvent".into(),
)));
return;
}
} else {
let _ = result_tx.send(Err(WorkflowError::Context(
"InputRequestEvent type mismatch".into(),
)));
return;
};
#[cfg(feature = "telemetry")]
if let Some(ref tx) = history_tx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::InputRequested {
request_id: request.request_id.clone(),
prompt: request.prompt.clone(),
},
});
}
if let Some(ref handler) = input_handler {
match handler(request).await {
Ok(response) => {
let envelope =
EventEnvelope::new(Box::new(response), Some("__input_handler".into()));
let _ = event_tx.send(envelope);
continue;
}
Err(e) => {
let _ = result_tx.send(Err(e));
return;
}
}
}
#[cfg(feature = "telemetry")]
if let Some(ref tx) = history_tx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::WorkflowPaused {
reason: blazen_telemetry::PauseReason::InputRequired,
pending_count: 0,
},
});
}
handle_input_pause(
&mut in_flight,
&mut event_rx,
&ctx,
result_tx,
snapshot_tx,
&workflow_name,
run_id,
&request,
)
.instrument(tracing::info_span!("workflow.pause", pause_type = "input"))
.await;
return;
}
let Some(handlers) = registry.get(event_type) else {
tracing::warn!(event_type, "no handler registered for event type");
#[cfg(feature = "telemetry")]
if let Some(ref tx) = history_tx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::WorkflowFailed {
error: format!("no handler registered for event type: {event_type}"),
duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
},
});
}
let _ = result_tx.send(Err(WorkflowError::NoHandler {
event_type: event_type.to_owned(),
}));
return;
};
let handlers = handlers.clone();
ctx.push_collected(&*event).await;
dispatch_to_handlers(
&handlers,
&*event,
&ctx,
&event_tx,
&error_tx,
&mut in_flight,
&in_flight_count,
auto_publish_events,
#[cfg(feature = "telemetry")]
history_tx.as_ref(),
);
#[cfg(feature = "persist")]
if checkpoint_config.after_step
&& let Some(ref store) = checkpoint_config.store
{
save_checkpoint(&**store, &ctx, &workflow_name, run_id).await;
}
}
}
async fn handle_pause(
in_flight: &mut JoinSet<()>,
event_rx: &mut mpsc::UnboundedReceiver<EventEnvelope>,
ctx: &Context,
result_tx: oneshot::Sender<Result<Box<dyn AnyEvent>, WorkflowError>>,
snapshot_tx: oneshot::Sender<WorkflowSnapshot>,
workflow_name: &str,
run_id: Uuid,
) {
tracing::info!("pause requested -- waiting for in-flight steps to complete");
while in_flight.join_next().await.is_some() {}
tracing::debug!("all in-flight steps completed");
let mut pending_events = Vec::new();
while let Ok(envelope) = event_rx.try_recv() {
let serialized = SerializedEvent {
event_type: envelope.event.event_type_id().to_owned(),
data: envelope.event.to_json(),
source_step: envelope.source_step,
};
pending_events.push(serialized);
}
tracing::debug!(
pending_count = pending_events.len(),
"drained pending events"
);
let context_state = ctx.snapshot_state().await;
let collected_events = ctx.snapshot_collected().await;
let metadata = ctx.snapshot_metadata().await;
let snapshot = WorkflowSnapshot {
workflow_name: workflow_name.to_owned(),
run_id,
timestamp: Utc::now(),
context_state,
collected_events,
pending_events,
metadata,
#[cfg(feature = "telemetry")]
history: Vec::new(),
};
let _ = snapshot_tx.send(snapshot);
let _ = result_tx.send(Err(WorkflowError::Paused));
}
#[allow(clippy::too_many_arguments)]
async fn handle_input_pause(
in_flight: &mut JoinSet<()>,
event_rx: &mut mpsc::UnboundedReceiver<EventEnvelope>,
ctx: &Context,
result_tx: oneshot::Sender<Result<Box<dyn AnyEvent>, WorkflowError>>,
snapshot_tx: oneshot::Sender<WorkflowSnapshot>,
workflow_name: &str,
run_id: Uuid,
request: &InputRequestEvent,
) {
tracing::info!(
request_id = %request.request_id,
"input requested -- pausing for human input"
);
while in_flight.join_next().await.is_some() {}
let mut pending_events = Vec::new();
while let Ok(envelope) = event_rx.try_recv() {
let serialized = SerializedEvent {
event_type: envelope.event.event_type_id().to_owned(),
data: envelope.event.to_json(),
source_step: envelope.source_step,
};
pending_events.push(serialized);
}
ctx.set_metadata(
"__input_request",
serde_json::to_value(request).expect("InputRequestEvent serialization should never fail"),
)
.await;
let context_state = ctx.snapshot_state().await;
let collected_events = ctx.snapshot_collected().await;
let metadata = ctx.snapshot_metadata().await;
let snapshot = WorkflowSnapshot {
workflow_name: workflow_name.to_owned(),
run_id,
timestamp: Utc::now(),
context_state,
collected_events,
pending_events,
metadata,
#[cfg(feature = "telemetry")]
history: Vec::new(),
};
let _ = snapshot_tx.send(snapshot);
let _ = result_tx.send(Err(WorkflowError::InputRequired {
request_id: request.request_id.clone(),
prompt: request.prompt.clone(),
metadata: request.metadata.clone(),
}));
}
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
fn dispatch_to_handlers(
handlers: &[StepRegistration],
event: &dyn AnyEvent,
ctx: &Context,
event_tx: &mpsc::UnboundedSender<EventEnvelope>,
error_tx: &mpsc::UnboundedSender<WorkflowError>,
in_flight: &mut JoinSet<()>,
in_flight_count: &Arc<AtomicUsize>,
auto_publish_events: bool,
#[cfg(feature = "telemetry")] history_tx: Option<
&mpsc::UnboundedSender<blazen_telemetry::HistoryEvent>,
>,
) {
for step in handlers {
let event_clone = event.clone_boxed();
let ctx_clone = ctx.clone();
let handler = step.handler.clone();
let step_name = step.name.clone();
let event_tx_clone = event_tx.clone();
let error_tx_clone = error_tx.clone();
let counter = Arc::clone(in_flight_count);
let event_type = event.event_type_id().to_owned();
#[cfg(feature = "telemetry")]
let htx = history_tx.cloned();
#[cfg(feature = "telemetry")]
if let Some(ref tx) = htx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::StepDispatched {
step_name: step_name.clone(),
event_type: event_type.clone(),
},
});
}
let stream_ctx = if auto_publish_events {
Some(ctx.clone())
} else {
None
};
counter.fetch_add(1, Ordering::Relaxed);
let step_span = tracing::info_span!(
"workflow.step",
step_name = %step_name,
event_type = %event_type,
otel.status_code = tracing::field::Empty,
duration_ms = tracing::field::Empty,
);
let step_span_clone = step_span.clone();
in_flight.spawn(
async move {
if let Some(ref sctx) = stream_ctx {
let mut data = serde_json::Map::new();
data.insert(
"kind".into(),
serde_json::Value::String("step_started".into()),
);
data.insert(
"step_name".into(),
serde_json::Value::String(step_name.clone()),
);
data.insert(
"event_type".into(),
serde_json::Value::String(event_type.clone()),
);
sctx.write_event_to_stream(DynamicEvent {
event_type: "blazen::lifecycle".to_owned(),
data: serde_json::Value::Object(data),
})
.await;
}
let start = Instant::now();
match handler(event_clone, ctx_clone).await {
Ok(StepOutput::Single(output_event)) => {
let duration =
u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
step_span_clone.record("duration_ms", duration);
step_span_clone.record("otel.status_code", "OK");
#[cfg(feature = "telemetry")]
if let Some(ref tx) = htx {
let output_type = output_event.event_type_id().to_owned();
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::StepCompleted {
step_name: step_name.clone(),
duration_ms: duration,
output_type,
},
});
}
if let Some(ref sctx) = stream_ctx {
let mut data = serde_json::Map::new();
data.insert(
"kind".into(),
serde_json::Value::String("step_completed".into()),
);
data.insert(
"step_name".into(),
serde_json::Value::String(step_name.clone()),
);
data.insert(
"duration_ms".into(),
serde_json::Value::Number(duration.into()),
);
sctx.write_event_to_stream(DynamicEvent {
event_type: "blazen::lifecycle".to_owned(),
data: serde_json::Value::Object(data),
})
.await;
}
let envelope = EventEnvelope::new(output_event, Some(step_name));
let _ = event_tx_clone.send(envelope);
}
Ok(StepOutput::Multiple(events)) => {
let duration =
u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
step_span_clone.record("duration_ms", duration);
step_span_clone.record("otel.status_code", "OK");
#[cfg(feature = "telemetry")]
if let Some(ref tx) = htx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::StepCompleted {
step_name: step_name.clone(),
duration_ms: duration,
output_type: "Multiple".to_owned(),
},
});
}
if let Some(ref sctx) = stream_ctx {
let mut data = serde_json::Map::new();
data.insert(
"kind".into(),
serde_json::Value::String("step_completed".into()),
);
data.insert(
"step_name".into(),
serde_json::Value::String(step_name.clone()),
);
data.insert(
"duration_ms".into(),
serde_json::Value::Number(duration.into()),
);
sctx.write_event_to_stream(DynamicEvent {
event_type: "blazen::lifecycle".to_owned(),
data: serde_json::Value::Object(data),
})
.await;
}
for e in events {
let envelope = EventEnvelope::new(e, Some(step_name.clone()));
let _ = event_tx_clone.send(envelope);
}
}
Ok(StepOutput::None) => {
let duration =
u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
step_span_clone.record("duration_ms", duration);
step_span_clone.record("otel.status_code", "OK");
#[cfg(feature = "telemetry")]
if let Some(ref tx) = htx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::StepCompleted {
step_name: step_name.clone(),
duration_ms: duration,
output_type: "None".to_owned(),
},
});
}
if let Some(ref sctx) = stream_ctx {
let mut data = serde_json::Map::new();
data.insert(
"kind".into(),
serde_json::Value::String("step_completed".into()),
);
data.insert(
"step_name".into(),
serde_json::Value::String(step_name.clone()),
);
data.insert(
"duration_ms".into(),
serde_json::Value::Number(duration.into()),
);
sctx.write_event_to_stream(DynamicEvent {
event_type: "blazen::lifecycle".to_owned(),
data: serde_json::Value::Object(data),
})
.await;
}
}
Err(err) => {
let duration =
u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
step_span_clone.record("duration_ms", duration);
step_span_clone.record("otel.status_code", "ERROR");
let err_str = err.to_string();
#[cfg(feature = "telemetry")]
if let Some(ref tx) = htx {
let _ = tx.send(blazen_telemetry::HistoryEvent {
timestamp: Utc::now(),
sequence: 0,
kind: blazen_telemetry::HistoryEventKind::StepFailed {
step_name: step_name.clone(),
error: err_str.clone(),
duration_ms: duration,
},
});
}
if let Some(ref sctx) = stream_ctx {
let mut data = serde_json::Map::new();
data.insert(
"kind".into(),
serde_json::Value::String("step_failed".into()),
);
data.insert(
"step_name".into(),
serde_json::Value::String(step_name.clone()),
);
data.insert(
"duration_ms".into(),
serde_json::Value::Number(duration.into()),
);
data.insert("error".into(), serde_json::Value::String(err_str));
sctx.write_event_to_stream(DynamicEvent {
event_type: "blazen::lifecycle".to_owned(),
data: serde_json::Value::Object(data),
})
.await;
}
tracing::error!(
step = %step_name,
error = %err,
"step failed"
);
let _ = error_tx_clone.send(WorkflowError::StepFailed {
step_name,
source: Box::new(err),
});
}
}
counter.fetch_sub(1, Ordering::Relaxed);
}
.instrument(step_span),
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use blazen_events::{StartEvent, StopEvent};
use std::sync::Arc;
use crate::step::{StepFn, StepOutput, StepRegistration};
fn echo_step() -> StepRegistration {
let handler: StepFn = Arc::new(|event, _ctx| {
Box::pin(async move {
let start = event
.as_any()
.downcast_ref::<StartEvent>()
.expect("expected StartEvent");
let stop = StopEvent {
result: start.data.clone(),
};
Ok(StepOutput::Single(Box::new(stop)))
})
});
StepRegistration {
name: "echo".into(),
accepts: vec![StartEvent::event_type()],
emits: vec![StopEvent::event_type()],
handler,
max_concurrency: 0,
}
}
#[tokio::test]
async fn simple_start_to_stop() {
let workflow = WorkflowBuilder::new("test")
.step(echo_step())
.build()
.unwrap();
let handler = workflow
.run(serde_json::json!({"hello": "world"}))
.await
.unwrap();
let result = handler.result().await.unwrap();
assert_eq!(result.event_type_id(), StopEvent::event_type());
let stop = result.downcast_ref::<StopEvent>().unwrap();
assert_eq!(stop.result, serde_json::json!({"hello": "world"}));
}
#[tokio::test]
async fn empty_workflow_fails_validation() {
let result = WorkflowBuilder::new("empty").build();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, WorkflowError::ValidationFailed(_)));
}
#[tokio::test]
async fn timeout_triggers() {
let handler: StepFn = Arc::new(|_event, _ctx| {
Box::pin(async move {
tokio::time::sleep(Duration::from_secs(3600)).await;
Ok(StepOutput::None)
})
});
let step = StepRegistration {
name: "slow".into(),
accepts: vec![StartEvent::event_type()],
emits: vec![],
handler,
max_concurrency: 0,
};
let workflow = WorkflowBuilder::new("timeout-test")
.step(step)
.timeout(Duration::from_millis(50))
.build()
.unwrap();
let wf_handler = workflow.run(serde_json::json!(null)).await.unwrap();
let result = wf_handler.result().await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), WorkflowError::Timeout { .. }));
}
#[tokio::test]
async fn step_error_propagates() {
let handler: StepFn = Arc::new(|_event, _ctx| {
Box::pin(async move { Err(WorkflowError::Context("test error".into())) })
});
let step = StepRegistration {
name: "failing".into(),
accepts: vec![StartEvent::event_type()],
emits: vec![],
handler,
max_concurrency: 0,
};
let workflow = WorkflowBuilder::new("error-test")
.step(step)
.build()
.unwrap();
let wf_handler = workflow.run(serde_json::json!(null)).await.unwrap();
let result = wf_handler.result().await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
WorkflowError::StepFailed { .. }
));
}
}