use std::collections::HashSet;
use std::sync::{Arc, OnceLock, Weak};
use std::time::Duration;
use async_trait::async_trait;
use once_cell::sync::Lazy;
use regex::Regex;
use serde::Deserialize;
use serde_json::{json, Value};
use tokio::sync::{mpsc, Mutex};
use tracing::{info, warn};
use crate::agent::{Agent, StatusUpdate};
use crate::channels::ChannelHub;
use crate::traits::{AgentRole, StateStore, Tool, ToolCapabilities};
use crate::types::{ChannelContext, ChannelVisibility, UserRole};
pub struct SpawnAgentTool {
agent: OnceLock<Weak<Agent>>,
hub: OnceLock<Weak<ChannelHub>>,
state: Option<Arc<dyn StateStore>>,
max_response_chars: usize,
timeout_secs: u64,
executor_task_runs: Arc<Mutex<HashSet<String>>>,
}
#[cfg(test)]
const BACKGROUND_PROGRESS_INTERVAL_SECS: u64 = 1;
#[cfg(not(test))]
const BACKGROUND_PROGRESS_INTERVAL_SECS: u64 = 20;
impl SpawnAgentTool {
#[allow(dead_code)]
pub fn new(agent: Weak<Agent>, max_response_chars: usize, timeout_secs: u64) -> Self {
let lock = OnceLock::new();
let _ = lock.set(agent);
Self {
agent: lock,
hub: OnceLock::new(),
state: None,
max_response_chars,
timeout_secs,
executor_task_runs: Arc::new(Mutex::new(HashSet::new())),
}
}
pub fn new_deferred(max_response_chars: usize, timeout_secs: u64) -> Self {
Self {
agent: OnceLock::new(),
hub: OnceLock::new(),
state: None,
max_response_chars,
timeout_secs,
executor_task_runs: Arc::new(Mutex::new(HashSet::new())),
}
}
pub fn with_state(mut self, state: Arc<dyn StateStore>) -> Self {
self.state = Some(state);
self
}
pub fn set_agent(&self, agent: Weak<Agent>) {
self.agent
.set(agent)
.expect("SpawnAgentTool::set_agent called more than once");
}
fn get_agent(&self) -> anyhow::Result<std::sync::Arc<Agent>> {
let weak = self
.agent
.get()
.ok_or_else(|| anyhow::anyhow!("SpawnAgentTool: agent reference not set"))?;
weak.upgrade()
.ok_or_else(|| anyhow::anyhow!("SpawnAgentTool: parent agent has been dropped"))
}
pub fn set_hub(&self, hub: Weak<ChannelHub>) {
let _ = self.hub.set(hub);
}
fn get_hub(&self) -> Option<Arc<ChannelHub>> {
self.hub.get().and_then(|w| w.upgrade())
}
async fn try_begin_executor_task(&self, task_id: &str) -> bool {
let mut runs = self.executor_task_runs.lock().await;
if runs.contains(task_id) {
return false;
}
runs.insert(task_id.to_string());
true
}
async fn finish_executor_task(&self, task_id: &str) {
self.executor_task_runs.lock().await.remove(task_id);
}
}
fn truncate_utf8(s: &str, max_chars: usize) -> &str {
if s.len() <= max_chars {
return s;
}
let boundary = s
.char_indices()
.map(|(i, _)| i)
.take_while(|&i| i <= max_chars)
.last()
.unwrap_or(0);
&s[..boundary]
}
fn parse_leading_wait_seconds(task: &str) -> Option<u64> {
static LEADING_WAIT_RE: Lazy<Regex> = Lazy::new(|| {
Regex::new(
r"(?i)^\s*(?:wait\s+(?:for\s+)?|in\s+|after\s+)(\d+)\s*(seconds?|secs?|s|minutes?|mins?|min|m|hours?|hrs?|h)\b",
)
.expect("leading wait regex should compile")
});
let caps = LEADING_WAIT_RE.captures(task.trim())?;
let value: u64 = caps.get(1)?.as_str().parse().ok()?;
let unit = caps.get(2)?.as_str().to_ascii_lowercase();
match unit.as_str() {
"s" | "sec" | "secs" | "second" | "seconds" => Some(value),
"m" | "min" | "mins" | "minute" | "minutes" => Some(value.saturating_mul(60)),
"h" | "hr" | "hrs" | "hour" | "hours" => Some(value.saturating_mul(3600)),
_ => None,
}
}
fn strip_leading_wait(task: &str) -> String {
static STRIP_WAIT_RE: Lazy<Regex> = Lazy::new(|| {
Regex::new(
r"(?i)^\s*(?:wait\s+(?:for\s+)?|in\s+|after\s+)\d+\s*(?:seconds?|secs?|s|minutes?|mins?|min|m|hours?|hrs?|h)\s*[,;]?\s*(?:then\s+|and\s+|,\s*)?",
)
.expect("strip wait regex should compile")
});
let remainder = STRIP_WAIT_RE.replace(task.trim(), "").to_string();
let trimmed = remainder.trim().to_string();
if trimmed.len() < 3 {
String::new()
} else {
trimmed
}
}
async fn deliver_background_notification(
hub: Option<&Arc<ChannelHub>>,
state: Option<&Arc<dyn StateStore>>,
goal_id: &str,
session_id: &str,
notification_type: &str,
message: &str,
context: &str,
) {
let mut delivered = false;
if let Some(hub_arc) = hub {
if let Err(e) = hub_arc.send_text(session_id, message).await {
warn!(
session_id = %session_id,
goal_id = %goal_id,
notification_type = %notification_type,
error = %e,
"{context}: direct hub delivery failed"
);
} else {
delivered = true;
}
}
if delivered {
return;
}
if let Some(state_store) = state {
let entry =
crate::traits::NotificationEntry::new(goal_id, session_id, notification_type, message);
if let Err(e) = state_store.enqueue_notification(&entry).await {
warn!(
session_id = %session_id,
goal_id = %goal_id,
notification_type = %notification_type,
error = %e,
"{context}: enqueue fallback failed"
);
}
} else {
warn!(
session_id = %session_id,
goal_id = %goal_id,
notification_type = %notification_type,
"{context}: no hub and no queue fallback configured; update dropped"
);
}
}
#[derive(Deserialize)]
struct SpawnArgs {
mission: String,
task: String,
#[serde(default)]
background: bool,
#[serde(default)]
task_id: Option<String>,
#[serde(default)]
_session_id: Option<String>,
#[serde(default)]
_channel_visibility: Option<String>,
#[serde(default)]
_user_role: Option<String>,
#[serde(default)]
_task_id: Option<String>,
#[serde(default)]
_goal_id: Option<String>,
#[serde(default)]
_trusted_session: Option<bool>,
#[serde(default)]
_project_scope: Option<String>,
}
fn build_child_channel_context(args: &SpawnArgs) -> ChannelContext {
let visibility = args
._channel_visibility
.as_deref()
.map(ChannelVisibility::from_str_lossy)
.unwrap_or(ChannelVisibility::Internal);
ChannelContext {
visibility,
platform: "internal".to_string(),
channel_name: None,
channel_id: None,
sender_name: None,
sender_id: None,
channel_member_names: vec![],
user_id_map: std::collections::HashMap::new(),
trusted: args._trusted_session.unwrap_or(false),
}
}
#[async_trait]
impl Tool for SpawnAgentTool {
fn name(&self) -> &str {
"spawn_agent"
}
fn description(&self) -> &str {
"Spawn a sub-agent to handle a focused task autonomously. \
The sub-agent has access to all tools and runs its own reasoning loop. \
Use this for tasks that benefit from isolated, parallel reasoning."
}
fn schema(&self) -> Value {
json!({
"name": "spawn_agent",
"description": "Spawn a sub-agent to handle a focused task autonomously. \
The sub-agent has access to all tools and runs its own reasoning loop. \
Use this for complex sub-tasks that benefit from isolated context and focused reasoning.",
"parameters": {
"type": "object",
"properties": {
"mission": {
"type": "string",
"description": "High-level mission or role for the sub-agent \
(e.g. 'Research assistant focused on Python packaging')"
},
"task": {
"type": "string",
"description": "The specific task or question the sub-agent should accomplish"
},
"background": {
"type": "boolean",
"description": "When true, spawn the sub-agent in the background and return immediately. \
The result will be sent as a message when the sub-agent finishes. \
Use this for long-running tasks where the user doesn't need to wait.",
"default": false
},
"task_id": {
"type": "string",
"description": "Task ID to associate with this executor (used by task leads to connect executor work to task tracking)"
}
},
"required": ["mission", "task"],
"additionalProperties": false
}
})
}
fn capabilities(&self) -> ToolCapabilities {
ToolCapabilities {
read_only: false,
external_side_effect: false,
needs_approval: false,
idempotent: false,
high_impact_write: true,
}
}
async fn call(&self, arguments: &str) -> anyhow::Result<String> {
self.call_with_status(arguments, None).await
}
async fn call_with_status(
&self,
arguments: &str,
status_tx: Option<mpsc::Sender<StatusUpdate>>,
) -> anyhow::Result<String> {
let args: SpawnArgs = serde_json::from_str(arguments)?;
let agent = self.get_agent()?;
info!(
depth = agent.depth(),
max_depth = agent.max_depth(),
mission = %args.mission,
background = args.background,
"spawn_agent tool invoked"
);
let channel_ctx = build_child_channel_context(&args);
let user_role = match args._user_role.as_deref() {
Some("Owner") => UserRole::Owner,
Some("Guest") => UserRole::Guest,
_ => UserRole::Guest,
};
let child_role = if agent.role() == AgentRole::TaskLead {
Some(AgentRole::Executor)
} else {
None
};
let task_id_ref = args.task_id.or(args._task_id.clone());
let goal_id_ref = args._goal_id.clone();
let executor_task_id = if child_role == Some(AgentRole::Executor) {
let Some(task_id) = task_id_ref.clone() else {
return Ok(
"Blocked: TaskLead must pass task_id when spawning an executor. Claim a task first with manage_goal_tasks(action='claim_task')."
.to_string(),
);
};
if let Err(e) = agent
.validate_executor_task_for_spawn(&task_id, goal_id_ref.as_deref())
.await
{
return Ok(format!(
"Blocked executor spawn for task {}: {}",
task_id, e
));
}
if !self.try_begin_executor_task(&task_id).await {
return Ok(format!(
"Blocked: task {} already has an executor running. Wait for it to finish before spawning another.",
task_id
));
}
Some(task_id)
} else {
None
};
let mut effective_mission = args.mission.clone();
let mut effective_task = args.task.clone();
if let Some(wait_secs) = parse_leading_wait_seconds(&effective_task) {
let remainder = strip_leading_wait(&effective_task);
info!(
wait_secs,
has_remainder = !remainder.is_empty(),
"Intercepted leading wait in spawn_agent task; sleeping locally"
);
tokio::time::sleep(Duration::from_secs(wait_secs)).await;
if remainder.is_empty() {
if let Some(ref task_id) = executor_task_id {
self.finish_executor_task(task_id).await;
}
return Ok(format!("Waited for {} second(s).", wait_secs));
}
effective_task = remainder.clone();
if parse_leading_wait_seconds(&effective_mission).is_some()
|| effective_mission
.trim()
.eq_ignore_ascii_case(args.task.trim())
{
effective_mission = remainder;
}
}
if !args.background {
let result = self
.run_sync(
agent,
&effective_mission,
&effective_task,
status_tx,
channel_ctx,
user_role,
child_role,
goal_id_ref.as_deref(),
task_id_ref.as_deref(),
args._project_scope.as_deref(),
)
.await;
if let Some(ref task_id) = executor_task_id {
self.finish_executor_task(task_id).await;
}
return result;
}
let hub = self.get_hub();
let state = self.state.clone();
if hub.is_none() && state.is_none() {
info!(
"Background mode requested but no hub/state notification path is available, falling back to sync"
);
let result = self
.run_sync(
agent,
&effective_mission,
&effective_task,
status_tx,
channel_ctx,
user_role,
child_role,
goal_id_ref.as_deref(),
task_id_ref.as_deref(),
args._project_scope.as_deref(),
)
.await;
if let Some(ref task_id) = executor_task_id {
self.finish_executor_task(task_id).await;
}
return result;
}
let session_id = match args._session_id {
Some(ref id) if !id.is_empty() => id.clone(),
_ => {
info!("Background mode requested but no session_id, falling back to sync");
let result = self
.run_sync(
agent,
&effective_mission,
&effective_task,
status_tx,
channel_ctx,
user_role,
child_role,
goal_id_ref.as_deref(),
task_id_ref.as_deref(),
args._project_scope.as_deref(),
)
.await;
if let Some(ref task_id) = executor_task_id {
self.finish_executor_task(task_id).await;
}
return result;
}
};
let task = effective_task.clone();
let mission = effective_mission.clone();
let timeout_secs = self.timeout_secs;
let max_response_chars = self.max_response_chars;
let executor_task_runs = Arc::clone(&self.executor_task_runs);
let executor_task_id_for_bg = executor_task_id;
let notify_goal_id = goal_id_ref.clone().unwrap_or_else(|| "global".to_string());
let notify_status_tx = status_tx.clone();
tokio::spawn(async move {
let started_at = std::time::Instant::now();
let mut progress_interval =
tokio::time::interval(Duration::from_secs(BACKGROUND_PROGRESS_INTERVAL_SECS));
progress_interval.tick().await; let timeout_duration = Duration::from_secs(timeout_secs);
let mut result_fut = std::pin::pin!(tokio::time::timeout(
timeout_duration,
agent.spawn_child(
&mission,
&task,
status_tx.clone(),
channel_ctx,
user_role,
child_role,
goal_id_ref.as_deref(),
task_id_ref.as_deref(),
args._project_scope.as_deref(),
),
));
let result = loop {
tokio::select! {
res = &mut result_fut => break res,
_ = progress_interval.tick() => {
let elapsed_secs = started_at.elapsed().as_secs();
let progress_message = format!(
"Background sub-agent still running after {}s.\nMission: {}",
elapsed_secs, mission
);
if let Some(ref tx) = notify_status_tx {
let _ = tx.try_send(StatusUpdate::ToolProgress {
name: "spawn_agent".to_string(),
chunk: format!(
"Background sub-agent still running ({}s): {}",
elapsed_secs, mission
),
});
}
deliver_background_notification(
hub.as_ref(),
state.as_ref(),
¬ify_goal_id,
&session_id,
"progress",
&progress_message,
"spawn_agent background progress notifier",
)
.await;
}
}
};
let (notification_type, message) = match result {
Ok(Ok(response)) => {
let text = if response.len() > max_response_chars {
truncate_utf8(&response, max_response_chars).to_string()
} else {
response
};
(
"completed",
format!(
"\u{2705} Background task complete\nMission: {}\n\n{}",
mission, text
),
)
}
Ok(Err(e)) => (
"failed",
format!(
"\u{274c} Background task failed\nMission: {}\nError: {}",
mission, e
),
),
Err(_) => (
"failed",
format!(
"\u{23f1} Background task timed out\nMission: {}\nTimed out after {}s",
mission, timeout_secs
),
),
};
deliver_background_notification(
hub.as_ref(),
state.as_ref(),
¬ify_goal_id,
&session_id,
notification_type,
&message,
"spawn_agent background completion notifier",
)
.await;
if let Some(task_id) = executor_task_id_for_bg {
executor_task_runs.lock().await.remove(&task_id);
}
});
Ok(format!(
"Sub-agent spawned in background for mission: \"{}\". \
The result will be sent as a message when it completes.",
args.mission
))
}
}
impl SpawnAgentTool {
#[allow(clippy::too_many_arguments)]
async fn run_sync(
&self,
agent: Arc<Agent>,
mission: &str,
task: &str,
status_tx: Option<mpsc::Sender<StatusUpdate>>,
channel_ctx: ChannelContext,
user_role: UserRole,
child_role: Option<AgentRole>,
goal_id: Option<&str>,
task_id: Option<&str>,
project_scope: Option<&str>,
) -> anyhow::Result<String> {
let timeout_duration = Duration::from_secs(self.timeout_secs);
let result = tokio::time::timeout(
timeout_duration,
agent.spawn_child(
mission,
task,
status_tx,
channel_ctx,
user_role,
child_role,
goal_id,
task_id,
project_scope,
),
)
.await;
match result {
Ok(Ok(response)) => {
let max_len = self.max_response_chars;
if response.len() > max_len {
let truncated = truncate_utf8(&response, max_len);
Ok(format!(
"{}\n\n[Sub-agent response truncated at {} chars]",
truncated, max_len
))
} else {
Ok(response)
}
}
Ok(Err(e)) => Ok(format!("Sub-agent error: {}", e)),
Err(_) => {
if child_role == Some(AgentRole::Executor) {
if let Some(task_id) = task_id {
agent
.mark_executor_task_timeout(task_id, self.timeout_secs)
.await;
}
}
Ok(format!(
"Sub-agent timed out after {} seconds",
self.timeout_secs
))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::embeddings::EmbeddingService;
use crate::state::SqliteStateStore;
use crate::traits::{NotificationStore, StateStore};
use std::sync::Arc;
#[test]
fn truncate_utf8_ascii() {
assert_eq!(truncate_utf8("hello world", 5), "hello");
assert_eq!(truncate_utf8("hello", 10), "hello");
assert_eq!(truncate_utf8("hello", 5), "hello");
}
#[test]
fn truncate_utf8_multibyte() {
let s = "🔥🔥🔥";
assert_eq!(s.len(), 12);
assert_eq!(truncate_utf8(s, 4), "🔥");
assert_eq!(truncate_utf8(s, 5), "🔥");
assert_eq!(truncate_utf8(s, 8), "🔥🔥");
assert_eq!(truncate_utf8(s, 1), "");
}
#[test]
fn truncate_utf8_mixed() {
let s = "hi🌍!";
assert_eq!(truncate_utf8(s, 3), "hi");
assert_eq!(truncate_utf8(s, 6), "hi🌍");
assert_eq!(truncate_utf8(s, 7), "hi🌍!");
}
#[test]
fn truncate_utf8_empty() {
assert_eq!(truncate_utf8("", 10), "");
assert_eq!(truncate_utf8("", 0), "");
}
#[test]
fn deferred_initialization_not_set() {
let tool = SpawnAgentTool::new_deferred(8000, 300);
let result = tool.get_agent();
assert!(result.is_err());
assert!(result.err().unwrap().to_string().contains("not set"));
}
#[test]
fn config_defaults() {
use crate::config::SubagentsConfig;
let cfg = SubagentsConfig::default();
assert!(cfg.enabled);
assert_eq!(cfg.max_depth, 3);
assert_eq!(cfg.max_iterations, 10);
assert_eq!(cfg.max_response_chars, 8000);
assert_eq!(cfg.timeout_secs, 300);
}
#[test]
fn deferred_hub_not_set() {
let tool = SpawnAgentTool::new_deferred(8000, 300);
assert!(tool.get_hub().is_none());
}
#[test]
fn spawn_args_background_default() {
let json = r#"{"mission": "test", "task": "do stuff"}"#;
let args: SpawnArgs = serde_json::from_str(json).unwrap();
assert!(!args.background);
assert!(args._session_id.is_none());
assert!(args._channel_visibility.is_none());
}
#[test]
fn spawn_args_background_true() {
let json = r#"{"mission": "test", "task": "do stuff", "background": true, "_session_id": "tg:123"}"#;
let args: SpawnArgs = serde_json::from_str(json).unwrap();
assert!(args.background);
assert_eq!(args._session_id.as_deref(), Some("tg:123"));
}
#[test]
fn spawn_args_with_channel_visibility() {
let json = r#"{"mission": "test", "task": "do stuff", "_channel_visibility": "public"}"#;
let args: SpawnArgs = serde_json::from_str(json).unwrap();
assert_eq!(args._channel_visibility.as_deref(), Some("public"));
}
#[test]
fn spawn_args_with_trusted_session() {
let json = r#"{"mission":"test","task":"do stuff","_trusted_session":true,"_channel_visibility":"internal"}"#;
let args: SpawnArgs = serde_json::from_str(json).unwrap();
let channel_ctx = build_child_channel_context(&args);
assert_eq!(channel_ctx.visibility, ChannelVisibility::Internal);
assert!(channel_ctx.trusted);
}
#[test]
fn parse_and_strip_leading_wait() {
assert_eq!(
parse_leading_wait_seconds("wait for 2 minutes then run df"),
Some(120)
);
assert_eq!(
strip_leading_wait("wait for 2 minutes then run df"),
"run df"
);
assert_eq!(parse_leading_wait_seconds("in 45 sec check disk"), Some(45));
assert_eq!(strip_leading_wait("after 1 hour, reboot"), "reboot");
}
#[test]
fn strip_leading_wait_pure_wait_returns_empty() {
assert_eq!(parse_leading_wait_seconds("wait 5 min"), Some(300));
assert!(strip_leading_wait("wait 5 min").is_empty());
}
#[tokio::test]
async fn executor_task_lock_deduplicates_concurrent_spawns() {
let tool = SpawnAgentTool::new_deferred(8000, 300);
assert!(tool.try_begin_executor_task("task-1").await);
assert!(
!tool.try_begin_executor_task("task-1").await,
"Second acquire should be rejected while first is active"
);
tool.finish_executor_task("task-1").await;
assert!(
tool.try_begin_executor_task("task-1").await,
"Task lock should be reusable after release"
);
}
#[tokio::test]
async fn background_notification_falls_back_to_queue_when_hub_missing() {
let db_file = tempfile::NamedTempFile::new().unwrap();
let db_path = db_file.path().display().to_string();
let embedding_service = Arc::new(EmbeddingService::new().unwrap());
let state = Arc::new(
SqliteStateStore::new(&db_path, 100, None, embedding_service)
.await
.unwrap(),
);
let state_dyn: Arc<dyn StateStore> = state.clone();
deliver_background_notification(
None,
Some(&state_dyn),
"goal_spawn_test",
"sess_spawn_test",
"progress",
"Background sub-agent still running after 20s.\nMission: test",
"spawn_test",
)
.await;
let pending = state.get_pending_notifications(10).await.unwrap();
assert!(pending.iter().any(|entry| {
entry.goal_id == "goal_spawn_test"
&& entry.session_id == "sess_spawn_test"
&& entry.notification_type == "progress"
&& entry.message.contains("still running")
}));
}
}