Skip to main content

astrid_runtime/
subagent_executor.rs

1//! Sub-agent executor — implements `SubAgentSpawner` using the runtime's agentic loop.
2
3use crate::AgentRuntime;
4use crate::session::AgentSession;
5use crate::subagent::{SubAgentId, SubAgentPool};
6
7use astrid_audit::{AuditAction, AuditOutcome, AuthorizationProof};
8use astrid_core::{Frontend, SessionId};
9use astrid_llm::{LlmProvider, Message, MessageContent, MessageRole};
10use astrid_tools::{SubAgentRequest, SubAgentResult, SubAgentSpawner, truncate_at_char_boundary};
11
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::{debug, info, warn};
15
16/// Default sub-agent timeout (5 minutes).
17pub const DEFAULT_SUBAGENT_TIMEOUT: Duration = Duration::from_secs(300);
18
19/// Executor that spawns sub-agents through the runtime's agentic loop.
20///
21/// Created per-turn and injected into `ToolContext` as `Arc<dyn SubAgentSpawner>`.
22pub struct SubAgentExecutor<P: LlmProvider, F: Frontend + 'static> {
23    /// The runtime (owns LLM, MCP, audit, etc.).
24    runtime: Arc<AgentRuntime<P>>,
25    /// The shared sub-agent pool (enforces concurrency/depth).
26    pool: Arc<SubAgentPool>,
27    /// The frontend for this turn (for approval forwarding).
28    frontend: Arc<F>,
29    /// Parent user ID (inherited by child sessions).
30    parent_user_id: [u8; 8],
31    /// Parent sub-agent ID (if this executor is itself inside a sub-agent).
32    parent_subagent_id: Option<SubAgentId>,
33    /// Parent session ID (for audit linkage).
34    parent_session_id: SessionId,
35    /// Parent's allowance store (shared with child for permission inheritance).
36    parent_allowance_store: Arc<astrid_approval::AllowanceStore>,
37    /// Parent's capability store (shared with child for capability inheritance).
38    parent_capabilities: Arc<astrid_capabilities::CapabilityStore>,
39    /// Parent's budget tracker (shared with child so spend is visible bidirectionally).
40    parent_budget_tracker: Arc<astrid_approval::budget::BudgetTracker>,
41    /// Default timeout for sub-agents.
42    default_timeout: Duration,
43    /// Parent agent's callsign (inherited for sub-agent identity).
44    parent_callsign: Option<String>,
45    /// Parent plugin context (inherited by child sessions for security rules).
46    parent_capsule_context: Option<String>,
47}
48
49impl<P: LlmProvider, F: Frontend + 'static> SubAgentExecutor<P, F> {
50    /// Create a new sub-agent executor.
51    #[allow(clippy::too_many_arguments)]
52    pub fn new(
53        runtime: Arc<AgentRuntime<P>>,
54        pool: Arc<SubAgentPool>,
55        frontend: Arc<F>,
56        parent_user_id: [u8; 8],
57        parent_subagent_id: Option<SubAgentId>,
58        parent_session_id: SessionId,
59        parent_allowance_store: Arc<astrid_approval::AllowanceStore>,
60        parent_capabilities: Arc<astrid_capabilities::CapabilityStore>,
61        parent_budget_tracker: Arc<astrid_approval::budget::BudgetTracker>,
62        default_timeout: Duration,
63        parent_callsign: Option<String>,
64        parent_capsule_context: Option<String>,
65    ) -> Self {
66        Self {
67            runtime,
68            pool,
69            frontend,
70            parent_user_id,
71            parent_subagent_id,
72            parent_session_id,
73            parent_allowance_store,
74            parent_capabilities,
75            parent_budget_tracker,
76            default_timeout,
77            parent_callsign,
78            parent_capsule_context,
79        }
80    }
81}
82
83#[async_trait::async_trait]
84impl<P: LlmProvider + 'static, F: Frontend + 'static> SubAgentSpawner for SubAgentExecutor<P, F> {
85    #[allow(clippy::too_many_lines)]
86    async fn spawn(&self, request: SubAgentRequest) -> Result<SubAgentResult, String> {
87        let start = std::time::Instant::now();
88        let timeout = request.timeout.unwrap_or(self.default_timeout);
89
90        // 1. Acquire a slot in the pool (enforces concurrency + depth)
91        let handle = self
92            .pool
93            .spawn(&request.description, self.parent_subagent_id.clone())
94            .await
95            .map_err(|e| e.to_string())?;
96
97        let handle_id = handle.id.clone();
98
99        info!(
100            subagent_id = %handle.id,
101            depth = handle.depth,
102            description = %request.description,
103            "Sub-agent spawned"
104        );
105
106        // 2. Mark as running
107        handle.mark_running().await;
108
109        // 3. Create a child session with shared stores from parent
110        //
111        // Sub-agents inherit the parent's AllowanceStore, CapabilityStore, and BudgetTracker
112        // so that project-level permissions and budget are shared. The ApprovalManager and
113        // DeferredResolutionStore are fresh per-child (independent approval handler registration
114        // and independent deferred queue).
115        let session_id = SessionId::new();
116
117        // Truncate description in the system prompt to limit prompt injection surface.
118        // The full description is still logged/audited separately.
119        let safe_description = if request.description.len() > 200 {
120            format!(
121                "{}...",
122                truncate_at_char_boundary(&request.description, 200)
123            )
124        } else {
125            request.description.clone()
126        };
127        let identity = if let Some(ref callsign) = self.parent_callsign {
128            format!("You are {callsign} (sub-agent).")
129        } else {
130            "You are a focused sub-agent.".to_string()
131        };
132        let subagent_system_prompt = format!(
133            "{identity} Your task:\n\n{safe_description}\n\n\
134             Complete this task and provide a clear, concise result. \
135             Do not ask for clarification — work with what you have. \
136             When done, provide your final answer as a clear summary.",
137        );
138
139        let mut session = AgentSession::with_shared_stores(
140            session_id.clone(),
141            self.parent_user_id,
142            subagent_system_prompt,
143            Arc::clone(&self.parent_allowance_store),
144            Arc::clone(&self.parent_capabilities),
145            Arc::clone(&self.parent_budget_tracker),
146        );
147        session.capsule_context = self.parent_capsule_context.clone();
148
149        // 4. Audit: sub-agent spawned (parent→child linkage)
150        {
151            if let Err(e) = self.runtime.audit().append(
152                self.parent_session_id.clone(),
153                AuditAction::SubAgentSpawned {
154                    parent_session_id: self.parent_session_id.0.to_string(),
155                    child_session_id: session_id.0.to_string(),
156                    description: request.description.clone(),
157                },
158                AuthorizationProof::System {
159                    reason: format!("sub-agent spawned for: {}", request.description),
160                },
161                AuditOutcome::success(),
162            ) {
163                warn!(error = %e, "Failed to audit sub-agent spawn linkage");
164            }
165        }
166
167        // 5. Audit: session started
168        {
169            if let Err(e) = self.runtime.audit().append(
170                session_id.clone(),
171                AuditAction::SessionStarted {
172                    user_id: self.parent_user_id,
173                    frontend: "sub-agent".to_string(),
174                },
175                AuthorizationProof::System {
176                    reason: format!("sub-agent for: {}", request.description),
177                },
178                AuditOutcome::success(),
179            ) {
180                warn!(error = %e, "Failed to audit sub-agent session start");
181            }
182        }
183
184        // 6. Run the agentic loop with timeout + cooperative cancellation
185        //
186        // `None` = cancelled via token (treated same as timeout — extract partial output).
187        // `Some(Ok(Ok(())))` = completed successfully.
188        // `Some(Ok(Err(e)))` = runtime error.
189        // `Some(Err(_))` = timed out.
190        let cancel_token = self.pool.cancellation_token();
191        let loop_result = tokio::select! {
192            biased;
193            () = cancel_token.cancelled() => None,
194            result = tokio::time::timeout(
195                timeout,
196                self.runtime.run_subagent_turn(
197                    &mut session,
198                    &request.prompt,
199                    Arc::clone(&self.frontend),
200                    Some(handle_id.clone()),
201                ),
202            ) => Some(result),
203        };
204
205        // 7. Process result
206        let tool_call_count = session.metadata.tool_call_count;
207        // Sub-agent timeout is at most 5 minutes, so millis always fits in u64.
208        #[allow(clippy::cast_possible_truncation)]
209        let duration_ms = start.elapsed().as_millis() as u64;
210
211        let result = match loop_result {
212            Some(Ok(Ok(()))) => {
213                // Extract last assistant message as the output
214                let output = extract_last_assistant_text(&session.messages);
215
216                debug!(
217                    subagent_id = %handle_id,
218                    duration_ms,
219                    tool_calls = tool_call_count,
220                    output_len = output.len(),
221                    "Sub-agent completed successfully"
222                );
223
224                handle.complete(&output).await;
225
226                SubAgentResult {
227                    success: true,
228                    output,
229                    duration_ms,
230                    tool_calls: tool_call_count,
231                    error: None,
232                }
233            },
234            Some(Ok(Err(e))) => {
235                let error_msg = e.to_string();
236                let partial_output = extract_last_assistant_text(&session.messages);
237                warn!(
238                    subagent_id = %handle_id,
239                    error = %error_msg,
240                    partial_output_len = partial_output.len(),
241                    duration_ms,
242                    "Sub-agent failed"
243                );
244
245                handle.fail(&error_msg).await;
246
247                SubAgentResult {
248                    success: false,
249                    output: partial_output,
250                    duration_ms,
251                    tool_calls: tool_call_count,
252                    error: Some(error_msg),
253                }
254            },
255            Some(Err(_elapsed)) => {
256                let partial_output = extract_last_assistant_text(&session.messages);
257                warn!(
258                    subagent_id = %handle_id,
259                    timeout_secs = timeout.as_secs(),
260                    partial_output_len = partial_output.len(),
261                    duration_ms,
262                    "Sub-agent timed out"
263                );
264
265                handle.timeout().await;
266
267                SubAgentResult {
268                    success: false,
269                    output: partial_output,
270                    duration_ms,
271                    tool_calls: tool_call_count,
272                    error: Some(format!(
273                        "Sub-agent timed out after {} seconds",
274                        timeout.as_secs()
275                    )),
276                }
277            },
278            None => {
279                // Cooperative cancellation via CancellationToken
280                let partial_output = extract_last_assistant_text(&session.messages);
281                warn!(
282                    subagent_id = %handle_id,
283                    partial_output_len = partial_output.len(),
284                    duration_ms,
285                    "Sub-agent cancelled via token"
286                );
287
288                handle.cancel().await;
289
290                SubAgentResult {
291                    success: false,
292                    output: partial_output,
293                    duration_ms,
294                    tool_calls: tool_call_count,
295                    error: Some("Sub-agent cancelled".to_string()),
296                }
297            },
298        };
299
300        // 8. Release from pool (releases semaphore permit)
301        self.pool.release(&handle_id).await;
302
303        // 9. Audit: session ended
304        {
305            let reason = if result.success {
306                "completed".to_string()
307            } else {
308                result.error.as_deref().unwrap_or("failed").to_string()
309            };
310            if let Err(e) = self.runtime.audit().append(
311                session_id,
312                AuditAction::SessionEnded {
313                    reason,
314                    duration_secs: duration_ms / 1000,
315                },
316                AuthorizationProof::System {
317                    reason: "sub-agent ended".to_string(),
318                },
319                AuditOutcome::success(),
320            ) {
321                warn!(error = %e, "Failed to audit sub-agent session end");
322            }
323        }
324
325        Ok(result)
326    }
327}
328
329/// Extract the last assistant text message from the session messages.
330fn extract_last_assistant_text(messages: &[Message]) -> String {
331    messages
332        .iter()
333        .rev()
334        .find(|m| m.role == MessageRole::Assistant)
335        .and_then(|m| match &m.content {
336            MessageContent::Text(text) => Some(text.clone()),
337            _ => None,
338        })
339        .unwrap_or_else(|| "(sub-agent produced no text output)".to_string())
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_extract_last_assistant_text() {
348        let messages = vec![
349            Message::user("Hello"),
350            Message::assistant("First response"),
351            Message::user("Another question"),
352            Message::assistant("Final answer"),
353        ];
354        assert_eq!(extract_last_assistant_text(&messages), "Final answer");
355    }
356
357    #[test]
358    fn test_extract_last_assistant_text_no_assistant_returns_fallback() {
359        let messages = vec![Message::user("Hello")];
360        assert_eq!(
361            extract_last_assistant_text(&messages),
362            "(sub-agent produced no text output)"
363        );
364    }
365
366    #[test]
367    fn test_extract_last_assistant_text_empty_returns_fallback() {
368        let messages: Vec<Message> = vec![];
369        assert_eq!(
370            extract_last_assistant_text(&messages),
371            "(sub-agent produced no text output)"
372        );
373    }
374
375    /// Helper that mirrors the production truncation logic in `spawn()`.
376    fn safe_description(desc: &str) -> String {
377        if desc.len() > 200 {
378            format!("{}...", truncate_at_char_boundary(desc, 200))
379        } else {
380            desc.to_string()
381        }
382    }
383
384    /// Regression test for issue #74: must not panic when a 4-byte emoji
385    /// straddles the 200-byte boundary.
386    #[test]
387    fn test_safe_description_multibyte_4byte_emoji() {
388        let mut desc = "x".repeat(198);
389        desc.push('🦀'); // 4 bytes → total 202
390        assert!(desc.len() > 200);
391
392        let safe = safe_description(&desc);
393        assert_eq!(safe, format!("{}...", "x".repeat(198)));
394    }
395
396    /// Must not panic when a 3-byte char sits at the boundary.
397    #[test]
398    fn test_safe_description_multibyte_3byte_char() {
399        let mut desc = "x".repeat(199);
400        desc.push('€'); // 3 bytes → total 202
401        assert!(desc.len() > 200);
402
403        let safe = safe_description(&desc);
404        assert_eq!(safe, format!("{}...", "x".repeat(199)));
405    }
406
407    /// Must not panic when a 2-byte char sits at the boundary.
408    #[test]
409    fn test_safe_description_multibyte_2byte_char() {
410        let mut desc = "x".repeat(199);
411        desc.push('ñ'); // 2 bytes → total 201
412        assert!(desc.len() > 200);
413
414        let safe = safe_description(&desc);
415        assert_eq!(safe, format!("{}...", "x".repeat(199)));
416    }
417
418    #[test]
419    fn test_safe_description_short_passes_through() {
420        assert_eq!(safe_description("short"), "short");
421    }
422
423    #[test]
424    fn test_safe_description_exact_200_bytes() {
425        let desc = "x".repeat(200);
426        assert_eq!(safe_description(&desc), desc);
427    }
428}