use crate::providers::ToolDefinition;
use serde_json::{Value, json};
use std::sync::Arc;
use std::time::Duration;
use crate::bg_agent::{
AgentStatus, BgAgentRegistry, BgAgentResult, BgTaskSnapshot, CancelOutcome, WaitOutcome,
};
use crate::tools::ToolResult;
use crate::tools::bg_process::{
BgProcessSnapshot, BgProcessStatus, BgRegistry, ProcessWaitOutcome,
};
pub const WAIT_TASK_MAX_TIMEOUT_SECS: u32 = 300;
pub const WAIT_TASK_DEFAULT_TIMEOUT_SECS: u32 = 30;
pub fn definitions() -> Vec<ToolDefinition> {
vec![
ToolDefinition {
name: "ListBackgroundTasks".to_string(),
description:
"List every background task you have running — both background sub-agents (spawned via \
InvokeAgent { background: true }) and background shell processes (spawned via Bash \
{ background: true }).\n\n\
Returns a JSON array of objects, each with:\n\
- task_id: prefixed string. \"agent:N\" for sub-agent tasks, \"process:N\" for shell processes.\n\
- task_type: \"agent\" or \"process\".\n\
- description: agent name + prompt for agents; the original command for processes.\n\
- status: \"pending\" | \"running\" | \"completed\" | \"errored\" | \"cancelled\" \
(agents) or \"running\" | \"exited\" | \"killed\" (processes).\n\
- age_secs: wall-clock seconds since the task was spawned.\n\
- exit_code: present only for exited processes.\n\n\
Use this when:\n\
- You launched background work and want to check progress before doing more.\n\
- You need a task_id to feed CancelTask or WaitTask.\n\n\
Do NOT use this when:\n\
- You're not sure whether you launched anything (you'd see an empty array — \
cheap, but pointless if you didn't intend to background work).\n\n\
Scope: returns only YOUR tasks. You will never see another agent's tasks or \
the user's top-level tasks here."
.to_string(),
parameters: json!({
"type": "object",
"properties": {},
"additionalProperties": false
}),
},
ToolDefinition {
name: "CancelTask".to_string(),
description:
"Cancel a single background task by its task_id (from ListBackgroundTasks).\n\n\
For sub-agent tasks (\"agent:N\"): fires the per-task cancel token. The agent \
observes it on its next inference iteration and shuts down cleanly. The \
cancellation result will appear in your conversation as a normal sub-agent \
completion with a cancelled marker.\n\n\
For shell processes (\"process:N\"): sends SIGTERM. The process status \
transitions to \"killed\" immediately; the OS exit code surfaces on the next \
ListBackgroundTasks / WaitTask call once the process is fully reaped.\n\n\
Idempotent — calling on an already-cancelled / already-exited task is a \
successful no-op. Returns an error if the task_id is unknown OR if you don't \
own the task (Model E scope: you can only cancel tasks you spawned)."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"task_id": {
"type": "string",
"description": "Prefixed task id from ListBackgroundTasks: \
\"agent:N\" or \"process:N\"."
}
},
"required": ["task_id"],
"additionalProperties": false
}),
},
ToolDefinition {
name: "WaitTask".to_string(),
description: format!(
"Block until a background task finishes (or timeout fires).\n\n\
Returns the task's terminal state and result so you don't have to keep \
polling ListBackgroundTasks. Prefer WaitTask over a polling loop — one \
tool call instead of many.\n\n\
For sub-agent tasks (\"agent:N\"): on completion, returns the agent's full \
output. The result will NOT also appear in the auto-drain on the next \
iteration — WaitTask consumes it.\n\n\
For shell processes (\"process:N\"): on exit, returns the OS exit code. \
Process stdout/stderr is NOT captured — if you need the output, redirect \
inside the command (e.g. `Bash {{ command: \"cmd > /tmp/out.log 2>&1\", \
background: true }}`) and Read the file separately.\n\n\
If the task hasn't finished by `timeout_secs`, returns the current status \
without consuming the task — you can call again to keep waiting. Default \
{default}s, max {max}s. Returns an error if the task_id is unknown or \
doesn't belong to you.",
default = WAIT_TASK_DEFAULT_TIMEOUT_SECS,
max = WAIT_TASK_MAX_TIMEOUT_SECS,
),
parameters: json!({
"type": "object",
"properties": {
"task_id": {
"type": "string",
"description": "Prefixed task id: \"agent:N\" or \"process:N\"."
},
"timeout_secs": {
"type": "integer",
"minimum": 1,
"maximum": WAIT_TASK_MAX_TIMEOUT_SECS,
"description": format!(
"Maximum seconds to wait. Default {default}, capped at {max} to \
prevent runaway parks of the inference loop.",
default = WAIT_TASK_DEFAULT_TIMEOUT_SECS,
max = WAIT_TASK_MAX_TIMEOUT_SECS,
)
}
},
"required": ["task_id"],
"additionalProperties": false
}),
},
]
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskId {
Agent(u32),
Process(u32),
}
pub fn parse_task_id(input: &str) -> Result<TaskId, String> {
let trimmed = input.trim();
if trimmed.is_empty() {
return Err("task_id is empty".to_string());
}
if let Some(rest) = trimmed.strip_prefix("agent:") {
return rest
.parse::<u32>()
.map(TaskId::Agent)
.map_err(|_| format!("invalid agent id: '{rest}' (expected non-negative integer)"));
}
if let Some(rest) = trimmed.strip_prefix("process:") {
return rest
.parse::<u32>()
.map(TaskId::Process)
.map_err(|_| format!("invalid process id: '{rest}' (expected non-negative integer)"));
}
if let Ok(n) = trimmed.parse::<u32>() {
return Ok(TaskId::Agent(n));
}
Err(format!(
"unrecognized task_id '{input}'; expected \"agent:N\" or \"process:N\""
))
}
pub fn clamp_wait_timeout_secs(requested: Option<u32>) -> u32 {
let raw = requested.unwrap_or(WAIT_TASK_DEFAULT_TIMEOUT_SECS);
raw.clamp(1, WAIT_TASK_MAX_TIMEOUT_SECS)
}
fn agent_status_str(s: &AgentStatus) -> &'static str {
match s {
AgentStatus::Pending => "pending",
AgentStatus::Running { .. } => "running",
AgentStatus::Completed { .. } => "completed",
AgentStatus::Errored { .. } => "errored",
AgentStatus::Cancelled => "cancelled",
}
}
fn process_status_str(s: &BgProcessStatus) -> &'static str {
match s {
BgProcessStatus::Running => "running",
BgProcessStatus::Exited { .. } => "exited",
BgProcessStatus::Killed => "killed",
}
}
fn agent_snapshot_to_json(s: &BgTaskSnapshot) -> Value {
json!({
"task_id": format!("agent:{}", s.task_id),
"task_type": "agent",
"description": format!("{}: {}", s.agent_name, s.prompt),
"status": agent_status_str(&s.status),
"age_secs": s.age.as_secs(),
})
}
fn process_snapshot_to_json(s: &BgProcessSnapshot) -> Value {
let mut obj = json!({
"task_id": format!("process:{}", s.pid),
"task_type": "process",
"description": s.command.clone(),
"status": process_status_str(&s.status),
"age_secs": s.age.as_secs(),
});
if let BgProcessStatus::Exited { code } = s.status {
obj.as_object_mut()
.unwrap()
.insert("exit_code".into(), json!(code));
}
obj
}
fn err(msg: impl Into<String>) -> ToolResult {
ToolResult {
output: msg.into(),
success: false,
full_output: None,
}
}
fn ok(value: Value) -> ToolResult {
ToolResult {
output: value.to_string(),
success: true,
full_output: None,
}
}
pub async fn execute(
tool_name: &str,
arguments: &str,
bg_agents: &Arc<BgAgentRegistry>,
bg_processes: &BgRegistry,
caller_spawner: Option<u32>,
) -> ToolResult {
match tool_name {
"ListBackgroundTasks" => execute_list(bg_agents, bg_processes, caller_spawner),
"CancelTask" => execute_cancel(arguments, bg_agents, bg_processes, caller_spawner),
"WaitTask" => execute_wait(arguments, bg_agents, bg_processes, caller_spawner).await,
other => err(format!(
"bg_task_tools::execute called with unknown tool '{other}' \
(router bug — should have matched in tool_dispatch)"
)),
}
}
fn execute_list(
bg_agents: &BgAgentRegistry,
bg_processes: &BgRegistry,
caller_spawner: Option<u32>,
) -> ToolResult {
bg_processes.reap();
let mut entries: Vec<Value> = bg_agents
.snapshot_for_caller(caller_spawner)
.iter()
.map(agent_snapshot_to_json)
.collect();
entries.extend(
bg_processes
.snapshot_for_caller(caller_spawner)
.iter()
.map(process_snapshot_to_json),
);
ok(Value::Array(entries))
}
fn execute_cancel(
arguments: &str,
bg_agents: &BgAgentRegistry,
bg_processes: &BgRegistry,
caller_spawner: Option<u32>,
) -> ToolResult {
let args: Value = match serde_json::from_str(arguments) {
Ok(v) => v,
Err(e) => return err(format!("CancelTask: invalid JSON arguments: {e}")),
};
let task_id_str = match args.get("task_id").and_then(|v| v.as_str()) {
Some(s) => s,
None => return err("CancelTask: missing required 'task_id' (string)"),
};
let task_id = match parse_task_id(task_id_str) {
Ok(t) => t,
Err(e) => return err(format!("CancelTask: {e}")),
};
let outcome = match task_id {
TaskId::Agent(n) => bg_agents.cancel_as_caller(n, caller_spawner),
TaskId::Process(n) => bg_processes.kill_as_caller(n, caller_spawner),
};
match outcome {
CancelOutcome::Cancelled => ok(json!({
"task_id": task_id_str,
"cancelled": true,
})),
CancelOutcome::NotFound => err(format!(
"CancelTask: no background task with id '{task_id_str}' \
(already finished, never existed, or already drained)"
)),
CancelOutcome::Forbidden => err(format!(
"CancelTask: task '{task_id_str}' is not owned by this caller"
)),
}
}
async fn execute_wait(
arguments: &str,
bg_agents: &BgAgentRegistry,
bg_processes: &BgRegistry,
caller_spawner: Option<u32>,
) -> ToolResult {
let args: Value = match serde_json::from_str(arguments) {
Ok(v) => v,
Err(e) => return err(format!("WaitTask: invalid JSON arguments: {e}")),
};
let task_id_str = match args.get("task_id").and_then(|v| v.as_str()) {
Some(s) => s,
None => return err("WaitTask: missing required 'task_id' (string)"),
};
let task_id = match parse_task_id(task_id_str) {
Ok(t) => t,
Err(e) => return err(format!("WaitTask: {e}")),
};
let timeout_secs = clamp_wait_timeout_secs(
args.get("timeout_secs")
.and_then(|v| v.as_u64())
.map(|n| n as u32),
);
let timeout = Duration::from_secs(timeout_secs as u64);
match task_id {
TaskId::Agent(n) => {
let outcome = bg_agents
.wait_for_completion(n, caller_spawner, timeout)
.await;
agent_wait_to_tool_result(task_id_str, outcome)
}
TaskId::Process(n) => {
let outcome = bg_processes
.wait_for_exit_as_caller(n, caller_spawner, timeout)
.await;
process_wait_to_tool_result(task_id_str, outcome)
}
}
}
fn agent_wait_to_tool_result(task_id_str: &str, outcome: WaitOutcome) -> ToolResult {
match outcome {
WaitOutcome::Completed(BgAgentResult {
agent_name,
prompt,
output,
success,
events,
}) => ok(json!({
"task_id": task_id_str,
"status": if success { "completed" } else { "errored" },
"agent_name": agent_name,
"prompt": prompt,
"output": output,
"events": events,
})),
WaitOutcome::Cancelled => ok(json!({
"task_id": task_id_str,
"status": "cancelled",
})),
WaitOutcome::TimedOut(snap) => ok(json!({
"task_id": task_id_str,
"status": "timed_out",
"current": agent_snapshot_to_json(&snap),
})),
WaitOutcome::NotFound => err(format!(
"WaitTask: no background task with id '{task_id_str}'"
)),
WaitOutcome::Forbidden => err(format!(
"WaitTask: task '{task_id_str}' is not owned by this caller"
)),
}
}
fn process_wait_to_tool_result(task_id_str: &str, outcome: ProcessWaitOutcome) -> ToolResult {
match outcome {
ProcessWaitOutcome::Exited { code } => ok(json!({
"task_id": task_id_str,
"status": "exited",
"exit_code": code,
})),
ProcessWaitOutcome::TimedOut(snap) => ok(json!({
"task_id": task_id_str,
"status": "timed_out",
"current": process_snapshot_to_json(&snap),
})),
ProcessWaitOutcome::NotFound => err(format!(
"WaitTask: no background task with id '{task_id_str}'"
)),
ProcessWaitOutcome::Forbidden => err(format!(
"WaitTask: task '{task_id_str}' is not owned by this caller"
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn definitions_returns_three_tools_with_expected_names() {
let defs = definitions();
let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
assert_eq!(names, vec!["ListBackgroundTasks", "CancelTask", "WaitTask"]);
}
#[test]
fn list_background_tasks_takes_no_arguments() {
let defs = definitions();
let list = defs
.iter()
.find(|d| d.name == "ListBackgroundTasks")
.unwrap();
let required = list.parameters.get("required");
assert!(
required.is_none() || required.unwrap().as_array().unwrap().is_empty(),
"ListBackgroundTasks must take no required args"
);
}
#[test]
fn cancel_and_wait_require_task_id() {
let defs = definitions();
for name in ["CancelTask", "WaitTask"] {
let def = defs.iter().find(|d| d.name == name).unwrap();
let required = def.parameters["required"].as_array().unwrap();
assert!(
required.iter().any(|v| v == "task_id"),
"{name} must require task_id"
);
}
}
#[test]
fn parse_task_id_accepts_prefixed_forms() {
assert_eq!(parse_task_id("agent:7").unwrap(), TaskId::Agent(7));
assert_eq!(
parse_task_id("process:12345").unwrap(),
TaskId::Process(12345)
);
assert_eq!(parse_task_id(" agent:1 ").unwrap(), TaskId::Agent(1));
}
#[test]
fn parse_task_id_accepts_bare_numeric_as_agent() {
assert_eq!(parse_task_id("5").unwrap(), TaskId::Agent(5));
}
#[test]
fn parse_task_id_rejects_bad_input() {
assert!(parse_task_id("").is_err());
assert!(parse_task_id(" ").is_err());
assert!(parse_task_id("agent:").is_err());
assert!(parse_task_id("agent:abc").is_err());
assert!(parse_task_id("process:-1").is_err());
assert!(parse_task_id("foobar").is_err());
assert!(parse_task_id("mcp:1").is_err()); }
#[test]
fn clamp_wait_timeout_handles_none_default() {
assert_eq!(
clamp_wait_timeout_secs(None),
WAIT_TASK_DEFAULT_TIMEOUT_SECS
);
}
#[test]
fn clamp_wait_timeout_caps_at_max() {
assert_eq!(
clamp_wait_timeout_secs(Some(86400)),
WAIT_TASK_MAX_TIMEOUT_SECS
);
}
#[test]
fn clamp_wait_timeout_floors_at_one() {
assert_eq!(clamp_wait_timeout_secs(Some(0)), 1);
}
#[test]
fn clamp_wait_timeout_passes_through_in_range() {
assert_eq!(clamp_wait_timeout_secs(Some(45)), 45);
assert_eq!(
clamp_wait_timeout_secs(Some(WAIT_TASK_MAX_TIMEOUT_SECS)),
WAIT_TASK_MAX_TIMEOUT_SECS
);
}
fn fresh_registries() -> (Arc<BgAgentRegistry>, BgRegistry) {
(Arc::new(BgAgentRegistry::new()), BgRegistry::new())
}
#[tokio::test]
async fn execute_list_returns_empty_array_when_no_tasks() {
let (agents, processes) = fresh_registries();
let r = execute("ListBackgroundTasks", "{}", &agents, &processes, None).await;
assert!(r.success);
assert_eq!(r.output, "[]");
}
#[tokio::test]
async fn execute_list_includes_caller_agent_tasks() {
let (agents, processes) = fresh_registries();
let (id, _tx, _, _) = agents.register_test_with_status("explore", "map repo", None);
let r = execute("ListBackgroundTasks", "{}", &agents, &processes, None).await;
assert!(r.success);
let arr: Value = serde_json::from_str(&r.output).unwrap();
let arr = arr.as_array().unwrap();
assert_eq!(arr.len(), 1);
assert_eq!(arr[0]["task_id"], format!("agent:{id}"));
assert_eq!(arr[0]["task_type"], "agent");
assert_eq!(arr[0]["status"], "pending");
assert_eq!(arr[0]["description"], "explore: map repo");
}
#[tokio::test]
async fn execute_list_filters_out_other_callers_tasks() {
let (agents, processes) = fresh_registries();
agents.register_test_with_status("a", "top", None);
agents.register_test_with_status("b", "sub", Some(7));
let top = execute("ListBackgroundTasks", "{}", &agents, &processes, None).await;
let arr: Value = serde_json::from_str(&top.output).unwrap();
assert_eq!(arr.as_array().unwrap().len(), 1, "top sees only its own");
let sub = execute("ListBackgroundTasks", "{}", &agents, &processes, Some(7)).await;
let arr: Value = serde_json::from_str(&sub.output).unwrap();
assert_eq!(arr.as_array().unwrap().len(), 1, "sub sees only its own");
}
#[tokio::test]
async fn execute_cancel_succeeds_for_owned_agent_task() {
let (agents, processes) = fresh_registries();
let (id, _tx, _, observer) = agents.register_test_with_status("x", "y", None);
let r = execute(
"CancelTask",
&json!({ "task_id": format!("agent:{id}") }).to_string(),
&agents,
&processes,
None,
)
.await;
assert!(r.success, "got: {}", r.output);
assert!(observer.is_cancelled(), "cancel token must fire");
let payload: Value = serde_json::from_str(&r.output).unwrap();
assert_eq!(payload["cancelled"], true);
assert_eq!(payload["task_id"], format!("agent:{id}"));
}
#[tokio::test]
async fn execute_cancel_returns_not_found_for_unknown_id() {
let (agents, processes) = fresh_registries();
let r = execute(
"CancelTask",
&json!({ "task_id": "agent:9999" }).to_string(),
&agents,
&processes,
None,
)
.await;
assert!(!r.success);
assert!(r.output.contains("no background task"), "got: {}", r.output);
}
#[tokio::test]
async fn execute_cancel_returns_forbidden_for_cross_caller() {
let (agents, processes) = fresh_registries();
let (id, _tx, _, observer) = agents.register_test_with_status("x", "y", Some(5));
let r = execute(
"CancelTask",
&json!({ "task_id": format!("agent:{id}") }).to_string(),
&agents,
&processes,
None,
)
.await;
assert!(!r.success);
assert!(
r.output.contains("not owned by this caller"),
"got: {}",
r.output
);
assert!(!observer.is_cancelled(), "forbidden must NOT fire token");
}
#[tokio::test]
async fn execute_cancel_rejects_malformed_json() {
let (agents, processes) = fresh_registries();
let r = execute("CancelTask", "not-json", &agents, &processes, None).await;
assert!(!r.success);
assert!(r.output.contains("invalid JSON"), "got: {}", r.output);
}
#[tokio::test]
async fn execute_cancel_rejects_missing_task_id() {
let (agents, processes) = fresh_registries();
let r = execute("CancelTask", "{}", &agents, &processes, None).await;
assert!(!r.success);
assert!(r.output.contains("missing required"), "got: {}", r.output);
}
#[tokio::test]
async fn execute_wait_returns_completed_for_finished_agent() {
let (agents, processes) = fresh_registries();
let (id, tx, status_tx, _) = agents.register_test_with_status("explore", "map", None);
tx.send(Ok(("final answer".into(), vec!["e1".into()])))
.unwrap();
status_tx
.send(AgentStatus::Completed {
summary: "final".into(),
})
.unwrap();
let r = execute(
"WaitTask",
&json!({ "task_id": format!("agent:{id}"), "timeout_secs": 1 }).to_string(),
&agents,
&processes,
None,
)
.await;
assert!(r.success, "got: {}", r.output);
let payload: Value = serde_json::from_str(&r.output).unwrap();
assert_eq!(payload["status"], "completed");
assert_eq!(payload["output"], "final answer");
assert_eq!(payload["events"].as_array().unwrap().len(), 1);
assert_eq!(agents.snapshot().len(), 0);
}
#[tokio::test]
async fn execute_wait_returns_timed_out_with_snapshot() {
let (agents, processes) = fresh_registries();
let (id, _tx, _status_tx, _observer) = agents.register_test_with_status("slow", "x", None);
let r = execute(
"WaitTask",
&json!({ "task_id": format!("agent:{id}"), "timeout_secs": 1 }).to_string(),
&agents,
&processes,
None,
)
.await;
assert!(r.success);
let payload: Value = serde_json::from_str(&r.output).unwrap();
assert_eq!(payload["status"], "timed_out");
assert_eq!(payload["current"]["task_id"], format!("agent:{id}"));
assert_eq!(agents.snapshot().len(), 1);
}
#[tokio::test]
async fn execute_wait_returns_cancelled_when_token_fires() {
let (agents, processes) = fresh_registries();
let (id, tx, status_tx, observer) = agents.register_test_with_status("slow", "x", None);
observer.cancel();
status_tx.send(AgentStatus::Cancelled).unwrap();
drop(tx);
let r = execute(
"WaitTask",
&json!({ "task_id": format!("agent:{id}"), "timeout_secs": 5 }).to_string(),
&agents,
&processes,
None,
)
.await;
assert!(
r.success,
"WaitTask on a cancelled task must still succeed: {}",
r.output
);
let payload: Value = serde_json::from_str(&r.output).unwrap();
assert_eq!(payload["status"], "cancelled");
assert_eq!(payload["task_id"], format!("agent:{id}"));
assert_eq!(agents.snapshot().len(), 0);
}
#[tokio::test]
async fn execute_unknown_tool_name_returns_error() {
let (agents, processes) = fresh_registries();
let r = execute("NotAToolWeKnow", "{}", &agents, &processes, None).await;
assert!(!r.success);
assert!(r.output.contains("unknown tool"));
}
}