use async_trait::async_trait;
use futures_util::StreamExt as _;
use serde_json::json;
use weavegraph::{
channels::Channel as _,
event_bus::{Event, STREAM_END_SCOPE},
graphs::GraphBuilder,
message::{Message, Role},
node::{Node, NodeContext, NodeError, NodePartial},
state::{StateSnapshot, VersionedState},
types::NodeKind,
};
use tracing::info;
use tracing_error::ErrorLayer;
use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
type ExampleResult<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
fn event_kind_label(event: &Event) -> &'static str {
match event {
Event::Node(_) => "node",
Event::Diagnostic(_) => "diagnostic",
Event::LLM(_) => "llm",
}
}
fn event_payload(event: &Event) -> serde_json::Value {
json!({
"kind": event_kind_label(event),
"scope": event.scope_label().unwrap_or("workflow"),
"message": event.message(),
"observed_at": chrono::Utc::now().to_rfc3339(),
})
}
fn init_tracing() {
tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer()
.with_target(false)
.with_thread_ids(false)
.with_thread_names(false)
.compact(),
)
.with(
EnvFilter::from_default_env()
.add_directive("weavegraph=info".parse().unwrap())
.add_directive("streaming_events=info".parse().unwrap()),
)
.with(ErrorLayer::default())
.init();
}
struct ProcessingNode;
#[async_trait]
impl Node for ProcessingNode {
async fn run(
&self,
snapshot: StateSnapshot,
ctx: NodeContext,
) -> Result<NodePartial, NodeError> {
let query = snapshot
.messages
.first()
.map(|m| m.content.as_str())
.unwrap_or("default query");
ctx.emit("processing", format!("Starting to process: {}", query))?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
ctx.emit("processing", "Step 1/3: Analyzing input")?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
ctx.emit("processing", "Step 2/3: Computing result")?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
ctx.emit("processing", "Step 3/3: Formatting output")?;
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
ctx.emit("processing", "Processing complete!")?;
Ok(NodePartial::new().with_messages(vec![Message::with_role(
Role::Assistant,
"Processing finished successfully.",
)]))
}
}
#[tokio::main]
async fn main() -> ExampleResult<()> {
init_tracing();
info!("=== Streaming workflow events ===");
let worker_id = NodeKind::from("processor-stream");
let stream_builder = GraphBuilder::new()
.add_node(worker_id.clone(), ProcessingNode)
.add_edge(NodeKind::Start, worker_id.clone())
.add_edge(worker_id, NodeKind::End);
let app = stream_builder.compile()?;
let seed_state = VersionedState::new_with_user_message("Process dashboard data");
let (run_handle, live_events) = app.invoke_streaming(seed_state).await;
info!("๐ก Forwarding event payloads as structured JSON");
let collect_events = async move {
let mut stream = live_events.into_async_stream();
let mut seen = 0usize;
while let Some(next_event) = stream.next().await {
seen = seen.saturating_add(1);
let payload = event_payload(&next_event);
info!("๐จ {}", serde_json::to_string_pretty(&payload)?);
if matches!(next_event.scope_label(), Some(scope) if scope == STREAM_END_SCOPE) {
info!("โ
saw the terminal stream marker");
return Ok::<usize, serde_json::Error>(seen);
}
}
Ok(seen)
};
let collector = tokio::spawn(collect_events);
let completed_state = run_handle.join().await?;
let delivered = collector
.await
.map_err(|join_err| std::io::Error::other(join_err.to_string()))??;
let final_messages = completed_state.messages.snapshot();
info!("๐งพ Final state kept {} message(s)", final_messages.len());
info!("๐ Stream delivered {delivered} event(s)");
info!("๐ก Pair this with Axum SSE, or swap to invoke_with_channel for flume consumers.");
Ok(())
}