1use crate::AgentRuntime;
4use crate::session::AgentSession;
5use crate::subagent::{SubAgentId, SubAgentPool};
6
7use astrid_audit::{AuditAction, AuditOutcome, AuthorizationProof};
8use astrid_core::{Frontend, SessionId};
9use astrid_llm::{LlmProvider, Message, MessageContent, MessageRole};
10use astrid_tools::{SubAgentRequest, SubAgentResult, SubAgentSpawner, truncate_at_char_boundary};
11
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16pub const DEFAULT_SUBAGENT_TIMEOUT: Duration = Duration::from_secs(300);
18
19pub struct SubAgentExecutor<P: LlmProvider, F: Frontend + 'static> {
23 runtime: Arc<AgentRuntime<P>>,
25 pool: Arc<SubAgentPool>,
27 frontend: Arc<F>,
29 parent_user_id: [u8; 8],
31 parent_subagent_id: Option<SubAgentId>,
33 parent_session_id: SessionId,
35 parent_allowance_store: Arc<astrid_approval::AllowanceStore>,
37 parent_capabilities: Arc<astrid_capabilities::CapabilityStore>,
39 parent_budget_tracker: Arc<astrid_approval::budget::BudgetTracker>,
41 default_timeout: Duration,
43 parent_callsign: Option<String>,
45 parent_capsule_context: Option<String>,
47}
48
49impl<P: LlmProvider, F: Frontend + 'static> SubAgentExecutor<P, F> {
50 #[allow(clippy::too_many_arguments)]
52 pub fn new(
53 runtime: Arc<AgentRuntime<P>>,
54 pool: Arc<SubAgentPool>,
55 frontend: Arc<F>,
56 parent_user_id: [u8; 8],
57 parent_subagent_id: Option<SubAgentId>,
58 parent_session_id: SessionId,
59 parent_allowance_store: Arc<astrid_approval::AllowanceStore>,
60 parent_capabilities: Arc<astrid_capabilities::CapabilityStore>,
61 parent_budget_tracker: Arc<astrid_approval::budget::BudgetTracker>,
62 default_timeout: Duration,
63 parent_callsign: Option<String>,
64 parent_capsule_context: Option<String>,
65 ) -> Self {
66 Self {
67 runtime,
68 pool,
69 frontend,
70 parent_user_id,
71 parent_subagent_id,
72 parent_session_id,
73 parent_allowance_store,
74 parent_capabilities,
75 parent_budget_tracker,
76 default_timeout,
77 parent_callsign,
78 parent_capsule_context,
79 }
80 }
81}
82
83#[async_trait::async_trait]
84impl<P: LlmProvider + 'static, F: Frontend + 'static> SubAgentSpawner for SubAgentExecutor<P, F> {
85 #[allow(clippy::too_many_lines)]
86 async fn spawn(&self, request: SubAgentRequest) -> Result<SubAgentResult, String> {
87 let start = std::time::Instant::now();
88 let timeout = request.timeout.unwrap_or(self.default_timeout);
89
90 let handle = self
92 .pool
93 .spawn(&request.description, self.parent_subagent_id.clone())
94 .await
95 .map_err(|e| e.to_string())?;
96
97 let handle_id = handle.id.clone();
98
99 info!(
100 subagent_id = %handle.id,
101 depth = handle.depth,
102 description = %request.description,
103 "Sub-agent spawned"
104 );
105
106 handle.mark_running().await;
108
109 let session_id = SessionId::new();
116
117 let safe_description = if request.description.len() > 200 {
120 format!(
121 "{}...",
122 truncate_at_char_boundary(&request.description, 200)
123 )
124 } else {
125 request.description.clone()
126 };
127 let identity = if let Some(ref callsign) = self.parent_callsign {
128 format!("You are {callsign} (sub-agent).")
129 } else {
130 "You are a focused sub-agent.".to_string()
131 };
132 let subagent_system_prompt = format!(
133 "{identity} Your task:\n\n{safe_description}\n\n\
134 Complete this task and provide a clear, concise result. \
135 Do not ask for clarification — work with what you have. \
136 When done, provide your final answer as a clear summary.",
137 );
138
139 let mut session = AgentSession::with_shared_stores(
140 session_id.clone(),
141 self.parent_user_id,
142 subagent_system_prompt,
143 Arc::clone(&self.parent_allowance_store),
144 Arc::clone(&self.parent_capabilities),
145 Arc::clone(&self.parent_budget_tracker),
146 );
147 session.capsule_context = self.parent_capsule_context.clone();
148
149 {
151 if let Err(e) = self.runtime.audit().append(
152 self.parent_session_id.clone(),
153 AuditAction::SubAgentSpawned {
154 parent_session_id: self.parent_session_id.0.to_string(),
155 child_session_id: session_id.0.to_string(),
156 description: request.description.clone(),
157 },
158 AuthorizationProof::System {
159 reason: format!("sub-agent spawned for: {}", request.description),
160 },
161 AuditOutcome::success(),
162 ) {
163 warn!(error = %e, "Failed to audit sub-agent spawn linkage");
164 }
165 }
166
167 {
169 if let Err(e) = self.runtime.audit().append(
170 session_id.clone(),
171 AuditAction::SessionStarted {
172 user_id: self.parent_user_id,
173 frontend: "sub-agent".to_string(),
174 },
175 AuthorizationProof::System {
176 reason: format!("sub-agent for: {}", request.description),
177 },
178 AuditOutcome::success(),
179 ) {
180 warn!(error = %e, "Failed to audit sub-agent session start");
181 }
182 }
183
184 let cancel_token = self.pool.cancellation_token();
191 let loop_result = tokio::select! {
192 biased;
193 () = cancel_token.cancelled() => None,
194 result = tokio::time::timeout(
195 timeout,
196 self.runtime.run_subagent_turn(
197 &mut session,
198 &request.prompt,
199 Arc::clone(&self.frontend),
200 Some(handle_id.clone()),
201 ),
202 ) => Some(result),
203 };
204
205 let tool_call_count = session.metadata.tool_call_count;
207 #[allow(clippy::cast_possible_truncation)]
209 let duration_ms = start.elapsed().as_millis() as u64;
210
211 let result = match loop_result {
212 Some(Ok(Ok(()))) => {
213 let output = extract_last_assistant_text(&session.messages);
215
216 debug!(
217 subagent_id = %handle_id,
218 duration_ms,
219 tool_calls = tool_call_count,
220 output_len = output.len(),
221 "Sub-agent completed successfully"
222 );
223
224 handle.complete(&output).await;
225
226 SubAgentResult {
227 success: true,
228 output,
229 duration_ms,
230 tool_calls: tool_call_count,
231 error: None,
232 }
233 },
234 Some(Ok(Err(e))) => {
235 let error_msg = e.to_string();
236 let partial_output = extract_last_assistant_text(&session.messages);
237 warn!(
238 subagent_id = %handle_id,
239 error = %error_msg,
240 partial_output_len = partial_output.len(),
241 duration_ms,
242 "Sub-agent failed"
243 );
244
245 handle.fail(&error_msg).await;
246
247 SubAgentResult {
248 success: false,
249 output: partial_output,
250 duration_ms,
251 tool_calls: tool_call_count,
252 error: Some(error_msg),
253 }
254 },
255 Some(Err(_elapsed)) => {
256 let partial_output = extract_last_assistant_text(&session.messages);
257 warn!(
258 subagent_id = %handle_id,
259 timeout_secs = timeout.as_secs(),
260 partial_output_len = partial_output.len(),
261 duration_ms,
262 "Sub-agent timed out"
263 );
264
265 handle.timeout().await;
266
267 SubAgentResult {
268 success: false,
269 output: partial_output,
270 duration_ms,
271 tool_calls: tool_call_count,
272 error: Some(format!(
273 "Sub-agent timed out after {} seconds",
274 timeout.as_secs()
275 )),
276 }
277 },
278 None => {
279 let partial_output = extract_last_assistant_text(&session.messages);
281 warn!(
282 subagent_id = %handle_id,
283 partial_output_len = partial_output.len(),
284 duration_ms,
285 "Sub-agent cancelled via token"
286 );
287
288 handle.cancel().await;
289
290 SubAgentResult {
291 success: false,
292 output: partial_output,
293 duration_ms,
294 tool_calls: tool_call_count,
295 error: Some("Sub-agent cancelled".to_string()),
296 }
297 },
298 };
299
300 self.pool.release(&handle_id).await;
302
303 {
305 let reason = if result.success {
306 "completed".to_string()
307 } else {
308 result.error.as_deref().unwrap_or("failed").to_string()
309 };
310 if let Err(e) = self.runtime.audit().append(
311 session_id,
312 AuditAction::SessionEnded {
313 reason,
314 duration_secs: duration_ms / 1000,
315 },
316 AuthorizationProof::System {
317 reason: "sub-agent ended".to_string(),
318 },
319 AuditOutcome::success(),
320 ) {
321 warn!(error = %e, "Failed to audit sub-agent session end");
322 }
323 }
324
325 Ok(result)
326 }
327}
328
329fn extract_last_assistant_text(messages: &[Message]) -> String {
331 messages
332 .iter()
333 .rev()
334 .find(|m| m.role == MessageRole::Assistant)
335 .and_then(|m| match &m.content {
336 MessageContent::Text(text) => Some(text.clone()),
337 _ => None,
338 })
339 .unwrap_or_else(|| "(sub-agent produced no text output)".to_string())
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_extract_last_assistant_text() {
348 let messages = vec![
349 Message::user("Hello"),
350 Message::assistant("First response"),
351 Message::user("Another question"),
352 Message::assistant("Final answer"),
353 ];
354 assert_eq!(extract_last_assistant_text(&messages), "Final answer");
355 }
356
357 #[test]
358 fn test_extract_last_assistant_text_no_assistant_returns_fallback() {
359 let messages = vec![Message::user("Hello")];
360 assert_eq!(
361 extract_last_assistant_text(&messages),
362 "(sub-agent produced no text output)"
363 );
364 }
365
366 #[test]
367 fn test_extract_last_assistant_text_empty_returns_fallback() {
368 let messages: Vec<Message> = vec![];
369 assert_eq!(
370 extract_last_assistant_text(&messages),
371 "(sub-agent produced no text output)"
372 );
373 }
374
375 fn safe_description(desc: &str) -> String {
377 if desc.len() > 200 {
378 format!("{}...", truncate_at_char_boundary(desc, 200))
379 } else {
380 desc.to_string()
381 }
382 }
383
384 #[test]
387 fn test_safe_description_multibyte_4byte_emoji() {
388 let mut desc = "x".repeat(198);
389 desc.push('🦀'); assert!(desc.len() > 200);
391
392 let safe = safe_description(&desc);
393 assert_eq!(safe, format!("{}...", "x".repeat(198)));
394 }
395
396 #[test]
398 fn test_safe_description_multibyte_3byte_char() {
399 let mut desc = "x".repeat(199);
400 desc.push('€'); assert!(desc.len() > 200);
402
403 let safe = safe_description(&desc);
404 assert_eq!(safe, format!("{}...", "x".repeat(199)));
405 }
406
407 #[test]
409 fn test_safe_description_multibyte_2byte_char() {
410 let mut desc = "x".repeat(199);
411 desc.push('ñ'); assert!(desc.len() > 200);
413
414 let safe = safe_description(&desc);
415 assert_eq!(safe, format!("{}...", "x".repeat(199)));
416 }
417
418 #[test]
419 fn test_safe_description_short_passes_through() {
420 assert_eq!(safe_description("short"), "short");
421 }
422
423 #[test]
424 fn test_safe_description_exact_200_bytes() {
425 let desc = "x".repeat(200);
426 assert_eq!(safe_description(&desc), desc);
427 }
428}