use async_trait::async_trait;
use futures_util::StreamExt;
use rustc_hash::FxHashMap;
use serde_json::Value;
use weavegraph::channels::Channel;
use weavegraph::event_bus::STREAM_END_SCOPE;
use weavegraph::graphs::GraphBuilder;
use weavegraph::message::{Message, Role};
use weavegraph::node::{Node, NodeContext, NodeError, NodePartial};
use weavegraph::state::VersionedState;
use weavegraph::types::NodeKind;
mod common;
use common::*;
fn make_app() -> weavegraph::app::App {
GraphBuilder::new()
.add_edge(NodeKind::Start, NodeKind::End)
.compile()
.unwrap()
}
fn message_app() -> weavegraph::app::App {
GraphBuilder::new()
.add_node(
NodeKind::Custom("node".into()),
SimpleMessageNode::new("response"),
)
.add_edge(NodeKind::Start, NodeKind::Custom("node".into()))
.add_edge(NodeKind::Custom("node".into()), NodeKind::End)
.compile()
.unwrap()
}
#[tokio::test]
async fn apply_barrier_appends_messages_and_bumps_version() {
let app = make_app();
let state = &mut state_with_user("hi");
let partial =
NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "foo")]);
let outcome = app
.apply_barrier(state, &[NodeKind::Start], vec![partial])
.await
.unwrap();
assert!(outcome.updated_channels.contains(&"messages"));
assert!(outcome.errors.is_empty());
assert_eq!(state.messages.snapshot().last().unwrap().content, "foo");
assert_eq!(state.messages.version(), 2);
assert_eq!(state.extra.version(), 1);
}
#[tokio::test]
async fn apply_barrier_with_empty_partial_changes_nothing() {
let app = make_app();
let state = &mut state_with_user("hi");
let outcome = app
.apply_barrier(state, &[NodeKind::Start], vec![NodePartial::new()])
.await
.unwrap();
assert!(outcome.updated_channels.is_empty());
assert!(outcome.errors.is_empty());
assert_eq!(state.messages.version(), 1);
assert_eq!(state.extra.version(), 1);
}
#[tokio::test]
async fn apply_barrier_version_saturates_at_max() {
let app = make_app();
let state = &mut state_with_user("hi");
state.messages.set_version(u32::MAX);
let partial = NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "x")]);
app.apply_barrier(state, &[NodeKind::Start], vec![partial])
.await
.unwrap();
assert_eq!(state.messages.version(), u32::MAX);
}
#[tokio::test]
async fn apply_barrier_channel_order_is_messages_then_extra() {
use weavegraph::channels::errors::{ErrorEvent, ErrorScope};
let app = make_app();
let state = &mut state_with_user("hi");
let partial_a =
NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "a")]);
let partial_b = NodePartial::new().with_extra({
let mut map = FxHashMap::default();
map.insert("z".into(), Value::String("1".into()));
map.insert("a".into(), Value::String("2".into()));
map
});
let err_event = ErrorEvent {
scope: ErrorScope::Node {
kind: "anode".into(),
step: 2,
},
when: chrono::Utc::now(),
..Default::default()
};
let partial_c = NodePartial::new().with_errors(vec![err_event.clone()]);
let outcome = app
.apply_barrier(
state,
&[NodeKind::Start],
vec![partial_a, partial_b, partial_c],
)
.await
.unwrap();
assert_eq!(outcome.updated_channels, vec!["messages", "extra"]);
assert_eq!(outcome.errors, vec![err_event]);
assert_eq!(state.messages.version(), 2);
assert_eq!(state.extra.version(), 2);
let mut keys: Vec<_> = state.extra.snapshot().keys().cloned().collect();
keys.sort();
assert_eq!(keys, vec!["a".to_string(), "z".to_string()]);
}
struct EmitOnce;
#[async_trait]
impl Node for EmitOnce {
async fn run(
&self,
_snapshot: weavegraph::state::StateSnapshot,
ctx: NodeContext,
) -> Result<NodePartial, NodeError> {
ctx.emit("test", "event").unwrap();
Ok(NodePartial::default())
}
}
#[tokio::test]
async fn invoke_streaming_closes_stream() {
let app = GraphBuilder::new()
.add_node(NodeKind::Custom("emit".into()), EmitOnce)
.add_edge(NodeKind::Start, NodeKind::Custom("emit".into()))
.add_edge(NodeKind::Custom("emit".into()), NodeKind::End)
.compile()
.unwrap();
let initial = VersionedState::new_with_user_message("hello");
let (invocation, events) = app.invoke_streaming(initial).await;
let mut stream = events.into_async_stream();
let mut seen_non_terminal = 0;
let mut sentinel_seen = false;
while let Some(event) = stream.next().await {
if event.scope_label() == Some(STREAM_END_SCOPE) {
assert!(
!sentinel_seen,
"STREAM_END_SCOPE should appear exactly once"
);
sentinel_seen = true;
} else {
seen_non_terminal += 1;
}
}
assert_eq!(seen_non_terminal, 1);
assert!(sentinel_seen, "expected terminal sentinel event");
invocation.join().await.unwrap();
}
#[tokio::test]
async fn apply_barrier_from_multiple_partials_appends_all_messages() {
let app = make_app();
let state = &mut state_with_user("hi");
let p1 = NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "foo")]);
let p2 = NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "bar")]);
let outcome = app
.apply_barrier(state, &[NodeKind::Start, NodeKind::End], vec![p1, p2])
.await
.unwrap();
let snap = state.messages.snapshot();
assert!(outcome.updated_channels.contains(&"messages"));
assert_eq!(snap[snap.len() - 2].content, "foo");
assert_eq!(snap[snap.len() - 1].content, "bar");
assert_eq!(state.messages.version(), 2);
}
#[tokio::test]
async fn apply_barrier_empty_collections_are_no_ops() {
let app = make_app();
let state = &mut state_with_user("hi");
let outcome = app
.apply_barrier(
state,
&[NodeKind::Start, NodeKind::End],
vec![
NodePartial::new().with_messages(vec![]),
NodePartial::new().with_extra(FxHashMap::default()),
],
)
.await
.unwrap();
assert!(outcome.updated_channels.is_empty());
assert_eq!(state.messages.version(), 1);
assert_eq!(state.extra.version(), 1);
}
#[tokio::test]
async fn apply_barrier_extra_partials_merge_and_later_key_wins() {
let app = make_app();
let state = &mut state_with_user("hi");
let mut m1 = FxHashMap::default();
m1.insert("k1".into(), Value::String("v1".into()));
let mut m2 = FxHashMap::default();
m2.insert("k2".into(), Value::String("v2".into()));
m2.insert("k1".into(), Value::String("v3".into()));
let outcome = app
.apply_barrier(
state,
&[NodeKind::Start, NodeKind::End],
vec![
NodePartial::new().with_extra(m1),
NodePartial::new().with_extra(m2),
],
)
.await
.unwrap();
assert!(outcome.updated_channels.contains(&"extra"));
let snap = state.extra.snapshot();
assert_eq!(snap.get("k1"), Some(&Value::String("v3".into())));
assert_eq!(snap.get("k2"), Some(&Value::String("v2".into())));
assert_eq!(state.extra.version(), 2);
}
#[tokio::test]
async fn apply_barrier_error_partials_appear_in_outcome() {
use weavegraph::channels::errors::ErrorEvent;
let app = make_app();
let state = &mut state_with_user("hi");
let partial = NodePartial::new().with_errors(vec![ErrorEvent::default()]);
let outcome = app
.apply_barrier(state, &[NodeKind::Start], vec![partial])
.await
.unwrap();
assert!(outcome.updated_channels.is_empty());
assert_eq!(outcome.errors.len(), 1);
}
#[tokio::test]
async fn invoke_with_channel_returns_completed_state() {
let app = message_app();
let (result, _events) = app.invoke_with_channel(state_with_user("prompt")).await;
let final_state = result.expect("workflow completes");
assert_eq!(
final_state.messages.snapshot().last().unwrap().content,
"response"
);
}
#[tokio::test]
async fn invoke_with_channel_second_run_increments_message_version() {
let app = message_app();
let (result, _events) = app
.invoke_with_channel(VersionedState::new_with_user_message("first"))
.await;
let after_first = result.expect("first run succeeds");
assert_eq!(after_first.messages.version(), 2);
let (result, _events) = app.invoke_with_channel(after_first.clone()).await;
let after_second = result.expect("second run succeeds");
assert_eq!(after_second.messages.version(), 3);
assert_eq!(after_second.extra.version(), after_first.extra.version());
}
#[tokio::test]
async fn invoke_streaming_delivers_node_events_before_close() {
use weavegraph::event_bus::Event;
let app = GraphBuilder::new()
.add_node(NodeKind::Custom("emitter".into()), EmitterNode)
.add_edge(NodeKind::Start, NodeKind::Custom("emitter".into()))
.add_edge(NodeKind::Custom("emitter".into()), NodeKind::End)
.compile()
.unwrap();
let (invocation, events) = app.invoke_streaming(state_with_user("go")).await;
let mut stream = events.into_async_stream();
let mut node_events = 0;
let mut stream_closed = false;
while let Some(event) = stream.next().await {
if event.scope_label() == Some(STREAM_END_SCOPE) {
stream_closed = true;
} else if matches!(event, Event::Node(_)) {
node_events += 1;
}
}
assert!(stream_closed, "stream must close with STREAM_END_SCOPE");
assert!(
node_events > 0,
"node events must arrive before stream closes"
);
invocation.join().await.unwrap();
}
#[tokio::test]
async fn invoke_with_sinks_completes_workflow() {
use weavegraph::event_bus::MemorySink;
let app = message_app();
let final_state = app
.invoke_with_sinks(state_with_user("prompt"), vec![Box::new(MemorySink::new())])
.await
.expect("workflow completes");
assert_eq!(
final_state.messages.snapshot().last().unwrap().content,
"response"
);
}
#[tokio::test]
async fn invoke_with_sinks_accepts_multiple_sink_types() {
use weavegraph::event_bus::{ChannelSink, MemorySink, StdOutSink};
let app = message_app();
let (tx, _rx) = flume::unbounded();
let final_state = app
.invoke_with_sinks(
state_with_user("prompt"),
vec![
Box::new(StdOutSink::default()),
Box::new(ChannelSink::new(tx)),
Box::new(MemorySink::new()),
],
)
.await
.expect("workflow completes with multiple sinks");
assert_eq!(
final_state.messages.snapshot().last().unwrap().content,
"response"
);
}