use crate::providers::ToolDefinition;
use serde_json::{Value, json};
use std::sync::Arc;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use crate::bg_agent::{
AgentStatus, BgAgentRegistry, BgAgentResult, BgTaskSnapshot, CancelOutcome, WaitOutcome,
};
use crate::tools::ToolResult;
use crate::tools::bg_process::{
BgProcessSnapshot, BgProcessStatus, BgRegistry, ProcessWaitOutcome,
};
pub const WAIT_TASK_MAX_TIMEOUT_SECS: u32 = 300;
pub const WAIT_TASK_DEFAULT_TIMEOUT_SECS: u32 = 60;
pub fn definitions() -> Vec<ToolDefinition> {
vec![
ToolDefinition {
name: "ListBackgroundTasks".to_string(),
description:
"List every background task you have running — both background sub-agents (spawned via \
InvokeAgent { background: true }) and background shell processes (spawned via Bash \
{ background: true }).\n\n\
Returns a JSON array of objects, each with:\n\
- task_id: prefixed string. \"agent:N\" for sub-agent tasks, \"process:N\" for shell processes.\n\
- task_type: \"agent\" or \"process\".\n\
- description: agent name + prompt for agents; the original command for processes.\n\
- status: \"pending\" | \"running\" | \"completed\" | \"errored\" | \"cancelled\" \
(agents) or \"running\" | \"exited\" | \"killed\" (processes).\n\
- age_secs: wall-clock seconds since the task was spawned.\n\
- exit_code: present only for exited processes.\n\n\
Use this when:\n\
- You launched background work and want to check progress before doing more.\n\
- You need a task_id to feed CancelTask, or one or more task_ids \
to feed WaitTask.\n\n\
Do NOT use this when:\n\
- You're not sure whether you launched anything (you'd see an empty array — \
cheap, but pointless if you didn't intend to background work).\n\n\
Scope: returns only YOUR tasks. You will never see another agent's tasks or \
the user's top-level tasks here."
.to_string(),
parameters: json!({
"type": "object",
"properties": {},
"additionalProperties": false
}),
},
ToolDefinition {
name: "CancelTask".to_string(),
description:
"Cancel a single background task by its task_id (from ListBackgroundTasks).\n\n\
For sub-agent tasks (\"agent:N\"): fires the per-task cancel token. The agent \
observes it on its next inference iteration and shuts down cleanly. The \
cancellation result will appear in your conversation as a normal sub-agent \
completion with a cancelled marker.\n\n\
For shell processes (\"process:N\"): sends SIGTERM. The process status \
transitions to \"killed\" immediately; the OS exit code surfaces on the next \
ListBackgroundTasks / WaitTask call once the process is fully reaped.\n\n\
Idempotent — calling on an already-cancelled / already-exited task is a \
successful no-op. Returns an error if the task_id is unknown OR if you don't \
own the task (Model E scope: you can only cancel tasks you spawned)."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"task_id": {
"type": "string",
"description": "Prefixed task id from ListBackgroundTasks: \
\"agent:N\" or \"process:N\"."
}
},
"required": ["task_id"],
"additionalProperties": false
}),
},
ToolDefinition {
name: "WaitTask".to_string(),
description: format!(
"Block until ONE OR MORE background tasks finish (or timeout fires).\n\n\
Accepts an array of task ids — wait for one task by passing a single-element \
array, or for many tasks atomically by passing the full set. The atomic-gather \
shape (#1157) means N parallel sub-agents land in ONE structured response \
instead of one tool-result + N-1 synthetic user-message injections.\n\n\
Returns a JSON object: `{{tasks: [...], summary: {{...}}}}`. The `tasks` array \
contains one entry per requested task_id in input order, each carrying its \
terminal state and result. The `summary` object carries counts \
(`total`, `completed`, `errored`, `timed_out`, `cancelled`, `exited`, `killed`, \
`not_found`, `forbidden`) so you can branch on \"all done\" vs \"some still \
running\" without re-iterating the array.\n\n\
Per-task entry shape:\n\
- For sub-agent tasks (\"agent:N\") that complete: `{{task_id, status: \
\"completed\"|\"errored\", agent_name, prompt, output, events}}`. The agent's \
full output is in `output`. The result will NOT also appear in the auto-drain \
on the next iteration — WaitTask consumes it.\n\
- For shell processes (\"process:N\") that exit: `{{task_id, status: \"exited\", \
exit_code}}`. Process stdout/stderr is NOT captured — if you need the output, \
redirect inside the command (e.g. `Bash {{{{ command: \"cmd > /tmp/out.log 2>&1\", \
background: true }}}}`) and Read the file separately.\n\
- For tasks still running at deadline: `{{task_id, status: \"timed_out\", \
current: <ListBackgroundTasks-style snapshot>}}`. The task remains in the \
registry; call again to keep waiting.\n\
- For unknown / cancelled tasks: `{{task_id, status: \"cancelled\"|\"not_found\"\
|\"forbidden\", error: <message>}}` (status \"forbidden\" only fires for tasks \
you didn't spawn — Model E scope).\n\n\
Prefer WaitTask over a polling loop. Pick a duration that means 'I'm willing to \
wait this long', not 'check back quickly' — for sub-agent tasks (multi-iteration \
inference) prefer 120-300 s; the default suits short shell-process waits. The \
same `timeout_secs` applies to every task in the request (per-task, not total). \
Default {default}s, max {max}s.\n\n\
**When NOT to call WaitTask.** Background sub-agents and processes drain \
automatically: completed tasks inject as a user message at the start of the \
next iteration. So if you've just spawned bg work and have ANY other useful \
work to do (more file edits, follow-up searches, summarizing progress for the \
user, even \"acknowledge the spawn and stop\"), do that instead and let the \
result arrive on its own. Calling WaitTask immediately after a bg spawn turns \
the parallel/non-blocking pattern back into a serial blocking call — the \
parent's TUI goes silent for the entire duration with no per-tool feedback \
from the child. Only call WaitTask when (a) the next step you'd take strictly \
depends on this task's result and you cannot make progress without it, (b) \
you're at the end of your work and need the result to finalize the user-facing \
reply, or (c) you're rendezvousing on N parallel agents you fanned out earlier \
and you've already exhausted other work.",
default = WAIT_TASK_DEFAULT_TIMEOUT_SECS,
max = WAIT_TASK_MAX_TIMEOUT_SECS,
),
parameters: json!({
"type": "object",
"properties": {
"task_ids": {
"type": "array",
"items": { "type": "string" },
"minItems": 1,
"description": "Prefixed task ids to wait on. Each entry is \
\"agent:N\" or \"process:N\". Pass a single-element \
array to wait on one task; pass many to gather \
N parallel sub-agents atomically."
},
"timeout_secs": {
"type": "integer",
"minimum": 1,
"maximum": WAIT_TASK_MAX_TIMEOUT_SECS,
"description": format!(
"Per-task max seconds to wait (not total — a 60s timeout with \
4 task_ids waits up to 60s for ALL of them in parallel, not 240s). \
Default {default}, capped at {max} to prevent runaway parks of \
the inference loop.",
default = WAIT_TASK_DEFAULT_TIMEOUT_SECS,
max = WAIT_TASK_MAX_TIMEOUT_SECS,
)
}
},
"required": ["task_ids"],
"additionalProperties": false
}),
},
]
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskId {
Agent(u32),
Process(u32),
}
pub fn parse_task_id(input: &str) -> Result<TaskId, String> {
let trimmed = input.trim();
if trimmed.is_empty() {
return Err("task_id is empty".to_string());
}
if let Some(rest) = trimmed.strip_prefix("agent:") {
return rest
.parse::<u32>()
.map(TaskId::Agent)
.map_err(|_| format!("invalid agent id: '{rest}' (expected non-negative integer)"));
}
if let Some(rest) = trimmed.strip_prefix("process:") {
return rest
.parse::<u32>()
.map(TaskId::Process)
.map_err(|_| format!("invalid process id: '{rest}' (expected non-negative integer)"));
}
if let Ok(n) = trimmed.parse::<u32>() {
return Ok(TaskId::Agent(n));
}
Err(format!(
"unrecognized task_id '{input}'; expected \"agent:N\" or \"process:N\""
))
}
pub fn clamp_wait_timeout_secs(requested: Option<u32>) -> u32 {
let raw = requested.unwrap_or(WAIT_TASK_DEFAULT_TIMEOUT_SECS);
raw.clamp(1, WAIT_TASK_MAX_TIMEOUT_SECS)
}
fn agent_status_str(s: &AgentStatus) -> &'static str {
match s {
AgentStatus::Pending => "pending",
AgentStatus::Running { .. } => "running",
AgentStatus::Completed { .. } => "completed",
AgentStatus::Errored { .. } => "errored",
AgentStatus::Cancelled => "cancelled",
}
}
fn process_status_str(s: &BgProcessStatus) -> &'static str {
match s {
BgProcessStatus::Running => "running",
BgProcessStatus::Exited { .. } => "exited",
BgProcessStatus::Killed => "killed",
}
}
fn agent_snapshot_to_json(s: &BgTaskSnapshot) -> Value {
json!({
"task_id": format!("agent:{}", s.task_id),
"task_type": "agent",
"description": format!("{}: {}", s.agent_name, s.prompt),
"status": agent_status_str(&s.status),
"age_secs": s.age.as_secs(),
})
}
fn process_snapshot_to_json(s: &BgProcessSnapshot) -> Value {
let mut obj = json!({
"task_id": format!("process:{}", s.pid),
"task_type": "process",
"description": s.command.clone(),
"status": process_status_str(&s.status),
"age_secs": s.age.as_secs(),
});
if let BgProcessStatus::Exited { code } = s.status {
obj.as_object_mut()
.unwrap()
.insert("exit_code".into(), json!(code));
}
obj
}
fn err(msg: impl Into<String>) -> ToolResult {
ToolResult {
output: msg.into(),
success: false,
full_output: None,
}
}
fn ok(value: Value) -> ToolResult {
ToolResult {
output: value.to_string(),
success: true,
full_output: None,
}
}
pub async fn execute(
tool_name: &str,
arguments: &str,
bg_agents: &Arc<BgAgentRegistry>,
bg_processes: &BgRegistry,
caller_spawner: Option<u32>,
cancel: &CancellationToken,
) -> ToolResult {
match tool_name {
"ListBackgroundTasks" => execute_list(bg_agents, bg_processes, caller_spawner),
"CancelTask" => execute_cancel(arguments, bg_agents, bg_processes, caller_spawner),
"WaitTask" => {
execute_wait(arguments, bg_agents, bg_processes, caller_spawner, cancel).await
}
other => err(format!(
"bg_task_tools::execute called with unknown tool '{other}' \
(router bug — should have matched in tool_dispatch)"
)),
}
}
fn execute_list(
bg_agents: &BgAgentRegistry,
bg_processes: &BgRegistry,
caller_spawner: Option<u32>,
) -> ToolResult {
bg_processes.reap();
let mut entries: Vec<Value> = bg_agents
.snapshot_for_caller(caller_spawner)
.iter()
.map(agent_snapshot_to_json)
.collect();
entries.extend(
bg_processes
.snapshot_for_caller(caller_spawner)
.iter()
.map(process_snapshot_to_json),
);
ok(Value::Array(entries))
}
fn execute_cancel(
arguments: &str,
bg_agents: &BgAgentRegistry,
bg_processes: &BgRegistry,
caller_spawner: Option<u32>,
) -> ToolResult {
let args: Value = match serde_json::from_str(arguments) {
Ok(v) => v,
Err(e) => return err(format!("CancelTask: invalid JSON arguments: {e}")),
};
let task_id_str = match args.get("task_id").and_then(|v| v.as_str()) {
Some(s) => s,
None => return err("CancelTask: missing required 'task_id' (string)"),
};
let task_id = match parse_task_id(task_id_str) {
Ok(t) => t,
Err(e) => return err(format!("CancelTask: {e}")),
};
let outcome = match task_id {
TaskId::Agent(n) => bg_agents.cancel_as_caller(n, caller_spawner),
TaskId::Process(n) => bg_processes.kill_as_caller(n, caller_spawner),
};
match outcome {
CancelOutcome::Cancelled => ok(json!({
"task_id": task_id_str,
"cancelled": true,
})),
CancelOutcome::NotFound => err(format!(
"CancelTask: no background task with id '{task_id_str}' \
(already finished, never existed, or already drained)"
)),
CancelOutcome::Forbidden => err(format!(
"CancelTask: task '{task_id_str}' is not owned by this caller"
)),
}
}
async fn execute_wait(
arguments: &str,
bg_agents: &Arc<BgAgentRegistry>,
bg_processes: &BgRegistry,
caller_spawner: Option<u32>,
cancel: &CancellationToken,
) -> ToolResult {
let args: Value = match serde_json::from_str(arguments) {
Ok(v) => v,
Err(e) => return err(format!("WaitTask: invalid JSON arguments: {e}")),
};
let task_ids: Vec<String> = match args.get("task_ids").and_then(|v| v.as_array()) {
Some(arr) if !arr.is_empty() => {
let mut out = Vec::with_capacity(arr.len());
for (i, v) in arr.iter().enumerate() {
match v.as_str() {
Some(s) => out.push(s.to_string()),
None => {
return err(format!(
"WaitTask: 'task_ids[{i}]' must be a string, got {}",
v
));
}
}
}
out
}
Some(_) => return err("WaitTask: 'task_ids' must be a non-empty array"),
None => return err("WaitTask: missing required 'task_ids' (array of strings)"),
};
let timeout_secs = clamp_wait_timeout_secs(
args.get("timeout_secs")
.and_then(|v| v.as_u64())
.map(|n| n as u32),
);
let timeout = Duration::from_secs(timeout_secs as u64);
let futures = task_ids.iter().map(|id| {
wait_one_task(
id.clone(),
bg_agents,
bg_processes,
caller_spawner,
timeout,
cancel,
)
});
let task_results: Vec<Value> = futures_util::future::join_all(futures).await;
let mut summary = serde_json::Map::new();
summary.insert("total".into(), json!(task_results.len()));
for entry in &task_results {
if let Some(status) = entry.get("status").and_then(|v| v.as_str()) {
let counter = summary.entry(status.to_string()).or_insert(json!(0));
if let Some(n) = counter.as_u64() {
*counter = json!(n + 1);
}
}
}
ok(json!({
"tasks": task_results,
"summary": Value::Object(summary),
}))
}
async fn wait_one_task(
task_id_str: String,
bg_agents: &Arc<BgAgentRegistry>,
bg_processes: &BgRegistry,
caller_spawner: Option<u32>,
timeout: Duration,
cancel: &CancellationToken,
) -> Value {
let task_id = match parse_task_id(&task_id_str) {
Ok(t) => t,
Err(e) => {
return json!({
"task_id": task_id_str,
"status": "parse_error",
"error": e,
});
}
};
match task_id {
TaskId::Agent(n) => {
tokio::select! {
outcome = bg_agents.wait_for_completion(n, caller_spawner, timeout) => {
agent_wait_to_value(&task_id_str, outcome)
}
_ = cancel.cancelled() => agent_wait_to_value(&task_id_str, WaitOutcome::Cancelled),
}
}
TaskId::Process(n) => {
tokio::select! {
outcome = bg_processes.wait_for_exit_as_caller(n, caller_spawner, timeout) => {
process_wait_to_value(&task_id_str, outcome)
}
_ = cancel.cancelled() => json!({
"task_id": task_id_str,
"status": "cancelled",
}),
}
}
}
}
fn agent_wait_to_value(task_id_str: &str, outcome: WaitOutcome) -> Value {
match outcome {
WaitOutcome::Completed(BgAgentResult {
agent_name,
prompt,
output,
success,
events,
parent_tool_call_id: _,
}) => json!({
"task_id": task_id_str,
"status": if success { "completed" } else { "errored" },
"agent_name": agent_name,
"prompt": prompt,
"output": output,
"events": events,
}),
WaitOutcome::Cancelled => json!({
"task_id": task_id_str,
"status": "cancelled",
}),
WaitOutcome::TimedOut(snap) => json!({
"task_id": task_id_str,
"status": "timed_out",
"current": agent_snapshot_to_json(&snap),
}),
WaitOutcome::NotFound => json!({
"task_id": task_id_str,
"status": "not_found",
"error": format!("no background task with id '{task_id_str}'"),
}),
WaitOutcome::Forbidden => json!({
"task_id": task_id_str,
"status": "forbidden",
"error": format!("task '{task_id_str}' is not owned by this caller"),
}),
}
}
fn process_wait_to_value(task_id_str: &str, outcome: ProcessWaitOutcome) -> Value {
match outcome {
ProcessWaitOutcome::Exited { code } => json!({
"task_id": task_id_str,
"status": "exited",
"exit_code": code,
}),
ProcessWaitOutcome::TimedOut(snap) => json!({
"task_id": task_id_str,
"status": "timed_out",
"current": process_snapshot_to_json(&snap),
}),
ProcessWaitOutcome::NotFound => json!({
"task_id": task_id_str,
"status": "not_found",
"error": format!("no background task with id '{task_id_str}'"),
}),
ProcessWaitOutcome::Forbidden => json!({
"task_id": task_id_str,
"status": "forbidden",
"error": format!("task '{task_id_str}' is not owned by this caller"),
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn no_cancel() -> CancellationToken {
CancellationToken::new()
}
#[test]
fn definitions_returns_three_tools_with_expected_names() {
let defs = definitions();
let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
assert_eq!(names, vec!["ListBackgroundTasks", "CancelTask", "WaitTask"]);
}
#[test]
fn list_background_tasks_takes_no_arguments() {
let defs = definitions();
let list = defs
.iter()
.find(|d| d.name == "ListBackgroundTasks")
.unwrap();
let required = list.parameters.get("required");
assert!(
required.is_none() || required.unwrap().as_array().unwrap().is_empty(),
"ListBackgroundTasks must take no required args"
);
}
#[test]
fn wait_task_description_discourages_immediate_post_spawn_call() {
let defs = definitions();
let wait = defs.iter().find(|d| d.name == "WaitTask").unwrap();
let desc = &wait.description;
assert!(
desc.contains("drain"),
"WaitTask description must mention auto-drain so the model knows \
results arrive on the next iteration without an explicit wait \
(#1201 A1)"
);
assert!(
desc.contains("immediately") || desc.contains("immediate"),
"WaitTask description must call out the immediate-post-spawn \
anti-pattern by name (#1201 A1)"
);
assert!(
desc.contains("blocking") || desc.contains("silent"),
"WaitTask description must explain the cost of immediate \
waiting (parent goes silent / blocks) so the model has a \
reason to avoid it (#1201 A1)"
);
}
#[test]
fn cancel_requires_task_id_and_wait_requires_task_ids() {
let defs = definitions();
let cancel = defs.iter().find(|d| d.name == "CancelTask").unwrap();
let cancel_req = cancel.parameters["required"].as_array().unwrap();
assert!(
cancel_req.iter().any(|v| v == "task_id"),
"CancelTask must require task_id"
);
let wait = defs.iter().find(|d| d.name == "WaitTask").unwrap();
let wait_req = wait.parameters["required"].as_array().unwrap();
assert!(
wait_req.iter().any(|v| v == "task_ids"),
"WaitTask must require task_ids (array, #1157)"
);
assert_eq!(wait.parameters["properties"]["task_ids"]["type"], "array");
assert_eq!(
wait.parameters["properties"]["task_ids"]["items"]["type"],
"string"
);
}
#[test]
fn parse_task_id_accepts_prefixed_forms() {
assert_eq!(parse_task_id("agent:7").unwrap(), TaskId::Agent(7));
assert_eq!(
parse_task_id("process:12345").unwrap(),
TaskId::Process(12345)
);
assert_eq!(parse_task_id(" agent:1 ").unwrap(), TaskId::Agent(1));
}
#[test]
fn parse_task_id_accepts_bare_numeric_as_agent() {
assert_eq!(parse_task_id("5").unwrap(), TaskId::Agent(5));
}
#[test]
fn parse_task_id_rejects_bad_input() {
assert!(parse_task_id("").is_err());
assert!(parse_task_id(" ").is_err());
assert!(parse_task_id("agent:").is_err());
assert!(parse_task_id("agent:abc").is_err());
assert!(parse_task_id("process:-1").is_err());
assert!(parse_task_id("foobar").is_err());
assert!(parse_task_id("mcp:1").is_err()); }
#[test]
fn clamp_wait_timeout_handles_none_default() {
assert_eq!(
clamp_wait_timeout_secs(None),
WAIT_TASK_DEFAULT_TIMEOUT_SECS
);
}
#[test]
fn clamp_wait_timeout_caps_at_max() {
assert_eq!(
clamp_wait_timeout_secs(Some(86400)),
WAIT_TASK_MAX_TIMEOUT_SECS
);
}
#[test]
fn clamp_wait_timeout_floors_at_one() {
assert_eq!(clamp_wait_timeout_secs(Some(0)), 1);
}
#[test]
fn clamp_wait_timeout_passes_through_in_range() {
assert_eq!(clamp_wait_timeout_secs(Some(45)), 45);
assert_eq!(
clamp_wait_timeout_secs(Some(WAIT_TASK_MAX_TIMEOUT_SECS)),
WAIT_TASK_MAX_TIMEOUT_SECS
);
}
fn fresh_registries() -> (Arc<BgAgentRegistry>, BgRegistry) {
(Arc::new(BgAgentRegistry::new()), BgRegistry::new())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_list_returns_empty_array_when_no_tasks() {
let (agents, processes) = fresh_registries();
let r = execute(
"ListBackgroundTasks",
"{}",
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(r.success);
assert_eq!(r.output, "[]");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_list_includes_caller_agent_tasks() {
let (agents, processes) = fresh_registries();
let (id, _tx, _, _) = agents.register_test_with_status("explore", "map repo", None);
let r = execute(
"ListBackgroundTasks",
"{}",
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(r.success);
let arr: Value = serde_json::from_str(&r.output).unwrap();
let arr = arr.as_array().unwrap();
assert_eq!(arr.len(), 1);
assert_eq!(arr[0]["task_id"], format!("agent:{id}"));
assert_eq!(arr[0]["task_type"], "agent");
assert_eq!(arr[0]["status"], "pending");
assert_eq!(arr[0]["description"], "explore: map repo");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_list_filters_out_other_callers_tasks() {
let (agents, processes) = fresh_registries();
agents.register_test_with_status("a", "top", None);
agents.register_test_with_status("b", "sub", Some(7));
let top = execute(
"ListBackgroundTasks",
"{}",
&agents,
&processes,
None,
&no_cancel(),
)
.await;
let arr: Value = serde_json::from_str(&top.output).unwrap();
assert_eq!(arr.as_array().unwrap().len(), 1, "top sees only its own");
let sub = execute(
"ListBackgroundTasks",
"{}",
&agents,
&processes,
Some(7),
&no_cancel(),
)
.await;
let arr: Value = serde_json::from_str(&sub.output).unwrap();
assert_eq!(arr.as_array().unwrap().len(), 1, "sub sees only its own");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_cancel_succeeds_for_owned_agent_task() {
let (agents, processes) = fresh_registries();
let (id, _tx, _, observer) = agents.register_test_with_status("x", "y", None);
let r = execute(
"CancelTask",
&json!({ "task_id": format!("agent:{id}") }).to_string(),
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(r.success, "got: {}", r.output);
assert!(observer.is_cancelled(), "cancel token must fire");
let payload: Value = serde_json::from_str(&r.output).unwrap();
assert_eq!(payload["cancelled"], true);
assert_eq!(payload["task_id"], format!("agent:{id}"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_cancel_returns_not_found_for_unknown_id() {
let (agents, processes) = fresh_registries();
let r = execute(
"CancelTask",
&json!({ "task_id": "agent:9999" }).to_string(),
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(!r.success);
assert!(r.output.contains("no background task"), "got: {}", r.output);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_cancel_returns_forbidden_for_cross_caller() {
let (agents, processes) = fresh_registries();
let (id, _tx, _, observer) = agents.register_test_with_status("x", "y", Some(5));
let r = execute(
"CancelTask",
&json!({ "task_id": format!("agent:{id}") }).to_string(),
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(!r.success);
assert!(
r.output.contains("not owned by this caller"),
"got: {}",
r.output
);
assert!(!observer.is_cancelled(), "forbidden must NOT fire token");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_cancel_rejects_malformed_json() {
let (agents, processes) = fresh_registries();
let r = execute(
"CancelTask",
"not-json",
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(!r.success);
assert!(r.output.contains("invalid JSON"), "got: {}", r.output);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_cancel_rejects_missing_task_id() {
let (agents, processes) = fresh_registries();
let r = execute("CancelTask", "{}", &agents, &processes, None, &no_cancel()).await;
assert!(!r.success);
assert!(r.output.contains("missing required"), "got: {}", r.output);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_wait_returns_completed_for_finished_agent() {
let (agents, processes) = fresh_registries();
let (id, tx, status_tx, _) = agents.register_test_with_status("explore", "map", None);
tx.send(Ok(("final answer".into(), vec!["e1".into()])))
.unwrap();
status_tx
.send(AgentStatus::Completed {
summary: "final".into(),
})
.unwrap();
let r = execute(
"WaitTask",
&json!({ "task_ids": [format!("agent:{id}")], "timeout_secs": 1 }).to_string(),
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(r.success, "got: {}", r.output);
let payload: Value = serde_json::from_str(&r.output).unwrap();
let task = &payload["tasks"][0];
assert_eq!(task["status"], "completed");
assert_eq!(task["output"], "final answer");
assert_eq!(task["events"].as_array().unwrap().len(), 1);
assert_eq!(payload["summary"]["total"], 1);
assert_eq!(payload["summary"]["completed"], 1);
assert_eq!(agents.snapshot().len(), 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_wait_returns_timed_out_with_snapshot() {
let (agents, processes) = fresh_registries();
let (id, _tx, _status_tx, _observer) = agents.register_test_with_status("slow", "x", None);
let r = execute(
"WaitTask",
&json!({ "task_ids": [format!("agent:{id}")], "timeout_secs": 1 }).to_string(),
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(r.success);
let payload: Value = serde_json::from_str(&r.output).unwrap();
let task = &payload["tasks"][0];
assert_eq!(task["status"], "timed_out");
assert_eq!(task["current"]["task_id"], format!("agent:{id}"));
assert_eq!(payload["summary"]["timed_out"], 1);
assert_eq!(agents.snapshot().len(), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_wait_returns_cancelled_when_token_fires() {
let (agents, processes) = fresh_registries();
let (id, tx, status_tx, observer) = agents.register_test_with_status("slow", "x", None);
observer.cancel();
status_tx.send(AgentStatus::Cancelled).unwrap();
drop(tx);
let r = execute(
"WaitTask",
&json!({ "task_ids": [format!("agent:{id}")], "timeout_secs": 5 }).to_string(),
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(
r.success,
"WaitTask on a cancelled task must still succeed: {}",
r.output
);
let payload: Value = serde_json::from_str(&r.output).unwrap();
let task = &payload["tasks"][0];
assert_eq!(task["status"], "cancelled");
assert_eq!(task["task_id"], format!("agent:{id}"));
assert_eq!(payload["summary"]["cancelled"], 1);
assert_eq!(agents.snapshot().len(), 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_wait_unblocks_immediately_on_master_cancel() {
let (agents, processes) = fresh_registries();
let (id, _tx, _status_tx, _observer) =
agents.register_test_with_status("wedged", "x", None);
let cancel = CancellationToken::new();
let cancel_for_fire = cancel.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
cancel_for_fire.cancel();
});
let started = std::time::Instant::now();
let r = execute(
"WaitTask",
&json!({ "task_ids": [format!("agent:{id}")], "timeout_secs": 5 }).to_string(),
&agents,
&processes,
None,
&cancel,
)
.await;
let elapsed = started.elapsed();
assert!(
elapsed < Duration::from_secs(1),
"WaitTask must unblock on master cancel, not wait out the 5s timeout (took {elapsed:?})"
);
assert!(r.success, "WaitTask payload must still parse: {}", r.output);
let payload: Value = serde_json::from_str(&r.output).unwrap();
assert_eq!(payload["tasks"][0]["status"], "cancelled");
assert_eq!(payload["summary"]["cancelled"], 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn execute_wait_atomically_gathers_multiple_tasks() {
let (agents, processes) = fresh_registries();
let mut ids = Vec::new();
for i in 0..3 {
let (id, tx, status_tx, _) =
agents.register_test_with_status("explore", &format!("task-{i}"), None);
tx.send(Ok((format!("output-{i}"), vec![]))).unwrap();
status_tx
.send(AgentStatus::Completed {
summary: format!("sum-{i}"),
})
.unwrap();
ids.push(id);
}
let task_ids: Vec<String> = ids.iter().map(|i| format!("agent:{i}")).collect();
let r = execute(
"WaitTask",
&json!({ "task_ids": task_ids, "timeout_secs": 2 }).to_string(),
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(r.success, "got: {}", r.output);
let payload: Value = serde_json::from_str(&r.output).unwrap();
let tasks = payload["tasks"].as_array().unwrap();
assert_eq!(tasks.len(), 3, "all 3 tasks must come back: {payload:#}");
for (i, task) in tasks.iter().enumerate() {
assert_eq!(task["task_id"], format!("agent:{}", ids[i]));
assert_eq!(task["status"], "completed");
assert_eq!(task["output"], format!("output-{i}"));
}
assert_eq!(payload["summary"]["total"], 3);
assert_eq!(payload["summary"]["completed"], 3);
assert_eq!(
agents.snapshot().len(),
0,
"all 3 entries must be consumed atomically"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn execute_wait_mixes_outcomes_without_failing_whole_call() {
let (agents, processes) = fresh_registries();
let (done_id, done_tx, done_status, _done_obs) =
agents.register_test_with_status("fast", "a", None);
done_tx.send(Ok(("done!".into(), vec![]))).unwrap();
done_status
.send(AgentStatus::Completed {
summary: "s".into(),
})
.unwrap();
let (slow_id, _slow_tx, _slow_status, _slow_obs) =
agents.register_test_with_status("slow", "b", None);
let r = execute(
"WaitTask",
&json!({
"task_ids": [
format!("agent:{done_id}"),
format!("agent:{slow_id}"),
"agent:99999", ],
"timeout_secs": 1,
})
.to_string(),
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(r.success, "top-level call must succeed: {}", r.output);
let payload: Value = serde_json::from_str(&r.output).unwrap();
let tasks = payload["tasks"].as_array().unwrap();
assert_eq!(tasks.len(), 3);
assert_eq!(tasks[0]["status"], "completed");
assert_eq!(tasks[1]["status"], "timed_out");
assert_eq!(tasks[2]["status"], "not_found");
assert!(
tasks[2]["error"].as_str().unwrap().contains("agent:99999"),
"not_found entry must mention the offending id: {payload:#}"
);
assert_eq!(payload["summary"]["completed"], 1);
assert_eq!(payload["summary"]["timed_out"], 1);
assert_eq!(payload["summary"]["not_found"], 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_wait_rejects_empty_task_ids_array() {
let (agents, processes) = fresh_registries();
let r = execute(
"WaitTask",
&json!({ "task_ids": [] }).to_string(),
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(!r.success);
assert!(r.output.contains("non-empty array"), "got: {}", r.output);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_wait_rejects_non_string_task_id_entry() {
let (agents, processes) = fresh_registries();
let r = execute(
"WaitTask",
&json!({ "task_ids": ["agent:1", 42] }).to_string(),
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(!r.success);
assert!(
r.output.contains("task_ids[1]") && r.output.contains("must be a string"),
"error must point at the bad index: {}",
r.output
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn execute_unknown_tool_name_returns_error() {
let (agents, processes) = fresh_registries();
let r = execute(
"NotAToolWeKnow",
"{}",
&agents,
&processes,
None,
&no_cancel(),
)
.await;
assert!(!r.success);
assert!(r.output.contains("unknown tool"));
}
}