1use std::path::Path;
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::RwLock;
8use tracing::{debug, info, instrument, warn};
9
10use super::common::BudgetContext;
11use super::events::AgentResult;
12use super::executor::Agent;
13use super::request::RequestBuilder;
14use super::state_formatter::collect_compaction_state;
15use super::{AgentMetrics, AgentState};
16use crate::context::PromptOrchestrator;
17use crate::hooks::{HookContext, HookEvent, HookInput};
18use crate::session::ExecutionGuard;
19use crate::types::{
20 CompactResult, ContentBlock, Message, PermissionDenial, StopReason, ToolResultBlock, Usage,
21 context_window,
22};
23
24impl Agent {
25 async fn handle_compaction<'a>(
26 &self,
27 _guard: &ExecutionGuard<'a>,
28 hook_ctx: &HookContext,
29 metrics: &mut AgentMetrics,
30 ) {
31 let pre_compact_input = HookInput::pre_compact(&*self.session_id);
32 if let Err(e) = self
33 .hooks
34 .execute(HookEvent::PreCompact, pre_compact_input, hook_ctx)
35 .await
36 {
37 warn!(error = %e, "PreCompact hook failed");
38 }
39
40 debug!("Compacting session context");
41 let compact_result = self
42 .state
43 .compact(&self.client, self.config.execution.compact_keep_messages)
44 .await;
45
46 match compact_result {
47 Ok(CompactResult::Compacted { saved_tokens, .. }) => {
48 info!(saved_tokens, "Session context compacted");
49 metrics.record_compaction();
50
51 let state_sections = collect_compaction_state(&self.tools).await;
52 if !state_sections.is_empty() {
53 self.state
54 .with_session_mut(|session| {
55 session.add_user_message(format!(
56 "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
57 state_sections.join("\n\n")
58 ));
59 })
60 .await;
61 }
62 }
63 Ok(CompactResult::NotNeeded | CompactResult::Skipped { .. }) => {
64 debug!("Compaction skipped or not needed");
65 }
66 Err(e) => {
67 warn!(error = %e, "Session compaction failed");
68 }
69 }
70 }
71
72 fn check_budget(&self) -> crate::Result<()> {
73 BudgetContext {
74 tracker: &self.budget_tracker,
75 tenant: self.tenant_budget.as_deref(),
76 config: &self.config.budget,
77 }
78 .check()
79 }
80
81 pub async fn execute(&self, prompt: &str) -> crate::Result<AgentResult> {
82 let timeout = self
83 .config
84 .execution
85 .timeout
86 .unwrap_or(std::time::Duration::from_secs(600));
87
88 if self.state.is_executing() {
89 self.state
90 .enqueue(prompt)
91 .await
92 .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
93 return self.wait_for_execution(timeout).await;
94 }
95
96 tokio::time::timeout(timeout, self.execute_inner(prompt))
97 .await
98 .map_err(|_| crate::Error::Timeout(timeout))?
99 }
100
101 async fn wait_for_execution(&self, timeout: std::time::Duration) -> crate::Result<AgentResult> {
102 let start = Instant::now();
103 loop {
104 if start.elapsed() > timeout {
105 return Err(crate::Error::Timeout(timeout));
106 }
107
108 if !self.state.is_executing()
109 && let Some(merged) = self.state.dequeue_or_merge().await
110 {
111 return self.execute_inner(&merged.content).await;
112 }
113
114 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
115 }
116 }
117
118 pub async fn execute_with_messages(
119 &self,
120 previous_messages: Vec<Message>,
121 prompt: &str,
122 ) -> crate::Result<AgentResult> {
123 let context_summary = previous_messages
124 .iter()
125 .filter_map(|m| {
126 m.content
127 .iter()
128 .filter_map(|b| match b {
129 ContentBlock::Text { text, .. } => Some(text.as_str()),
130 _ => None,
131 })
132 .next()
133 })
134 .collect::<Vec<_>>()
135 .join("\n---\n");
136
137 let enriched_prompt = if context_summary.is_empty() {
138 prompt.to_string()
139 } else {
140 format!(
141 "Previous conversation context:\n{}\n\nContinue with: {}",
142 context_summary, prompt
143 )
144 };
145
146 self.execute(&enriched_prompt).await
147 }
148
149 #[instrument(skip(self, prompt), fields(session_id = %self.session_id))]
150 async fn execute_inner(&self, prompt: &str) -> crate::Result<AgentResult> {
151 let guard = self.state.acquire_execution().await;
152 let execution_start = Instant::now();
153 let hook_ctx = self.hook_context();
154
155 let session_start_input = HookInput::session_start(&*self.session_id);
156 if let Err(e) = self
157 .hooks
158 .execute(HookEvent::SessionStart, session_start_input, &hook_ctx)
159 .await
160 {
161 warn!(error = %e, "SessionStart hook failed");
162 }
163
164 let final_prompt = if let Some(merged) = self.state.dequeue_or_merge().await {
165 format!("{}\n{}", prompt, merged.content)
166 } else {
167 prompt.to_string()
168 };
169
170 let prompt_input = HookInput::user_prompt_submit(&*self.session_id, &final_prompt);
171 let prompt_output = self
172 .hooks
173 .execute(HookEvent::UserPromptSubmit, prompt_input, &hook_ctx)
174 .await?;
175
176 if !prompt_output.continue_execution {
177 let session_end_input = HookInput::session_end(&*self.session_id);
178 if let Err(e) = self
179 .hooks
180 .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
181 .await
182 {
183 warn!(error = %e, "SessionEnd hook failed");
184 }
185 return Err(crate::Error::Permission(
186 prompt_output
187 .stop_reason
188 .unwrap_or_else(|| "Blocked by hook".into()),
189 ));
190 }
191
192 self.state
193 .with_session_mut(|session| {
194 session.add_user_message(&final_prompt);
195 })
196 .await;
197
198 let mut metrics = AgentMetrics::default();
199 let mut final_text = String::new();
200 let mut final_stop_reason = StopReason::EndTurn;
201 let mut dynamic_rules_context = String::new();
202 let mut total_usage = Usage::default();
203
204 let request_builder = RequestBuilder::new(&self.config, Arc::clone(&self.tools));
205 let max_tokens = context_window::for_model(&self.config.model.primary);
206
207 info!(prompt_len = final_prompt.len(), "Starting agent execution");
208
209 loop {
210 metrics.iterations += 1;
211 if metrics.iterations > self.config.execution.max_iterations {
212 warn!(
213 max = self.config.execution.max_iterations,
214 "Max iterations reached"
215 );
216 break;
217 }
218
219 self.check_budget()?;
220
221 debug!(iteration = metrics.iterations, "Starting iteration");
222
223 let cache_messages = self.config.cache.enabled && self.config.cache.message_cache;
224 let messages = self
225 .state
226 .with_session(|session| session.to_api_messages_with_cache(cache_messages))
227 .await;
228
229 let api_start = Instant::now();
230 let request = request_builder.build(messages, &dynamic_rules_context);
231 let response = self.client.send_with_auth_retry(request).await?;
232 let api_duration_ms = api_start.elapsed().as_millis() as u64;
233 metrics.record_api_call_with_timing(api_duration_ms);
234 debug!(api_time_ms = api_duration_ms, "API call completed");
235
236 self.state
237 .with_session_mut(|session| {
238 session.update_usage(&response.usage);
239 })
240 .await;
241
242 total_usage.input_tokens += response.usage.input_tokens;
243 total_usage.output_tokens += response.usage.output_tokens;
244 metrics.add_usage_with_cache(&response.usage);
245 metrics.record_model_usage(&self.config.model.primary, &response.usage);
246
247 if let Some(ref server_usage) = response.usage.server_tool_use {
248 metrics.update_server_tool_use_from_api(server_usage);
249 }
250
251 let cost = self
252 .budget_tracker
253 .record(&self.config.model.primary, &response.usage);
254 metrics.add_cost(cost);
255 if let Some(ref tenant_budget) = self.tenant_budget {
256 tenant_budget.record(&self.config.model.primary, &response.usage);
257 }
258
259 final_text = response.text();
260 final_stop_reason = response.stop_reason.unwrap_or(StopReason::EndTurn);
261
262 self.state
263 .with_session_mut(|session| {
264 session.add_assistant_message(response.content.clone(), Some(response.usage));
265 })
266 .await;
267
268 if !response.wants_tool_use() {
269 debug!("No tool use requested, ending loop");
270 break;
271 }
272
273 let tool_uses = response.tool_uses();
274 let hook_ctx = self.hook_context();
275
276 let mut prepared = Vec::with_capacity(tool_uses.len());
277 let mut blocked = Vec::with_capacity(tool_uses.len());
278
279 for tool_use in &tool_uses {
280 let pre_input = HookInput::pre_tool_use(
281 &*self.session_id,
282 &tool_use.name,
283 tool_use.input.clone(),
284 );
285 let pre_output = self
286 .hooks
287 .execute(HookEvent::PreToolUse, pre_input, &hook_ctx)
288 .await?;
289
290 if !pre_output.continue_execution {
291 debug!(tool = %tool_use.name, "Tool blocked by hook");
292 let reason = pre_output
293 .stop_reason
294 .clone()
295 .unwrap_or_else(|| "Blocked by hook".into());
296 blocked.push(ToolResultBlock::error(&tool_use.id, reason.clone()));
297 metrics.record_permission_denial(
298 PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
299 .with_reason(reason),
300 );
301 } else {
302 let input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
303 prepared.push((tool_use.id.clone(), tool_use.name.clone(), input));
304 }
305 }
306
307 let tool_futures = prepared.into_iter().map(|(id, name, input)| {
308 let tools = &self.tools;
309 async move {
310 let start = Instant::now();
311 let result = tools.execute(&name, input.clone()).await;
312 let duration_ms = start.elapsed().as_millis() as u64;
313 (id, name, input, result, duration_ms)
314 }
315 });
316
317 let parallel_results: Vec<_> = futures::future::join_all(tool_futures).await;
318
319 let all_non_retryable = !parallel_results.is_empty()
320 && parallel_results
321 .iter()
322 .all(|(_, _, _, result, _)| result.is_non_retryable());
323
324 let mut results = blocked;
325 for (id, name, input, result, duration_ms) in parallel_results {
326 let is_error = result.is_error();
327 debug!(tool = %name, duration_ms, is_error, "Tool execution completed");
328 metrics.record_tool(&id, &name, duration_ms, is_error);
329
330 if let Some(ref inner_usage) = result.inner_usage {
331 self.state
332 .with_session_mut(|session| {
333 session.update_usage(inner_usage);
334 })
335 .await;
336 total_usage.input_tokens += inner_usage.input_tokens;
337 total_usage.output_tokens += inner_usage.output_tokens;
338 metrics.add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
339 let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
340 metrics.record_model_usage(inner_model, inner_usage);
341
342 let inner_cost = self.budget_tracker.record(inner_model, inner_usage);
343 metrics.add_cost(inner_cost);
344
345 debug!(
346 tool = %name,
347 model = %inner_model,
348 input_tokens = inner_usage.input_tokens,
349 output_tokens = inner_usage.output_tokens,
350 cost_usd = inner_cost,
351 "Accumulated inner usage from tool"
352 );
353 }
354
355 if let Some(file_path) = extract_file_path(&name, &input)
356 && let Some(ref orchestrator) = self.orchestrator
357 {
358 let new_rules = activate_rules_for_file(orchestrator, &file_path).await;
359 if !new_rules.is_empty() {
360 dynamic_rules_context =
361 build_dynamic_rules_context(orchestrator, &file_path).await;
362 debug!(rules = ?new_rules, "Activated rules for file");
363 }
364 }
365
366 if is_error {
367 let failure_input = HookInput::post_tool_use_failure(
368 &*self.session_id,
369 &name,
370 result.error_message(),
371 );
372 if let Err(e) = self
373 .hooks
374 .execute(HookEvent::PostToolUseFailure, failure_input, &hook_ctx)
375 .await
376 {
377 warn!(tool = %name, error = %e, "PostToolUseFailure hook failed");
378 }
379 } else {
380 let post_input =
381 HookInput::post_tool_use(&*self.session_id, &name, result.output.clone());
382 if let Err(e) = self
383 .hooks
384 .execute(HookEvent::PostToolUse, post_input, &hook_ctx)
385 .await
386 {
387 warn!(tool = %name, error = %e, "PostToolUse hook failed");
388 }
389 }
390 results.push(ToolResultBlock::from_tool_result(&id, result));
391 }
392
393 self.state
394 .with_session_mut(|session| {
395 session.add_tool_results(results);
396 })
397 .await;
398
399 if all_non_retryable {
400 warn!("All tool calls failed with non-retryable errors, ending execution");
401 break;
402 }
403
404 let should_compact = self
405 .state
406 .with_session(|session| {
407 self.config.execution.auto_compact
408 && session.should_compact(
409 max_tokens,
410 self.config.execution.compact_threshold,
411 self.config.execution.compact_keep_messages,
412 )
413 })
414 .await;
415
416 if should_compact {
417 self.handle_compaction(&guard, &hook_ctx, &mut metrics)
418 .await;
419 }
420 }
421
422 metrics.execution_time_ms = execution_start.elapsed().as_millis() as u64;
423
424 let stop_input = HookInput::stop(&*self.session_id);
425 if let Err(e) = self
426 .hooks
427 .execute(HookEvent::Stop, stop_input, &hook_ctx)
428 .await
429 {
430 warn!(error = %e, "Stop hook failed");
431 }
432
433 let session_end_input = HookInput::session_end(&*self.session_id);
434 if let Err(e) = self
435 .hooks
436 .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
437 .await
438 {
439 warn!(error = %e, "SessionEnd hook failed");
440 }
441
442 info!(
443 iterations = metrics.iterations,
444 tool_calls = metrics.tool_calls,
445 api_calls = metrics.api_calls,
446 total_tokens = metrics.total_tokens(),
447 execution_time_ms = metrics.execution_time_ms,
448 "Agent execution completed"
449 );
450
451 let messages = self
452 .state
453 .with_session(|session| session.to_api_messages())
454 .await;
455
456 drop(guard);
457
458 Ok(AgentResult {
459 text: final_text,
460 usage: total_usage,
461 tool_calls: metrics.tool_calls,
462 iterations: metrics.iterations,
463 stop_reason: final_stop_reason,
464 state: AgentState::Completed,
465 metrics,
466 session_id: self.session_id.to_string(),
467 structured_output: None,
468 messages,
469 uuid: uuid::Uuid::new_v4().to_string(),
470 })
471 }
472
473 pub(crate) fn hook_context(&self) -> HookContext {
474 HookContext::new(&*self.session_id)
475 .with_cwd(self.config.working_dir.clone().unwrap_or_default())
476 .with_env(self.config.security.env.clone())
477 }
478}
479
480pub(crate) fn extract_file_path(tool_name: &str, input: &serde_json::Value) -> Option<String> {
481 match tool_name {
482 "Read" | "Write" | "Edit" => input
483 .get("file_path")
484 .and_then(|v| v.as_str())
485 .map(String::from),
486 "Glob" | "Grep" => input.get("path").and_then(|v| v.as_str()).map(String::from),
487 _ => None,
488 }
489}
490
491pub(crate) async fn activate_rules_for_file(
492 orchestrator: &Arc<RwLock<PromptOrchestrator>>,
493 file_path: &str,
494) -> Vec<String> {
495 let orch = orchestrator.read().await;
496 let path = Path::new(file_path);
497 let rules = orch.rules_engine().find_matching(path);
498 rules.iter().map(|r| r.name.clone()).collect()
499}
500
501pub(crate) async fn build_dynamic_rules_context(
502 orchestrator: &Arc<RwLock<PromptOrchestrator>>,
503 file_path: &str,
504) -> String {
505 let orch = orchestrator.read().await;
506 let path = Path::new(file_path);
507 orch.build_dynamic_context(Some(path)).await
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_extract_file_path() {
516 let input = serde_json::json!({"file_path": "/src/lib.rs"});
517 assert_eq!(
518 extract_file_path("Read", &input),
519 Some("/src/lib.rs".to_string())
520 );
521
522 let input = serde_json::json!({"path": "/src"});
523 assert_eq!(extract_file_path("Glob", &input), Some("/src".to_string()));
524
525 let input = serde_json::json!({"command": "ls"});
526 assert_eq!(extract_file_path("Bash", &input), None);
527 }
528
529 #[test]
530 fn test_extract_file_path_all_tools() {
531 let file_input = serde_json::json!({"file_path": "/test/file.rs"});
532 let path_input = serde_json::json!({"path": "/test/dir"});
533
534 assert_eq!(
535 extract_file_path("Read", &file_input),
536 Some("/test/file.rs".to_string())
537 );
538 assert_eq!(
539 extract_file_path("Write", &file_input),
540 Some("/test/file.rs".to_string())
541 );
542 assert_eq!(
543 extract_file_path("Edit", &file_input),
544 Some("/test/file.rs".to_string())
545 );
546
547 assert_eq!(
548 extract_file_path("Glob", &path_input),
549 Some("/test/dir".to_string())
550 );
551 assert_eq!(
552 extract_file_path("Grep", &path_input),
553 Some("/test/dir".to_string())
554 );
555
556 assert_eq!(extract_file_path("WebFetch", &file_input), None);
557 assert_eq!(extract_file_path("Task", &file_input), None);
558 }
559
560 #[test]
561 fn test_extract_file_path_missing_field() {
562 let empty = serde_json::json!({});
563 assert_eq!(extract_file_path("Read", &empty), None);
564 assert_eq!(extract_file_path("Glob", &empty), None);
565
566 let wrong_field = serde_json::json!({"other": "value"});
567 assert_eq!(extract_file_path("Read", &wrong_field), None);
568 assert_eq!(extract_file_path("Glob", &wrong_field), None);
569 }
570
571 #[test]
572 fn test_extract_file_path_non_string() {
573 let input = serde_json::json!({"file_path": 123});
574 assert_eq!(extract_file_path("Read", &input), None);
575
576 let input = serde_json::json!({"path": null});
577 assert_eq!(extract_file_path("Glob", &input), None);
578 }
579}