claude_agent/agent/
streaming.rs

1//! Agent streaming execution.
2
3use std::collections::VecDeque;
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Instant;
7
8use futures::Stream;
9use tokio::sync::{Mutex, RwLock};
10use tracing::{debug, warn};
11
12use super::events::{AgentEvent, AgentResult};
13use super::execution::extract_file_path;
14use super::executor::Agent;
15use super::request::RequestBuilder;
16use super::state_formatter::collect_compaction_state;
17use super::{AgentConfig, AgentMetrics, AgentState, ConversationContext};
18use crate::budget::{BudgetTracker, TenantBudget};
19use crate::context::PromptOrchestrator;
20use crate::hooks::{HookContext, HookEvent, HookInput, HookManager};
21use crate::types::{
22    CompactResult, ContentBlock, Message, PermissionDenial, StopReason, ToolResultBlock,
23    ToolUseBlock,
24};
25use crate::{Client, ToolRegistry};
26
27impl Agent {
28    pub async fn execute_stream(
29        &self,
30        prompt: &str,
31    ) -> crate::Result<impl Stream<Item = crate::Result<AgentEvent>>> {
32        use futures::stream;
33
34        let context = Arc::new(Mutex::new(ConversationContext::new()));
35        {
36            let mut history = context.lock().await;
37            history.push(Message::user(prompt));
38        }
39
40        let request_builder = RequestBuilder::new(&self.config, Arc::clone(&self.tools));
41
42        let event_stream = stream::unfold(
43            StreamState::new(
44                context,
45                self.client.clone(),
46                self.config.clone(),
47                Arc::clone(&self.tools),
48                self.hooks.clone(),
49                self.hook_context(),
50                request_builder,
51                self.orchestrator.clone(),
52                self.session_id.clone(),
53                self.budget_tracker.clone(),
54                self.tenant_budget.clone(),
55            ),
56            |mut state| async move { state.next_event().await.map(|event| (event, state)) },
57        );
58
59        Ok(event_stream)
60    }
61}
62
63pub(crate) struct StreamState {
64    context: Arc<Mutex<ConversationContext>>,
65    client: Client,
66    config: AgentConfig,
67    tools: Arc<ToolRegistry>,
68    hooks: HookManager,
69    hook_context: HookContext,
70    request_builder: RequestBuilder,
71    dynamic_rules: String,
72    orchestrator: Option<Arc<RwLock<PromptOrchestrator>>>,
73    metrics: AgentMetrics,
74    start_time: Instant,
75    pending_events: VecDeque<crate::Result<AgentEvent>>,
76    pending_tool_results: Vec<ToolResultBlock>,
77    pending_tool_uses: Vec<ToolUseBlock>,
78    final_text: String,
79    done: bool,
80    session_id: String,
81    budget_tracker: BudgetTracker,
82    tenant_budget: Option<Arc<TenantBudget>>,
83}
84
85impl StreamState {
86    #[allow(clippy::too_many_arguments)]
87    fn new(
88        context: Arc<Mutex<ConversationContext>>,
89        client: Client,
90        config: AgentConfig,
91        tools: Arc<ToolRegistry>,
92        hooks: HookManager,
93        hook_context: HookContext,
94        request_builder: RequestBuilder,
95        orchestrator: Option<Arc<RwLock<PromptOrchestrator>>>,
96        session_id: String,
97        budget_tracker: BudgetTracker,
98        tenant_budget: Option<Arc<TenantBudget>>,
99    ) -> Self {
100        Self {
101            context,
102            client,
103            config,
104            tools,
105            hooks,
106            hook_context,
107            request_builder,
108            dynamic_rules: String::new(),
109            orchestrator,
110            metrics: AgentMetrics::default(),
111            start_time: Instant::now(),
112            pending_events: VecDeque::new(),
113            pending_tool_results: Vec::new(),
114            pending_tool_uses: Vec::new(),
115            final_text: String::new(),
116            done: false,
117            session_id,
118            budget_tracker,
119            tenant_budget,
120        }
121    }
122
123    async fn next_event(&mut self) -> Option<crate::Result<AgentEvent>> {
124        if let Some(event) = self.pending_events.pop_front() {
125            return Some(event);
126        }
127
128        if self.done {
129            return None;
130        }
131
132        if !self.pending_tool_uses.is_empty() {
133            return self.process_tool_use().await;
134        }
135
136        if self.budget_tracker.should_stop() {
137            self.done = true;
138            let status = self.budget_tracker.check();
139            return Some(Err(crate::Error::BudgetExceeded {
140                used: status.used(),
141                limit: self.config.budget.max_cost_usd.unwrap_or(0.0),
142            }));
143        }
144        if let Some(ref tenant_budget) = self.tenant_budget
145            && tenant_budget.should_stop()
146        {
147            self.done = true;
148            return Some(Err(crate::Error::BudgetExceeded {
149                used: tenant_budget.used_cost_usd(),
150                limit: tenant_budget.max_cost_usd(),
151            }));
152        }
153
154        self.metrics.iterations += 1;
155        if self.metrics.iterations > self.config.execution.max_iterations {
156            return self.complete_with_max_iterations().await;
157        }
158
159        self.fetch_and_process_response().await
160    }
161
162    async fn process_tool_use(&mut self) -> Option<crate::Result<AgentEvent>> {
163        let tool_use = self.pending_tool_uses.remove(0);
164
165        let pre_input =
166            HookInput::pre_tool_use(&self.session_id, &tool_use.name, tool_use.input.clone());
167        let pre_output = match self
168            .hooks
169            .execute(HookEvent::PreToolUse, pre_input, &self.hook_context)
170            .await
171        {
172            Ok(output) => output,
173            Err(e) => {
174                warn!(tool = %tool_use.name, error = %e, "PreToolUse hook failed");
175                crate::hooks::HookOutput::allow()
176            }
177        };
178
179        if !pre_output.continue_execution {
180            let reason = pre_output
181                .stop_reason
182                .clone()
183                .unwrap_or_else(|| "Blocked by hook".into());
184            debug!(tool = %tool_use.name, "Tool blocked by hook");
185            self.pending_events.push_back(Ok(AgentEvent::ToolStart {
186                id: tool_use.id.clone(),
187                name: tool_use.name.clone(),
188                input: tool_use.input.clone(),
189            }));
190            self.pending_events.push_back(Ok(AgentEvent::ToolEnd {
191                id: tool_use.id.clone(),
192                output: reason.clone(),
193                is_error: true,
194            }));
195            self.pending_tool_results
196                .push(ToolResultBlock::error(&tool_use.id, reason.clone()));
197            self.metrics.record_permission_denial(
198                PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
199                    .with_reason(reason),
200            );
201            return self.pending_events.pop_front();
202        }
203
204        let actual_input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
205
206        self.pending_events.push_back(Ok(AgentEvent::ToolStart {
207            id: tool_use.id.clone(),
208            name: tool_use.name.clone(),
209            input: actual_input.clone(),
210        }));
211
212        let start = Instant::now();
213        let result = self
214            .tools
215            .execute(&tool_use.name, actual_input.clone())
216            .await;
217        let duration_ms = start.elapsed().as_millis() as u64;
218
219        let (output, is_error) = match &result.output {
220            crate::types::ToolOutput::Success(s) => (s.clone(), false),
221            crate::types::ToolOutput::SuccessBlocks(blocks) => {
222                let text = blocks
223                    .iter()
224                    .filter_map(|b| match b {
225                        crate::types::ToolOutputBlock::Text { text } => Some(text.as_str()),
226                        _ => None,
227                    })
228                    .collect::<Vec<_>>()
229                    .join("\n");
230                (text, false)
231            }
232            crate::types::ToolOutput::Error(e) => (e.to_string(), true),
233            crate::types::ToolOutput::Empty => (String::new(), false),
234        };
235
236        self.metrics
237            .record_tool(&tool_use.name, duration_ms, is_error);
238
239        if let Some(ref inner_usage) = result.inner_usage {
240            let mut history = self.context.lock().await;
241            history.update_usage(*inner_usage);
242            self.metrics
243                .add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
244            let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
245            self.metrics.record_model_usage(inner_model, inner_usage);
246            let inner_cost = self.budget_tracker.record(inner_model, inner_usage);
247            self.metrics.add_cost(inner_cost);
248        }
249
250        if is_error {
251            let failure_input = HookInput::post_tool_use_failure(
252                &self.session_id,
253                &tool_use.name,
254                result.error_message(),
255            );
256            if let Err(e) = self
257                .hooks
258                .execute(
259                    HookEvent::PostToolUseFailure,
260                    failure_input,
261                    &self.hook_context,
262                )
263                .await
264            {
265                warn!(tool = %tool_use.name, error = %e, "PostToolUseFailure hook failed");
266            }
267        } else {
268            let post_input =
269                HookInput::post_tool_use(&self.session_id, &tool_use.name, result.output.clone());
270            if let Err(e) = self
271                .hooks
272                .execute(HookEvent::PostToolUse, post_input, &self.hook_context)
273                .await
274            {
275                warn!(tool = %tool_use.name, error = %e, "PostToolUse hook failed");
276            }
277        }
278
279        if let Some(file_path) = extract_file_path(&tool_use.name, &actual_input)
280            && let Some(ref orchestrator) = self.orchestrator
281        {
282            let orch = orchestrator.read().await;
283            let path = Path::new(&file_path);
284            let rules = orch.rules_engine().find_matching(path);
285            if !rules.is_empty() {
286                let rule_names: Vec<String> = rules.iter().map(|r| r.name.clone()).collect();
287                let dynamic_ctx = orch.build_dynamic_context(Some(path)).await;
288                if !dynamic_ctx.is_empty() {
289                    self.dynamic_rules = dynamic_ctx;
290                }
291                self.pending_events
292                    .push_back(Ok(AgentEvent::RulesActivated {
293                        file_path,
294                        rule_names,
295                    }));
296            }
297        }
298
299        self.pending_events.push_back(Ok(AgentEvent::ToolEnd {
300            id: tool_use.id.clone(),
301            output: output.clone(),
302            is_error,
303        }));
304
305        self.pending_tool_results
306            .push(ToolResultBlock::from_tool_result(&tool_use.id, result));
307
308        if self.pending_tool_uses.is_empty() && !self.pending_tool_results.is_empty() {
309            self.finalize_tool_results().await;
310        }
311
312        self.pending_events.pop_front()
313    }
314
315    async fn finalize_tool_results(&mut self) {
316        let results = std::mem::take(&mut self.pending_tool_results);
317        let mut history = self.context.lock().await;
318        history.push(Message::tool_results(results));
319
320        let used_tokens = history.estimated_tokens() as u64;
321        let max_tokens = 200_000u64;
322        self.pending_events.push_back(Ok(AgentEvent::ContextUpdate {
323            used_tokens,
324            max_tokens,
325        }));
326
327        if self.config.execution.auto_compact
328            && history.should_compact(
329                max_tokens as usize,
330                self.config.execution.compact_threshold,
331                self.config.execution.compact_keep_messages,
332            )
333        {
334            self.pending_events
335                .push_back(Ok(AgentEvent::CompactStarted));
336            let previous_tokens = history.estimated_tokens() as u64;
337
338            if let Ok(CompactResult::Compacted { .. }) = history
339                .compact(&self.client, self.config.execution.compact_keep_messages)
340                .await
341            {
342                let current_tokens = history.estimated_tokens() as u64;
343                self.pending_events
344                    .push_back(Ok(AgentEvent::CompactCompleted {
345                        previous_tokens,
346                        current_tokens,
347                    }));
348                self.metrics.record_compaction();
349
350                let state_sections = collect_compaction_state(&self.tools).await;
351                if !state_sections.is_empty() {
352                    history.push(Message::user(format!(
353                        "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
354                        state_sections.join("\n\n")
355                    )));
356                }
357            }
358        }
359    }
360
361    async fn complete_with_max_iterations(&mut self) -> Option<crate::Result<AgentEvent>> {
362        self.done = true;
363        self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
364        let history = self.context.lock().await;
365        Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
366            text: self.final_text.clone(),
367            usage: *history.total_usage(),
368            tool_calls: self.metrics.tool_calls,
369            iterations: self.metrics.iterations - 1,
370            stop_reason: StopReason::MaxTokens,
371            state: AgentState::Completed,
372            metrics: self.metrics.clone(),
373            session_id: self.session_id.clone(),
374            structured_output: None,
375            messages: history.messages().to_vec(),
376            uuid: uuid::Uuid::new_v4().to_string(),
377        }))))
378    }
379
380    async fn fetch_and_process_response(&mut self) -> Option<crate::Result<AgentEvent>> {
381        let request = {
382            let history = self.context.lock().await;
383            self.request_builder
384                .build(history.messages().to_vec(), &self.dynamic_rules)
385        };
386
387        let response = match self.client.send(request.clone()).await {
388            Ok(r) => r,
389            Err(e) if e.is_unauthorized() => {
390                if let Err(refresh_err) = self.client.refresh_credentials().await {
391                    self.done = true;
392                    return Some(Err(refresh_err));
393                }
394                match self.client.send(request).await {
395                    Ok(r) => r,
396                    Err(e) => {
397                        self.done = true;
398                        return Some(Err(e));
399                    }
400                }
401            }
402            Err(e) => {
403                self.done = true;
404                return Some(Err(e));
405            }
406        };
407
408        self.metrics.record_api_call();
409        self.metrics
410            .add_usage(response.usage.input_tokens, response.usage.output_tokens);
411        self.metrics
412            .record_model_usage(&self.config.model.primary, &response.usage);
413
414        let cost = self
415            .budget_tracker
416            .record(&self.config.model.primary, &response.usage);
417        self.metrics.add_cost(cost);
418        if let Some(ref tenant_budget) = self.tenant_budget {
419            tenant_budget.record(&self.config.model.primary, &response.usage);
420        }
421
422        {
423            let mut history = self.context.lock().await;
424            history.update_usage(response.usage);
425        }
426
427        let mut text_content = String::new();
428        let mut tool_uses = Vec::new();
429
430        for block in &response.content {
431            match block {
432                ContentBlock::Text { text, .. } => {
433                    text_content.push_str(text);
434                    self.pending_events
435                        .push_back(Ok(AgentEvent::Text(text.clone())));
436                }
437                ContentBlock::ToolUse(tool_use) => {
438                    tool_uses.push(tool_use.clone());
439                }
440                _ => {}
441            }
442        }
443
444        self.final_text = text_content;
445
446        {
447            let mut history = self.context.lock().await;
448            history.push(Message {
449                role: crate::types::Role::Assistant,
450                content: response.content.clone(),
451            });
452        }
453
454        if response.wants_tool_use() && !tool_uses.is_empty() {
455            self.pending_tool_uses = tool_uses;
456        } else {
457            self.done = true;
458            self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
459            let history = self.context.lock().await;
460            self.pending_events
461                .push_back(Ok(AgentEvent::Complete(Box::new(AgentResult {
462                    text: self.final_text.clone(),
463                    usage: *history.total_usage(),
464                    tool_calls: self.metrics.tool_calls,
465                    iterations: self.metrics.iterations,
466                    stop_reason: response.stop_reason.unwrap_or(StopReason::EndTurn),
467                    state: AgentState::Completed,
468                    metrics: self.metrics.clone(),
469                    session_id: self.session_id.clone(),
470                    structured_output: None,
471                    messages: history.messages().to_vec(),
472                    uuid: uuid::Uuid::new_v4().to_string(),
473                }))));
474        }
475
476        self.pending_events.pop_front()
477    }
478}