1use astrid_approval::manager::ApprovalHandler;
4use astrid_audit::{AuditAction, AuditOutcome, AuthorizationProof};
5use astrid_core::Frontend;
6use astrid_hooks::{HookEvent, HookResult};
7use astrid_llm::{LlmProvider, LlmToolDefinition, Message, StreamEvent, ToolCall};
8use astrid_tools::ToolContext;
9use futures::StreamExt;
10use std::sync::Arc;
11use tracing::{debug, error};
12
13use crate::error::{RuntimeError, RuntimeResult};
14use crate::session::AgentSession;
15use crate::subagent::SubAgentId;
16
17use super::security::FrontendApprovalHandler;
18use super::{AgentRuntime, tokens_to_usd};
19
20impl<P: LlmProvider + 'static> AgentRuntime<P> {
21 #[allow(clippy::too_many_lines)]
34 pub async fn run_turn_streaming<F: Frontend + 'static>(
35 &self,
36 session: &mut AgentSession,
37 input: &str,
38 frontend: Arc<F>,
39 ) -> RuntimeResult<()> {
40 let handler: Arc<dyn ApprovalHandler> = Arc::new(FrontendApprovalHandler {
42 frontend: Arc::clone(&frontend),
43 });
44 session.approval_manager.register_handler(handler).await;
45
46 session.add_message(Message::user(input));
48
49 {
51 let ctx = self
52 .build_hook_context(session, HookEvent::UserPrompt)
53 .with_data("input", serde_json::json!(input));
54 let result = self.hooks.trigger_simple(HookEvent::UserPrompt, ctx).await;
55 if let HookResult::Block { reason } = result {
56 return Err(RuntimeError::ApprovalDenied { reason });
57 }
58 if let HookResult::ContinueWith { modifications } = &result {
59 debug!(?modifications, "UserPrompt hook modified context");
60 }
61 }
62
63 {
65 let _ = self.audit.append(
66 session.id.clone(),
67 AuditAction::LlmRequest {
68 model: self.llm.model().to_string(),
69 input_tokens: session.token_count,
70 output_tokens: 0,
71 },
72 AuthorizationProof::System {
73 reason: "user input".to_string(),
74 },
75 AuditOutcome::success(),
76 );
77 }
78
79 if self.config.auto_summarize && self.context.needs_summarization(session) {
81 frontend.show_status("Summarizing context...");
82 let result = self.context.summarize(session, self.llm.as_ref()).await?;
83
84 {
86 let _ = self.audit.append(
87 session.id.clone(),
88 AuditAction::ContextSummarized {
89 evicted_count: result.messages_evicted,
90 tokens_freed: result.tokens_freed,
91 },
92 AuthorizationProof::System {
93 reason: "context overflow".to_string(),
94 },
95 AuditOutcome::success(),
96 );
97 }
98 }
99
100 #[allow(clippy::collapsible_if)]
103 if session.capsule_context.is_none() {
104 if let Some(ref registry_lock) = self.capsule_registry {
105 let mut combined_context = String::new();
106 let active_plugins: Vec<astrid_capsule::capsule::CapsuleId> = {
107 let registry = registry_lock.read().await;
108 registry.list().into_iter().cloned().collect()
109 };
110
111 for capsule_id in active_plugins {
112 let (tool_arc, _tool_config) = {
114 let registry = registry_lock.read().await;
115 let tool_name = format!("capsule:{capsule_id}:__astrid_get_agent_context");
116 match registry.find_tool(&tool_name) {
117 Some((plugin, t)) => {
118 let config = plugin
119 .manifest()
120 .env
121 .iter()
122 .filter_map(|(k, v)| v.default.clone().map(|d| (k.clone(), d)))
123 .collect();
124 (Some(t), config)
125 },
126 None => (None, std::collections::HashMap::new()),
127 }
128 };
129
130 if let Some(tool) = tool_arc {
132 let plugin_kv =
133 {
134 let kv_key = format!("{}:capsule:{capsule_id}", session.id);
135 let mut stores = self
136 .capsule_kv_stores
137 .lock()
138 .unwrap_or_else(std::sync::PoisonError::into_inner);
139 Arc::clone(stores.entry(kv_key).or_insert_with(|| {
140 Arc::new(astrid_storage::MemoryKvStore::new())
141 }))
142 };
143
144 let scoped_name = format!("capsule-tool:capsule:{capsule_id}");
145 if let Ok(scoped_kv) =
146 astrid_storage::ScopedKvStore::new(plugin_kv, scoped_name)
147 {
148 let user_uuid = Self::user_uuid(session.user_id);
149 let tool_ctx = astrid_capsule::context::CapsuleToolContext::new(
150 capsule_id.clone(),
151 self.config.workspace.root.clone(),
152 scoped_kv,
153 )
154 .with_session(session.id.clone())
156 .with_user(user_uuid);
157
158 let execute_future = tool.execute(
159 serde_json::Value::Object(serde_json::Map::default()),
160 &tool_ctx,
161 );
162 if let Ok(Ok(ctx_result)) = tokio::time::timeout(
163 std::time::Duration::from_secs(5),
164 execute_future,
165 )
166 .await
167 {
168 let trimmed = ctx_result.trim();
169 if !trimmed.is_empty() {
170 combined_context.push_str(trimmed);
171 combined_context.push_str("\n\n");
172 }
173 } else {
174 tracing::warn!(%capsule_id, "Context tool execution timed out or failed");
175 }
176 }
177 }
178 }
179
180 if combined_context.is_empty() {
181 session.capsule_context = Some(String::new()); } else {
183 session.capsule_context = Some(combined_context);
184 }
185 }
186 }
187
188 let tool_ctx = ToolContext::with_shared_cwd(
190 self.config.workspace.root.clone(),
191 Arc::clone(&self.shared_cwd),
192 self.config.spark_file.clone(),
193 );
194
195 self.inject_subagent_spawner(&tool_ctx, session, &frontend, None)
197 .await;
198
199 let loop_result = self.run_loop(session, &*frontend, &tool_ctx).await;
201
202 let save_result = self.sessions.save(session);
203
204 loop_result?;
205 save_result?;
206
207 Ok(())
208 }
209
210 pub async fn run_subagent_turn<F: Frontend + 'static>(
222 &self,
223 session: &mut AgentSession,
224 prompt: &str,
225 frontend: Arc<F>,
226 parent_subagent_id: Option<SubAgentId>,
227 ) -> RuntimeResult<()> {
228 let handler: Arc<dyn ApprovalHandler> = Arc::new(FrontendApprovalHandler {
230 frontend: Arc::clone(&frontend),
231 });
232 session.approval_manager.register_handler(handler).await;
233
234 session.add_message(Message::user(prompt));
236
237 {
239 let _ = self.audit.append(
240 session.id.clone(),
241 AuditAction::LlmRequest {
242 model: self.llm.model().to_string(),
243 input_tokens: session.token_count,
244 output_tokens: 0,
245 },
246 AuthorizationProof::System {
247 reason: "sub-agent prompt".to_string(),
248 },
249 AuditOutcome::success(),
250 );
251 }
252
253 let tool_ctx = ToolContext::with_shared_cwd(
255 self.config.workspace.root.clone(),
256 Arc::clone(&self.shared_cwd),
257 self.config.spark_file.clone(),
258 );
259
260 self.inject_subagent_spawner(&tool_ctx, session, &frontend, parent_subagent_id)
262 .await;
263
264 self.run_loop(session, &*frontend, &tool_ctx).await
267 }
268
269 #[allow(clippy::too_many_lines)]
273 pub(super) async fn run_loop<F: Frontend>(
274 &self,
275 session: &mut AgentSession,
276 frontend: &F,
277 tool_ctx: &ToolContext,
278 ) -> RuntimeResult<()> {
279 loop {
280 let mut llm_tools: Vec<LlmToolDefinition> = self.tool_registry.all_definitions();
282
283 let mcp_tools = self.mcp.list_tools().await?;
284 llm_tools.extend(mcp_tools.iter().map(|t| {
285 LlmToolDefinition::new(format!("{}:{}", &t.server, &t.name))
286 .with_description(t.description.clone().unwrap_or_default())
287 .with_schema(t.input_schema.clone())
288 }));
289
290 if let Some(ref registry) = self.capsule_registry {
292 let registry = registry.read().await;
293 llm_tools.extend(registry.all_tool_definitions().into_iter().map(|td| {
294 LlmToolDefinition::new(td.name)
295 .with_description(td.description)
296 .with_schema(td.input_schema)
297 }));
298 }
299
300 let mut effective_prompt = if session.is_subagent {
304 session.system_prompt.clone()
305 } else if let Some(spark) = self.read_effective_spark() {
306 if let Some(preamble) = spark.build_preamble() {
307 format!("{preamble}\n\n{}", session.system_prompt)
308 } else {
309 session.system_prompt.clone()
310 }
311 } else {
312 session.system_prompt.clone()
313 };
314
315 if let Some(ctx) = session.capsule_context.as_ref().filter(|c| !c.is_empty()) {
317 effective_prompt = format!("{ctx}\n\n{effective_prompt}");
318 }
319
320 let mut stream = self
322 .llm
323 .stream(&session.messages, &llm_tools, &effective_prompt)
324 .await?;
325
326 let mut response_text = String::new();
327 let mut tool_calls: Vec<ToolCall> = Vec::new();
328 let mut current_tool_args = String::new();
329
330 while let Some(event) = stream.next().await {
331 match event? {
332 StreamEvent::TextDelta(text) => {
333 frontend.show_status(&text);
334 response_text.push_str(&text);
335 },
336 StreamEvent::ToolCallStart { id, name } => {
337 tool_calls.push(ToolCall::new(id, name));
338 current_tool_args.clear();
339 },
340 StreamEvent::ToolCallDelta { id: _, args_delta } => {
341 current_tool_args.push_str(&args_delta);
342 },
343 StreamEvent::ToolCallEnd { id } => {
344 if let Some(call) = tool_calls.iter_mut().find(|c| c.id == id)
346 && let Ok(args) = serde_json::from_str(¤t_tool_args)
347 {
348 call.arguments = args;
349 }
350 current_tool_args.clear();
351 },
352 StreamEvent::Usage {
353 input_tokens,
354 output_tokens,
355 } => {
356 debug!(input = input_tokens, output = output_tokens, "Token usage");
357 let cost = tokens_to_usd(input_tokens, output_tokens);
359 session.budget_tracker.record_cost(cost);
360 if let Some(ref ws_budget) = session.workspace_budget_tracker {
362 ws_budget.record_cost(cost);
363 }
364 },
365 StreamEvent::ReasoningDelta(_) => {
366 },
368 StreamEvent::Done => break,
369 StreamEvent::Error(e) => {
370 error!(error = %e, "Stream error");
371 return Err(RuntimeError::LlmError(
372 astrid_llm::LlmError::StreamingError(e),
373 ));
374 },
375 }
376 }
377
378 if !tool_calls.is_empty() {
380 session.add_message(Message::assistant_with_tools(tool_calls.clone()));
382
383 for call in &tool_calls {
385 frontend.tool_started(&call.id, &call.name, &call.arguments);
386 let result = self
387 .execute_tool_call(session, call, frontend, tool_ctx)
388 .await?;
389 frontend.tool_completed(&call.id, &result.content, result.is_error);
390 session.add_message(Message::tool_result(result));
391 session.metadata.tool_call_count =
392 session.metadata.tool_call_count.saturating_add(1);
393 }
394
395 continue;
397 }
398
399 if !response_text.is_empty() {
401 session.add_message(Message::assistant(&response_text));
402 return Ok(());
403 }
404
405 break;
407 }
408
409 Ok(())
410 }
411}