Skip to main content

claude_agent/agent/
execution.rs

1//! Agent execution logic with session-based context management.
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use tracing::{debug, info, instrument, warn};
7
8use super::AgentMetrics;
9use super::common::{
10    self, BudgetContext, accumulate_inner_usage, accumulate_response_usage, handle_compaction,
11    run_post_tool_hooks, run_stop_hooks, try_activate_dynamic_rules,
12};
13use super::events::AgentResult;
14use super::executor::Agent;
15use super::request::RequestBuilder;
16use crate::hooks::{HookContext, HookEvent, HookInput};
17use crate::types::{
18    ContentBlock, Message, PermissionDenial, StopReason, ToolResultBlock, Usage, context_window,
19};
20
21impl Agent {
22    fn check_budget(&self) -> crate::Result<()> {
23        BudgetContext {
24            tracker: &self.budget_tracker,
25            tenant: self.tenant_budget.as_deref(),
26            config: &self.config.budget,
27        }
28        .check()
29    }
30
31    pub async fn execute(&self, prompt: &str) -> crate::Result<AgentResult> {
32        let timeout = self
33            .config
34            .execution
35            .timeout
36            .unwrap_or(std::time::Duration::from_secs(600));
37
38        if self.state.is_executing() {
39            self.state
40                .enqueue(prompt)
41                .await
42                .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
43            return self.wait_for_execution(timeout).await;
44        }
45
46        tokio::time::timeout(timeout, self.execute_inner(prompt))
47            .await
48            .map_err(|_| crate::Error::Timeout(timeout))?
49    }
50
51    async fn wait_for_execution(&self, timeout: std::time::Duration) -> crate::Result<AgentResult> {
52        tokio::time::timeout(timeout, async {
53            loop {
54                self.state.wait_for_queue_signal().await;
55                if !self.state.is_executing()
56                    && let Some(merged) = self.state.dequeue_or_merge().await
57                {
58                    return self.execute_inner(&merged.content).await;
59                }
60            }
61        })
62        .await
63        .map_err(|_| crate::Error::Timeout(timeout))?
64    }
65
66    pub async fn execute_with_messages(
67        &self,
68        previous_messages: Vec<Message>,
69        prompt: &str,
70    ) -> crate::Result<AgentResult> {
71        let context_summary = previous_messages
72            .iter()
73            .filter_map(|m| {
74                m.content
75                    .iter()
76                    .filter_map(|b| match b {
77                        ContentBlock::Text { text, .. } => Some(text.as_str()),
78                        _ => None,
79                    })
80                    .next()
81            })
82            .collect::<Vec<_>>()
83            .join("\n---\n");
84
85        let enriched_prompt = if context_summary.is_empty() {
86            prompt.to_string()
87        } else {
88            format!(
89                "Previous conversation context:\n{}\n\nContinue with: {}",
90                context_summary, prompt
91            )
92        };
93
94        self.execute(&enriched_prompt).await
95    }
96
97    #[instrument(skip(self, prompt), fields(session_id = %self.session_id))]
98    async fn execute_inner(&self, prompt: &str) -> crate::Result<AgentResult> {
99        let _guard = self.state.acquire_execution().await;
100        let execution_start = Instant::now();
101        let hook_ctx = self.hook_context();
102
103        let session_start_input = HookInput::session_start(&*self.session_id);
104        if let Err(e) = self
105            .hooks
106            .execute(HookEvent::SessionStart, session_start_input, &hook_ctx)
107            .await
108        {
109            warn!(error = %e, "SessionStart hook failed");
110        }
111
112        let final_prompt = if let Some(merged) = self.state.dequeue_or_merge().await {
113            format!("{}\n{}", prompt, merged.content)
114        } else {
115            prompt.to_string()
116        };
117
118        let prompt_input = HookInput::user_prompt_submit(&*self.session_id, &final_prompt);
119        let prompt_output = self
120            .hooks
121            .execute(HookEvent::UserPromptSubmit, prompt_input, &hook_ctx)
122            .await?;
123
124        if !prompt_output.continue_execution {
125            let session_end_input = HookInput::session_end(&*self.session_id);
126            if let Err(e) = self
127                .hooks
128                .execute(HookEvent::SessionEnd, session_end_input, &hook_ctx)
129                .await
130            {
131                warn!(error = %e, "SessionEnd hook failed");
132            }
133            return Err(crate::Error::Permission(
134                prompt_output
135                    .stop_reason
136                    .unwrap_or_else(|| "Blocked by hook".into()),
137            ));
138        }
139
140        self.state
141            .with_session_mut(|session| {
142                session.add_user_message(&final_prompt);
143            })
144            .await;
145
146        let mut metrics = AgentMetrics::default();
147        let mut final_text = String::new();
148        let mut final_stop_reason = StopReason::EndTurn;
149        let mut dynamic_rules_context = String::new();
150        let mut total_usage = Usage::default();
151
152        let mut request_builder = {
153            let builder = RequestBuilder::new(&self.config, Arc::clone(&self.tools));
154
155            if let Some(ref tsm) = self.tool_search_manager {
156                let prepared = tsm.prepare_tools().await;
157                if prepared.use_search {
158                    info!(
159                        immediate = prepared.immediate.len(),
160                        deferred = prepared.deferred.len(),
161                        tokens_saved = prepared.token_savings(),
162                        "MCP Progressive Disclosure active"
163                    );
164                }
165                builder.prepared_tools(prepared)
166            } else {
167                builder
168            }
169        };
170        let max_tokens = context_window::for_model(&self.config.model.primary);
171
172        info!(prompt_len = final_prompt.len(), "Starting agent execution");
173
174        loop {
175            metrics.iterations += 1;
176            if metrics.iterations > self.config.execution.max_iterations {
177                warn!(
178                    max = self.config.execution.max_iterations,
179                    "Max iterations reached"
180                );
181                break;
182            }
183
184            self.check_budget()?;
185
186            let budget_ctx = BudgetContext {
187                tracker: &self.budget_tracker,
188                tenant: self.tenant_budget.as_deref(),
189                config: &self.config.budget,
190            };
191            if let Some(fallback) = budget_ctx.fallback_model() {
192                request_builder.set_model(fallback);
193            }
194
195            debug!(iteration = metrics.iterations, "Starting iteration");
196
197            let messages = self
198                .state
199                .with_session(|session| {
200                    session.to_api_messages_with_cache(self.config.cache.message_ttl_option())
201                })
202                .await;
203
204            let api_start = Instant::now();
205            let request = request_builder.build(messages, &dynamic_rules_context);
206            let response = self.client.send_with_auth_retry(request).await?;
207            let api_duration_ms = api_start.elapsed().as_millis() as u64;
208            metrics.record_api_call_with_timing(api_duration_ms);
209            debug!(api_time_ms = api_duration_ms, "API call completed");
210
211            self.state
212                .with_session_mut(|session| {
213                    session.update_usage(&response.usage);
214                })
215                .await;
216
217            accumulate_response_usage(
218                &mut total_usage,
219                &mut metrics,
220                &self.budget_tracker,
221                self.tenant_budget.as_deref(),
222                &self.config.model.primary,
223                &response.usage,
224            );
225
226            final_text = response.text();
227            final_stop_reason = response.stop_reason.unwrap_or(StopReason::EndTurn);
228
229            self.state
230                .with_session_mut(|session| {
231                    session.add_assistant_message(response.content.clone(), Some(response.usage));
232                })
233                .await;
234
235            if !response.wants_tool_use() {
236                debug!("No tool use requested, ending loop");
237                break;
238            }
239
240            let tool_uses = response.tool_uses();
241            let hook_ctx = self.hook_context();
242
243            let mut prepared = Vec::with_capacity(tool_uses.len());
244            let mut blocked = Vec::with_capacity(tool_uses.len());
245
246            for tool_use in &tool_uses {
247                let pre_input = HookInput::pre_tool_use(
248                    &*self.session_id,
249                    &tool_use.name,
250                    tool_use.input.clone(),
251                );
252                let pre_output = self
253                    .hooks
254                    .execute(HookEvent::PreToolUse, pre_input, &hook_ctx)
255                    .await?;
256
257                if !pre_output.continue_execution {
258                    debug!(tool = %tool_use.name, "Tool blocked by hook");
259                    let reason = pre_output
260                        .stop_reason
261                        .clone()
262                        .unwrap_or_else(|| "Blocked by hook".into());
263                    blocked.push(ToolResultBlock::error(&tool_use.id, reason.clone()));
264                    metrics.record_permission_denial(
265                        PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
266                            .reason(reason),
267                    );
268                } else {
269                    let input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
270                    prepared.push((tool_use.id.clone(), tool_use.name.clone(), input));
271                }
272            }
273
274            let tool_futures = prepared.into_iter().map(|(id, name, input)| {
275                let tools = &self.tools;
276                async move {
277                    let start = Instant::now();
278                    let result = tools.execute(&name, input.clone()).await;
279                    let duration_ms = start.elapsed().as_millis() as u64;
280                    (id, name, input, result, duration_ms)
281                }
282            });
283
284            let parallel_results: Vec<_> = futures::future::join_all(tool_futures).await;
285
286            let all_non_retryable = !parallel_results.is_empty()
287                && parallel_results
288                    .iter()
289                    .all(|(_, _, _, result, _)| result.is_non_retryable());
290
291            let mut results = blocked;
292            for (id, name, input, result, duration_ms) in parallel_results {
293                let is_error = result.is_error();
294                debug!(tool = %name, duration_ms, is_error, "Tool execution completed");
295                metrics.record_tool(&id, &name, duration_ms, is_error);
296
297                accumulate_inner_usage(
298                    &self.state,
299                    &mut total_usage,
300                    &mut metrics,
301                    &self.budget_tracker,
302                    &result,
303                    &name,
304                )
305                .await;
306
307                try_activate_dynamic_rules(
308                    &name,
309                    &input,
310                    &self.orchestrator,
311                    &mut dynamic_rules_context,
312                )
313                .await;
314
315                run_post_tool_hooks(
316                    &self.hooks,
317                    &hook_ctx,
318                    &self.session_id,
319                    &name,
320                    is_error,
321                    &result,
322                )
323                .await;
324
325                results.push(ToolResultBlock::from_tool_result(&id, &result));
326            }
327
328            self.state
329                .with_session_mut(|session| {
330                    session.add_tool_results(results);
331                })
332                .await;
333
334            if all_non_retryable {
335                warn!("All tool calls failed with non-retryable errors, ending execution");
336                break;
337            }
338
339            handle_compaction(
340                &self.state,
341                &self.client,
342                &self.tools,
343                &self.hooks,
344                &hook_ctx,
345                &self.session_id,
346                &self.config.execution,
347                max_tokens,
348                &mut metrics,
349            )
350            .await;
351        }
352
353        metrics.execution_time_ms = execution_start.elapsed().as_millis() as u64;
354
355        run_stop_hooks(&self.hooks, &hook_ctx, &self.session_id).await;
356
357        info!(
358            iterations = metrics.iterations,
359            tool_calls = metrics.tool_calls,
360            api_calls = metrics.api_calls,
361            total_tokens = metrics.total_tokens(),
362            execution_time_ms = metrics.execution_time_ms,
363            "Agent execution completed"
364        );
365
366        let messages = self
367            .state
368            .with_session(|session| session.to_api_messages())
369            .await;
370
371        let structured_output = self.extract_structured_output(&final_text);
372        Ok(AgentResult::new(
373            final_text,
374            total_usage,
375            metrics.iterations,
376            final_stop_reason,
377            metrics,
378            self.session_id.to_string(),
379            structured_output,
380            messages,
381        ))
382    }
383
384    pub(crate) fn hook_context(&self) -> HookContext {
385        HookContext::new(&*self.session_id)
386            .cwd(self.config.working_dir.clone().unwrap_or_default())
387            .env(self.config.security.env.clone())
388    }
389
390    fn extract_structured_output(&self, text: &str) -> Option<serde_json::Value> {
391        common::extract_structured_output(self.config.prompt.output_schema.as_ref(), text)
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::common::extract_file_path;
398
399    #[test]
400    fn test_extract_file_path() {
401        let input = serde_json::json!({"file_path": "/src/lib.rs"});
402        assert_eq!(
403            extract_file_path("Read", &input),
404            Some("/src/lib.rs".to_string())
405        );
406
407        let input = serde_json::json!({"path": "/src"});
408        assert_eq!(extract_file_path("Glob", &input), Some("/src".to_string()));
409
410        let input = serde_json::json!({"command": "ls"});
411        assert_eq!(extract_file_path("Bash", &input), None);
412    }
413}