Skip to main content

astrid_runtime/runtime/
execution.rs

1//! Agent loop: `run_turn_streaming`, `run_subagent_turn`, and `run_loop`.
2
3use astrid_approval::manager::ApprovalHandler;
4use astrid_audit::{AuditAction, AuditOutcome, AuthorizationProof};
5use astrid_core::Frontend;
6use astrid_hooks::{HookEvent, HookResult};
7use astrid_llm::{LlmProvider, LlmToolDefinition, Message, StreamEvent, ToolCall};
8use astrid_tools::ToolContext;
9use futures::StreamExt;
10use std::sync::Arc;
11use tracing::{debug, error};
12
13use crate::error::{RuntimeError, RuntimeResult};
14use crate::session::AgentSession;
15use crate::subagent::SubAgentId;
16
17use super::security::FrontendApprovalHandler;
18use super::{AgentRuntime, tokens_to_usd};
19
20impl<P: LlmProvider + 'static> AgentRuntime<P> {
21    /// Run a single turn with streaming output.
22    ///
23    /// The `frontend` parameter is wrapped in `Arc` so it can be registered as an
24    /// approval handler for the duration of the turn.
25    ///
26    /// # Errors
27    ///
28    /// Returns an error if:
29    /// - The LLM provider fails to generate a response
30    /// - An MCP tool call fails
31    /// - An approval request fails
32    /// - Session persistence fails
33    #[allow(clippy::too_many_lines)]
34    pub async fn run_turn_streaming<F: Frontend + 'static>(
35        &self,
36        session: &mut AgentSession,
37        input: &str,
38        frontend: Arc<F>,
39    ) -> RuntimeResult<()> {
40        // Register the frontend as the approval handler for this turn.
41        let handler: Arc<dyn ApprovalHandler> = Arc::new(FrontendApprovalHandler {
42            frontend: Arc::clone(&frontend),
43        });
44        session.approval_manager.register_handler(handler).await;
45
46        // Add user message
47        session.add_message(Message::user(input));
48
49        // Fire UserPrompt hook
50        {
51            let ctx = self
52                .build_hook_context(session, HookEvent::UserPrompt)
53                .with_data("input", serde_json::json!(input));
54            let result = self.hooks.trigger_simple(HookEvent::UserPrompt, ctx).await;
55            if let HookResult::Block { reason } = result {
56                return Err(RuntimeError::ApprovalDenied { reason });
57            }
58            if let HookResult::ContinueWith { modifications } = &result {
59                debug!(?modifications, "UserPrompt hook modified context");
60            }
61        }
62
63        // Log session activity
64        {
65            let _ = self.audit.append(
66                session.id.clone(),
67                AuditAction::LlmRequest {
68                    model: self.llm.model().to_string(),
69                    input_tokens: session.token_count,
70                    output_tokens: 0,
71                },
72                AuthorizationProof::System {
73                    reason: "user input".to_string(),
74                },
75                AuditOutcome::success(),
76            );
77        }
78
79        // Check context limit and summarize if needed
80        if self.config.auto_summarize && self.context.needs_summarization(session) {
81            frontend.show_status("Summarizing context...");
82            let result = self.context.summarize(session, self.llm.as_ref()).await?;
83
84            // Log summarization
85            {
86                let _ = self.audit.append(
87                    session.id.clone(),
88                    AuditAction::ContextSummarized {
89                        evicted_count: result.messages_evicted,
90                        tokens_freed: result.tokens_freed,
91                    },
92                    AuthorizationProof::System {
93                        reason: "context overflow".to_string(),
94                    },
95                    AuditOutcome::success(),
96                );
97            }
98        }
99
100        // Collect agent context from plugins if not already collected this turn.
101        // It is held in session.capsule_context and dynamically injected into the prompt.
102        #[allow(clippy::collapsible_if)]
103        if session.capsule_context.is_none() {
104            if let Some(ref registry_lock) = self.capsule_registry {
105                let mut combined_context = String::new();
106                let active_plugins: Vec<astrid_capsule::capsule::CapsuleId> = {
107                    let registry = registry_lock.read().await;
108                    registry.list().into_iter().cloned().collect()
109                };
110
111                for capsule_id in active_plugins {
112                    // Discover if it exposes the context tool
113                    let (tool_arc, _tool_config) = {
114                        let registry = registry_lock.read().await;
115                        let tool_name = format!("capsule:{capsule_id}:__astrid_get_agent_context");
116                        match registry.find_tool(&tool_name) {
117                            Some((plugin, t)) => {
118                                let config = plugin
119                                    .manifest()
120                                    .env
121                                    .iter()
122                                    .filter_map(|(k, v)| v.default.clone().map(|d| (k.clone(), d)))
123                                    .collect();
124                                (Some(t), config)
125                            },
126                            None => (None, std::collections::HashMap::new()),
127                        }
128                    };
129
130                    // Execute the tool if present with a 5-second timeout
131                    if let Some(tool) = tool_arc {
132                        let plugin_kv =
133                            {
134                                let kv_key = format!("{}:capsule:{capsule_id}", session.id);
135                                let mut stores = self
136                                    .capsule_kv_stores
137                                    .lock()
138                                    .unwrap_or_else(std::sync::PoisonError::into_inner);
139                                Arc::clone(stores.entry(kv_key).or_insert_with(|| {
140                                    Arc::new(astrid_storage::MemoryKvStore::new())
141                                }))
142                            };
143
144                        let scoped_name = format!("capsule-tool:capsule:{capsule_id}");
145                        if let Ok(scoped_kv) =
146                            astrid_storage::ScopedKvStore::new(plugin_kv, scoped_name)
147                        {
148                            let user_uuid = Self::user_uuid(session.user_id);
149                            let tool_ctx = astrid_capsule::context::CapsuleToolContext::new(
150                                capsule_id.clone(),
151                                self.config.workspace.root.clone(),
152                                scoped_kv,
153                            )
154                            // .with_config(tool_config) // Context tools do not take config directly in capsule implementation
155                            .with_session(session.id.clone())
156                            .with_user(user_uuid);
157
158                            let execute_future = tool.execute(
159                                serde_json::Value::Object(serde_json::Map::default()),
160                                &tool_ctx,
161                            );
162                            if let Ok(Ok(ctx_result)) = tokio::time::timeout(
163                                std::time::Duration::from_secs(5),
164                                execute_future,
165                            )
166                            .await
167                            {
168                                let trimmed = ctx_result.trim();
169                                if !trimmed.is_empty() {
170                                    combined_context.push_str(trimmed);
171                                    combined_context.push_str("\n\n");
172                                }
173                            } else {
174                                tracing::warn!(%capsule_id, "Context tool execution timed out or failed");
175                            }
176                        }
177                    }
178                }
179
180                if combined_context.is_empty() {
181                    session.capsule_context = Some(String::new()); // Mark as collected but empty
182                } else {
183                    session.capsule_context = Some(combined_context);
184                }
185            }
186        }
187
188        // Create per-turn ToolContext (shares cwd, owns its own spawner slot)
189        let tool_ctx = ToolContext::with_shared_cwd(
190            self.config.workspace.root.clone(),
191            Arc::clone(&self.shared_cwd),
192            self.config.spark_file.clone(),
193        );
194
195        // Inject sub-agent spawner (if self_arc is available)
196        self.inject_subagent_spawner(&tool_ctx, session, &frontend, None)
197            .await;
198
199        // Run the agentic loop (tool_ctx is dropped at turn end — no cleanup needed)
200        let loop_result = self.run_loop(session, &*frontend, &tool_ctx).await;
201
202        let save_result = self.sessions.save(session);
203
204        loop_result?;
205        save_result?;
206
207        Ok(())
208    }
209
210    /// Run a single turn for a sub-agent session.
211    ///
212    /// Like `run_turn_streaming` but without hooks, summarization, or session persistence.
213    /// The session is ephemeral and owned by the caller (`SubAgentExecutor`).
214    ///
215    /// `parent_subagent_id` is the pool handle ID of this sub-agent, passed so that
216    /// nested sub-agents (if the sub-agent calls `task` tool) can declare their parent.
217    ///
218    /// # Errors
219    ///
220    /// Returns an error if the LLM or tool execution fails.
221    pub async fn run_subagent_turn<F: Frontend + 'static>(
222        &self,
223        session: &mut AgentSession,
224        prompt: &str,
225        frontend: Arc<F>,
226        parent_subagent_id: Option<SubAgentId>,
227    ) -> RuntimeResult<()> {
228        // Register the frontend as the approval handler for this turn.
229        let handler: Arc<dyn ApprovalHandler> = Arc::new(FrontendApprovalHandler {
230            frontend: Arc::clone(&frontend),
231        });
232        session.approval_manager.register_handler(handler).await;
233
234        // Add user message
235        session.add_message(Message::user(prompt));
236
237        // Log sub-agent LLM request
238        {
239            let _ = self.audit.append(
240                session.id.clone(),
241                AuditAction::LlmRequest {
242                    model: self.llm.model().to_string(),
243                    input_tokens: session.token_count,
244                    output_tokens: 0,
245                },
246                AuthorizationProof::System {
247                    reason: "sub-agent prompt".to_string(),
248                },
249                AuditOutcome::success(),
250            );
251        }
252
253        // Create per-turn ToolContext (shares cwd, owns its own spawner slot)
254        let tool_ctx = ToolContext::with_shared_cwd(
255            self.config.workspace.root.clone(),
256            Arc::clone(&self.shared_cwd),
257            self.config.spark_file.clone(),
258        );
259
260        // Inject sub-agent spawner for nested sub-agents
261        self.inject_subagent_spawner(&tool_ctx, session, &frontend, parent_subagent_id)
262            .await;
263
264        // Run the agentic loop (no hooks, no summarize, no save)
265        // tool_ctx is dropped at turn end — no cleanup needed
266        self.run_loop(session, &*frontend, &tool_ctx).await
267    }
268
269    /// The inner agentic loop: stream LLM → collect tool calls → execute → repeat.
270    ///
271    /// Shared by `run_turn_streaming` and `run_subagent_turn`.
272    #[allow(clippy::too_many_lines)]
273    pub(super) async fn run_loop<F: Frontend>(
274        &self,
275        session: &mut AgentSession,
276        frontend: &F,
277        tool_ctx: &ToolContext,
278    ) -> RuntimeResult<()> {
279        loop {
280            // Get tools: built-in + MCP
281            let mut llm_tools: Vec<LlmToolDefinition> = self.tool_registry.all_definitions();
282
283            let mcp_tools = self.mcp.list_tools().await?;
284            llm_tools.extend(mcp_tools.iter().map(|t| {
285                LlmToolDefinition::new(format!("{}:{}", &t.server, &t.name))
286                    .with_description(t.description.clone().unwrap_or_default())
287                    .with_schema(t.input_schema.clone())
288            }));
289
290            // Capsule tools (snapshot under a brief read lock).
291            if let Some(ref registry) = self.capsule_registry {
292                let registry = registry.read().await;
293                llm_tools.extend(registry.all_tool_definitions().into_iter().map(|td| {
294                    LlmToolDefinition::new(td.name)
295                        .with_description(td.description)
296                        .with_schema(td.input_schema)
297                }));
298            }
299
300            // Re-read spark for hot-reload (cheap: ~1KB file read per loop iteration).
301            // Sub-agents skip this: their identity is baked into session.system_prompt
302            // by SubAgentExecutor to avoid contradictory double injection.
303            let mut effective_prompt = if session.is_subagent {
304                session.system_prompt.clone()
305            } else if let Some(spark) = self.read_effective_spark() {
306                if let Some(preamble) = spark.build_preamble() {
307                    format!("{preamble}\n\n{}", session.system_prompt)
308                } else {
309                    session.system_prompt.clone()
310                }
311            } else {
312                session.system_prompt.clone()
313            };
314
315            // Inject dynamic plugin context if present
316            if let Some(ctx) = session.capsule_context.as_ref().filter(|c| !c.is_empty()) {
317                effective_prompt = format!("{ctx}\n\n{effective_prompt}");
318            }
319
320            // Stream from LLM
321            let mut stream = self
322                .llm
323                .stream(&session.messages, &llm_tools, &effective_prompt)
324                .await?;
325
326            let mut response_text = String::new();
327            let mut tool_calls: Vec<ToolCall> = Vec::new();
328            let mut current_tool_args = String::new();
329
330            while let Some(event) = stream.next().await {
331                match event? {
332                    StreamEvent::TextDelta(text) => {
333                        frontend.show_status(&text);
334                        response_text.push_str(&text);
335                    },
336                    StreamEvent::ToolCallStart { id, name } => {
337                        tool_calls.push(ToolCall::new(id, name));
338                        current_tool_args.clear();
339                    },
340                    StreamEvent::ToolCallDelta { id: _, args_delta } => {
341                        current_tool_args.push_str(&args_delta);
342                    },
343                    StreamEvent::ToolCallEnd { id } => {
344                        // Parse and set arguments for the completed tool call
345                        if let Some(call) = tool_calls.iter_mut().find(|c| c.id == id)
346                            && let Ok(args) = serde_json::from_str(&current_tool_args)
347                        {
348                            call.arguments = args;
349                        }
350                        current_tool_args.clear();
351                    },
352                    StreamEvent::Usage {
353                        input_tokens,
354                        output_tokens,
355                    } => {
356                        debug!(input = input_tokens, output = output_tokens, "Token usage");
357                        // Track cost in the session budget tracker
358                        let cost = tokens_to_usd(input_tokens, output_tokens);
359                        session.budget_tracker.record_cost(cost);
360                        // Track cost in the workspace cumulative budget tracker
361                        if let Some(ref ws_budget) = session.workspace_budget_tracker {
362                            ws_budget.record_cost(cost);
363                        }
364                    },
365                    StreamEvent::ReasoningDelta(_) => {
366                        // Reasoning tokens are informational; not included in final output.
367                    },
368                    StreamEvent::Done => break,
369                    StreamEvent::Error(e) => {
370                        error!(error = %e, "Stream error");
371                        return Err(RuntimeError::LlmError(
372                            astrid_llm::LlmError::StreamingError(e),
373                        ));
374                    },
375                }
376            }
377
378            // If we have tool calls, execute them
379            if !tool_calls.is_empty() {
380                // Add assistant message with tool calls
381                session.add_message(Message::assistant_with_tools(tool_calls.clone()));
382
383                // Execute each tool call
384                for call in &tool_calls {
385                    frontend.tool_started(&call.id, &call.name, &call.arguments);
386                    let result = self
387                        .execute_tool_call(session, call, frontend, tool_ctx)
388                        .await?;
389                    frontend.tool_completed(&call.id, &result.content, result.is_error);
390                    session.add_message(Message::tool_result(result));
391                    session.metadata.tool_call_count =
392                        session.metadata.tool_call_count.saturating_add(1);
393                }
394
395                // Continue the loop for next LLM turn
396                continue;
397            }
398
399            // If we have text and no tool calls, we're done
400            if !response_text.is_empty() {
401                session.add_message(Message::assistant(&response_text));
402                return Ok(());
403            }
404
405            // Empty response, done
406            break;
407        }
408
409        Ok(())
410    }
411}