claude_agent/agent/
execution.rs

1//! Agent execution logic with session-based context management.
2
3use std::path::Path;
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::RwLock;
8use tracing::{debug, info, instrument, warn};
9
10use super::common::BudgetContext;
11use super::events::AgentResult;
12use super::executor::Agent;
13use super::request::RequestBuilder;
14use super::state_formatter::collect_compaction_state;
15use super::{AgentMetrics, AgentState};
16use crate::context::PromptOrchestrator;
17use crate::hooks::{HookContext, HookEvent, HookInput};
18use crate::session::ExecutionGuard;
19use crate::types::{
20    CompactResult, ContentBlock, Message, PermissionDenial, StopReason, ToolResultBlock, Usage,
21    context_window,
22};
23
24impl Agent {
25    async fn handle_compaction<'a>(
26        &self,
27        _guard: &ExecutionGuard<'a>,
28        hook_ctx: &HookContext,
29        metrics: &mut AgentMetrics,
30    ) {
31        let pre_compact_input = HookInput::pre_compact(&*self.session_id);
32        if let Err(e) = self
33            .hooks
34            .execute(HookEvent::PreCompact, pre_compact_input, hook_ctx)
35            .await
36        {
37            warn!(error = %e, "PreCompact hook failed");
38        }
39
40        debug!("Compacting session context");
41        let compact_result = self
42            .state
43            .compact(&self.client, self.config.execution.compact_keep_messages)
44            .await;
45
46        match compact_result {
47            Ok(CompactResult::Compacted { saved_tokens, .. }) => {
48                info!(saved_tokens, "Session context compacted");
49                metrics.record_compaction();
50
51                let state_sections = collect_compaction_state(&self.tools).await;
52                if !state_sections.is_empty() {
53                    self.state
54                        .with_session_mut(|session| {
55                            session.add_user_message(format!(
56                                "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
57                                state_sections.join("\n\n")
58                            ));
59                        })
60                        .await;
61                }
62            }
63            Ok(CompactResult::NotNeeded | CompactResult::Skipped { .. }) => {
64                debug!("Compaction skipped or not needed");
65            }
66            Err(e) => {
67                warn!(error = %e, "Session compaction failed");
68            }
69        }
70    }
71
72    fn check_budget(&self) -> crate::Result<()> {
73        BudgetContext {
74            tracker: &self.budget_tracker,
75            tenant: self.tenant_budget.as_deref(),
76            config: &self.config.budget,
77        }
78        .check()
79    }
80
81    pub async fn execute(&self, prompt: &str) -> crate::Result<AgentResult> {
82        let timeout = self
83            .config
84            .execution
85            .timeout
86            .unwrap_or(std::time::Duration::from_secs(600));
87
88        if self.state.is_executing() {
89            self.state
90                .enqueue(prompt)
91                .await
92                .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
93            return self.wait_for_execution(timeout).await;
94        }
95
96        tokio::time::timeout(timeout, self.execute_inner(prompt))
97            .await
98            .map_err(|_| crate::Error::Timeout(timeout))?
99    }
100
101    async fn wait_for_execution(&self, timeout: std::time::Duration) -> crate::Result<AgentResult> {
102        let start = Instant::now();
103        loop {
104            if start.elapsed() > timeout {
105                return Err(crate::Error::Timeout(timeout));
106            }
107
108            if !self.state.is_executing()
109                && let Some(merged) = self.state.dequeue_or_merge().await
110            {
111                return self.execute_inner(&merged.content).await;
112            }
113
114            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
115        }
116    }
117
118    pub async fn execute_with_messages(
119        &self,
120        previous_messages: Vec<Message>,
121        prompt: &str,
122    ) -> crate::Result<AgentResult> {
123        let context_summary = previous_messages
124            .iter()
125            .filter_map(|m| {
126                m.content
127                    .iter()
128                    .filter_map(|b| match b {
129                        ContentBlock::Text { text, .. } => Some(text.as_str()),
130                        _ => None,
131                    })
132                    .next()
133            })
134            .collect::<Vec<_>>()
135            .join("\n---\n");
136
137        let enriched_prompt = if context_summary.is_empty() {
138            prompt.to_string()
139        } else {
140            format!(
141                "Previous conversation context:\n{}\n\nContinue with: {}",
142                context_summary, prompt
143            )
144        };
145
146        self.execute(&enriched_prompt).await
147    }
148
149    #[instrument(skip(self, prompt), fields(session_id = %self.session_id))]
150    async fn execute_inner(&self, prompt: &str) -> crate::Result<AgentResult> {
151        let guard = self.state.acquire_execution().await;
152        let execution_start = Instant::now();
153        let hook_ctx = self.hook_context();
154
155        let session_start_input = HookInput::session_start(&*self.session_id);
156        if let Err(e) = self
157            .hooks
158            .execute(HookEvent::SessionStart, session_start_input, &hook_ctx)
159            .await
160        {
161            warn!(error = %e, "SessionStart hook failed");
162        }
163
164        let final_prompt = if let Some(merged) = self.state.dequeue_or_merge().await {
165            format!("{}\n{}", prompt, merged.content)
166        } else {
167            prompt.to_string()
168        };
169
170        let prompt_input = HookInput::user_prompt_submit(&*self.session_id, &final_prompt);
171        let prompt_output = self
172            .hooks
173            .execute(HookEvent::UserPromptSubmit, prompt_input, &hook_ctx)
174            .await?;
175
176        if !prompt_output.continue_execution {
177            let session_end_input = HookInput::session_end(&*self.session_id);
178            if let Err(e) = self
179                .hooks
180                .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
181                .await
182            {
183                warn!(error = %e, "SessionEnd hook failed");
184            }
185            return Err(crate::Error::Permission(
186                prompt_output
187                    .stop_reason
188                    .unwrap_or_else(|| "Blocked by hook".into()),
189            ));
190        }
191
192        self.state
193            .with_session_mut(|session| {
194                session.add_user_message(&final_prompt);
195            })
196            .await;
197
198        let mut metrics = AgentMetrics::default();
199        let mut final_text = String::new();
200        let mut final_stop_reason = StopReason::EndTurn;
201        let mut dynamic_rules_context = String::new();
202        let mut total_usage = Usage::default();
203
204        let request_builder = RequestBuilder::new(&self.config, Arc::clone(&self.tools));
205        let max_tokens = context_window::for_model(&self.config.model.primary);
206
207        info!(prompt_len = final_prompt.len(), "Starting agent execution");
208
209        loop {
210            metrics.iterations += 1;
211            if metrics.iterations > self.config.execution.max_iterations {
212                warn!(
213                    max = self.config.execution.max_iterations,
214                    "Max iterations reached"
215                );
216                break;
217            }
218
219            self.check_budget()?;
220
221            debug!(iteration = metrics.iterations, "Starting iteration");
222
223            let messages = self
224                .state
225                .with_session(|session| {
226                    session.to_api_messages_with_cache(self.config.cache.message_ttl_option())
227                })
228                .await;
229
230            let api_start = Instant::now();
231            let request = request_builder.build(messages, &dynamic_rules_context);
232            let response = self.client.send_with_auth_retry(request).await?;
233            let api_duration_ms = api_start.elapsed().as_millis() as u64;
234            metrics.record_api_call_with_timing(api_duration_ms);
235            debug!(api_time_ms = api_duration_ms, "API call completed");
236
237            self.state
238                .with_session_mut(|session| {
239                    session.update_usage(&response.usage);
240                })
241                .await;
242
243            total_usage.input_tokens += response.usage.input_tokens;
244            total_usage.output_tokens += response.usage.output_tokens;
245            metrics.add_usage_with_cache(&response.usage);
246            metrics.record_model_usage(&self.config.model.primary, &response.usage);
247
248            if let Some(ref server_usage) = response.usage.server_tool_use {
249                metrics.update_server_tool_use_from_api(server_usage);
250            }
251
252            let cost = self
253                .budget_tracker
254                .record(&self.config.model.primary, &response.usage);
255            metrics.add_cost(cost);
256            if let Some(ref tenant_budget) = self.tenant_budget {
257                tenant_budget.record(&self.config.model.primary, &response.usage);
258            }
259
260            final_text = response.text();
261            final_stop_reason = response.stop_reason.unwrap_or(StopReason::EndTurn);
262
263            self.state
264                .with_session_mut(|session| {
265                    session.add_assistant_message(response.content.clone(), Some(response.usage));
266                })
267                .await;
268
269            if !response.wants_tool_use() {
270                debug!("No tool use requested, ending loop");
271                break;
272            }
273
274            let tool_uses = response.tool_uses();
275            let hook_ctx = self.hook_context();
276
277            let mut prepared = Vec::with_capacity(tool_uses.len());
278            let mut blocked = Vec::with_capacity(tool_uses.len());
279
280            for tool_use in &tool_uses {
281                let pre_input = HookInput::pre_tool_use(
282                    &*self.session_id,
283                    &tool_use.name,
284                    tool_use.input.clone(),
285                );
286                let pre_output = self
287                    .hooks
288                    .execute(HookEvent::PreToolUse, pre_input, &hook_ctx)
289                    .await?;
290
291                if !pre_output.continue_execution {
292                    debug!(tool = %tool_use.name, "Tool blocked by hook");
293                    let reason = pre_output
294                        .stop_reason
295                        .clone()
296                        .unwrap_or_else(|| "Blocked by hook".into());
297                    blocked.push(ToolResultBlock::error(&tool_use.id, reason.clone()));
298                    metrics.record_permission_denial(
299                        PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
300                            .with_reason(reason),
301                    );
302                } else {
303                    let input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
304                    prepared.push((tool_use.id.clone(), tool_use.name.clone(), input));
305                }
306            }
307
308            let tool_futures = prepared.into_iter().map(|(id, name, input)| {
309                let tools = &self.tools;
310                async move {
311                    let start = Instant::now();
312                    let result = tools.execute(&name, input.clone()).await;
313                    let duration_ms = start.elapsed().as_millis() as u64;
314                    (id, name, input, result, duration_ms)
315                }
316            });
317
318            let parallel_results: Vec<_> = futures::future::join_all(tool_futures).await;
319
320            let all_non_retryable = !parallel_results.is_empty()
321                && parallel_results
322                    .iter()
323                    .all(|(_, _, _, result, _)| result.is_non_retryable());
324
325            let mut results = blocked;
326            for (id, name, input, result, duration_ms) in parallel_results {
327                let is_error = result.is_error();
328                debug!(tool = %name, duration_ms, is_error, "Tool execution completed");
329                metrics.record_tool(&id, &name, duration_ms, is_error);
330
331                if let Some(ref inner_usage) = result.inner_usage {
332                    self.state
333                        .with_session_mut(|session| {
334                            session.update_usage(inner_usage);
335                        })
336                        .await;
337                    total_usage.input_tokens += inner_usage.input_tokens;
338                    total_usage.output_tokens += inner_usage.output_tokens;
339                    metrics.add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
340                    let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
341                    metrics.record_model_usage(inner_model, inner_usage);
342
343                    let inner_cost = self.budget_tracker.record(inner_model, inner_usage);
344                    metrics.add_cost(inner_cost);
345
346                    debug!(
347                        tool = %name,
348                        model = %inner_model,
349                        input_tokens = inner_usage.input_tokens,
350                        output_tokens = inner_usage.output_tokens,
351                        cost_usd = inner_cost,
352                        "Accumulated inner usage from tool"
353                    );
354                }
355
356                if let Some(file_path) = extract_file_path(&name, &input)
357                    && let Some(ref orchestrator) = self.orchestrator
358                {
359                    let new_rules = activate_rules_for_file(orchestrator, &file_path).await;
360                    if !new_rules.is_empty() {
361                        dynamic_rules_context =
362                            build_dynamic_rules_context(orchestrator, &file_path).await;
363                        debug!(rules = ?new_rules, "Activated rules for file");
364                    }
365                }
366
367                if is_error {
368                    let failure_input = HookInput::post_tool_use_failure(
369                        &*self.session_id,
370                        &name,
371                        result.error_message(),
372                    );
373                    if let Err(e) = self
374                        .hooks
375                        .execute(HookEvent::PostToolUseFailure, failure_input, &hook_ctx)
376                        .await
377                    {
378                        warn!(tool = %name, error = %e, "PostToolUseFailure hook failed");
379                    }
380                } else {
381                    let post_input =
382                        HookInput::post_tool_use(&*self.session_id, &name, result.output.clone());
383                    if let Err(e) = self
384                        .hooks
385                        .execute(HookEvent::PostToolUse, post_input, &hook_ctx)
386                        .await
387                    {
388                        warn!(tool = %name, error = %e, "PostToolUse hook failed");
389                    }
390                }
391                results.push(ToolResultBlock::from_tool_result(&id, result));
392            }
393
394            self.state
395                .with_session_mut(|session| {
396                    session.add_tool_results(results);
397                })
398                .await;
399
400            if all_non_retryable {
401                warn!("All tool calls failed with non-retryable errors, ending execution");
402                break;
403            }
404
405            let should_compact = self
406                .state
407                .with_session(|session| {
408                    self.config.execution.auto_compact
409                        && session.should_compact(
410                            max_tokens,
411                            self.config.execution.compact_threshold,
412                            self.config.execution.compact_keep_messages,
413                        )
414                })
415                .await;
416
417            if should_compact {
418                self.handle_compaction(&guard, &hook_ctx, &mut metrics)
419                    .await;
420            }
421        }
422
423        metrics.execution_time_ms = execution_start.elapsed().as_millis() as u64;
424
425        let stop_input = HookInput::stop(&*self.session_id);
426        if let Err(e) = self
427            .hooks
428            .execute(HookEvent::Stop, stop_input, &hook_ctx)
429            .await
430        {
431            warn!(error = %e, "Stop hook failed");
432        }
433
434        let session_end_input = HookInput::session_end(&*self.session_id);
435        if let Err(e) = self
436            .hooks
437            .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
438            .await
439        {
440            warn!(error = %e, "SessionEnd hook failed");
441        }
442
443        info!(
444            iterations = metrics.iterations,
445            tool_calls = metrics.tool_calls,
446            api_calls = metrics.api_calls,
447            total_tokens = metrics.total_tokens(),
448            execution_time_ms = metrics.execution_time_ms,
449            "Agent execution completed"
450        );
451
452        let messages = self
453            .state
454            .with_session(|session| session.to_api_messages())
455            .await;
456
457        drop(guard);
458
459        Ok(AgentResult {
460            text: final_text,
461            usage: total_usage,
462            tool_calls: metrics.tool_calls,
463            iterations: metrics.iterations,
464            stop_reason: final_stop_reason,
465            state: AgentState::Completed,
466            metrics,
467            session_id: self.session_id.to_string(),
468            structured_output: None,
469            messages,
470            uuid: uuid::Uuid::new_v4().to_string(),
471        })
472    }
473
474    pub(crate) fn hook_context(&self) -> HookContext {
475        HookContext::new(&*self.session_id)
476            .with_cwd(self.config.working_dir.clone().unwrap_or_default())
477            .with_env(self.config.security.env.clone())
478    }
479}
480
481pub(crate) fn extract_file_path(tool_name: &str, input: &serde_json::Value) -> Option<String> {
482    match tool_name {
483        "Read" | "Write" | "Edit" => input
484            .get("file_path")
485            .and_then(|v| v.as_str())
486            .map(String::from),
487        "Glob" | "Grep" => input.get("path").and_then(|v| v.as_str()).map(String::from),
488        _ => None,
489    }
490}
491
492pub(crate) async fn activate_rules_for_file(
493    orchestrator: &Arc<RwLock<PromptOrchestrator>>,
494    file_path: &str,
495) -> Vec<String> {
496    let orch = orchestrator.read().await;
497    let path = Path::new(file_path);
498    let rules = orch.rules_engine().find_matching(path);
499    rules.iter().map(|r| r.name.clone()).collect()
500}
501
502pub(crate) async fn build_dynamic_rules_context(
503    orchestrator: &Arc<RwLock<PromptOrchestrator>>,
504    file_path: &str,
505) -> String {
506    let orch = orchestrator.read().await;
507    let path = Path::new(file_path);
508    orch.build_dynamic_context(Some(path)).await
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    #[test]
516    fn test_extract_file_path() {
517        let input = serde_json::json!({"file_path": "/src/lib.rs"});
518        assert_eq!(
519            extract_file_path("Read", &input),
520            Some("/src/lib.rs".to_string())
521        );
522
523        let input = serde_json::json!({"path": "/src"});
524        assert_eq!(extract_file_path("Glob", &input), Some("/src".to_string()));
525
526        let input = serde_json::json!({"command": "ls"});
527        assert_eq!(extract_file_path("Bash", &input), None);
528    }
529
530    #[test]
531    fn test_extract_file_path_all_tools() {
532        let file_input = serde_json::json!({"file_path": "/test/file.rs"});
533        let path_input = serde_json::json!({"path": "/test/dir"});
534
535        assert_eq!(
536            extract_file_path("Read", &file_input),
537            Some("/test/file.rs".to_string())
538        );
539        assert_eq!(
540            extract_file_path("Write", &file_input),
541            Some("/test/file.rs".to_string())
542        );
543        assert_eq!(
544            extract_file_path("Edit", &file_input),
545            Some("/test/file.rs".to_string())
546        );
547
548        assert_eq!(
549            extract_file_path("Glob", &path_input),
550            Some("/test/dir".to_string())
551        );
552        assert_eq!(
553            extract_file_path("Grep", &path_input),
554            Some("/test/dir".to_string())
555        );
556
557        assert_eq!(extract_file_path("WebFetch", &file_input), None);
558        assert_eq!(extract_file_path("Task", &file_input), None);
559    }
560
561    #[test]
562    fn test_extract_file_path_missing_field() {
563        let empty = serde_json::json!({});
564        assert_eq!(extract_file_path("Read", &empty), None);
565        assert_eq!(extract_file_path("Glob", &empty), None);
566
567        let wrong_field = serde_json::json!({"other": "value"});
568        assert_eq!(extract_file_path("Read", &wrong_field), None);
569        assert_eq!(extract_file_path("Glob", &wrong_field), None);
570    }
571
572    #[test]
573    fn test_extract_file_path_non_string() {
574        let input = serde_json::json!({"file_path": 123});
575        assert_eq!(extract_file_path("Read", &input), None);
576
577        let input = serde_json::json!({"path": null});
578        assert_eq!(extract_file_path("Glob", &input), None);
579    }
580}