1use std::path::Path;
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::RwLock;
8use tracing::{debug, error, info, instrument, warn};
9
10use super::events::AgentResult;
11use super::request::RequestBuilder;
12use super::state_formatter::collect_compaction_state;
13use super::{AgentMetrics, AgentState, ConversationContext};
14use crate::context::PromptOrchestrator;
15use crate::hooks::{HookContext, HookEvent, HookInput};
16use crate::types::{
17 CompactResult, ContentBlock, Message, PermissionDenial, StopReason, ToolResultBlock,
18};
19
20use super::executor::Agent;
21use crate::types::ApiResponse;
22
23impl Agent {
24 async fn send_with_retry(
25 &self,
26 request_builder: &RequestBuilder,
27 messages: Vec<Message>,
28 dynamic_rules: &str,
29 ) -> crate::Result<ApiResponse> {
30 const MAX_AUTH_RETRIES: u8 = 2;
31 let mut auth_attempts = 0u8;
32
33 loop {
34 let request = request_builder.build(messages.clone(), dynamic_rules);
35 match self.client.send(request).await {
36 Ok(resp) => return Ok(resp),
37 Err(e) if e.is_unauthorized() && auth_attempts < MAX_AUTH_RETRIES => {
38 auth_attempts += 1;
39 debug!(
40 attempt = auth_attempts,
41 "Received 401, attempting credential refresh"
42 );
43 self.client.refresh_credentials().await?;
44 }
45 Err(e) => {
46 error!(error = %e, "API request failed");
47 return Err(e);
48 }
49 }
50 }
51 }
52
53 async fn handle_compaction(
54 &self,
55 context: &mut ConversationContext,
56 hook_ctx: &HookContext,
57 metrics: &mut AgentMetrics,
58 ) {
59 let pre_compact_input = HookInput::pre_compact(&self.session_id);
60 if let Err(e) = self
61 .hooks
62 .execute(HookEvent::PreCompact, pre_compact_input, hook_ctx)
63 .await
64 {
65 warn!(error = %e, "PreCompact hook failed");
66 }
67
68 debug!("Compacting context");
69 match context
70 .compact(&self.client, self.config.execution.compact_keep_messages)
71 .await
72 {
73 Ok(CompactResult::Compacted { saved_tokens, .. }) => {
74 info!(saved_tokens, "Context compacted");
75 metrics.record_compaction();
76
77 let state_sections = collect_compaction_state(&self.tools).await;
78 if !state_sections.is_empty() {
79 context.push(Message::user(format!(
80 "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
81 state_sections.join("\n\n")
82 )));
83 }
84 }
85 Ok(CompactResult::NotNeeded | CompactResult::Skipped { .. }) => {
86 debug!("Compaction skipped or not needed");
87 }
88 Err(e) => {
89 warn!(error = %e, "Context compaction failed");
90 }
91 }
92 }
93
94 fn check_budget(&self) -> crate::Result<()> {
95 if self.budget_tracker.should_stop() {
96 let status = self.budget_tracker.check();
97 warn!(used = ?status.used(), "Budget exceeded, stopping execution");
98 return Err(crate::Error::BudgetExceeded {
99 used: status.used(),
100 limit: self.config.budget.max_cost_usd.unwrap_or(0.0),
101 });
102 }
103 if let Some(ref tenant_budget) = self.tenant_budget
104 && tenant_budget.should_stop()
105 {
106 warn!(
107 tenant_id = %tenant_budget.tenant_id,
108 used = tenant_budget.used_cost_usd(),
109 "Tenant budget exceeded, stopping execution"
110 );
111 return Err(crate::Error::BudgetExceeded {
112 used: tenant_budget.used_cost_usd(),
113 limit: tenant_budget.max_cost_usd(),
114 });
115 }
116 Ok(())
117 }
118
119 pub async fn execute(&self, prompt: &str) -> crate::Result<AgentResult> {
120 let timeout = self
121 .config
122 .execution
123 .timeout
124 .unwrap_or(std::time::Duration::from_secs(600));
125
126 tokio::time::timeout(timeout, self.execute_inner(prompt))
127 .await
128 .map_err(|_| crate::Error::Timeout(timeout))?
129 }
130
131 pub async fn execute_with_messages(
132 &self,
133 previous_messages: Vec<Message>,
134 prompt: &str,
135 ) -> crate::Result<AgentResult> {
136 let context_summary = previous_messages
137 .iter()
138 .filter_map(|m| {
139 m.content
140 .iter()
141 .filter_map(|b| match b {
142 ContentBlock::Text { text, .. } => Some(text.as_str()),
143 _ => None,
144 })
145 .next()
146 })
147 .collect::<Vec<_>>()
148 .join("\n---\n");
149
150 let enriched_prompt = if context_summary.is_empty() {
151 prompt.to_string()
152 } else {
153 format!(
154 "Previous conversation context:\n{}\n\nContinue with: {}",
155 context_summary, prompt
156 )
157 };
158
159 self.execute(&enriched_prompt).await
160 }
161
162 #[instrument(skip(self, prompt), fields(session_id = %self.session_id))]
163 async fn execute_inner(&self, prompt: &str) -> crate::Result<AgentResult> {
164 let execution_start = Instant::now();
165 let hook_ctx = self.hook_context();
166
167 let session_start_input = HookInput::session_start(&self.session_id);
168 if let Err(e) = self
169 .hooks
170 .execute(HookEvent::SessionStart, session_start_input, &hook_ctx)
171 .await
172 {
173 warn!(error = %e, "SessionStart hook failed");
174 }
175
176 let prompt_input = HookInput::user_prompt_submit(&self.session_id, prompt);
177 let prompt_output = self
178 .hooks
179 .execute(HookEvent::UserPromptSubmit, prompt_input, &hook_ctx)
180 .await?;
181
182 if !prompt_output.continue_execution {
183 let session_end_input = HookInput::session_end(&self.session_id);
184 if let Err(e) = self
185 .hooks
186 .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
187 .await
188 {
189 warn!(error = %e, "SessionEnd hook failed");
190 }
191 return Err(crate::Error::Permission(
192 prompt_output
193 .stop_reason
194 .unwrap_or_else(|| "Blocked by hook".into()),
195 ));
196 }
197
198 let mut context = ConversationContext::new();
199 context.push(Message::user(prompt));
200
201 let mut metrics = AgentMetrics::default();
202 let mut final_text = String::new();
203 let mut final_stop_reason = StopReason::EndTurn;
204 let mut dynamic_rules_context = String::new();
205
206 let request_builder = RequestBuilder::new(&self.config, Arc::clone(&self.tools));
207
208 info!(prompt_len = prompt.len(), "Starting agent execution");
209
210 loop {
211 metrics.iterations += 1;
212 if metrics.iterations > self.config.execution.max_iterations {
213 warn!(
214 max = self.config.execution.max_iterations,
215 "Max iterations reached"
216 );
217 break;
218 }
219
220 self.check_budget()?;
221
222 debug!(iteration = metrics.iterations, "Starting iteration");
223
224 let api_start = Instant::now();
225 let response = self
226 .send_with_retry(
227 &request_builder,
228 context.messages().to_vec(),
229 &dynamic_rules_context,
230 )
231 .await?;
232 let api_duration_ms = api_start.elapsed().as_millis() as u64;
233 metrics.record_api_call_with_timing(api_duration_ms);
234 debug!(api_time_ms = api_duration_ms, "API call completed");
235
236 context.update_usage(response.usage);
237 metrics.add_usage(response.usage.input_tokens, response.usage.output_tokens);
238 metrics.record_model_usage(&self.config.model.primary, &response.usage);
239
240 if let Some(ref server_usage) = response.usage.server_tool_use {
241 metrics.update_server_tool_use_from_api(server_usage);
242 }
243
244 let cost = self
245 .budget_tracker
246 .record(&self.config.model.primary, &response.usage);
247 metrics.add_cost(cost);
248 if let Some(ref tenant_budget) = self.tenant_budget {
249 tenant_budget.record(&self.config.model.primary, &response.usage);
250 }
251
252 if let Some(ref orchestrator) = self.orchestrator {
253 let mut orch = orchestrator.write().await;
254 orch.update_usage(&crate::types::TokenUsage {
255 input_tokens: response.usage.input_tokens as u64,
256 output_tokens: response.usage.output_tokens as u64,
257 cache_read_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0)
258 as u64,
259 cache_creation_input_tokens: response
260 .usage
261 .cache_creation_input_tokens
262 .unwrap_or(0) as u64,
263 });
264 }
265
266 final_text = response.text();
267 final_stop_reason = response.stop_reason.unwrap_or(StopReason::EndTurn);
268
269 context.push(Message {
270 role: crate::types::Role::Assistant,
271 content: response.content.clone(),
272 });
273
274 if !response.wants_tool_use() {
275 debug!("No tool use requested, ending loop");
276 break;
277 }
278
279 let tool_uses = response.tool_uses();
280 let hook_ctx = self.hook_context();
281
282 let mut prepared: Vec<(String, String, serde_json::Value)> = Vec::new();
283 let mut blocked: Vec<ToolResultBlock> = Vec::new();
284
285 for tool_use in &tool_uses {
286 let pre_input = HookInput::pre_tool_use(
287 &self.session_id,
288 &tool_use.name,
289 tool_use.input.clone(),
290 );
291 let pre_output = self
292 .hooks
293 .execute(HookEvent::PreToolUse, pre_input, &hook_ctx)
294 .await?;
295
296 if !pre_output.continue_execution {
297 debug!(tool = %tool_use.name, "Tool blocked by hook");
298 let reason = pre_output
299 .stop_reason
300 .clone()
301 .unwrap_or_else(|| "Blocked by hook".into());
302 blocked.push(ToolResultBlock::error(&tool_use.id, reason.clone()));
303 metrics.record_permission_denial(
304 PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
305 .with_reason(reason),
306 );
307 } else {
308 let input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
309 prepared.push((tool_use.id.clone(), tool_use.name.clone(), input));
310 }
311 }
312
313 let tool_futures = prepared.into_iter().map(|(id, name, input)| {
314 let tools = &self.tools;
315 async move {
316 let start = Instant::now();
317 let result = tools.execute(&name, input.clone()).await;
318 let duration_ms = start.elapsed().as_millis() as u64;
319 (id, name, input, result, duration_ms)
320 }
321 });
322
323 let parallel_results: Vec<_> = futures::future::join_all(tool_futures).await;
324
325 let mut results = blocked;
326 for (id, name, input, result, duration_ms) in parallel_results {
327 let is_error = result.is_error();
328 debug!(tool = %name, duration_ms, is_error, "Tool execution completed");
329 metrics.record_tool(&name, duration_ms, is_error);
330
331 if let Some(ref inner_usage) = result.inner_usage {
332 context.update_usage(*inner_usage);
333 metrics.add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
334 let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
335 metrics.record_model_usage(inner_model, inner_usage);
336
337 let inner_cost = self.budget_tracker.record(inner_model, inner_usage);
338 metrics.add_cost(inner_cost);
339
340 debug!(
341 tool = %name,
342 model = %inner_model,
343 input_tokens = inner_usage.input_tokens,
344 output_tokens = inner_usage.output_tokens,
345 cost_usd = inner_cost,
346 "Accumulated inner usage from tool"
347 );
348 }
349
350 if let Some(file_path) = extract_file_path(&name, &input)
351 && let Some(ref orchestrator) = self.orchestrator
352 {
353 let new_rules = activate_rules_for_file(orchestrator, &file_path).await;
354 if !new_rules.is_empty() {
355 dynamic_rules_context =
356 build_dynamic_rules_context(orchestrator, &file_path).await;
357 debug!(rules = ?new_rules, "Activated rules for file");
358 }
359 }
360
361 if is_error {
362 let failure_input = HookInput::post_tool_use_failure(
363 &self.session_id,
364 &name,
365 result.error_message(),
366 );
367 if let Err(e) = self
368 .hooks
369 .execute(HookEvent::PostToolUseFailure, failure_input, &hook_ctx)
370 .await
371 {
372 warn!(tool = %name, error = %e, "PostToolUseFailure hook failed");
373 }
374 } else {
375 let post_input =
376 HookInput::post_tool_use(&self.session_id, &name, result.output.clone());
377 if let Err(e) = self
378 .hooks
379 .execute(HookEvent::PostToolUse, post_input, &hook_ctx)
380 .await
381 {
382 warn!(tool = %name, error = %e, "PostToolUse hook failed");
383 }
384 }
385 results.push(ToolResultBlock::from_tool_result(&id, result));
386 }
387
388 context.push(Message::tool_results(results));
389
390 if self.config.execution.auto_compact
391 && context.should_compact(
392 200_000,
393 self.config.execution.compact_threshold,
394 self.config.execution.compact_keep_messages,
395 )
396 {
397 self.handle_compaction(&mut context, &hook_ctx, &mut metrics)
398 .await;
399 }
400 }
401
402 metrics.execution_time_ms = execution_start.elapsed().as_millis() as u64;
403
404 let stop_input = HookInput::stop(&self.session_id);
405 if let Err(e) = self
406 .hooks
407 .execute(HookEvent::Stop, stop_input, &hook_ctx)
408 .await
409 {
410 warn!(error = %e, "Stop hook failed");
411 }
412
413 let session_end_input = HookInput::session_end(&self.session_id);
414 if let Err(e) = self
415 .hooks
416 .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
417 .await
418 {
419 warn!(error = %e, "SessionEnd hook failed");
420 }
421
422 info!(
423 iterations = metrics.iterations,
424 tool_calls = metrics.tool_calls,
425 api_calls = metrics.api_calls,
426 total_tokens = metrics.total_tokens(),
427 execution_time_ms = metrics.execution_time_ms,
428 "Agent execution completed"
429 );
430
431 Ok(AgentResult {
432 text: final_text,
433 usage: *context.total_usage(),
434 tool_calls: metrics.tool_calls,
435 iterations: metrics.iterations,
436 stop_reason: final_stop_reason,
437 state: AgentState::Completed,
438 metrics,
439 session_id: self.session_id.clone(),
440 structured_output: None,
441 messages: context.messages().to_vec(),
442 uuid: uuid::Uuid::new_v4().to_string(),
443 })
444 }
445
446 pub(crate) fn hook_context(&self) -> HookContext {
447 HookContext::new(&self.session_id)
448 .with_cwd(self.config.working_dir.clone().unwrap_or_default())
449 .with_env(self.config.security.env.clone())
450 }
451}
452
453pub(crate) fn extract_file_path(tool_name: &str, input: &serde_json::Value) -> Option<String> {
454 match tool_name {
455 "Read" | "Write" | "Edit" => input
456 .get("file_path")
457 .and_then(|v| v.as_str())
458 .map(String::from),
459 "Glob" | "Grep" => input.get("path").and_then(|v| v.as_str()).map(String::from),
460 _ => None,
461 }
462}
463
464pub(crate) async fn activate_rules_for_file(
465 orchestrator: &Arc<RwLock<PromptOrchestrator>>,
466 file_path: &str,
467) -> Vec<String> {
468 let orch = orchestrator.read().await;
469 let path = Path::new(file_path);
470 let rules = orch.rules_engine().find_matching(path);
471 rules.iter().map(|r| r.name.clone()).collect()
472}
473
474pub(crate) async fn build_dynamic_rules_context(
475 orchestrator: &Arc<RwLock<PromptOrchestrator>>,
476 file_path: &str,
477) -> String {
478 let orch = orchestrator.read().await;
479 let path = Path::new(file_path);
480 orch.build_dynamic_context(Some(path)).await
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_extract_file_path() {
489 let input = serde_json::json!({"file_path": "/src/lib.rs"});
490 assert_eq!(
491 extract_file_path("Read", &input),
492 Some("/src/lib.rs".to_string())
493 );
494
495 let input = serde_json::json!({"path": "/src"});
496 assert_eq!(extract_file_path("Glob", &input), Some("/src".to_string()));
497
498 let input = serde_json::json!({"command": "ls"});
499 assert_eq!(extract_file_path("Bash", &input), None);
500 }
501
502 #[test]
503 fn test_extract_file_path_all_tools() {
504 let file_input = serde_json::json!({"file_path": "/test/file.rs"});
505 let path_input = serde_json::json!({"path": "/test/dir"});
506
507 assert_eq!(
508 extract_file_path("Read", &file_input),
509 Some("/test/file.rs".to_string())
510 );
511 assert_eq!(
512 extract_file_path("Write", &file_input),
513 Some("/test/file.rs".to_string())
514 );
515 assert_eq!(
516 extract_file_path("Edit", &file_input),
517 Some("/test/file.rs".to_string())
518 );
519
520 assert_eq!(
521 extract_file_path("Glob", &path_input),
522 Some("/test/dir".to_string())
523 );
524 assert_eq!(
525 extract_file_path("Grep", &path_input),
526 Some("/test/dir".to_string())
527 );
528
529 assert_eq!(extract_file_path("WebFetch", &file_input), None);
530 assert_eq!(extract_file_path("Task", &file_input), None);
531 }
532
533 #[test]
534 fn test_extract_file_path_missing_field() {
535 let empty = serde_json::json!({});
536 assert_eq!(extract_file_path("Read", &empty), None);
537 assert_eq!(extract_file_path("Glob", &empty), None);
538
539 let wrong_field = serde_json::json!({"other": "value"});
540 assert_eq!(extract_file_path("Read", &wrong_field), None);
541 assert_eq!(extract_file_path("Glob", &wrong_field), None);
542 }
543
544 #[test]
545 fn test_extract_file_path_non_string() {
546 let input = serde_json::json!({"file_path": 123});
547 assert_eq!(extract_file_path("Read", &input), None);
548
549 let input = serde_json::json!({"path": null});
550 assert_eq!(extract_file_path("Glob", &input), None);
551 }
552}