Skip to main content

astrid_runtime/
subagent_executor.rs

1//! Sub-agent executor — implements `SubAgentSpawner` using the runtime's agentic loop.
2
3use crate::AgentRuntime;
4use crate::session::AgentSession;
5use crate::subagent::{SubAgentId, SubAgentPool};
6
7use astrid_audit::{AuditAction, AuditOutcome, AuthorizationProof};
8use astrid_core::{Frontend, SessionId};
9use astrid_llm::{LlmProvider, Message, MessageContent, MessageRole};
10use astrid_tools::{SubAgentRequest, SubAgentResult, SubAgentSpawner};
11
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16/// Default sub-agent timeout (5 minutes).
17pub const DEFAULT_SUBAGENT_TIMEOUT: Duration = Duration::from_secs(300);
18
19/// Executor that spawns sub-agents through the runtime's agentic loop.
20///
21/// Created per-turn and injected into `ToolContext` as `Arc<dyn SubAgentSpawner>`.
22pub struct SubAgentExecutor<P: LlmProvider, F: Frontend + 'static> {
23    /// The runtime (owns LLM, MCP, audit, etc.).
24    runtime: Arc<AgentRuntime<P>>,
25    /// The shared sub-agent pool (enforces concurrency/depth).
26    pool: Arc<SubAgentPool>,
27    /// The frontend for this turn (for approval forwarding).
28    frontend: Arc<F>,
29    /// Parent user ID (inherited by child sessions).
30    parent_user_id: [u8; 8],
31    /// Parent sub-agent ID (if this executor is itself inside a sub-agent).
32    parent_subagent_id: Option<SubAgentId>,
33    /// Parent session ID (for audit linkage).
34    parent_session_id: SessionId,
35    /// Parent's allowance store (shared with child for permission inheritance).
36    parent_allowance_store: Arc<astrid_approval::AllowanceStore>,
37    /// Parent's capability store (shared with child for capability inheritance).
38    parent_capabilities: Arc<astrid_capabilities::CapabilityStore>,
39    /// Parent's budget tracker (shared with child so spend is visible bidirectionally).
40    parent_budget_tracker: Arc<astrid_approval::budget::BudgetTracker>,
41    /// Default timeout for sub-agents.
42    default_timeout: Duration,
43}
44
45impl<P: LlmProvider, F: Frontend + 'static> SubAgentExecutor<P, F> {
46    /// Create a new sub-agent executor.
47    #[allow(clippy::too_many_arguments)]
48    pub fn new(
49        runtime: Arc<AgentRuntime<P>>,
50        pool: Arc<SubAgentPool>,
51        frontend: Arc<F>,
52        parent_user_id: [u8; 8],
53        parent_subagent_id: Option<SubAgentId>,
54        parent_session_id: SessionId,
55        parent_allowance_store: Arc<astrid_approval::AllowanceStore>,
56        parent_capabilities: Arc<astrid_capabilities::CapabilityStore>,
57        parent_budget_tracker: Arc<astrid_approval::budget::BudgetTracker>,
58        default_timeout: Duration,
59    ) -> Self {
60        Self {
61            runtime,
62            pool,
63            frontend,
64            parent_user_id,
65            parent_subagent_id,
66            parent_session_id,
67            parent_allowance_store,
68            parent_capabilities,
69            parent_budget_tracker,
70            default_timeout,
71        }
72    }
73}
74
75#[async_trait::async_trait]
76impl<P: LlmProvider + 'static, F: Frontend + 'static> SubAgentSpawner for SubAgentExecutor<P, F> {
77    #[allow(clippy::too_many_lines)]
78    async fn spawn(&self, request: SubAgentRequest) -> Result<SubAgentResult, String> {
79        let start = std::time::Instant::now();
80        let timeout = request.timeout.unwrap_or(self.default_timeout);
81
82        // 1. Acquire a slot in the pool (enforces concurrency + depth)
83        let handle = self
84            .pool
85            .spawn(&request.description, self.parent_subagent_id.clone())
86            .await
87            .map_err(|e| e.to_string())?;
88
89        let handle_id = handle.id.clone();
90
91        info!(
92            subagent_id = %handle.id,
93            depth = handle.depth,
94            description = %request.description,
95            "Sub-agent spawned"
96        );
97
98        // 2. Mark as running
99        handle.mark_running().await;
100
101        // 3. Create a child session with shared stores from parent
102        //
103        // Sub-agents inherit the parent's AllowanceStore, CapabilityStore, and BudgetTracker
104        // so that project-level permissions and budget are shared. The ApprovalManager and
105        // DeferredResolutionStore are fresh per-child (independent approval handler registration
106        // and independent deferred queue).
107        let session_id = SessionId::new();
108
109        // Truncate description in the system prompt to limit prompt injection surface.
110        // The full description is still logged/audited separately.
111        let safe_description = if request.description.len() > 200 {
112            format!("{}...", &request.description[..200])
113        } else {
114            request.description.clone()
115        };
116        let subagent_system_prompt = format!(
117            "You are a focused sub-agent. Your task:\n\n{safe_description}\n\n\
118             Complete this task and provide a clear, concise result. \
119             Do not ask for clarification — work with what you have. \
120             When done, provide your final answer as a clear summary.",
121        );
122
123        let mut session = AgentSession::with_shared_stores(
124            session_id.clone(),
125            self.parent_user_id,
126            subagent_system_prompt,
127            Arc::clone(&self.parent_allowance_store),
128            Arc::clone(&self.parent_capabilities),
129            Arc::clone(&self.parent_budget_tracker),
130        );
131
132        // 4. Audit: sub-agent spawned (parent→child linkage)
133        {
134            if let Err(e) = self.runtime.audit().append(
135                self.parent_session_id.clone(),
136                AuditAction::SubAgentSpawned {
137                    parent_session_id: self.parent_session_id.0.to_string(),
138                    child_session_id: session_id.0.to_string(),
139                    description: request.description.clone(),
140                },
141                AuthorizationProof::System {
142                    reason: format!("sub-agent spawned for: {}", request.description),
143                },
144                AuditOutcome::success(),
145            ) {
146                warn!(error = %e, "Failed to audit sub-agent spawn linkage");
147            }
148        }
149
150        // 5. Audit: session started
151        {
152            if let Err(e) = self.runtime.audit().append(
153                session_id.clone(),
154                AuditAction::SessionStarted {
155                    user_id: self.parent_user_id,
156                    frontend: "sub-agent".to_string(),
157                },
158                AuthorizationProof::System {
159                    reason: format!("sub-agent for: {}", request.description),
160                },
161                AuditOutcome::success(),
162            ) {
163                warn!(error = %e, "Failed to audit sub-agent session start");
164            }
165        }
166
167        // 6. Run the agentic loop with timeout + cooperative cancellation
168        //
169        // `None` = cancelled via token (treated same as timeout — extract partial output).
170        // `Some(Ok(Ok(())))` = completed successfully.
171        // `Some(Ok(Err(e)))` = runtime error.
172        // `Some(Err(_))` = timed out.
173        let cancel_token = self.pool.cancellation_token();
174        let loop_result = tokio::select! {
175            biased;
176            () = cancel_token.cancelled() => None,
177            result = tokio::time::timeout(
178                timeout,
179                self.runtime.run_subagent_turn(
180                    &mut session,
181                    &request.prompt,
182                    Arc::clone(&self.frontend),
183                    Some(handle_id.clone()),
184                ),
185            ) => Some(result),
186        };
187
188        // 7. Process result
189        let tool_call_count = session.metadata.tool_call_count;
190        // Sub-agent timeout is at most 5 minutes, so millis always fits in u64.
191        #[allow(clippy::cast_possible_truncation)]
192        let duration_ms = start.elapsed().as_millis() as u64;
193
194        let result = match loop_result {
195            Some(Ok(Ok(()))) => {
196                // Extract last assistant message as the output
197                let output = extract_last_assistant_text(&session.messages);
198
199                debug!(
200                    subagent_id = %handle_id,
201                    duration_ms,
202                    tool_calls = tool_call_count,
203                    output_len = output.len(),
204                    "Sub-agent completed successfully"
205                );
206
207                handle.complete(&output).await;
208
209                SubAgentResult {
210                    success: true,
211                    output,
212                    duration_ms,
213                    tool_calls: tool_call_count,
214                    error: None,
215                }
216            },
217            Some(Ok(Err(e))) => {
218                let error_msg = e.to_string();
219                let partial_output = extract_last_assistant_text(&session.messages);
220                warn!(
221                    subagent_id = %handle_id,
222                    error = %error_msg,
223                    partial_output_len = partial_output.len(),
224                    duration_ms,
225                    "Sub-agent failed"
226                );
227
228                handle.fail(&error_msg).await;
229
230                SubAgentResult {
231                    success: false,
232                    output: partial_output,
233                    duration_ms,
234                    tool_calls: tool_call_count,
235                    error: Some(error_msg),
236                }
237            },
238            Some(Err(_elapsed)) => {
239                let partial_output = extract_last_assistant_text(&session.messages);
240                warn!(
241                    subagent_id = %handle_id,
242                    timeout_secs = timeout.as_secs(),
243                    partial_output_len = partial_output.len(),
244                    duration_ms,
245                    "Sub-agent timed out"
246                );
247
248                handle.timeout().await;
249
250                SubAgentResult {
251                    success: false,
252                    output: partial_output,
253                    duration_ms,
254                    tool_calls: tool_call_count,
255                    error: Some(format!(
256                        "Sub-agent timed out after {} seconds",
257                        timeout.as_secs()
258                    )),
259                }
260            },
261            None => {
262                // Cooperative cancellation via CancellationToken
263                let partial_output = extract_last_assistant_text(&session.messages);
264                warn!(
265                    subagent_id = %handle_id,
266                    partial_output_len = partial_output.len(),
267                    duration_ms,
268                    "Sub-agent cancelled via token"
269                );
270
271                handle.cancel().await;
272
273                SubAgentResult {
274                    success: false,
275                    output: partial_output,
276                    duration_ms,
277                    tool_calls: tool_call_count,
278                    error: Some("Sub-agent cancelled".to_string()),
279                }
280            },
281        };
282
283        // 8. Release from pool (releases semaphore permit)
284        self.pool.release(&handle_id).await;
285
286        // 9. Audit: session ended
287        {
288            let reason = if result.success {
289                "completed".to_string()
290            } else {
291                result.error.as_deref().unwrap_or("failed").to_string()
292            };
293            if let Err(e) = self.runtime.audit().append(
294                session_id,
295                AuditAction::SessionEnded {
296                    reason,
297                    duration_secs: duration_ms / 1000,
298                },
299                AuthorizationProof::System {
300                    reason: "sub-agent ended".to_string(),
301                },
302                AuditOutcome::success(),
303            ) {
304                warn!(error = %e, "Failed to audit sub-agent session end");
305            }
306        }
307
308        Ok(result)
309    }
310}
311
312/// Extract the last assistant text message from the session messages.
313fn extract_last_assistant_text(messages: &[Message]) -> String {
314    messages
315        .iter()
316        .rev()
317        .find(|m| m.role == MessageRole::Assistant)
318        .and_then(|m| match &m.content {
319            MessageContent::Text(text) => Some(text.clone()),
320            _ => None,
321        })
322        .unwrap_or_else(|| "(sub-agent produced no text output)".to_string())
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_extract_last_assistant_text() {
331        let messages = vec![
332            Message::user("Hello"),
333            Message::assistant("First response"),
334            Message::user("Another question"),
335            Message::assistant("Final answer"),
336        ];
337        assert_eq!(extract_last_assistant_text(&messages), "Final answer");
338    }
339
340    #[test]
341    fn test_extract_last_assistant_text_no_assistant_returns_fallback() {
342        let messages = vec![Message::user("Hello")];
343        assert_eq!(
344            extract_last_assistant_text(&messages),
345            "(sub-agent produced no text output)"
346        );
347    }
348
349    #[test]
350    fn test_extract_last_assistant_text_empty_returns_fallback() {
351        let messages: Vec<Message> = vec![];
352        assert_eq!(
353            extract_last_assistant_text(&messages),
354            "(sub-agent produced no text output)"
355        );
356    }
357}