1use std::path::Path;
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::RwLock;
8use tracing::{debug, info, instrument, warn};
9
10use super::common::BudgetContext;
11use super::events::AgentResult;
12use super::executor::Agent;
13use super::request::RequestBuilder;
14use super::state_formatter::collect_compaction_state;
15use super::{AgentMetrics, AgentState};
16use crate::context::PromptOrchestrator;
17use crate::hooks::{HookContext, HookEvent, HookInput};
18use crate::session::ExecutionGuard;
19use crate::types::{
20 CompactResult, ContentBlock, Message, PermissionDenial, StopReason, ToolResultBlock, Usage,
21 context_window,
22};
23
24impl Agent {
25 async fn handle_compaction<'a>(
26 &self,
27 _guard: &ExecutionGuard<'a>,
28 hook_ctx: &HookContext,
29 metrics: &mut AgentMetrics,
30 ) {
31 let pre_compact_input = HookInput::pre_compact(&*self.session_id);
32 if let Err(e) = self
33 .hooks
34 .execute(HookEvent::PreCompact, pre_compact_input, hook_ctx)
35 .await
36 {
37 warn!(error = %e, "PreCompact hook failed");
38 }
39
40 debug!("Compacting session context");
41 let compact_result = self
42 .state
43 .compact(&self.client, self.config.execution.compact_keep_messages)
44 .await;
45
46 match compact_result {
47 Ok(CompactResult::Compacted { saved_tokens, .. }) => {
48 info!(saved_tokens, "Session context compacted");
49 metrics.record_compaction();
50
51 let state_sections = collect_compaction_state(&self.tools).await;
52 if !state_sections.is_empty() {
53 self.state
54 .with_session_mut(|session| {
55 session.add_user_message(format!(
56 "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
57 state_sections.join("\n\n")
58 ));
59 })
60 .await;
61 }
62 }
63 Ok(CompactResult::NotNeeded | CompactResult::Skipped { .. }) => {
64 debug!("Compaction skipped or not needed");
65 }
66 Err(e) => {
67 warn!(error = %e, "Session compaction failed");
68 }
69 }
70 }
71
72 fn check_budget(&self) -> crate::Result<()> {
73 BudgetContext {
74 tracker: &self.budget_tracker,
75 tenant: self.tenant_budget.as_deref(),
76 config: &self.config.budget,
77 }
78 .check()
79 }
80
81 pub async fn execute(&self, prompt: &str) -> crate::Result<AgentResult> {
82 let timeout = self
83 .config
84 .execution
85 .timeout
86 .unwrap_or(std::time::Duration::from_secs(600));
87
88 if self.state.is_executing() {
89 self.state
90 .enqueue(prompt)
91 .await
92 .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
93 return self.wait_for_execution(timeout).await;
94 }
95
96 tokio::time::timeout(timeout, self.execute_inner(prompt))
97 .await
98 .map_err(|_| crate::Error::Timeout(timeout))?
99 }
100
101 async fn wait_for_execution(&self, timeout: std::time::Duration) -> crate::Result<AgentResult> {
102 let start = Instant::now();
103 loop {
104 if start.elapsed() > timeout {
105 return Err(crate::Error::Timeout(timeout));
106 }
107
108 if !self.state.is_executing()
109 && let Some(merged) = self.state.dequeue_or_merge().await
110 {
111 return self.execute_inner(&merged.content).await;
112 }
113
114 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
115 }
116 }
117
118 pub async fn execute_with_messages(
119 &self,
120 previous_messages: Vec<Message>,
121 prompt: &str,
122 ) -> crate::Result<AgentResult> {
123 let context_summary = previous_messages
124 .iter()
125 .filter_map(|m| {
126 m.content
127 .iter()
128 .filter_map(|b| match b {
129 ContentBlock::Text { text, .. } => Some(text.as_str()),
130 _ => None,
131 })
132 .next()
133 })
134 .collect::<Vec<_>>()
135 .join("\n---\n");
136
137 let enriched_prompt = if context_summary.is_empty() {
138 prompt.to_string()
139 } else {
140 format!(
141 "Previous conversation context:\n{}\n\nContinue with: {}",
142 context_summary, prompt
143 )
144 };
145
146 self.execute(&enriched_prompt).await
147 }
148
149 #[instrument(skip(self, prompt), fields(session_id = %self.session_id))]
150 async fn execute_inner(&self, prompt: &str) -> crate::Result<AgentResult> {
151 let guard = self.state.acquire_execution().await;
152 let execution_start = Instant::now();
153 let hook_ctx = self.hook_context();
154
155 let session_start_input = HookInput::session_start(&*self.session_id);
156 if let Err(e) = self
157 .hooks
158 .execute(HookEvent::SessionStart, session_start_input, &hook_ctx)
159 .await
160 {
161 warn!(error = %e, "SessionStart hook failed");
162 }
163
164 let final_prompt = if let Some(merged) = self.state.dequeue_or_merge().await {
165 format!("{}\n{}", prompt, merged.content)
166 } else {
167 prompt.to_string()
168 };
169
170 let prompt_input = HookInput::user_prompt_submit(&*self.session_id, &final_prompt);
171 let prompt_output = self
172 .hooks
173 .execute(HookEvent::UserPromptSubmit, prompt_input, &hook_ctx)
174 .await?;
175
176 if !prompt_output.continue_execution {
177 let session_end_input = HookInput::session_end(&*self.session_id);
178 if let Err(e) = self
179 .hooks
180 .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
181 .await
182 {
183 warn!(error = %e, "SessionEnd hook failed");
184 }
185 return Err(crate::Error::Permission(
186 prompt_output
187 .stop_reason
188 .unwrap_or_else(|| "Blocked by hook".into()),
189 ));
190 }
191
192 self.state
193 .with_session_mut(|session| {
194 session.add_user_message(&final_prompt);
195 })
196 .await;
197
198 let mut metrics = AgentMetrics::default();
199 let mut final_text = String::new();
200 let mut final_stop_reason = StopReason::EndTurn;
201 let mut dynamic_rules_context = String::new();
202 let mut total_usage = Usage::default();
203
204 let request_builder = RequestBuilder::new(&self.config, Arc::clone(&self.tools));
205 let max_tokens = context_window::for_model(&self.config.model.primary);
206
207 info!(prompt_len = final_prompt.len(), "Starting agent execution");
208
209 loop {
210 metrics.iterations += 1;
211 if metrics.iterations > self.config.execution.max_iterations {
212 warn!(
213 max = self.config.execution.max_iterations,
214 "Max iterations reached"
215 );
216 break;
217 }
218
219 self.check_budget()?;
220
221 debug!(iteration = metrics.iterations, "Starting iteration");
222
223 let messages = self
224 .state
225 .with_session(|session| {
226 session.to_api_messages_with_cache(self.config.cache.message_ttl_option())
227 })
228 .await;
229
230 let api_start = Instant::now();
231 let request = request_builder.build(messages, &dynamic_rules_context);
232 let response = self.client.send_with_auth_retry(request).await?;
233 let api_duration_ms = api_start.elapsed().as_millis() as u64;
234 metrics.record_api_call_with_timing(api_duration_ms);
235 debug!(api_time_ms = api_duration_ms, "API call completed");
236
237 self.state
238 .with_session_mut(|session| {
239 session.update_usage(&response.usage);
240 })
241 .await;
242
243 total_usage.input_tokens += response.usage.input_tokens;
244 total_usage.output_tokens += response.usage.output_tokens;
245 metrics.add_usage_with_cache(&response.usage);
246 metrics.record_model_usage(&self.config.model.primary, &response.usage);
247
248 if let Some(ref server_usage) = response.usage.server_tool_use {
249 metrics.update_server_tool_use_from_api(server_usage);
250 }
251
252 let cost = self
253 .budget_tracker
254 .record(&self.config.model.primary, &response.usage);
255 metrics.add_cost(cost);
256 if let Some(ref tenant_budget) = self.tenant_budget {
257 tenant_budget.record(&self.config.model.primary, &response.usage);
258 }
259
260 final_text = response.text();
261 final_stop_reason = response.stop_reason.unwrap_or(StopReason::EndTurn);
262
263 self.state
264 .with_session_mut(|session| {
265 session.add_assistant_message(response.content.clone(), Some(response.usage));
266 })
267 .await;
268
269 if !response.wants_tool_use() {
270 debug!("No tool use requested, ending loop");
271 break;
272 }
273
274 let tool_uses = response.tool_uses();
275 let hook_ctx = self.hook_context();
276
277 let mut prepared = Vec::with_capacity(tool_uses.len());
278 let mut blocked = Vec::with_capacity(tool_uses.len());
279
280 for tool_use in &tool_uses {
281 let pre_input = HookInput::pre_tool_use(
282 &*self.session_id,
283 &tool_use.name,
284 tool_use.input.clone(),
285 );
286 let pre_output = self
287 .hooks
288 .execute(HookEvent::PreToolUse, pre_input, &hook_ctx)
289 .await?;
290
291 if !pre_output.continue_execution {
292 debug!(tool = %tool_use.name, "Tool blocked by hook");
293 let reason = pre_output
294 .stop_reason
295 .clone()
296 .unwrap_or_else(|| "Blocked by hook".into());
297 blocked.push(ToolResultBlock::error(&tool_use.id, reason.clone()));
298 metrics.record_permission_denial(
299 PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
300 .with_reason(reason),
301 );
302 } else {
303 let input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
304 prepared.push((tool_use.id.clone(), tool_use.name.clone(), input));
305 }
306 }
307
308 let tool_futures = prepared.into_iter().map(|(id, name, input)| {
309 let tools = &self.tools;
310 async move {
311 let start = Instant::now();
312 let result = tools.execute(&name, input.clone()).await;
313 let duration_ms = start.elapsed().as_millis() as u64;
314 (id, name, input, result, duration_ms)
315 }
316 });
317
318 let parallel_results: Vec<_> = futures::future::join_all(tool_futures).await;
319
320 let all_non_retryable = !parallel_results.is_empty()
321 && parallel_results
322 .iter()
323 .all(|(_, _, _, result, _)| result.is_non_retryable());
324
325 let mut results = blocked;
326 for (id, name, input, result, duration_ms) in parallel_results {
327 let is_error = result.is_error();
328 debug!(tool = %name, duration_ms, is_error, "Tool execution completed");
329 metrics.record_tool(&id, &name, duration_ms, is_error);
330
331 if let Some(ref inner_usage) = result.inner_usage {
332 self.state
333 .with_session_mut(|session| {
334 session.update_usage(inner_usage);
335 })
336 .await;
337 total_usage.input_tokens += inner_usage.input_tokens;
338 total_usage.output_tokens += inner_usage.output_tokens;
339 metrics.add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
340 let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
341 metrics.record_model_usage(inner_model, inner_usage);
342
343 let inner_cost = self.budget_tracker.record(inner_model, inner_usage);
344 metrics.add_cost(inner_cost);
345
346 debug!(
347 tool = %name,
348 model = %inner_model,
349 input_tokens = inner_usage.input_tokens,
350 output_tokens = inner_usage.output_tokens,
351 cost_usd = inner_cost,
352 "Accumulated inner usage from tool"
353 );
354 }
355
356 if let Some(file_path) = extract_file_path(&name, &input)
357 && let Some(ref orchestrator) = self.orchestrator
358 {
359 let new_rules = activate_rules_for_file(orchestrator, &file_path).await;
360 if !new_rules.is_empty() {
361 dynamic_rules_context =
362 build_dynamic_rules_context(orchestrator, &file_path).await;
363 debug!(rules = ?new_rules, "Activated rules for file");
364 }
365 }
366
367 if is_error {
368 let failure_input = HookInput::post_tool_use_failure(
369 &*self.session_id,
370 &name,
371 result.error_message(),
372 );
373 if let Err(e) = self
374 .hooks
375 .execute(HookEvent::PostToolUseFailure, failure_input, &hook_ctx)
376 .await
377 {
378 warn!(tool = %name, error = %e, "PostToolUseFailure hook failed");
379 }
380 } else {
381 let post_input =
382 HookInput::post_tool_use(&*self.session_id, &name, result.output.clone());
383 if let Err(e) = self
384 .hooks
385 .execute(HookEvent::PostToolUse, post_input, &hook_ctx)
386 .await
387 {
388 warn!(tool = %name, error = %e, "PostToolUse hook failed");
389 }
390 }
391 results.push(ToolResultBlock::from_tool_result(&id, result));
392 }
393
394 self.state
395 .with_session_mut(|session| {
396 session.add_tool_results(results);
397 })
398 .await;
399
400 if all_non_retryable {
401 warn!("All tool calls failed with non-retryable errors, ending execution");
402 break;
403 }
404
405 let should_compact = self
406 .state
407 .with_session(|session| {
408 self.config.execution.auto_compact
409 && session.should_compact(
410 max_tokens,
411 self.config.execution.compact_threshold,
412 self.config.execution.compact_keep_messages,
413 )
414 })
415 .await;
416
417 if should_compact {
418 self.handle_compaction(&guard, &hook_ctx, &mut metrics)
419 .await;
420 }
421 }
422
423 metrics.execution_time_ms = execution_start.elapsed().as_millis() as u64;
424
425 let stop_input = HookInput::stop(&*self.session_id);
426 if let Err(e) = self
427 .hooks
428 .execute(HookEvent::Stop, stop_input, &hook_ctx)
429 .await
430 {
431 warn!(error = %e, "Stop hook failed");
432 }
433
434 let session_end_input = HookInput::session_end(&*self.session_id);
435 if let Err(e) = self
436 .hooks
437 .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
438 .await
439 {
440 warn!(error = %e, "SessionEnd hook failed");
441 }
442
443 info!(
444 iterations = metrics.iterations,
445 tool_calls = metrics.tool_calls,
446 api_calls = metrics.api_calls,
447 total_tokens = metrics.total_tokens(),
448 execution_time_ms = metrics.execution_time_ms,
449 "Agent execution completed"
450 );
451
452 let messages = self
453 .state
454 .with_session(|session| session.to_api_messages())
455 .await;
456
457 drop(guard);
458
459 Ok(AgentResult {
460 text: final_text,
461 usage: total_usage,
462 tool_calls: metrics.tool_calls,
463 iterations: metrics.iterations,
464 stop_reason: final_stop_reason,
465 state: AgentState::Completed,
466 metrics,
467 session_id: self.session_id.to_string(),
468 structured_output: None,
469 messages,
470 uuid: uuid::Uuid::new_v4().to_string(),
471 })
472 }
473
474 pub(crate) fn hook_context(&self) -> HookContext {
475 HookContext::new(&*self.session_id)
476 .with_cwd(self.config.working_dir.clone().unwrap_or_default())
477 .with_env(self.config.security.env.clone())
478 }
479}
480
481pub(crate) fn extract_file_path(tool_name: &str, input: &serde_json::Value) -> Option<String> {
482 match tool_name {
483 "Read" | "Write" | "Edit" => input
484 .get("file_path")
485 .and_then(|v| v.as_str())
486 .map(String::from),
487 "Glob" | "Grep" => input.get("path").and_then(|v| v.as_str()).map(String::from),
488 _ => None,
489 }
490}
491
492pub(crate) async fn activate_rules_for_file(
493 orchestrator: &Arc<RwLock<PromptOrchestrator>>,
494 file_path: &str,
495) -> Vec<String> {
496 let orch = orchestrator.read().await;
497 let path = Path::new(file_path);
498 let rules = orch.rules_engine().find_matching(path);
499 rules.iter().map(|r| r.name.clone()).collect()
500}
501
502pub(crate) async fn build_dynamic_rules_context(
503 orchestrator: &Arc<RwLock<PromptOrchestrator>>,
504 file_path: &str,
505) -> String {
506 let orch = orchestrator.read().await;
507 let path = Path::new(file_path);
508 orch.build_dynamic_context(Some(path)).await
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn test_extract_file_path() {
517 let input = serde_json::json!({"file_path": "/src/lib.rs"});
518 assert_eq!(
519 extract_file_path("Read", &input),
520 Some("/src/lib.rs".to_string())
521 );
522
523 let input = serde_json::json!({"path": "/src"});
524 assert_eq!(extract_file_path("Glob", &input), Some("/src".to_string()));
525
526 let input = serde_json::json!({"command": "ls"});
527 assert_eq!(extract_file_path("Bash", &input), None);
528 }
529
530 #[test]
531 fn test_extract_file_path_all_tools() {
532 let file_input = serde_json::json!({"file_path": "/test/file.rs"});
533 let path_input = serde_json::json!({"path": "/test/dir"});
534
535 assert_eq!(
536 extract_file_path("Read", &file_input),
537 Some("/test/file.rs".to_string())
538 );
539 assert_eq!(
540 extract_file_path("Write", &file_input),
541 Some("/test/file.rs".to_string())
542 );
543 assert_eq!(
544 extract_file_path("Edit", &file_input),
545 Some("/test/file.rs".to_string())
546 );
547
548 assert_eq!(
549 extract_file_path("Glob", &path_input),
550 Some("/test/dir".to_string())
551 );
552 assert_eq!(
553 extract_file_path("Grep", &path_input),
554 Some("/test/dir".to_string())
555 );
556
557 assert_eq!(extract_file_path("WebFetch", &file_input), None);
558 assert_eq!(extract_file_path("Task", &file_input), None);
559 }
560
561 #[test]
562 fn test_extract_file_path_missing_field() {
563 let empty = serde_json::json!({});
564 assert_eq!(extract_file_path("Read", &empty), None);
565 assert_eq!(extract_file_path("Glob", &empty), None);
566
567 let wrong_field = serde_json::json!({"other": "value"});
568 assert_eq!(extract_file_path("Read", &wrong_field), None);
569 assert_eq!(extract_file_path("Glob", &wrong_field), None);
570 }
571
572 #[test]
573 fn test_extract_file_path_non_string() {
574 let input = serde_json::json!({"file_path": 123});
575 assert_eq!(extract_file_path("Read", &input), None);
576
577 let input = serde_json::json!({"path": null});
578 assert_eq!(extract_file_path("Glob", &input), None);
579 }
580}