claude_agent/agent/
execution.rs

1//! Agent execution logic.
2
3use std::path::Path;
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::RwLock;
8use tracing::{debug, error, info, instrument, warn};
9
10use super::events::AgentResult;
11use super::request::RequestBuilder;
12use super::state_formatter::collect_compaction_state;
13use super::{AgentMetrics, AgentState, ConversationContext};
14use crate::context::PromptOrchestrator;
15use crate::hooks::{HookContext, HookEvent, HookInput};
16use crate::types::{
17    CompactResult, ContentBlock, Message, PermissionDenial, StopReason, ToolResultBlock,
18};
19
20use super::executor::Agent;
21use crate::types::ApiResponse;
22
23impl Agent {
24    async fn send_with_retry(
25        &self,
26        request_builder: &RequestBuilder,
27        messages: Vec<Message>,
28        dynamic_rules: &str,
29    ) -> crate::Result<ApiResponse> {
30        const MAX_AUTH_RETRIES: u8 = 2;
31        let mut auth_attempts = 0u8;
32
33        loop {
34            let request = request_builder.build(messages.clone(), dynamic_rules);
35            match self.client.send(request).await {
36                Ok(resp) => return Ok(resp),
37                Err(e) if e.is_unauthorized() && auth_attempts < MAX_AUTH_RETRIES => {
38                    auth_attempts += 1;
39                    debug!(
40                        attempt = auth_attempts,
41                        "Received 401, attempting credential refresh"
42                    );
43                    self.client.refresh_credentials().await?;
44                }
45                Err(e) => {
46                    error!(error = %e, "API request failed");
47                    return Err(e);
48                }
49            }
50        }
51    }
52
53    async fn handle_compaction(
54        &self,
55        context: &mut ConversationContext,
56        hook_ctx: &HookContext,
57        metrics: &mut AgentMetrics,
58    ) {
59        let pre_compact_input = HookInput::pre_compact(&self.session_id);
60        if let Err(e) = self
61            .hooks
62            .execute(HookEvent::PreCompact, pre_compact_input, hook_ctx)
63            .await
64        {
65            warn!(error = %e, "PreCompact hook failed");
66        }
67
68        debug!("Compacting context");
69        match context
70            .compact(&self.client, self.config.execution.compact_keep_messages)
71            .await
72        {
73            Ok(CompactResult::Compacted { saved_tokens, .. }) => {
74                info!(saved_tokens, "Context compacted");
75                metrics.record_compaction();
76
77                let state_sections = collect_compaction_state(&self.tools).await;
78                if !state_sections.is_empty() {
79                    context.push(Message::user(format!(
80                        "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
81                        state_sections.join("\n\n")
82                    )));
83                }
84            }
85            Ok(CompactResult::NotNeeded | CompactResult::Skipped { .. }) => {
86                debug!("Compaction skipped or not needed");
87            }
88            Err(e) => {
89                warn!(error = %e, "Context compaction failed");
90            }
91        }
92    }
93
94    fn check_budget(&self) -> crate::Result<()> {
95        if self.budget_tracker.should_stop() {
96            let status = self.budget_tracker.check();
97            warn!(used = ?status.used(), "Budget exceeded, stopping execution");
98            return Err(crate::Error::BudgetExceeded {
99                used: status.used(),
100                limit: self.config.budget.max_cost_usd.unwrap_or(0.0),
101            });
102        }
103        if let Some(ref tenant_budget) = self.tenant_budget
104            && tenant_budget.should_stop()
105        {
106            warn!(
107                tenant_id = %tenant_budget.tenant_id,
108                used = tenant_budget.used_cost_usd(),
109                "Tenant budget exceeded, stopping execution"
110            );
111            return Err(crate::Error::BudgetExceeded {
112                used: tenant_budget.used_cost_usd(),
113                limit: tenant_budget.max_cost_usd(),
114            });
115        }
116        Ok(())
117    }
118
119    pub async fn execute(&self, prompt: &str) -> crate::Result<AgentResult> {
120        let timeout = self
121            .config
122            .execution
123            .timeout
124            .unwrap_or(std::time::Duration::from_secs(600));
125
126        tokio::time::timeout(timeout, self.execute_inner(prompt))
127            .await
128            .map_err(|_| crate::Error::Timeout(timeout))?
129    }
130
131    pub async fn execute_with_messages(
132        &self,
133        previous_messages: Vec<Message>,
134        prompt: &str,
135    ) -> crate::Result<AgentResult> {
136        let context_summary = previous_messages
137            .iter()
138            .filter_map(|m| {
139                m.content
140                    .iter()
141                    .filter_map(|b| match b {
142                        ContentBlock::Text { text, .. } => Some(text.as_str()),
143                        _ => None,
144                    })
145                    .next()
146            })
147            .collect::<Vec<_>>()
148            .join("\n---\n");
149
150        let enriched_prompt = if context_summary.is_empty() {
151            prompt.to_string()
152        } else {
153            format!(
154                "Previous conversation context:\n{}\n\nContinue with: {}",
155                context_summary, prompt
156            )
157        };
158
159        self.execute(&enriched_prompt).await
160    }
161
162    #[instrument(skip(self, prompt), fields(session_id = %self.session_id))]
163    async fn execute_inner(&self, prompt: &str) -> crate::Result<AgentResult> {
164        let execution_start = Instant::now();
165        let hook_ctx = self.hook_context();
166
167        let session_start_input = HookInput::session_start(&self.session_id);
168        if let Err(e) = self
169            .hooks
170            .execute(HookEvent::SessionStart, session_start_input, &hook_ctx)
171            .await
172        {
173            warn!(error = %e, "SessionStart hook failed");
174        }
175
176        let prompt_input = HookInput::user_prompt_submit(&self.session_id, prompt);
177        let prompt_output = self
178            .hooks
179            .execute(HookEvent::UserPromptSubmit, prompt_input, &hook_ctx)
180            .await?;
181
182        if !prompt_output.continue_execution {
183            let session_end_input = HookInput::session_end(&self.session_id);
184            if let Err(e) = self
185                .hooks
186                .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
187                .await
188            {
189                warn!(error = %e, "SessionEnd hook failed");
190            }
191            return Err(crate::Error::Permission(
192                prompt_output
193                    .stop_reason
194                    .unwrap_or_else(|| "Blocked by hook".into()),
195            ));
196        }
197
198        let mut context = ConversationContext::new();
199        context.push(Message::user(prompt));
200
201        let mut metrics = AgentMetrics::default();
202        let mut final_text = String::new();
203        let mut final_stop_reason = StopReason::EndTurn;
204        let mut dynamic_rules_context = String::new();
205
206        let request_builder = RequestBuilder::new(&self.config, Arc::clone(&self.tools));
207
208        info!(prompt_len = prompt.len(), "Starting agent execution");
209
210        loop {
211            metrics.iterations += 1;
212            if metrics.iterations > self.config.execution.max_iterations {
213                warn!(
214                    max = self.config.execution.max_iterations,
215                    "Max iterations reached"
216                );
217                break;
218            }
219
220            self.check_budget()?;
221
222            debug!(iteration = metrics.iterations, "Starting iteration");
223
224            let api_start = Instant::now();
225            let response = self
226                .send_with_retry(
227                    &request_builder,
228                    context.messages().to_vec(),
229                    &dynamic_rules_context,
230                )
231                .await?;
232            let api_duration_ms = api_start.elapsed().as_millis() as u64;
233            metrics.record_api_call_with_timing(api_duration_ms);
234            debug!(api_time_ms = api_duration_ms, "API call completed");
235
236            context.update_usage(response.usage);
237            metrics.add_usage(response.usage.input_tokens, response.usage.output_tokens);
238            metrics.record_model_usage(&self.config.model.primary, &response.usage);
239
240            if let Some(ref server_usage) = response.usage.server_tool_use {
241                metrics.update_server_tool_use_from_api(server_usage);
242            }
243
244            let cost = self
245                .budget_tracker
246                .record(&self.config.model.primary, &response.usage);
247            metrics.add_cost(cost);
248            if let Some(ref tenant_budget) = self.tenant_budget {
249                tenant_budget.record(&self.config.model.primary, &response.usage);
250            }
251
252            if let Some(ref orchestrator) = self.orchestrator {
253                let mut orch = orchestrator.write().await;
254                orch.update_usage(&crate::types::TokenUsage {
255                    input_tokens: response.usage.input_tokens as u64,
256                    output_tokens: response.usage.output_tokens as u64,
257                    cache_read_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0)
258                        as u64,
259                    cache_creation_input_tokens: response
260                        .usage
261                        .cache_creation_input_tokens
262                        .unwrap_or(0) as u64,
263                });
264            }
265
266            final_text = response.text();
267            final_stop_reason = response.stop_reason.unwrap_or(StopReason::EndTurn);
268
269            context.push(Message {
270                role: crate::types::Role::Assistant,
271                content: response.content.clone(),
272            });
273
274            if !response.wants_tool_use() {
275                debug!("No tool use requested, ending loop");
276                break;
277            }
278
279            let tool_uses = response.tool_uses();
280            let hook_ctx = self.hook_context();
281
282            let mut prepared: Vec<(String, String, serde_json::Value)> = Vec::new();
283            let mut blocked: Vec<ToolResultBlock> = Vec::new();
284
285            for tool_use in &tool_uses {
286                let pre_input = HookInput::pre_tool_use(
287                    &self.session_id,
288                    &tool_use.name,
289                    tool_use.input.clone(),
290                );
291                let pre_output = self
292                    .hooks
293                    .execute(HookEvent::PreToolUse, pre_input, &hook_ctx)
294                    .await?;
295
296                if !pre_output.continue_execution {
297                    debug!(tool = %tool_use.name, "Tool blocked by hook");
298                    let reason = pre_output
299                        .stop_reason
300                        .clone()
301                        .unwrap_or_else(|| "Blocked by hook".into());
302                    blocked.push(ToolResultBlock::error(&tool_use.id, reason.clone()));
303                    metrics.record_permission_denial(
304                        PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
305                            .with_reason(reason),
306                    );
307                } else {
308                    let input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
309                    prepared.push((tool_use.id.clone(), tool_use.name.clone(), input));
310                }
311            }
312
313            let tool_futures = prepared.into_iter().map(|(id, name, input)| {
314                let tools = &self.tools;
315                async move {
316                    let start = Instant::now();
317                    let result = tools.execute(&name, input.clone()).await;
318                    let duration_ms = start.elapsed().as_millis() as u64;
319                    (id, name, input, result, duration_ms)
320                }
321            });
322
323            let parallel_results: Vec<_> = futures::future::join_all(tool_futures).await;
324
325            let mut results = blocked;
326            for (id, name, input, result, duration_ms) in parallel_results {
327                let is_error = result.is_error();
328                debug!(tool = %name, duration_ms, is_error, "Tool execution completed");
329                metrics.record_tool(&name, duration_ms, is_error);
330
331                if let Some(ref inner_usage) = result.inner_usage {
332                    context.update_usage(*inner_usage);
333                    metrics.add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
334                    let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
335                    metrics.record_model_usage(inner_model, inner_usage);
336
337                    let inner_cost = self.budget_tracker.record(inner_model, inner_usage);
338                    metrics.add_cost(inner_cost);
339
340                    debug!(
341                        tool = %name,
342                        model = %inner_model,
343                        input_tokens = inner_usage.input_tokens,
344                        output_tokens = inner_usage.output_tokens,
345                        cost_usd = inner_cost,
346                        "Accumulated inner usage from tool"
347                    );
348                }
349
350                if let Some(file_path) = extract_file_path(&name, &input)
351                    && let Some(ref orchestrator) = self.orchestrator
352                {
353                    let new_rules = activate_rules_for_file(orchestrator, &file_path).await;
354                    if !new_rules.is_empty() {
355                        dynamic_rules_context =
356                            build_dynamic_rules_context(orchestrator, &file_path).await;
357                        debug!(rules = ?new_rules, "Activated rules for file");
358                    }
359                }
360
361                if is_error {
362                    let failure_input = HookInput::post_tool_use_failure(
363                        &self.session_id,
364                        &name,
365                        result.error_message(),
366                    );
367                    if let Err(e) = self
368                        .hooks
369                        .execute(HookEvent::PostToolUseFailure, failure_input, &hook_ctx)
370                        .await
371                    {
372                        warn!(tool = %name, error = %e, "PostToolUseFailure hook failed");
373                    }
374                } else {
375                    let post_input =
376                        HookInput::post_tool_use(&self.session_id, &name, result.output.clone());
377                    if let Err(e) = self
378                        .hooks
379                        .execute(HookEvent::PostToolUse, post_input, &hook_ctx)
380                        .await
381                    {
382                        warn!(tool = %name, error = %e, "PostToolUse hook failed");
383                    }
384                }
385                results.push(ToolResultBlock::from_tool_result(&id, result));
386            }
387
388            context.push(Message::tool_results(results));
389
390            if self.config.execution.auto_compact
391                && context.should_compact(
392                    200_000,
393                    self.config.execution.compact_threshold,
394                    self.config.execution.compact_keep_messages,
395                )
396            {
397                self.handle_compaction(&mut context, &hook_ctx, &mut metrics)
398                    .await;
399            }
400        }
401
402        metrics.execution_time_ms = execution_start.elapsed().as_millis() as u64;
403
404        let stop_input = HookInput::stop(&self.session_id);
405        if let Err(e) = self
406            .hooks
407            .execute(HookEvent::Stop, stop_input, &hook_ctx)
408            .await
409        {
410            warn!(error = %e, "Stop hook failed");
411        }
412
413        let session_end_input = HookInput::session_end(&self.session_id);
414        if let Err(e) = self
415            .hooks
416            .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
417            .await
418        {
419            warn!(error = %e, "SessionEnd hook failed");
420        }
421
422        info!(
423            iterations = metrics.iterations,
424            tool_calls = metrics.tool_calls,
425            api_calls = metrics.api_calls,
426            total_tokens = metrics.total_tokens(),
427            execution_time_ms = metrics.execution_time_ms,
428            "Agent execution completed"
429        );
430
431        Ok(AgentResult {
432            text: final_text,
433            usage: *context.total_usage(),
434            tool_calls: metrics.tool_calls,
435            iterations: metrics.iterations,
436            stop_reason: final_stop_reason,
437            state: AgentState::Completed,
438            metrics,
439            session_id: self.session_id.clone(),
440            structured_output: None,
441            messages: context.messages().to_vec(),
442            uuid: uuid::Uuid::new_v4().to_string(),
443        })
444    }
445
446    pub(crate) fn hook_context(&self) -> HookContext {
447        HookContext::new(&self.session_id)
448            .with_cwd(self.config.working_dir.clone().unwrap_or_default())
449            .with_env(self.config.security.env.clone())
450    }
451}
452
453pub(crate) fn extract_file_path(tool_name: &str, input: &serde_json::Value) -> Option<String> {
454    match tool_name {
455        "Read" | "Write" | "Edit" => input
456            .get("file_path")
457            .and_then(|v| v.as_str())
458            .map(String::from),
459        "Glob" | "Grep" => input.get("path").and_then(|v| v.as_str()).map(String::from),
460        _ => None,
461    }
462}
463
464pub(crate) async fn activate_rules_for_file(
465    orchestrator: &Arc<RwLock<PromptOrchestrator>>,
466    file_path: &str,
467) -> Vec<String> {
468    let orch = orchestrator.read().await;
469    let path = Path::new(file_path);
470    let rules = orch.rules_engine().find_matching(path);
471    rules.iter().map(|r| r.name.clone()).collect()
472}
473
474pub(crate) async fn build_dynamic_rules_context(
475    orchestrator: &Arc<RwLock<PromptOrchestrator>>,
476    file_path: &str,
477) -> String {
478    let orch = orchestrator.read().await;
479    let path = Path::new(file_path);
480    orch.build_dynamic_context(Some(path)).await
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_extract_file_path() {
489        let input = serde_json::json!({"file_path": "/src/lib.rs"});
490        assert_eq!(
491            extract_file_path("Read", &input),
492            Some("/src/lib.rs".to_string())
493        );
494
495        let input = serde_json::json!({"path": "/src"});
496        assert_eq!(extract_file_path("Glob", &input), Some("/src".to_string()));
497
498        let input = serde_json::json!({"command": "ls"});
499        assert_eq!(extract_file_path("Bash", &input), None);
500    }
501
502    #[test]
503    fn test_extract_file_path_all_tools() {
504        let file_input = serde_json::json!({"file_path": "/test/file.rs"});
505        let path_input = serde_json::json!({"path": "/test/dir"});
506
507        assert_eq!(
508            extract_file_path("Read", &file_input),
509            Some("/test/file.rs".to_string())
510        );
511        assert_eq!(
512            extract_file_path("Write", &file_input),
513            Some("/test/file.rs".to_string())
514        );
515        assert_eq!(
516            extract_file_path("Edit", &file_input),
517            Some("/test/file.rs".to_string())
518        );
519
520        assert_eq!(
521            extract_file_path("Glob", &path_input),
522            Some("/test/dir".to_string())
523        );
524        assert_eq!(
525            extract_file_path("Grep", &path_input),
526            Some("/test/dir".to_string())
527        );
528
529        assert_eq!(extract_file_path("WebFetch", &file_input), None);
530        assert_eq!(extract_file_path("Task", &file_input), None);
531    }
532
533    #[test]
534    fn test_extract_file_path_missing_field() {
535        let empty = serde_json::json!({});
536        assert_eq!(extract_file_path("Read", &empty), None);
537        assert_eq!(extract_file_path("Glob", &empty), None);
538
539        let wrong_field = serde_json::json!({"other": "value"});
540        assert_eq!(extract_file_path("Read", &wrong_field), None);
541        assert_eq!(extract_file_path("Glob", &wrong_field), None);
542    }
543
544    #[test]
545    fn test_extract_file_path_non_string() {
546        let input = serde_json::json!({"file_path": 123});
547        assert_eq!(extract_file_path("Read", &input), None);
548
549        let input = serde_json::json!({"path": null});
550        assert_eq!(extract_file_path("Glob", &input), None);
551    }
552}