use crate::AgentRuntime;
use crate::session::AgentSession;
use crate::subagent::{SubAgentId, SubAgentPool};
use astrid_audit::{AuditAction, AuditOutcome, AuthorizationProof};
use astrid_core::{Frontend, SessionId};
use astrid_llm::{LlmProvider, Message, MessageContent, MessageRole};
use astrid_tools::{SubAgentRequest, SubAgentResult, SubAgentSpawner, truncate_at_char_boundary};
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, info, warn};
pub const DEFAULT_SUBAGENT_TIMEOUT: Duration = Duration::from_secs(300);
pub struct SubAgentExecutor<P: LlmProvider, F: Frontend + 'static> {
runtime: Arc<AgentRuntime<P>>,
pool: Arc<SubAgentPool>,
frontend: Arc<F>,
parent_user_id: [u8; 8],
parent_subagent_id: Option<SubAgentId>,
parent_session_id: SessionId,
parent_allowance_store: Arc<astrid_approval::AllowanceStore>,
parent_capabilities: Arc<astrid_capabilities::CapabilityStore>,
parent_budget_tracker: Arc<astrid_approval::budget::BudgetTracker>,
default_timeout: Duration,
parent_callsign: Option<String>,
parent_capsule_context: Option<String>,
}
impl<P: LlmProvider, F: Frontend + 'static> SubAgentExecutor<P, F> {
#[allow(clippy::too_many_arguments)]
pub fn new(
runtime: Arc<AgentRuntime<P>>,
pool: Arc<SubAgentPool>,
frontend: Arc<F>,
parent_user_id: [u8; 8],
parent_subagent_id: Option<SubAgentId>,
parent_session_id: SessionId,
parent_allowance_store: Arc<astrid_approval::AllowanceStore>,
parent_capabilities: Arc<astrid_capabilities::CapabilityStore>,
parent_budget_tracker: Arc<astrid_approval::budget::BudgetTracker>,
default_timeout: Duration,
parent_callsign: Option<String>,
parent_capsule_context: Option<String>,
) -> Self {
Self {
runtime,
pool,
frontend,
parent_user_id,
parent_subagent_id,
parent_session_id,
parent_allowance_store,
parent_capabilities,
parent_budget_tracker,
default_timeout,
parent_callsign,
parent_capsule_context,
}
}
}
#[async_trait::async_trait]
impl<P: LlmProvider + 'static, F: Frontend + 'static> SubAgentSpawner for SubAgentExecutor<P, F> {
#[allow(clippy::too_many_lines)]
async fn spawn(&self, request: SubAgentRequest) -> Result<SubAgentResult, String> {
let start = std::time::Instant::now();
let timeout = request.timeout.unwrap_or(self.default_timeout);
let handle = self
.pool
.spawn(&request.description, self.parent_subagent_id.clone())
.await
.map_err(|e| e.to_string())?;
let handle_id = handle.id.clone();
info!(
subagent_id = %handle.id,
depth = handle.depth,
description = %request.description,
"Sub-agent spawned"
);
handle.mark_running().await;
let session_id = SessionId::new();
let safe_description = if request.description.len() > 200 {
format!(
"{}...",
truncate_at_char_boundary(&request.description, 200)
)
} else {
request.description.clone()
};
let identity = if let Some(ref callsign) = self.parent_callsign {
format!("You are {callsign} (sub-agent).")
} else {
"You are a focused sub-agent.".to_string()
};
let subagent_system_prompt = format!(
"{identity} Your task:\n\n{safe_description}\n\n\
Complete this task and provide a clear, concise result. \
Do not ask for clarification — work with what you have. \
When done, provide your final answer as a clear summary.",
);
let mut session = AgentSession::with_shared_stores(
session_id.clone(),
self.parent_user_id,
subagent_system_prompt,
Arc::clone(&self.parent_allowance_store),
Arc::clone(&self.parent_capabilities),
Arc::clone(&self.parent_budget_tracker),
);
session.capsule_context = self.parent_capsule_context.clone();
{
if let Err(e) = self.runtime.audit().append(
self.parent_session_id.clone(),
AuditAction::SubAgentSpawned {
parent_session_id: self.parent_session_id.0.to_string(),
child_session_id: session_id.0.to_string(),
description: request.description.clone(),
},
AuthorizationProof::System {
reason: format!("sub-agent spawned for: {}", request.description),
},
AuditOutcome::success(),
) {
warn!(error = %e, "Failed to audit sub-agent spawn linkage");
}
}
{
if let Err(e) = self.runtime.audit().append(
session_id.clone(),
AuditAction::SessionStarted {
user_id: self.parent_user_id,
frontend: "sub-agent".to_string(),
},
AuthorizationProof::System {
reason: format!("sub-agent for: {}", request.description),
},
AuditOutcome::success(),
) {
warn!(error = %e, "Failed to audit sub-agent session start");
}
}
let cancel_token = self.pool.cancellation_token();
let loop_result = tokio::select! {
biased;
() = cancel_token.cancelled() => None,
result = tokio::time::timeout(
timeout,
self.runtime.run_subagent_turn(
&mut session,
&request.prompt,
Arc::clone(&self.frontend),
Some(handle_id.clone()),
),
) => Some(result),
};
let tool_call_count = session.metadata.tool_call_count;
#[allow(clippy::cast_possible_truncation)]
let duration_ms = start.elapsed().as_millis() as u64;
let result = match loop_result {
Some(Ok(Ok(()))) => {
let output = extract_last_assistant_text(&session.messages);
debug!(
subagent_id = %handle_id,
duration_ms,
tool_calls = tool_call_count,
output_len = output.len(),
"Sub-agent completed successfully"
);
handle.complete(&output).await;
SubAgentResult {
success: true,
output,
duration_ms,
tool_calls: tool_call_count,
error: None,
}
},
Some(Ok(Err(e))) => {
let error_msg = e.to_string();
let partial_output = extract_last_assistant_text(&session.messages);
warn!(
subagent_id = %handle_id,
error = %error_msg,
partial_output_len = partial_output.len(),
duration_ms,
"Sub-agent failed"
);
handle.fail(&error_msg).await;
SubAgentResult {
success: false,
output: partial_output,
duration_ms,
tool_calls: tool_call_count,
error: Some(error_msg),
}
},
Some(Err(_elapsed)) => {
let partial_output = extract_last_assistant_text(&session.messages);
warn!(
subagent_id = %handle_id,
timeout_secs = timeout.as_secs(),
partial_output_len = partial_output.len(),
duration_ms,
"Sub-agent timed out"
);
handle.timeout().await;
SubAgentResult {
success: false,
output: partial_output,
duration_ms,
tool_calls: tool_call_count,
error: Some(format!(
"Sub-agent timed out after {} seconds",
timeout.as_secs()
)),
}
},
None => {
let partial_output = extract_last_assistant_text(&session.messages);
warn!(
subagent_id = %handle_id,
partial_output_len = partial_output.len(),
duration_ms,
"Sub-agent cancelled via token"
);
handle.cancel().await;
SubAgentResult {
success: false,
output: partial_output,
duration_ms,
tool_calls: tool_call_count,
error: Some("Sub-agent cancelled".to_string()),
}
},
};
self.pool.release(&handle_id).await;
{
let reason = if result.success {
"completed".to_string()
} else {
result.error.as_deref().unwrap_or("failed").to_string()
};
if let Err(e) = self.runtime.audit().append(
session_id,
AuditAction::SessionEnded {
reason,
duration_secs: duration_ms / 1000,
},
AuthorizationProof::System {
reason: "sub-agent ended".to_string(),
},
AuditOutcome::success(),
) {
warn!(error = %e, "Failed to audit sub-agent session end");
}
}
Ok(result)
}
}
fn extract_last_assistant_text(messages: &[Message]) -> String {
messages
.iter()
.rev()
.find(|m| m.role == MessageRole::Assistant)
.and_then(|m| match &m.content {
MessageContent::Text(text) => Some(text.clone()),
_ => None,
})
.unwrap_or_else(|| "(sub-agent produced no text output)".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_last_assistant_text() {
let messages = vec![
Message::user("Hello"),
Message::assistant("First response"),
Message::user("Another question"),
Message::assistant("Final answer"),
];
assert_eq!(extract_last_assistant_text(&messages), "Final answer");
}
#[test]
fn test_extract_last_assistant_text_no_assistant_returns_fallback() {
let messages = vec![Message::user("Hello")];
assert_eq!(
extract_last_assistant_text(&messages),
"(sub-agent produced no text output)"
);
}
#[test]
fn test_extract_last_assistant_text_empty_returns_fallback() {
let messages: Vec<Message> = vec![];
assert_eq!(
extract_last_assistant_text(&messages),
"(sub-agent produced no text output)"
);
}
fn safe_description(desc: &str) -> String {
if desc.len() > 200 {
format!("{}...", truncate_at_char_boundary(desc, 200))
} else {
desc.to_string()
}
}
#[test]
fn test_safe_description_multibyte_4byte_emoji() {
let mut desc = "x".repeat(198);
desc.push('🦀'); assert!(desc.len() > 200);
let safe = safe_description(&desc);
assert_eq!(safe, format!("{}...", "x".repeat(198)));
}
#[test]
fn test_safe_description_multibyte_3byte_char() {
let mut desc = "x".repeat(199);
desc.push('€'); assert!(desc.len() > 200);
let safe = safe_description(&desc);
assert_eq!(safe, format!("{}...", "x".repeat(199)));
}
#[test]
fn test_safe_description_multibyte_2byte_char() {
let mut desc = "x".repeat(199);
desc.push('ñ'); assert!(desc.len() > 200);
let safe = safe_description(&desc);
assert_eq!(safe, format!("{}...", "x".repeat(199)));
}
#[test]
fn test_safe_description_short_passes_through() {
assert_eq!(safe_description("short"), "short");
}
#[test]
fn test_safe_description_exact_200_bytes() {
let desc = "x".repeat(200);
assert_eq!(safe_description(&desc), desc);
}
}