use std::{convert::Infallible, sync::Arc, time::Duration};
use async_trait::async_trait;
use axum::{
Router,
extract::{Query, State},
response::{
IntoResponse,
sse::{Event as SseEvent, KeepAlive, Sse},
},
routing::get,
};
use futures_util::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use tracing::{error, info, warn};
use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
use weavegraph::{
app::{App, InvocationHandle},
channels::Channel,
event_bus::{Event, EventStream, STREAM_END_SCOPE},
graphs::GraphBuilder,
message::{Message, Role},
node::{Node, NodeContext, NodeError, NodePartial, NodeResultExt},
runtimes::{EventBusConfig, PostgresCheckpointer, RuntimeConfig},
state::{StateSnapshot, VersionedState},
types::NodeKind,
};
type BoxError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Clone)]
struct LlmNode;
#[async_trait]
impl Node for LlmNode {
async fn run(
&self,
snapshot: StateSnapshot,
ctx: NodeContext,
) -> Result<NodePartial, NodeError> {
let prompt = snapshot
.messages
.last()
.map(|m| m.content.as_str())
.unwrap_or("(no input)");
let tokens = [
"Hello",
", ",
"I",
" am",
" a",
" streaming",
" assistant",
"!",
];
for token in tokens {
ctx.emit("llm.token", format!("Response to '{}': {}", prompt, token))?;
tokio::time::sleep(Duration::from_millis(150)).await;
}
Ok(NodePartial::new().with_messages(vec![Message::with_role(
Role::Assistant,
&format!("Response to '{}'", prompt),
)]))
}
}
#[derive(Clone)]
struct ValidateNode;
#[async_trait]
impl Node for ValidateNode {
async fn run(
&self,
snapshot: StateSnapshot,
_ctx: NodeContext,
) -> Result<NodePartial, NodeError> {
let prompt = snapshot
.messages
.last()
.map(|m| m.content.as_str())
.unwrap_or("");
if prompt.trim().is_empty() {
return Err(NodeError::Other("prompt must not be empty".into()));
}
if prompt.len() > 4096 {
return Err(NodeError::Other(
format!("prompt too long: {} chars (max 4096)", prompt.len()).into(),
));
}
let _validated = std::str::from_utf8(prompt.as_bytes()).node_err()?;
Ok(NodePartial::new())
}
}
#[derive(Clone)]
struct AppState {
app: Arc<App>,
}
#[derive(Debug, Deserialize)]
struct RunQuery {
#[serde(default = "default_prompt")]
prompt: String,
}
fn default_prompt() -> String {
"Hello, weavegraph!".to_string()
}
async fn run_handler(
State(state): State<AppState>,
Query(query): Query<RunQuery>,
) -> impl IntoResponse {
info!(prompt = %query.prompt, "starting workflow invocation");
let initial_state = VersionedState::new_with_user_message(&query.prompt);
let (handle, event_stream) = state.app.invoke_streaming(initial_state).await;
let sse_stream = build_sse_stream(handle, event_stream);
Sse::new(sse_stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
}
fn build_sse_stream(
handle: InvocationHandle,
event_stream: EventStream,
) -> impl Stream<Item = Result<SseEvent, Infallible>> {
let handle = Arc::new(tokio::sync::Mutex::new(Some(handle)));
let stream = event_stream.into_async_stream().map(move |event| {
let is_end = event
.scope_label()
.map(|s| s == STREAM_END_SCOPE)
.unwrap_or(false);
let payload = serde_json::to_string(&SsePayload::from(&event))
.unwrap_or_else(|_| r#"{"error":"serialization failed"}"#.to_string());
let sse = SseEvent::default().data(payload);
(sse, is_end)
});
futures_util::stream::unfold(
(stream.boxed(), false, handle),
move |(mut stream, done, handle)| async move {
if done {
if let Some(h) = handle.lock().await.take() {
match h.join().await {
Ok(state) => info!(
messages = state.messages.len(),
"workflow completed successfully"
),
Err(e) => warn!(error = %e, "workflow ended with error"),
}
}
return None;
}
match stream.next().await {
Some((sse, is_end)) => Some((Ok(sse), (stream, is_end, handle))),
None => {
error!("event stream closed without STREAM_END_SCOPE");
None
}
}
},
)
}
#[derive(Debug, Serialize)]
struct SsePayload {
kind: &'static str,
message: String,
scope: Option<String>,
}
impl From<&Event> for SsePayload {
fn from(event: &Event) -> Self {
Self {
kind: match event {
Event::Node(_) => "node",
Event::Diagnostic(_) => "diagnostic",
Event::LLM(_) => "llm",
},
message: event.message().to_string(),
scope: event.scope_label().map(str::to_string),
}
}
}
async fn healthz() -> &'static str {
"ok"
}
async fn build_app() -> Result<App, BoxError> {
dotenvy::dotenv().ok();
let db_url = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost/weavegraph".to_string());
let pg = PostgresCheckpointer::connect(&db_url).await?;
let runtime_config = RuntimeConfig::new(None, None)
.checkpointer_custom(Arc::new(pg))
.with_event_bus(EventBusConfig::with_stdout_only());
let app = GraphBuilder::new()
.add_node(NodeKind::Custom("validate".into()), ValidateNode)
.add_node(NodeKind::Custom("llm".into()), LlmNode)
.add_edge(NodeKind::Start, NodeKind::Custom("validate".into()))
.add_edge(
NodeKind::Custom("validate".into()),
NodeKind::Custom("llm".into()),
)
.add_edge(NodeKind::Custom("llm".into()), NodeKind::End)
.with_runtime_config(runtime_config)
.compile()?;
info!(db_url = %db_url, "graph compiled with postgres checkpointing");
Ok(app)
}
#[tokio::main]
async fn main() -> Result<(), BoxError> {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer())
.with(EnvFilter::from_default_env().add_directive("info".parse().unwrap()))
.init();
let app = build_app().await?;
let state = AppState { app: Arc::new(app) };
let router = Router::new()
.route("/run", get(run_handler))
.route("/healthz", get(healthz))
.with_state(state);
let addr = "0.0.0.0:3000";
info!(addr, "production_streaming server listening");
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, router).await?;
Ok(())
}