1use std::path::Path;
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::RwLock;
8use tracing::{debug, info, instrument, warn};
9
10use super::common::BudgetContext;
11use super::events::AgentResult;
12use super::executor::Agent;
13use super::request::RequestBuilder;
14use super::state_formatter::collect_compaction_state;
15use super::{AgentMetrics, AgentState};
16use crate::context::PromptOrchestrator;
17use crate::hooks::{HookContext, HookEvent, HookInput};
18use crate::session::ExecutionGuard;
19use crate::types::{
20 CompactResult, ContentBlock, Message, PermissionDenial, StopReason, ToolResultBlock, Usage,
21 context_window,
22};
23
24impl Agent {
25 async fn handle_compaction<'a>(
26 &self,
27 _guard: &ExecutionGuard<'a>,
28 hook_ctx: &HookContext,
29 metrics: &mut AgentMetrics,
30 ) {
31 let pre_compact_input = HookInput::pre_compact(&*self.session_id);
32 if let Err(e) = self
33 .hooks
34 .execute(HookEvent::PreCompact, pre_compact_input, hook_ctx)
35 .await
36 {
37 warn!(error = %e, "PreCompact hook failed");
38 }
39
40 debug!("Compacting session context");
41 let compact_result = self
42 .state
43 .compact(&self.client, self.config.execution.compact_keep_messages)
44 .await;
45
46 match compact_result {
47 Ok(CompactResult::Compacted { saved_tokens, .. }) => {
48 info!(saved_tokens, "Session context compacted");
49 metrics.record_compaction();
50
51 let state_sections = collect_compaction_state(&self.tools).await;
52 if !state_sections.is_empty() {
53 self.state
54 .with_session_mut(|session| {
55 session.add_user_message(format!(
56 "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
57 state_sections.join("\n\n")
58 ));
59 })
60 .await;
61 }
62 }
63 Ok(CompactResult::NotNeeded | CompactResult::Skipped { .. }) => {
64 debug!("Compaction skipped or not needed");
65 }
66 Err(e) => {
67 warn!(error = %e, "Session compaction failed");
68 }
69 }
70 }
71
72 fn check_budget(&self) -> crate::Result<()> {
73 BudgetContext {
74 tracker: &self.budget_tracker,
75 tenant: self.tenant_budget.as_deref(),
76 config: &self.config.budget,
77 }
78 .check()
79 }
80
81 pub async fn execute(&self, prompt: &str) -> crate::Result<AgentResult> {
82 let timeout = self
83 .config
84 .execution
85 .timeout
86 .unwrap_or(std::time::Duration::from_secs(600));
87
88 if self.state.is_executing() {
89 self.state
90 .enqueue(prompt)
91 .await
92 .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
93 return self.wait_for_execution(timeout).await;
94 }
95
96 tokio::time::timeout(timeout, self.execute_inner(prompt))
97 .await
98 .map_err(|_| crate::Error::Timeout(timeout))?
99 }
100
101 async fn wait_for_execution(&self, timeout: std::time::Duration) -> crate::Result<AgentResult> {
102 let start = Instant::now();
103 loop {
104 if start.elapsed() > timeout {
105 return Err(crate::Error::Timeout(timeout));
106 }
107
108 if !self.state.is_executing()
109 && let Some(merged) = self.state.dequeue_or_merge().await
110 {
111 return self.execute_inner(&merged.content).await;
112 }
113
114 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
115 }
116 }
117
118 pub async fn execute_with_messages(
119 &self,
120 previous_messages: Vec<Message>,
121 prompt: &str,
122 ) -> crate::Result<AgentResult> {
123 let context_summary = previous_messages
124 .iter()
125 .filter_map(|m| {
126 m.content
127 .iter()
128 .filter_map(|b| match b {
129 ContentBlock::Text { text, .. } => Some(text.as_str()),
130 _ => None,
131 })
132 .next()
133 })
134 .collect::<Vec<_>>()
135 .join("\n---\n");
136
137 let enriched_prompt = if context_summary.is_empty() {
138 prompt.to_string()
139 } else {
140 format!(
141 "Previous conversation context:\n{}\n\nContinue with: {}",
142 context_summary, prompt
143 )
144 };
145
146 self.execute(&enriched_prompt).await
147 }
148
149 #[instrument(skip(self, prompt), fields(session_id = %self.session_id))]
150 async fn execute_inner(&self, prompt: &str) -> crate::Result<AgentResult> {
151 let guard = self.state.acquire_execution().await;
152 let execution_start = Instant::now();
153 let hook_ctx = self.hook_context();
154
155 let session_start_input = HookInput::session_start(&*self.session_id);
156 if let Err(e) = self
157 .hooks
158 .execute(HookEvent::SessionStart, session_start_input, &hook_ctx)
159 .await
160 {
161 warn!(error = %e, "SessionStart hook failed");
162 }
163
164 let final_prompt = if let Some(merged) = self.state.dequeue_or_merge().await {
165 format!("{}\n{}", prompt, merged.content)
166 } else {
167 prompt.to_string()
168 };
169
170 let prompt_input = HookInput::user_prompt_submit(&*self.session_id, &final_prompt);
171 let prompt_output = self
172 .hooks
173 .execute(HookEvent::UserPromptSubmit, prompt_input, &hook_ctx)
174 .await?;
175
176 if !prompt_output.continue_execution {
177 let session_end_input = HookInput::session_end(&*self.session_id);
178 if let Err(e) = self
179 .hooks
180 .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
181 .await
182 {
183 warn!(error = %e, "SessionEnd hook failed");
184 }
185 return Err(crate::Error::Permission(
186 prompt_output
187 .stop_reason
188 .unwrap_or_else(|| "Blocked by hook".into()),
189 ));
190 }
191
192 self.state
193 .with_session_mut(|session| {
194 session.add_user_message(&final_prompt);
195 })
196 .await;
197
198 let mut metrics = AgentMetrics::default();
199 let mut final_text = String::new();
200 let mut final_stop_reason = StopReason::EndTurn;
201 let mut dynamic_rules_context = String::new();
202 let mut total_usage = Usage::default();
203
204 let request_builder = {
206 let builder = RequestBuilder::new(&self.config, Arc::clone(&self.tools));
207
208 if let Some(ref tsm) = self.tool_search_manager {
210 let prepared = tsm.prepare_tools().await;
211 if prepared.use_search {
212 info!(
213 immediate = prepared.immediate.len(),
214 deferred = prepared.deferred.len(),
215 tokens_saved = prepared.token_savings(),
216 "MCP Progressive Disclosure active"
217 );
218 }
219 builder.with_prepared_tools(prepared)
220 } else {
221 builder
222 }
223 };
224 let max_tokens = context_window::for_model(&self.config.model.primary);
225
226 info!(prompt_len = final_prompt.len(), "Starting agent execution");
227
228 loop {
229 metrics.iterations += 1;
230 if metrics.iterations > self.config.execution.max_iterations {
231 warn!(
232 max = self.config.execution.max_iterations,
233 "Max iterations reached"
234 );
235 break;
236 }
237
238 self.check_budget()?;
239
240 debug!(iteration = metrics.iterations, "Starting iteration");
241
242 let messages = self
243 .state
244 .with_session(|session| {
245 session.to_api_messages_with_cache(self.config.cache.message_ttl_option())
246 })
247 .await;
248
249 let api_start = Instant::now();
250 let request = request_builder.build(messages, &dynamic_rules_context);
251 let response = self.client.send_with_auth_retry(request).await?;
252 let api_duration_ms = api_start.elapsed().as_millis() as u64;
253 metrics.record_api_call_with_timing(api_duration_ms);
254 debug!(api_time_ms = api_duration_ms, "API call completed");
255
256 self.state
257 .with_session_mut(|session| {
258 session.update_usage(&response.usage);
259 })
260 .await;
261
262 total_usage.input_tokens += response.usage.input_tokens;
263 total_usage.output_tokens += response.usage.output_tokens;
264 metrics.add_usage_with_cache(&response.usage);
265 metrics.record_model_usage(&self.config.model.primary, &response.usage);
266
267 if let Some(ref server_usage) = response.usage.server_tool_use {
268 metrics.update_server_tool_use_from_api(server_usage);
269 }
270
271 let cost = self
272 .budget_tracker
273 .record(&self.config.model.primary, &response.usage);
274 metrics.add_cost(cost);
275 if let Some(ref tenant_budget) = self.tenant_budget {
276 tenant_budget.record(&self.config.model.primary, &response.usage);
277 }
278
279 final_text = response.text();
280 final_stop_reason = response.stop_reason.unwrap_or(StopReason::EndTurn);
281
282 self.state
283 .with_session_mut(|session| {
284 session.add_assistant_message(response.content.clone(), Some(response.usage));
285 })
286 .await;
287
288 if !response.wants_tool_use() {
289 debug!("No tool use requested, ending loop");
290 break;
291 }
292
293 let tool_uses = response.tool_uses();
294 let hook_ctx = self.hook_context();
295
296 let mut prepared = Vec::with_capacity(tool_uses.len());
297 let mut blocked = Vec::with_capacity(tool_uses.len());
298
299 for tool_use in &tool_uses {
300 let pre_input = HookInput::pre_tool_use(
301 &*self.session_id,
302 &tool_use.name,
303 tool_use.input.clone(),
304 );
305 let pre_output = self
306 .hooks
307 .execute(HookEvent::PreToolUse, pre_input, &hook_ctx)
308 .await?;
309
310 if !pre_output.continue_execution {
311 debug!(tool = %tool_use.name, "Tool blocked by hook");
312 let reason = pre_output
313 .stop_reason
314 .clone()
315 .unwrap_or_else(|| "Blocked by hook".into());
316 blocked.push(ToolResultBlock::error(&tool_use.id, reason.clone()));
317 metrics.record_permission_denial(
318 PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
319 .with_reason(reason),
320 );
321 } else {
322 let input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
323 prepared.push((tool_use.id.clone(), tool_use.name.clone(), input));
324 }
325 }
326
327 let tool_futures = prepared.into_iter().map(|(id, name, input)| {
328 let tools = &self.tools;
329 async move {
330 let start = Instant::now();
331 let result = tools.execute(&name, input.clone()).await;
332 let duration_ms = start.elapsed().as_millis() as u64;
333 (id, name, input, result, duration_ms)
334 }
335 });
336
337 let parallel_results: Vec<_> = futures::future::join_all(tool_futures).await;
338
339 let all_non_retryable = !parallel_results.is_empty()
340 && parallel_results
341 .iter()
342 .all(|(_, _, _, result, _)| result.is_non_retryable());
343
344 let mut results = blocked;
345 for (id, name, input, result, duration_ms) in parallel_results {
346 let is_error = result.is_error();
347 debug!(tool = %name, duration_ms, is_error, "Tool execution completed");
348 metrics.record_tool(&id, &name, duration_ms, is_error);
349
350 if let Some(ref inner_usage) = result.inner_usage {
351 self.state
352 .with_session_mut(|session| {
353 session.update_usage(inner_usage);
354 })
355 .await;
356 total_usage.input_tokens += inner_usage.input_tokens;
357 total_usage.output_tokens += inner_usage.output_tokens;
358 metrics.add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
359 let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
360 metrics.record_model_usage(inner_model, inner_usage);
361
362 let inner_cost = self.budget_tracker.record(inner_model, inner_usage);
363 metrics.add_cost(inner_cost);
364
365 debug!(
366 tool = %name,
367 model = %inner_model,
368 input_tokens = inner_usage.input_tokens,
369 output_tokens = inner_usage.output_tokens,
370 cost_usd = inner_cost,
371 "Accumulated inner usage from tool"
372 );
373 }
374
375 if let Some(file_path) = extract_file_path(&name, &input)
376 && let Some(ref orchestrator) = self.orchestrator
377 {
378 let new_rules = activate_rules_for_file(orchestrator, &file_path).await;
379 if !new_rules.is_empty() {
380 dynamic_rules_context =
381 build_dynamic_rules_context(orchestrator, &file_path).await;
382 debug!(rules = ?new_rules, "Activated rules for file");
383 }
384 }
385
386 if is_error {
387 let failure_input = HookInput::post_tool_use_failure(
388 &*self.session_id,
389 &name,
390 result.error_message(),
391 );
392 if let Err(e) = self
393 .hooks
394 .execute(HookEvent::PostToolUseFailure, failure_input, &hook_ctx)
395 .await
396 {
397 warn!(tool = %name, error = %e, "PostToolUseFailure hook failed");
398 }
399 } else {
400 let post_input =
401 HookInput::post_tool_use(&*self.session_id, &name, result.output.clone());
402 if let Err(e) = self
403 .hooks
404 .execute(HookEvent::PostToolUse, post_input, &hook_ctx)
405 .await
406 {
407 warn!(tool = %name, error = %e, "PostToolUse hook failed");
408 }
409 }
410 results.push(ToolResultBlock::from_tool_result(&id, &result));
411 }
412
413 self.state
414 .with_session_mut(|session| {
415 session.add_tool_results(results);
416 })
417 .await;
418
419 if all_non_retryable {
420 warn!("All tool calls failed with non-retryable errors, ending execution");
421 break;
422 }
423
424 let should_compact = self
425 .state
426 .with_session(|session| {
427 self.config.execution.auto_compact
428 && session.should_compact(
429 max_tokens,
430 self.config.execution.compact_threshold,
431 self.config.execution.compact_keep_messages,
432 )
433 })
434 .await;
435
436 if should_compact {
437 self.handle_compaction(&guard, &hook_ctx, &mut metrics)
438 .await;
439 }
440 }
441
442 metrics.execution_time_ms = execution_start.elapsed().as_millis() as u64;
443
444 let stop_input = HookInput::stop(&*self.session_id);
445 if let Err(e) = self
446 .hooks
447 .execute(HookEvent::Stop, stop_input, &hook_ctx)
448 .await
449 {
450 warn!(error = %e, "Stop hook failed");
451 }
452
453 let session_end_input = HookInput::session_end(&*self.session_id);
454 if let Err(e) = self
455 .hooks
456 .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
457 .await
458 {
459 warn!(error = %e, "SessionEnd hook failed");
460 }
461
462 info!(
463 iterations = metrics.iterations,
464 tool_calls = metrics.tool_calls,
465 api_calls = metrics.api_calls,
466 total_tokens = metrics.total_tokens(),
467 execution_time_ms = metrics.execution_time_ms,
468 "Agent execution completed"
469 );
470
471 let messages = self
472 .state
473 .with_session(|session| session.to_api_messages())
474 .await;
475
476 drop(guard);
477
478 Ok(AgentResult {
479 text: final_text,
480 usage: total_usage,
481 tool_calls: metrics.tool_calls,
482 iterations: metrics.iterations,
483 stop_reason: final_stop_reason,
484 state: AgentState::Completed,
485 metrics,
486 session_id: self.session_id.to_string(),
487 structured_output: None,
488 messages,
489 uuid: uuid::Uuid::new_v4().to_string(),
490 })
491 }
492
493 pub(crate) fn hook_context(&self) -> HookContext {
494 HookContext::new(&*self.session_id)
495 .with_cwd(self.config.working_dir.clone().unwrap_or_default())
496 .with_env(self.config.security.env.clone())
497 }
498}
499
500pub(crate) fn extract_file_path(tool_name: &str, input: &serde_json::Value) -> Option<String> {
501 match tool_name {
502 "Read" | "Write" | "Edit" => input
503 .get("file_path")
504 .and_then(|v| v.as_str())
505 .map(String::from),
506 "Glob" | "Grep" => input.get("path").and_then(|v| v.as_str()).map(String::from),
507 _ => None,
508 }
509}
510
511pub(crate) async fn activate_rules_for_file(
512 orchestrator: &Arc<RwLock<PromptOrchestrator>>,
513 file_path: &str,
514) -> Vec<String> {
515 let orch = orchestrator.read().await;
516 let path = Path::new(file_path);
517 let rules = orch.find_matching_rules(path).await;
518 rules.iter().map(|r| r.name.clone()).collect()
519}
520
521pub(crate) async fn build_dynamic_rules_context(
522 orchestrator: &Arc<RwLock<PromptOrchestrator>>,
523 file_path: &str,
524) -> String {
525 let orch = orchestrator.read().await;
526 let path = Path::new(file_path);
527 orch.build_dynamic_context(Some(path)).await
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533
534 #[test]
535 fn test_extract_file_path() {
536 let input = serde_json::json!({"file_path": "/src/lib.rs"});
537 assert_eq!(
538 extract_file_path("Read", &input),
539 Some("/src/lib.rs".to_string())
540 );
541
542 let input = serde_json::json!({"path": "/src"});
543 assert_eq!(extract_file_path("Glob", &input), Some("/src".to_string()));
544
545 let input = serde_json::json!({"command": "ls"});
546 assert_eq!(extract_file_path("Bash", &input), None);
547 }
548
549 #[test]
550 fn test_extract_file_path_all_tools() {
551 let file_input = serde_json::json!({"file_path": "/test/file.rs"});
552 let path_input = serde_json::json!({"path": "/test/dir"});
553
554 assert_eq!(
555 extract_file_path("Read", &file_input),
556 Some("/test/file.rs".to_string())
557 );
558 assert_eq!(
559 extract_file_path("Write", &file_input),
560 Some("/test/file.rs".to_string())
561 );
562 assert_eq!(
563 extract_file_path("Edit", &file_input),
564 Some("/test/file.rs".to_string())
565 );
566
567 assert_eq!(
568 extract_file_path("Glob", &path_input),
569 Some("/test/dir".to_string())
570 );
571 assert_eq!(
572 extract_file_path("Grep", &path_input),
573 Some("/test/dir".to_string())
574 );
575
576 assert_eq!(extract_file_path("WebFetch", &file_input), None);
577 assert_eq!(extract_file_path("Task", &file_input), None);
578 }
579
580 #[test]
581 fn test_extract_file_path_missing_field() {
582 let empty = serde_json::json!({});
583 assert_eq!(extract_file_path("Read", &empty), None);
584 assert_eq!(extract_file_path("Glob", &empty), None);
585
586 let wrong_field = serde_json::json!({"other": "value"});
587 assert_eq!(extract_file_path("Read", &wrong_field), None);
588 assert_eq!(extract_file_path("Glob", &wrong_field), None);
589 }
590
591 #[test]
592 fn test_extract_file_path_non_string() {
593 let input = serde_json::json!({"file_path": 123});
594 assert_eq!(extract_file_path("Read", &input), None);
595
596 let input = serde_json::json!({"path": null});
597 assert_eq!(extract_file_path("Glob", &input), None);
598 }
599}