claude_agent/agent/
execution.rs

1//! Agent execution logic with session-based context management.
2
3use std::path::Path;
4use std::sync::Arc;
5use std::time::Instant;
6
7use tokio::sync::RwLock;
8use tracing::{debug, info, instrument, warn};
9
10use super::common::BudgetContext;
11use super::events::AgentResult;
12use super::executor::Agent;
13use super::request::RequestBuilder;
14use super::state_formatter::collect_compaction_state;
15use super::{AgentMetrics, AgentState};
16use crate::context::PromptOrchestrator;
17use crate::hooks::{HookContext, HookEvent, HookInput};
18use crate::session::ExecutionGuard;
19use crate::types::{
20    CompactResult, ContentBlock, Message, PermissionDenial, StopReason, ToolResultBlock, Usage,
21    context_window,
22};
23
24impl Agent {
25    async fn handle_compaction<'a>(
26        &self,
27        _guard: &ExecutionGuard<'a>,
28        hook_ctx: &HookContext,
29        metrics: &mut AgentMetrics,
30    ) {
31        let pre_compact_input = HookInput::pre_compact(&*self.session_id);
32        if let Err(e) = self
33            .hooks
34            .execute(HookEvent::PreCompact, pre_compact_input, hook_ctx)
35            .await
36        {
37            warn!(error = %e, "PreCompact hook failed");
38        }
39
40        debug!("Compacting session context");
41        let compact_result = self
42            .state
43            .compact(&self.client, self.config.execution.compact_keep_messages)
44            .await;
45
46        match compact_result {
47            Ok(CompactResult::Compacted { saved_tokens, .. }) => {
48                info!(saved_tokens, "Session context compacted");
49                metrics.record_compaction();
50
51                let state_sections = collect_compaction_state(&self.tools).await;
52                if !state_sections.is_empty() {
53                    self.state
54                        .with_session_mut(|session| {
55                            session.add_user_message(format!(
56                                "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
57                                state_sections.join("\n\n")
58                            ));
59                        })
60                        .await;
61                }
62            }
63            Ok(CompactResult::NotNeeded | CompactResult::Skipped { .. }) => {
64                debug!("Compaction skipped or not needed");
65            }
66            Err(e) => {
67                warn!(error = %e, "Session compaction failed");
68            }
69        }
70    }
71
72    fn check_budget(&self) -> crate::Result<()> {
73        BudgetContext {
74            tracker: &self.budget_tracker,
75            tenant: self.tenant_budget.as_deref(),
76            config: &self.config.budget,
77        }
78        .check()
79    }
80
81    pub async fn execute(&self, prompt: &str) -> crate::Result<AgentResult> {
82        let timeout = self
83            .config
84            .execution
85            .timeout
86            .unwrap_or(std::time::Duration::from_secs(600));
87
88        if self.state.is_executing() {
89            self.state
90                .enqueue(prompt)
91                .await
92                .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
93            return self.wait_for_execution(timeout).await;
94        }
95
96        tokio::time::timeout(timeout, self.execute_inner(prompt))
97            .await
98            .map_err(|_| crate::Error::Timeout(timeout))?
99    }
100
101    async fn wait_for_execution(&self, timeout: std::time::Duration) -> crate::Result<AgentResult> {
102        let start = Instant::now();
103        loop {
104            if start.elapsed() > timeout {
105                return Err(crate::Error::Timeout(timeout));
106            }
107
108            if !self.state.is_executing()
109                && let Some(merged) = self.state.dequeue_or_merge().await
110            {
111                return self.execute_inner(&merged.content).await;
112            }
113
114            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
115        }
116    }
117
118    pub async fn execute_with_messages(
119        &self,
120        previous_messages: Vec<Message>,
121        prompt: &str,
122    ) -> crate::Result<AgentResult> {
123        let context_summary = previous_messages
124            .iter()
125            .filter_map(|m| {
126                m.content
127                    .iter()
128                    .filter_map(|b| match b {
129                        ContentBlock::Text { text, .. } => Some(text.as_str()),
130                        _ => None,
131                    })
132                    .next()
133            })
134            .collect::<Vec<_>>()
135            .join("\n---\n");
136
137        let enriched_prompt = if context_summary.is_empty() {
138            prompt.to_string()
139        } else {
140            format!(
141                "Previous conversation context:\n{}\n\nContinue with: {}",
142                context_summary, prompt
143            )
144        };
145
146        self.execute(&enriched_prompt).await
147    }
148
149    #[instrument(skip(self, prompt), fields(session_id = %self.session_id))]
150    async fn execute_inner(&self, prompt: &str) -> crate::Result<AgentResult> {
151        let guard = self.state.acquire_execution().await;
152        let execution_start = Instant::now();
153        let hook_ctx = self.hook_context();
154
155        let session_start_input = HookInput::session_start(&*self.session_id);
156        if let Err(e) = self
157            .hooks
158            .execute(HookEvent::SessionStart, session_start_input, &hook_ctx)
159            .await
160        {
161            warn!(error = %e, "SessionStart hook failed");
162        }
163
164        let final_prompt = if let Some(merged) = self.state.dequeue_or_merge().await {
165            format!("{}\n{}", prompt, merged.content)
166        } else {
167            prompt.to_string()
168        };
169
170        let prompt_input = HookInput::user_prompt_submit(&*self.session_id, &final_prompt);
171        let prompt_output = self
172            .hooks
173            .execute(HookEvent::UserPromptSubmit, prompt_input, &hook_ctx)
174            .await?;
175
176        if !prompt_output.continue_execution {
177            let session_end_input = HookInput::session_end(&*self.session_id);
178            if let Err(e) = self
179                .hooks
180                .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
181                .await
182            {
183                warn!(error = %e, "SessionEnd hook failed");
184            }
185            return Err(crate::Error::Permission(
186                prompt_output
187                    .stop_reason
188                    .unwrap_or_else(|| "Blocked by hook".into()),
189            ));
190        }
191
192        self.state
193            .with_session_mut(|session| {
194                session.add_user_message(&final_prompt);
195            })
196            .await;
197
198        let mut metrics = AgentMetrics::default();
199        let mut final_text = String::new();
200        let mut final_stop_reason = StopReason::EndTurn;
201        let mut dynamic_rules_context = String::new();
202        let mut total_usage = Usage::default();
203
204        // Build request builder with optional MCP Progressive Disclosure
205        let request_builder = {
206            let builder = RequestBuilder::new(&self.config, Arc::clone(&self.tools));
207
208            // Connect ToolSearchManager for Progressive Disclosure of MCP tools
209            if let Some(ref tsm) = self.tool_search_manager {
210                let prepared = tsm.prepare_tools().await;
211                if prepared.use_search {
212                    info!(
213                        immediate = prepared.immediate.len(),
214                        deferred = prepared.deferred.len(),
215                        tokens_saved = prepared.token_savings(),
216                        "MCP Progressive Disclosure active"
217                    );
218                }
219                builder.with_prepared_tools(prepared)
220            } else {
221                builder
222            }
223        };
224        let max_tokens = context_window::for_model(&self.config.model.primary);
225
226        info!(prompt_len = final_prompt.len(), "Starting agent execution");
227
228        loop {
229            metrics.iterations += 1;
230            if metrics.iterations > self.config.execution.max_iterations {
231                warn!(
232                    max = self.config.execution.max_iterations,
233                    "Max iterations reached"
234                );
235                break;
236            }
237
238            self.check_budget()?;
239
240            debug!(iteration = metrics.iterations, "Starting iteration");
241
242            let messages = self
243                .state
244                .with_session(|session| {
245                    session.to_api_messages_with_cache(self.config.cache.message_ttl_option())
246                })
247                .await;
248
249            let api_start = Instant::now();
250            let request = request_builder.build(messages, &dynamic_rules_context);
251            let response = self.client.send_with_auth_retry(request).await?;
252            let api_duration_ms = api_start.elapsed().as_millis() as u64;
253            metrics.record_api_call_with_timing(api_duration_ms);
254            debug!(api_time_ms = api_duration_ms, "API call completed");
255
256            self.state
257                .with_session_mut(|session| {
258                    session.update_usage(&response.usage);
259                })
260                .await;
261
262            total_usage.input_tokens += response.usage.input_tokens;
263            total_usage.output_tokens += response.usage.output_tokens;
264            metrics.add_usage_with_cache(&response.usage);
265            metrics.record_model_usage(&self.config.model.primary, &response.usage);
266
267            if let Some(ref server_usage) = response.usage.server_tool_use {
268                metrics.update_server_tool_use_from_api(server_usage);
269            }
270
271            let cost = self
272                .budget_tracker
273                .record(&self.config.model.primary, &response.usage);
274            metrics.add_cost(cost);
275            if let Some(ref tenant_budget) = self.tenant_budget {
276                tenant_budget.record(&self.config.model.primary, &response.usage);
277            }
278
279            final_text = response.text();
280            final_stop_reason = response.stop_reason.unwrap_or(StopReason::EndTurn);
281
282            self.state
283                .with_session_mut(|session| {
284                    session.add_assistant_message(response.content.clone(), Some(response.usage));
285                })
286                .await;
287
288            if !response.wants_tool_use() {
289                debug!("No tool use requested, ending loop");
290                break;
291            }
292
293            let tool_uses = response.tool_uses();
294            let hook_ctx = self.hook_context();
295
296            let mut prepared = Vec::with_capacity(tool_uses.len());
297            let mut blocked = Vec::with_capacity(tool_uses.len());
298
299            for tool_use in &tool_uses {
300                let pre_input = HookInput::pre_tool_use(
301                    &*self.session_id,
302                    &tool_use.name,
303                    tool_use.input.clone(),
304                );
305                let pre_output = self
306                    .hooks
307                    .execute(HookEvent::PreToolUse, pre_input, &hook_ctx)
308                    .await?;
309
310                if !pre_output.continue_execution {
311                    debug!(tool = %tool_use.name, "Tool blocked by hook");
312                    let reason = pre_output
313                        .stop_reason
314                        .clone()
315                        .unwrap_or_else(|| "Blocked by hook".into());
316                    blocked.push(ToolResultBlock::error(&tool_use.id, reason.clone()));
317                    metrics.record_permission_denial(
318                        PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
319                            .with_reason(reason),
320                    );
321                } else {
322                    let input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
323                    prepared.push((tool_use.id.clone(), tool_use.name.clone(), input));
324                }
325            }
326
327            let tool_futures = prepared.into_iter().map(|(id, name, input)| {
328                let tools = &self.tools;
329                async move {
330                    let start = Instant::now();
331                    let result = tools.execute(&name, input.clone()).await;
332                    let duration_ms = start.elapsed().as_millis() as u64;
333                    (id, name, input, result, duration_ms)
334                }
335            });
336
337            let parallel_results: Vec<_> = futures::future::join_all(tool_futures).await;
338
339            let all_non_retryable = !parallel_results.is_empty()
340                && parallel_results
341                    .iter()
342                    .all(|(_, _, _, result, _)| result.is_non_retryable());
343
344            let mut results = blocked;
345            for (id, name, input, result, duration_ms) in parallel_results {
346                let is_error = result.is_error();
347                debug!(tool = %name, duration_ms, is_error, "Tool execution completed");
348                metrics.record_tool(&id, &name, duration_ms, is_error);
349
350                if let Some(ref inner_usage) = result.inner_usage {
351                    self.state
352                        .with_session_mut(|session| {
353                            session.update_usage(inner_usage);
354                        })
355                        .await;
356                    total_usage.input_tokens += inner_usage.input_tokens;
357                    total_usage.output_tokens += inner_usage.output_tokens;
358                    metrics.add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
359                    let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
360                    metrics.record_model_usage(inner_model, inner_usage);
361
362                    let inner_cost = self.budget_tracker.record(inner_model, inner_usage);
363                    metrics.add_cost(inner_cost);
364
365                    debug!(
366                        tool = %name,
367                        model = %inner_model,
368                        input_tokens = inner_usage.input_tokens,
369                        output_tokens = inner_usage.output_tokens,
370                        cost_usd = inner_cost,
371                        "Accumulated inner usage from tool"
372                    );
373                }
374
375                if let Some(file_path) = extract_file_path(&name, &input)
376                    && let Some(ref orchestrator) = self.orchestrator
377                {
378                    let new_rules = activate_rules_for_file(orchestrator, &file_path).await;
379                    if !new_rules.is_empty() {
380                        dynamic_rules_context =
381                            build_dynamic_rules_context(orchestrator, &file_path).await;
382                        debug!(rules = ?new_rules, "Activated rules for file");
383                    }
384                }
385
386                if is_error {
387                    let failure_input = HookInput::post_tool_use_failure(
388                        &*self.session_id,
389                        &name,
390                        result.error_message(),
391                    );
392                    if let Err(e) = self
393                        .hooks
394                        .execute(HookEvent::PostToolUseFailure, failure_input, &hook_ctx)
395                        .await
396                    {
397                        warn!(tool = %name, error = %e, "PostToolUseFailure hook failed");
398                    }
399                } else {
400                    let post_input =
401                        HookInput::post_tool_use(&*self.session_id, &name, result.output.clone());
402                    if let Err(e) = self
403                        .hooks
404                        .execute(HookEvent::PostToolUse, post_input, &hook_ctx)
405                        .await
406                    {
407                        warn!(tool = %name, error = %e, "PostToolUse hook failed");
408                    }
409                }
410                results.push(ToolResultBlock::from_tool_result(&id, &result));
411            }
412
413            self.state
414                .with_session_mut(|session| {
415                    session.add_tool_results(results);
416                })
417                .await;
418
419            if all_non_retryable {
420                warn!("All tool calls failed with non-retryable errors, ending execution");
421                break;
422            }
423
424            let should_compact = self
425                .state
426                .with_session(|session| {
427                    self.config.execution.auto_compact
428                        && session.should_compact(
429                            max_tokens,
430                            self.config.execution.compact_threshold,
431                            self.config.execution.compact_keep_messages,
432                        )
433                })
434                .await;
435
436            if should_compact {
437                self.handle_compaction(&guard, &hook_ctx, &mut metrics)
438                    .await;
439            }
440        }
441
442        metrics.execution_time_ms = execution_start.elapsed().as_millis() as u64;
443
444        let stop_input = HookInput::stop(&*self.session_id);
445        if let Err(e) = self
446            .hooks
447            .execute(HookEvent::Stop, stop_input, &hook_ctx)
448            .await
449        {
450            warn!(error = %e, "Stop hook failed");
451        }
452
453        let session_end_input = HookInput::session_end(&*self.session_id);
454        if let Err(e) = self
455            .hooks
456            .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
457            .await
458        {
459            warn!(error = %e, "SessionEnd hook failed");
460        }
461
462        info!(
463            iterations = metrics.iterations,
464            tool_calls = metrics.tool_calls,
465            api_calls = metrics.api_calls,
466            total_tokens = metrics.total_tokens(),
467            execution_time_ms = metrics.execution_time_ms,
468            "Agent execution completed"
469        );
470
471        let messages = self
472            .state
473            .with_session(|session| session.to_api_messages())
474            .await;
475
476        drop(guard);
477
478        Ok(AgentResult {
479            text: final_text,
480            usage: total_usage,
481            tool_calls: metrics.tool_calls,
482            iterations: metrics.iterations,
483            stop_reason: final_stop_reason,
484            state: AgentState::Completed,
485            metrics,
486            session_id: self.session_id.to_string(),
487            structured_output: None,
488            messages,
489            uuid: uuid::Uuid::new_v4().to_string(),
490        })
491    }
492
493    pub(crate) fn hook_context(&self) -> HookContext {
494        HookContext::new(&*self.session_id)
495            .with_cwd(self.config.working_dir.clone().unwrap_or_default())
496            .with_env(self.config.security.env.clone())
497    }
498}
499
500pub(crate) fn extract_file_path(tool_name: &str, input: &serde_json::Value) -> Option<String> {
501    match tool_name {
502        "Read" | "Write" | "Edit" => input
503            .get("file_path")
504            .and_then(|v| v.as_str())
505            .map(String::from),
506        "Glob" | "Grep" => input.get("path").and_then(|v| v.as_str()).map(String::from),
507        _ => None,
508    }
509}
510
511pub(crate) async fn activate_rules_for_file(
512    orchestrator: &Arc<RwLock<PromptOrchestrator>>,
513    file_path: &str,
514) -> Vec<String> {
515    let orch = orchestrator.read().await;
516    let path = Path::new(file_path);
517    let rules = orch.find_matching_rules(path).await;
518    rules.iter().map(|r| r.name.clone()).collect()
519}
520
521pub(crate) async fn build_dynamic_rules_context(
522    orchestrator: &Arc<RwLock<PromptOrchestrator>>,
523    file_path: &str,
524) -> String {
525    let orch = orchestrator.read().await;
526    let path = Path::new(file_path);
527    orch.build_dynamic_context(Some(path)).await
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533
534    #[test]
535    fn test_extract_file_path() {
536        let input = serde_json::json!({"file_path": "/src/lib.rs"});
537        assert_eq!(
538            extract_file_path("Read", &input),
539            Some("/src/lib.rs".to_string())
540        );
541
542        let input = serde_json::json!({"path": "/src"});
543        assert_eq!(extract_file_path("Glob", &input), Some("/src".to_string()));
544
545        let input = serde_json::json!({"command": "ls"});
546        assert_eq!(extract_file_path("Bash", &input), None);
547    }
548
549    #[test]
550    fn test_extract_file_path_all_tools() {
551        let file_input = serde_json::json!({"file_path": "/test/file.rs"});
552        let path_input = serde_json::json!({"path": "/test/dir"});
553
554        assert_eq!(
555            extract_file_path("Read", &file_input),
556            Some("/test/file.rs".to_string())
557        );
558        assert_eq!(
559            extract_file_path("Write", &file_input),
560            Some("/test/file.rs".to_string())
561        );
562        assert_eq!(
563            extract_file_path("Edit", &file_input),
564            Some("/test/file.rs".to_string())
565        );
566
567        assert_eq!(
568            extract_file_path("Glob", &path_input),
569            Some("/test/dir".to_string())
570        );
571        assert_eq!(
572            extract_file_path("Grep", &path_input),
573            Some("/test/dir".to_string())
574        );
575
576        assert_eq!(extract_file_path("WebFetch", &file_input), None);
577        assert_eq!(extract_file_path("Task", &file_input), None);
578    }
579
580    #[test]
581    fn test_extract_file_path_missing_field() {
582        let empty = serde_json::json!({});
583        assert_eq!(extract_file_path("Read", &empty), None);
584        assert_eq!(extract_file_path("Glob", &empty), None);
585
586        let wrong_field = serde_json::json!({"other": "value"});
587        assert_eq!(extract_file_path("Read", &wrong_field), None);
588        assert_eq!(extract_file_path("Glob", &wrong_field), None);
589    }
590
591    #[test]
592    fn test_extract_file_path_non_string() {
593        let input = serde_json::json!({"file_path": 123});
594        assert_eq!(extract_file_path("Read", &input), None);
595
596        let input = serde_json::json!({"path": null});
597        assert_eq!(extract_file_path("Glob", &input), None);
598    }
599}