use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use serde_json::json;
use crate::clock::system_clock;
use crate::frames::{Frame, FrameKind};
use crate::pipeline::{PipelineParams, PipelineTask};
use super::base::{Agent, BaseAgent, TaskHandler, TaskRequestCtx};
use super::bus::{AgentBus, BusMessage, BusPayload, LocalAgentBus, TaskStatus};
use super::edges::BusOutputEdge;
use super::runner::AgentRunner;
fn idle_task() -> PipelineTask {
PipelineTask::new(vec![], PipelineParams::default())
}
async fn start_and_wait(runner: &Arc<AgentRunner>, names: &[&str]) -> tokio::task::JoinHandle<()> {
let r = runner.clone();
let handle = tokio::spawn(async move {
let _ = r.run().await;
});
let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
'outer: loop {
assert!(
tokio::time::Instant::now() < deadline,
"agents never became ready"
);
for name in names {
if runner.registry().get(name).await.is_none() {
tokio::time::sleep(Duration::from_millis(10)).await;
continue 'outer;
}
}
break;
}
handle
}
fn echo_handler() -> TaskHandler {
Arc::new(|ctx: TaskRequestCtx| {
Box::pin(async move {
let payload = ctx.payload.clone();
ctx.complete(TaskStatus::Completed, payload).await;
})
})
}
#[tokio::test]
async fn task_round_trip() {
let a = Arc::new(BaseAgent::new("a", idle_task(), None, true));
let b = Arc::new(BaseAgent::new("b", idle_task(), None, true).on_task("echo", echo_handler()));
let bus = Arc::new(LocalAgentBus::new());
let runner = Arc::new(AgentRunner::new("r", bus, system_clock()));
runner.add_agent(a.clone()).await.unwrap();
runner.add_agent(b).await.unwrap();
let run = start_and_wait(&runner, &["a", "b"]).await;
let ctx = a.task_ctx().unwrap();
let handle = ctx
.dispatch("a", "b", Some("echo".into()), Some(json!({"x": 42})))
.await
.unwrap();
let result = handle
.await_completion(Some(Duration::from_secs(2)))
.await
.unwrap();
assert_eq!(result.status, TaskStatus::Completed);
assert_eq!(result.response, Some(json!({"x": 42})));
runner.end(None).await;
let _ = tokio::time::timeout(Duration::from_secs(5), run).await;
}
#[tokio::test]
async fn task_streaming() {
let streamer: TaskHandler = Arc::new(|ctx: TaskRequestCtx| {
Box::pin(async move {
ctx.stream_start(Some(json!("start"))).await;
for i in 0..3 {
ctx.stream_data(Some(json!(i))).await;
}
ctx.stream_end(Some(json!("end"))).await;
ctx.complete(TaskStatus::Completed, Some(json!("done")))
.await;
})
});
let a = Arc::new(BaseAgent::new("a", idle_task(), None, true));
let b = Arc::new(BaseAgent::new("b", idle_task(), None, true).on_task("stream", streamer));
let bus = Arc::new(LocalAgentBus::new());
let runner = Arc::new(AgentRunner::new("r", bus, system_clock()));
runner.add_agent(a.clone()).await.unwrap();
runner.add_agent(b).await.unwrap();
let run = start_and_wait(&runner, &["a", "b"]).await;
let ctx = a.task_ctx().unwrap();
let handle = ctx
.dispatch("a", "b", Some("stream".into()), None)
.await
.unwrap();
let (updates, result) = handle
.stream_updates(Some(Duration::from_secs(2)))
.await
.unwrap();
assert_eq!(updates.len(), 5, "expected 5 stream updates: {updates:?}");
assert_eq!(result.status, TaskStatus::Completed);
assert_eq!(result.response, Some(json!("done")));
runner.end(None).await;
let _ = tokio::time::timeout(Duration::from_secs(5), run).await;
}
#[tokio::test]
async fn task_cancel() {
let sleepy: TaskHandler = Arc::new(|ctx: TaskRequestCtx| {
Box::pin(async move {
tokio::time::sleep(Duration::from_secs(30)).await;
ctx.complete(TaskStatus::Completed, None).await;
})
});
let a = Arc::new(BaseAgent::new("a", idle_task(), None, true));
let b = Arc::new(BaseAgent::new("b", idle_task(), None, true).on_task("sleep", sleepy));
let bus = Arc::new(LocalAgentBus::new());
let runner = Arc::new(AgentRunner::new("r", bus, system_clock()));
runner.add_agent(a.clone()).await.unwrap();
runner.add_agent(b).await.unwrap();
let run = start_and_wait(&runner, &["a", "b"]).await;
let ctx = a.task_ctx().unwrap();
let handle = ctx
.dispatch("a", "b", Some("sleep".into()), None)
.await
.unwrap();
let task_id = handle.task_id.clone();
tokio::time::sleep(Duration::from_millis(100)).await;
ctx.cancel_task("a", "b", task_id, Some("changed my mind".into()))
.await;
let result = handle
.await_completion(Some(Duration::from_secs(2)))
.await
.unwrap();
assert_eq!(result.status, TaskStatus::Cancelled);
runner.end(None).await;
let _ = tokio::time::timeout(Duration::from_secs(5), run).await;
}
#[tokio::test]
async fn task_no_handler_fails_fast() {
let a = Arc::new(BaseAgent::new("a", idle_task(), None, true));
let b = Arc::new(BaseAgent::new("b", idle_task(), None, true));
let bus = Arc::new(LocalAgentBus::new());
let runner = Arc::new(AgentRunner::new("r", bus, system_clock()));
runner.add_agent(a.clone()).await.unwrap();
runner.add_agent(b).await.unwrap();
let run = start_and_wait(&runner, &["a", "b"]).await;
let ctx = a.task_ctx().unwrap();
let handle = ctx
.dispatch("a", "b", Some("nonexistent".into()), None)
.await
.unwrap();
let result = handle
.await_completion(Some(Duration::from_secs(2)))
.await
.unwrap();
assert_eq!(result.status, TaskStatus::Failed);
runner.end(None).await;
let _ = tokio::time::timeout(Duration::from_secs(5), run).await;
}
#[tokio::test]
async fn task_resolves_when_executor_dies() {
let sleepy: TaskHandler = Arc::new(|ctx: TaskRequestCtx| {
Box::pin(async move {
tokio::time::sleep(Duration::from_secs(30)).await;
ctx.complete(TaskStatus::Completed, None).await;
})
});
let a = Arc::new(BaseAgent::new("a", idle_task(), None, true));
let b = Arc::new(BaseAgent::new("b", idle_task(), None, true).on_task("sleep", sleepy));
let bus = Arc::new(LocalAgentBus::new());
let runner = Arc::new(AgentRunner::new("r", bus, system_clock()));
runner.add_agent(a.clone()).await.unwrap();
runner.add_agent(b.clone()).await.unwrap();
let run = start_and_wait(&runner, &["a", "b"]).await;
let ctx = a.task_ctx().unwrap();
let handle = ctx
.dispatch("a", "b", Some("sleep".into()), None)
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
b.end(Some("shutting down".into())).await.unwrap();
let result = tokio::time::timeout(
Duration::from_secs(1),
handle.await_completion(Some(Duration::from_secs(1))),
)
.await
.expect("handle did not resolve within 1s");
if let Ok(r) = result {
assert_eq!(r.status, TaskStatus::Cancelled);
}
runner.end(None).await;
let _ = tokio::time::timeout(Duration::from_secs(5), run).await;
}
#[tokio::test]
async fn dispatch_waits_for_target_ready() {
let a = Arc::new(BaseAgent::new("a", idle_task(), None, true));
let bus = Arc::new(LocalAgentBus::new());
let runner = Arc::new(AgentRunner::new("r", bus.clone(), system_clock()));
runner.add_agent(a.clone()).await.unwrap();
let run = start_and_wait(&runner, &["a"]).await;
let ctx = a.task_ctx().unwrap();
let registry = runner.registry().clone();
let register = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(200)).await;
registry
.register(super::registry::AgentInfo {
name: "late".into(),
runner: "r".into(),
parent: None,
active: true,
bridged: false,
started_at: None,
})
.await;
});
let started = tokio::time::Instant::now();
let handle = ctx
.dispatch_with("a", "late", None, None, Some(Duration::from_secs(5)))
.await
.expect("dispatch should succeed once the target registers");
assert!(
started.elapsed() >= Duration::from_millis(150),
"dispatch returned before the target was ready"
);
drop(handle);
register.await.unwrap();
runner.end(None).await;
let _ = tokio::time::timeout(Duration::from_secs(5), run).await;
}
#[tokio::test]
async fn dispatch_times_out_when_target_never_ready() {
let a = Arc::new(BaseAgent::new("a", idle_task(), None, true));
let bus = Arc::new(LocalAgentBus::new());
let runner = Arc::new(AgentRunner::new("r", bus, system_clock()));
runner.add_agent(a.clone()).await.unwrap();
let run = start_and_wait(&runner, &["a"]).await;
let ctx = a.task_ctx().unwrap();
let result = ctx
.dispatch_with("a", "ghost", None, None, Some(Duration::from_millis(200)))
.await;
assert!(result.is_err(), "dispatch to a missing agent should error");
runner.end(None).await;
let _ = tokio::time::timeout(Duration::from_secs(5), run).await;
}
#[tokio::test]
async fn end_cascades_child_before_parent() {
let finish_order: Arc<std::sync::Mutex<Vec<&'static str>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let parent_task = idle_task();
let order = finish_order.clone();
parent_task.add_on_pipeline_finished(move |_f, _r| {
let order = order.clone();
Box::pin(async move {
order.lock().unwrap().push("parent");
})
});
let child_task = idle_task();
let order = finish_order.clone();
child_task.add_on_pipeline_finished(move |_f, _r| {
let order = order.clone();
Box::pin(async move {
order.lock().unwrap().push("child");
})
});
let parent = Arc::new(BaseAgent::new("parent", parent_task, None, true));
let child = Arc::new(BaseAgent::new("child", child_task, None, true).with_parent("parent"));
let bus = Arc::new(LocalAgentBus::new());
let runner = Arc::new(AgentRunner::new("r", bus, system_clock()));
runner.add_agent(parent).await.unwrap();
runner.add_agent(child).await.unwrap();
let run = start_and_wait(&runner, &["parent", "child"]).await;
runner.end(None).await;
tokio::time::timeout(Duration::from_secs(8), run)
.await
.expect("runner did not exit within timeout")
.unwrap();
let order = finish_order.lock().unwrap().clone();
assert_eq!(
order,
vec!["child", "parent"],
"child must finish before the parent's pipeline"
);
}
fn bridged_agent(
name: &str,
peers: Vec<String>,
exclude: HashSet<FrameKind>,
count_kind: FrameKind,
) -> (Arc<BaseAgent>, Arc<AtomicUsize>) {
let edge = BusOutputEdge::with_exclude(name, peers.clone(), exclude);
let task = PipelineTask::new(vec![edge.to_processor()], PipelineParams::default());
let mut filter = HashSet::new();
filter.insert(count_kind);
task.set_downstream_filter(filter);
let count = Arc::new(AtomicUsize::new(0));
let count_cb = count.clone();
task.add_on_frame_reached_downstream(move |_frame| {
let count = count_cb.clone();
Box::pin(async move {
count.fetch_add(1, Ordering::SeqCst);
})
});
let agent = Arc::new(BaseAgent::new(name, task, Some(peers), true).with_output_edge(edge));
(agent, count)
}
#[tokio::test]
async fn bridged_pipelines_exchange_frames() {
let (a, a_transcripts) = bridged_agent(
"a",
vec!["b".into()],
HashSet::from([FrameKind::Transcription]),
FrameKind::Transcription,
);
let (b, b_texts) = bridged_agent(
"b",
vec!["a".into()],
HashSet::from([FrameKind::LLMText]),
FrameKind::LLMText,
);
let bus = Arc::new(LocalAgentBus::new());
let runner = Arc::new(AgentRunner::new("r", bus, system_clock()));
runner.add_agent(a.clone()).await.unwrap();
runner.add_agent(b.clone()).await.unwrap();
let run = start_and_wait(&runner, &["a", "b"]).await;
a.pipeline()
.push_frame(
Frame::llm_text("hello b".into()),
crate::frames::FrameDirection::Downstream,
)
.await
.unwrap();
b.pipeline()
.push_frame(
Frame::transcription(crate::frames::TranscriptionData::new(
"hello a", "user", "now",
)),
crate::frames::FrameDirection::Downstream,
)
.await
.unwrap();
let ok = async {
let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
while tokio::time::Instant::now() < deadline {
if b_texts.load(Ordering::SeqCst) >= 1 && a_transcripts.load(Ordering::SeqCst) >= 1 {
return true;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
false
}
.await;
assert!(
ok,
"bridge failed: b_texts={}, a_transcripts={}",
b_texts.load(Ordering::SeqCst),
a_transcripts.load(Ordering::SeqCst)
);
runner.end(None).await;
let _ = tokio::time::timeout(Duration::from_secs(5), run).await;
}
#[tokio::test]
async fn edge_activation_gating() {
let (a, a_local) = bridged_agent("a", vec!["b".into()], HashSet::new(), FrameKind::LLMText);
let (b, b_texts) = bridged_agent(
"b",
vec![],
HashSet::from([FrameKind::LLMText]),
FrameKind::LLMText,
);
let bus: Arc<dyn AgentBus> = Arc::new(LocalAgentBus::new());
let runner = Arc::new(AgentRunner::new("r", bus.clone(), system_clock()));
runner.add_agent(a.clone()).await.unwrap();
runner.add_agent(b.clone()).await.unwrap();
let run = start_and_wait(&runner, &["a", "b"]).await;
bus.send(BusMessage::new(
"r",
Some("a".into()),
BusPayload::Deactivate,
))
.await;
let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
while a.active() && tokio::time::Instant::now() < deadline {
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert!(!a.active(), "Deactivate never took effect");
a.pipeline()
.push_frame(
Frame::llm_text("while inactive".into()),
crate::frames::FrameDirection::Downstream,
)
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(a_local.load(Ordering::SeqCst), 1, "local flow stalled");
assert_eq!(
b_texts.load(Ordering::SeqCst),
0,
"published while inactive"
);
bus.send(BusMessage::new(
"r",
Some("a".into()),
BusPayload::Activate { args: None },
))
.await;
let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
while !a.active() && tokio::time::Instant::now() < deadline {
tokio::time::sleep(Duration::from_millis(10)).await;
}
a.pipeline()
.push_frame(
Frame::llm_text("while active".into()),
crate::frames::FrameDirection::Downstream,
)
.await
.unwrap();
let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
while b_texts.load(Ordering::SeqCst) == 0 && tokio::time::Instant::now() < deadline {
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(b_texts.load(Ordering::SeqCst), 1, "flow did not resume");
assert_eq!(a_local.load(Ordering::SeqCst), 2);
runner.end(None).await;
let _ = tokio::time::timeout(Duration::from_secs(5), run).await;
}
#[tokio::test]
async fn edge_exclusion() {
let (a, a_local) = bridged_agent(
"a",
vec!["b".into()],
HashSet::from([FrameKind::LLMText]),
FrameKind::LLMText,
);
let (b, b_texts) = bridged_agent("b", vec![], HashSet::new(), FrameKind::LLMText);
let bus = Arc::new(LocalAgentBus::new());
let runner = Arc::new(AgentRunner::new("r", bus, system_clock()));
runner.add_agent(a.clone()).await.unwrap();
runner.add_agent(b.clone()).await.unwrap();
let run = start_and_wait(&runner, &["a", "b"]).await;
a.pipeline()
.push_frame(
Frame::llm_text("excluded".into()),
crate::frames::FrameDirection::Downstream,
)
.await
.unwrap();
let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
while a_local.load(Ordering::SeqCst) == 0 && tokio::time::Instant::now() < deadline {
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(a_local.load(Ordering::SeqCst), 1, "local forward failed");
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(
b_texts.load(Ordering::SeqCst),
0,
"excluded frame published"
);
runner.end(None).await;
let _ = tokio::time::timeout(Duration::from_secs(5), run).await;
}
#[tokio::test]
async fn bridged_peer_filter() {
let task = idle_task();
let mut filter = HashSet::new();
filter.insert(FrameKind::LLMText);
task.set_downstream_filter(filter);
let received = Arc::new(AtomicUsize::new(0));
let received_cb = received.clone();
task.add_on_frame_reached_downstream(move |_frame| {
let received = received_cb.clone();
Box::pin(async move {
received.fetch_add(1, Ordering::SeqCst);
})
});
let agent = Arc::new(BaseAgent::new(
"listener",
task,
Some(vec!["voice".to_string()]),
true,
));
let bus: Arc<dyn AgentBus> = Arc::new(LocalAgentBus::new());
let runner = Arc::new(AgentRunner::new("r", bus.clone(), system_clock()));
runner.add_agent(agent.clone()).await.unwrap();
let run = start_and_wait(&runner, &["listener"]).await;
bus.send(BusMessage::new(
"voice",
Some("listener".into()),
BusPayload::Frame {
frame: Frame::llm_text("hello".into()),
direction: crate::frames::FrameDirection::Downstream,
},
))
.await;
bus.send(BusMessage::new(
"other",
Some("listener".into()),
BusPayload::Frame {
frame: Frame::llm_text("spam".into()),
direction: crate::frames::FrameDirection::Downstream,
},
))
.await;
tokio::time::sleep(Duration::from_millis(300)).await;
assert_eq!(
received.load(Ordering::SeqCst),
1,
"only the frame from 'voice' should be injected"
);
runner.end(None).await;
let _ = tokio::time::timeout(Duration::from_secs(5), run).await;
}