weavegraph 0.3.0

Graph-driven, concurrent agent workflow framework with versioned state, deterministic barrier merges, and rich diagnostics.
Documentation
use std::{convert::Infallible, sync::Arc, time::Duration};

use async_stream::stream;
use async_trait::async_trait;
use axum::{
    Router,
    extract::State,
    response::sse::{Event as SseEvent, Sse},
    routing::get,
};
use bytes::Bytes;
use futures_util::StreamExt;
use reqwest::Client;
use serde_json::json;
use tokio::{net::TcpListener, time::timeout};
use weavegraph::event_bus::STREAM_END_SCOPE;
use weavegraph::graphs::GraphBuilder;
use weavegraph::node::{Node, NodeContext, NodeError, NodePartial};
use weavegraph::state::{StateSnapshot, VersionedState};
use weavegraph::types::NodeKind;

struct TestNode;

#[async_trait]
impl Node for TestNode {
    async fn run(
        &self,
        snapshot: StateSnapshot,
        ctx: NodeContext,
    ) -> Result<NodePartial, NodeError> {
        if let Some(msg) = snapshot.messages.first() {
            ctx.emit("sse", format!("processing: {}", msg.content))?;
        }
        ctx.emit("sse", "stream-complete")?;
        Ok(NodePartial::new().with_messages(vec![]))
    }
}

async fn handler(
    State(app): State<Arc<weavegraph::app::App>>,
) -> Sse<impl futures_util::Stream<Item = Result<SseEvent, Infallible>>> {
    let (invocation, events) = app
        .invoke_streaming(VersionedState::new_with_user_message("ping"))
        .await;

    tokio::spawn(async move {
        if let Err(err) = invocation.join().await {
            tracing::error!("test workflow failed: {err:?}");
        }
    });

    let mut stream = events.into_async_stream();
    let sse_stream = stream! {
        while let Some(event) = stream.next().await {
            let payload = json!({
                "scope": event.scope_label(),
                "message": event.message(),
            });
            yield Ok(SseEvent::default().json_data(payload).unwrap());
            if event.scope_label() == Some(STREAM_END_SCOPE) {
                break;
            }
        }
    };

    Sse::new(sse_stream)
}

#[tokio::test(flavor = "multi_thread")]
#[ignore]
async fn axum_sse_example_streams_until_completion() -> Result<(), Box<dyn std::error::Error>> {
    let app = Arc::new(
        GraphBuilder::new()
            .add_node(NodeKind::Custom("test".into()), TestNode)
            .add_edge(NodeKind::Start, NodeKind::Custom("test".into()))
            .add_edge(NodeKind::Custom("test".into()), NodeKind::End)
            .compile()?,
    );

    let router = Router::new().route("/stream", get(handler)).with_state(app);
    let listener = TcpListener::bind("127.0.0.1:0").await?;
    let addr = listener.local_addr()?;

    let server = tokio::spawn(async move {
        if let Err(err) = axum::serve(listener, router.into_make_service()).await {
            tracing::error!("axum server error: {err:?}");
        }
    });

    let client = Client::builder().build()?;
    let response = client.get(format!("http://{addr}/stream")).send().await?;
    let mut body = response.bytes_stream();
    let mut saw_end = false;

    while let Some(chunk_result) = timeout(Duration::from_secs(1), body.next()).await? {
        let chunk: Bytes = chunk_result?;
        let text = String::from_utf8_lossy(&chunk);
        if text.contains(STREAM_END_SCOPE) {
            saw_end = true;
            break;
        }
    }

    assert!(saw_end, "stream should include termination event");

    server.abort();
    Ok(())
}