1use std::sync::Arc;
4use std::time::Instant;
5
6use tracing::{debug, info, instrument, warn};
7
8use super::AgentMetrics;
9use super::common::{
10 self, BudgetContext, accumulate_inner_usage, accumulate_response_usage, handle_compaction,
11 run_post_tool_hooks, run_stop_hooks, try_activate_dynamic_rules,
12};
13use super::events::AgentResult;
14use super::executor::Agent;
15use super::request::RequestBuilder;
16use crate::hooks::{HookContext, HookEvent, HookInput};
17use crate::types::{
18 ContentBlock, Message, PermissionDenial, StopReason, ToolResultBlock, Usage, context_window,
19};
20
21impl Agent {
22 fn check_budget(&self) -> crate::Result<()> {
23 BudgetContext {
24 tracker: &self.budget_tracker,
25 tenant: self.tenant_budget.as_deref(),
26 config: &self.config.budget,
27 }
28 .check()
29 }
30
31 pub async fn execute(&self, prompt: &str) -> crate::Result<AgentResult> {
32 let timeout = self
33 .config
34 .execution
35 .timeout
36 .unwrap_or(std::time::Duration::from_secs(600));
37
38 if self.state.is_executing() {
39 self.state
40 .enqueue(prompt)
41 .await
42 .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
43 return self.wait_for_execution(timeout).await;
44 }
45
46 tokio::time::timeout(timeout, self.execute_inner(prompt))
47 .await
48 .map_err(|_| crate::Error::Timeout(timeout))?
49 }
50
51 async fn wait_for_execution(&self, timeout: std::time::Duration) -> crate::Result<AgentResult> {
52 tokio::time::timeout(timeout, async {
53 loop {
54 self.state.wait_for_queue_signal().await;
55 if !self.state.is_executing()
56 && let Some(merged) = self.state.dequeue_or_merge().await
57 {
58 return self.execute_inner(&merged.content).await;
59 }
60 }
61 })
62 .await
63 .map_err(|_| crate::Error::Timeout(timeout))?
64 }
65
66 pub async fn execute_with_messages(
67 &self,
68 previous_messages: Vec<Message>,
69 prompt: &str,
70 ) -> crate::Result<AgentResult> {
71 let context_summary = previous_messages
72 .iter()
73 .filter_map(|m| {
74 m.content
75 .iter()
76 .filter_map(|b| match b {
77 ContentBlock::Text { text, .. } => Some(text.as_str()),
78 _ => None,
79 })
80 .next()
81 })
82 .collect::<Vec<_>>()
83 .join("\n---\n");
84
85 let enriched_prompt = if context_summary.is_empty() {
86 prompt.to_string()
87 } else {
88 format!(
89 "Previous conversation context:\n{}\n\nContinue with: {}",
90 context_summary, prompt
91 )
92 };
93
94 self.execute(&enriched_prompt).await
95 }
96
97 #[instrument(skip(self, prompt), fields(session_id = %self.session_id))]
98 async fn execute_inner(&self, prompt: &str) -> crate::Result<AgentResult> {
99 let _guard = self.state.acquire_execution().await;
100 let execution_start = Instant::now();
101 let hook_ctx = self.hook_context();
102
103 let session_start_input = HookInput::session_start(&*self.session_id);
104 if let Err(e) = self
105 .hooks
106 .execute(HookEvent::SessionStart, session_start_input, &hook_ctx)
107 .await
108 {
109 warn!(error = %e, "SessionStart hook failed");
110 }
111
112 let final_prompt = if let Some(merged) = self.state.dequeue_or_merge().await {
113 format!("{}\n{}", prompt, merged.content)
114 } else {
115 prompt.to_string()
116 };
117
118 let prompt_input = HookInput::user_prompt_submit(&*self.session_id, &final_prompt);
119 let prompt_output = self
120 .hooks
121 .execute(HookEvent::UserPromptSubmit, prompt_input, &hook_ctx)
122 .await?;
123
124 if !prompt_output.continue_execution {
125 let session_end_input = HookInput::session_end(&*self.session_id);
126 if let Err(e) = self
127 .hooks
128 .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
129 .await
130 {
131 warn!(error = %e, "SessionEnd hook failed");
132 }
133 return Err(crate::Error::Permission(
134 prompt_output
135 .stop_reason
136 .unwrap_or_else(|| "Blocked by hook".into()),
137 ));
138 }
139
140 self.state
141 .with_session_mut(|session| {
142 session.add_user_message(&final_prompt);
143 })
144 .await;
145
146 let mut metrics = AgentMetrics::default();
147 let mut final_text = String::new();
148 let mut final_stop_reason = StopReason::EndTurn;
149 let mut dynamic_rules_context = String::new();
150 let mut total_usage = Usage::default();
151
152 let mut request_builder = {
153 let builder = RequestBuilder::new(&self.config, Arc::clone(&self.tools));
154
155 if let Some(ref tsm) = self.tool_search_manager {
156 let prepared = tsm.prepare_tools().await;
157 if prepared.use_search {
158 info!(
159 immediate = prepared.immediate.len(),
160 deferred = prepared.deferred.len(),
161 tokens_saved = prepared.token_savings(),
162 "MCP Progressive Disclosure active"
163 );
164 }
165 builder.prepared_tools(prepared)
166 } else {
167 builder
168 }
169 };
170 let max_tokens = context_window::for_model(&self.config.model.primary);
171
172 info!(prompt_len = final_prompt.len(), "Starting agent execution");
173
174 loop {
175 metrics.iterations += 1;
176 if metrics.iterations > self.config.execution.max_iterations {
177 warn!(
178 max = self.config.execution.max_iterations,
179 "Max iterations reached"
180 );
181 break;
182 }
183
184 self.check_budget()?;
185
186 let budget_ctx = BudgetContext {
187 tracker: &self.budget_tracker,
188 tenant: self.tenant_budget.as_deref(),
189 config: &self.config.budget,
190 };
191 if let Some(fallback) = budget_ctx.fallback_model() {
192 request_builder.set_model(fallback);
193 }
194
195 debug!(iteration = metrics.iterations, "Starting iteration");
196
197 let messages = self
198 .state
199 .with_session(|session| {
200 session.to_api_messages_with_cache(self.config.cache.message_ttl_option())
201 })
202 .await;
203
204 let api_start = Instant::now();
205 let request = request_builder.build(messages, &dynamic_rules_context);
206 let response = self.client.send_with_auth_retry(request).await?;
207 let api_duration_ms = api_start.elapsed().as_millis() as u64;
208 metrics.record_api_call_with_timing(api_duration_ms);
209 debug!(api_time_ms = api_duration_ms, "API call completed");
210
211 self.state
212 .with_session_mut(|session| {
213 session.update_usage(&response.usage);
214 })
215 .await;
216
217 accumulate_response_usage(
218 &mut total_usage,
219 &mut metrics,
220 &self.budget_tracker,
221 self.tenant_budget.as_deref(),
222 &self.config.model.primary,
223 &response.usage,
224 );
225
226 final_text = response.text();
227 final_stop_reason = response.stop_reason.unwrap_or(StopReason::EndTurn);
228
229 self.state
230 .with_session_mut(|session| {
231 session.add_assistant_message(response.content.clone(), Some(response.usage));
232 })
233 .await;
234
235 if !response.wants_tool_use() {
236 debug!("No tool use requested, ending loop");
237 break;
238 }
239
240 let tool_uses = response.tool_uses();
241 let hook_ctx = self.hook_context();
242
243 let mut prepared = Vec::with_capacity(tool_uses.len());
244 let mut blocked = Vec::with_capacity(tool_uses.len());
245
246 for tool_use in &tool_uses {
247 let pre_input = HookInput::pre_tool_use(
248 &*self.session_id,
249 &tool_use.name,
250 tool_use.input.clone(),
251 );
252 let pre_output = self
253 .hooks
254 .execute(HookEvent::PreToolUse, pre_input, &hook_ctx)
255 .await?;
256
257 if !pre_output.continue_execution {
258 debug!(tool = %tool_use.name, "Tool blocked by hook");
259 let reason = pre_output
260 .stop_reason
261 .clone()
262 .unwrap_or_else(|| "Blocked by hook".into());
263 blocked.push(ToolResultBlock::error(&tool_use.id, reason.clone()));
264 metrics.record_permission_denial(
265 PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
266 .reason(reason),
267 );
268 } else {
269 let input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
270 prepared.push((tool_use.id.clone(), tool_use.name.clone(), input));
271 }
272 }
273
274 let tool_futures = prepared.into_iter().map(|(id, name, input)| {
275 let tools = &self.tools;
276 async move {
277 let start = Instant::now();
278 let result = tools.execute(&name, input.clone()).await;
279 let duration_ms = start.elapsed().as_millis() as u64;
280 (id, name, input, result, duration_ms)
281 }
282 });
283
284 let parallel_results: Vec<_> = futures::future::join_all(tool_futures).await;
285
286 let all_non_retryable = !parallel_results.is_empty()
287 && parallel_results
288 .iter()
289 .all(|(_, _, _, result, _)| result.is_non_retryable());
290
291 let mut results = blocked;
292 for (id, name, input, result, duration_ms) in parallel_results {
293 let is_error = result.is_error();
294 debug!(tool = %name, duration_ms, is_error, "Tool execution completed");
295 metrics.record_tool(&id, &name, duration_ms, is_error);
296
297 accumulate_inner_usage(
298 &self.state,
299 &mut total_usage,
300 &mut metrics,
301 &self.budget_tracker,
302 &result,
303 &name,
304 )
305 .await;
306
307 try_activate_dynamic_rules(
308 &name,
309 &input,
310 &self.orchestrator,
311 &mut dynamic_rules_context,
312 )
313 .await;
314
315 run_post_tool_hooks(
316 &self.hooks,
317 &hook_ctx,
318 &self.session_id,
319 &name,
320 is_error,
321 &result,
322 )
323 .await;
324
325 results.push(ToolResultBlock::from_tool_result(&id, &result));
326 }
327
328 self.state
329 .with_session_mut(|session| {
330 session.add_tool_results(results);
331 })
332 .await;
333
334 if all_non_retryable {
335 warn!("All tool calls failed with non-retryable errors, ending execution");
336 break;
337 }
338
339 handle_compaction(
340 &self.state,
341 &self.client,
342 &self.tools,
343 &self.hooks,
344 &hook_ctx,
345 &self.session_id,
346 &self.config.execution,
347 max_tokens,
348 &mut metrics,
349 )
350 .await;
351 }
352
353 metrics.execution_time_ms = execution_start.elapsed().as_millis() as u64;
354
355 run_stop_hooks(&self.hooks, &hook_ctx, &self.session_id).await;
356
357 info!(
358 iterations = metrics.iterations,
359 tool_calls = metrics.tool_calls,
360 api_calls = metrics.api_calls,
361 total_tokens = metrics.total_tokens(),
362 execution_time_ms = metrics.execution_time_ms,
363 "Agent execution completed"
364 );
365
366 let messages = self
367 .state
368 .with_session(|session| session.to_api_messages())
369 .await;
370
371 let structured_output = self.extract_structured_output(&final_text);
372 Ok(AgentResult::new(
373 final_text,
374 total_usage,
375 metrics.iterations,
376 final_stop_reason,
377 metrics,
378 self.session_id.to_string(),
379 structured_output,
380 messages,
381 ))
382 }
383
384 pub(crate) fn hook_context(&self) -> HookContext {
385 HookContext::new(&*self.session_id)
386 .cwd(self.config.working_dir.clone().unwrap_or_default())
387 .env(self.config.security.env.clone())
388 }
389
390 fn extract_structured_output(&self, text: &str) -> Option<serde_json::Value> {
391 common::extract_structured_output(self.config.prompt.output_schema.as_ref(), text)
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::common::extract_file_path;
398
399 #[test]
400 fn test_extract_file_path() {
401 let input = serde_json::json!({"file_path": "/src/lib.rs"});
402 assert_eq!(
403 extract_file_path("Read", &input),
404 Some("/src/lib.rs".to_string())
405 );
406
407 let input = serde_json::json!({"path": "/src"});
408 assert_eq!(extract_file_path("Glob", &input), Some("/src".to_string()));
409
410 let input = serde_json::json!({"command": "ls"});
411 assert_eq!(extract_file_path("Bash", &input), None);
412 }
413}