claude_agent/agent/
streaming.rs

1//! Agent streaming execution with session-based context management.
2
3use std::path::Path;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Instant;
7
8use futures::{Stream, StreamExt, stream};
9use tokio::sync::RwLock;
10use tracing::{debug, warn};
11
12use super::common::BudgetContext;
13use super::events::{AgentEvent, AgentResult};
14use super::execution::extract_file_path;
15use super::executor::Agent;
16use super::request::RequestBuilder;
17use super::state_formatter::collect_compaction_state;
18use super::{AgentConfig, AgentMetrics, AgentState};
19use crate::budget::{BudgetTracker, TenantBudget};
20use crate::client::{RecoverableStream, StreamItem};
21use crate::context::PromptOrchestrator;
22use crate::hooks::{HookContext, HookEvent, HookInput, HookManager};
23use crate::session::ToolState;
24use crate::types::{
25    CompactResult, ContentBlock, PermissionDenial, StopReason, StreamEvent, ToolResultBlock,
26    ToolUseBlock, Usage, context_window,
27};
28use crate::{Client, ToolRegistry};
29
30type BoxedByteStream =
31    Pin<Box<dyn Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send>>;
32
33impl Agent {
34    pub async fn execute_stream(
35        &self,
36        prompt: &str,
37    ) -> crate::Result<impl Stream<Item = crate::Result<AgentEvent>> + Send> {
38        let timeout = self
39            .config
40            .execution
41            .timeout
42            .unwrap_or(std::time::Duration::from_secs(600));
43
44        if self.state.is_executing() {
45            self.state
46                .enqueue(prompt)
47                .await
48                .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
49        } else {
50            self.state
51                .with_session_mut(|session| {
52                    session.add_user_message(prompt);
53                })
54                .await;
55        }
56
57        let state = StreamState::new(
58            StreamStateConfig {
59                tool_state: self.state.clone(),
60                client: Arc::clone(&self.client),
61                config: Arc::clone(&self.config),
62                tools: Arc::clone(&self.tools),
63                hooks: Arc::clone(&self.hooks),
64                hook_context: self.hook_context(),
65                request_builder: RequestBuilder::new(&self.config, Arc::clone(&self.tools)),
66                orchestrator: self.orchestrator.clone(),
67                session_id: Arc::clone(&self.session_id),
68                budget_tracker: Arc::clone(&self.budget_tracker),
69                tenant_budget: self.tenant_budget.clone(),
70            },
71            timeout,
72        );
73
74        Ok(stream::unfold(state, |mut state| async move {
75            state.next_event().await.map(|event| (event, state))
76        }))
77    }
78}
79
80struct StreamStateConfig {
81    tool_state: ToolState,
82    client: Arc<Client>,
83    config: Arc<AgentConfig>,
84    tools: Arc<ToolRegistry>,
85    hooks: Arc<HookManager>,
86    hook_context: HookContext,
87    request_builder: RequestBuilder,
88    orchestrator: Option<Arc<RwLock<PromptOrchestrator>>>,
89    session_id: Arc<str>,
90    budget_tracker: Arc<BudgetTracker>,
91    tenant_budget: Option<Arc<TenantBudget>>,
92}
93
94enum StreamPollResult {
95    Event(crate::Result<AgentEvent>),
96    Continue,
97    StreamEnded,
98}
99
100enum Phase {
101    StartRequest,
102    Streaming(Box<StreamingPhase>),
103    StreamEnded { accumulated_usage: Usage },
104    ProcessingTools { tool_index: usize },
105    Done,
106}
107
108struct StreamingPhase {
109    stream: RecoverableStream<BoxedByteStream>,
110    accumulated_usage: Usage,
111}
112
113struct StreamState {
114    cfg: StreamStateConfig,
115    timeout: std::time::Duration,
116    chunk_timeout: std::time::Duration,
117    dynamic_rules: String,
118    metrics: AgentMetrics,
119    start_time: Instant,
120    last_chunk_time: Instant,
121    pending_tool_results: Vec<ToolResultBlock>,
122    pending_tool_uses: Vec<ToolUseBlock>,
123    final_text: String,
124    total_usage: Usage,
125    phase: Phase,
126}
127
128impl StreamState {
129    fn new(cfg: StreamStateConfig, timeout: std::time::Duration) -> Self {
130        let chunk_timeout = cfg.config.execution.chunk_timeout;
131        let now = Instant::now();
132        Self {
133            cfg,
134            timeout,
135            chunk_timeout,
136            dynamic_rules: String::new(),
137            metrics: AgentMetrics::default(),
138            start_time: now,
139            last_chunk_time: now,
140            pending_tool_results: Vec::new(),
141            pending_tool_uses: Vec::new(),
142            final_text: String::new(),
143            total_usage: Usage::default(),
144            phase: Phase::StartRequest,
145        }
146    }
147
148    async fn next_event(&mut self) -> Option<crate::Result<AgentEvent>> {
149        loop {
150            if matches!(self.phase, Phase::Done) {
151                return None;
152            }
153
154            if self.start_time.elapsed() > self.timeout {
155                self.phase = Phase::Done;
156                return Some(Err(crate::Error::Timeout(self.timeout)));
157            }
158
159            if let Some(event) = self.check_budget_exceeded() {
160                return Some(event);
161            }
162
163            match std::mem::replace(&mut self.phase, Phase::Done) {
164                Phase::StartRequest => {
165                    if let Some(result) = self.do_start_request().await {
166                        return Some(result);
167                    }
168                }
169                Phase::Streaming(mut streaming) => {
170                    match self
171                        .do_poll_stream(&mut streaming.stream, &mut streaming.accumulated_usage)
172                        .await
173                    {
174                        StreamPollResult::Event(event) => {
175                            self.phase = Phase::Streaming(streaming);
176                            return Some(event);
177                        }
178                        StreamPollResult::Continue => {
179                            self.phase = Phase::Streaming(streaming);
180                        }
181                        StreamPollResult::StreamEnded => {
182                            self.phase = Phase::StreamEnded {
183                                accumulated_usage: streaming.accumulated_usage,
184                            };
185                        }
186                    }
187                }
188                Phase::StreamEnded { accumulated_usage } => {
189                    if let Some(event) = self.do_handle_stream_end(accumulated_usage).await {
190                        return Some(event);
191                    }
192                }
193                Phase::ProcessingTools { tool_index } => {
194                    if let Some(result) = self.do_process_tool(tool_index).await {
195                        return Some(result);
196                    }
197                }
198                Phase::Done => return None,
199            }
200        }
201    }
202
203    fn check_budget_exceeded(&mut self) -> Option<crate::Result<AgentEvent>> {
204        let result = BudgetContext {
205            tracker: &self.cfg.budget_tracker,
206            tenant: self.cfg.tenant_budget.as_deref(),
207            config: &self.cfg.config.budget,
208        }
209        .check();
210
211        if let Err(e) = result {
212            self.phase = Phase::Done;
213            return Some(Err(e));
214        }
215
216        None
217    }
218
219    async fn do_start_request(&mut self) -> Option<crate::Result<AgentEvent>> {
220        self.metrics.iterations += 1;
221        if self.metrics.iterations > self.cfg.config.execution.max_iterations {
222            self.phase = Phase::Done;
223            self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
224
225            let messages = self
226                .cfg
227                .tool_state
228                .with_session(|session| session.to_api_messages())
229                .await;
230
231            return Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
232                text: self.final_text.clone(),
233                usage: self.total_usage,
234                tool_calls: self.metrics.tool_calls,
235                iterations: self.metrics.iterations - 1,
236                stop_reason: StopReason::MaxTokens,
237                state: AgentState::Completed,
238                metrics: self.metrics.clone(),
239                session_id: self.cfg.session_id.to_string(),
240                structured_output: None,
241                messages,
242                uuid: uuid::Uuid::new_v4().to_string(),
243            }))));
244        }
245
246        let cache_messages = self.cfg.config.cache.enabled && self.cfg.config.cache.message_cache;
247        let messages = self
248            .cfg
249            .tool_state
250            .with_session(|session| session.to_api_messages_with_cache(cache_messages))
251            .await;
252
253        let stream_request = self
254            .cfg
255            .request_builder
256            .build(messages, &self.dynamic_rules)
257            .with_stream();
258
259        let response = match self
260            .cfg
261            .client
262            .send_stream_with_auth_retry(stream_request)
263            .await
264        {
265            Ok(r) => r,
266            Err(e) => {
267                self.phase = Phase::Done;
268                return Some(Err(e));
269            }
270        };
271
272        self.metrics.record_api_call();
273
274        let boxed_stream: BoxedByteStream = Box::pin(response.bytes_stream());
275        self.phase = Phase::Streaming(Box::new(StreamingPhase {
276            stream: RecoverableStream::new(boxed_stream),
277            accumulated_usage: Usage::default(),
278        }));
279
280        None
281    }
282
283    async fn do_poll_stream(
284        &mut self,
285        stream: &mut RecoverableStream<BoxedByteStream>,
286        accumulated_usage: &mut Usage,
287    ) -> StreamPollResult {
288        let chunk_result = tokio::time::timeout(self.chunk_timeout, stream.next()).await;
289
290        match chunk_result {
291            Ok(Some(Ok(item))) => {
292                self.last_chunk_time = Instant::now();
293                self.handle_stream_item(item, accumulated_usage)
294            }
295            Ok(Some(Err(e))) => {
296                self.phase = Phase::Done;
297                StreamPollResult::Event(Err(e))
298            }
299            Ok(None) => StreamPollResult::StreamEnded,
300            Err(_) => {
301                self.phase = Phase::Done;
302                StreamPollResult::Event(Err(crate::Error::Stream(format!(
303                    "Chunk timeout after {:?} (no data received)",
304                    self.chunk_timeout
305                ))))
306            }
307        }
308    }
309
310    fn handle_stream_item(
311        &mut self,
312        item: StreamItem,
313        accumulated_usage: &mut Usage,
314    ) -> StreamPollResult {
315        match item {
316            StreamItem::Text(text) => {
317                self.final_text.push_str(&text);
318                StreamPollResult::Event(Ok(AgentEvent::Text(text)))
319            }
320            StreamItem::Thinking(thinking) => {
321                StreamPollResult::Event(Ok(AgentEvent::Thinking(thinking)))
322            }
323            StreamItem::Citation(_) => StreamPollResult::Continue,
324            StreamItem::ToolUseComplete(tool_use) => {
325                self.pending_tool_uses.push(tool_use);
326                StreamPollResult::Continue
327            }
328            StreamItem::Event(event) => self.handle_stream_event(event, accumulated_usage),
329        }
330    }
331
332    fn handle_stream_event(
333        &mut self,
334        event: StreamEvent,
335        accumulated_usage: &mut Usage,
336    ) -> StreamPollResult {
337        match event {
338            StreamEvent::MessageStart { message } => {
339                accumulated_usage.input_tokens = message.usage.input_tokens;
340                accumulated_usage.output_tokens = message.usage.output_tokens;
341                accumulated_usage.cache_creation_input_tokens =
342                    message.usage.cache_creation_input_tokens;
343                accumulated_usage.cache_read_input_tokens = message.usage.cache_read_input_tokens;
344                StreamPollResult::Continue
345            }
346            StreamEvent::ContentBlockStart { .. } => StreamPollResult::Continue,
347            StreamEvent::ContentBlockDelta { .. } => StreamPollResult::Continue,
348            StreamEvent::ContentBlockStop { .. } => StreamPollResult::Continue,
349            StreamEvent::MessageDelta { usage, .. } => {
350                accumulated_usage.output_tokens = usage.output_tokens;
351                StreamPollResult::Continue
352            }
353            StreamEvent::MessageStop => StreamPollResult::StreamEnded,
354            StreamEvent::Ping => StreamPollResult::Continue,
355            StreamEvent::Error { error } => {
356                self.phase = Phase::Done;
357                StreamPollResult::Event(Err(crate::Error::Stream(error.message)))
358            }
359        }
360    }
361
362    async fn do_handle_stream_end(
363        &mut self,
364        accumulated_usage: Usage,
365    ) -> Option<crate::Result<AgentEvent>> {
366        self.total_usage.input_tokens += accumulated_usage.input_tokens;
367        self.total_usage.output_tokens += accumulated_usage.output_tokens;
368        self.metrics.add_usage(
369            accumulated_usage.input_tokens,
370            accumulated_usage.output_tokens,
371        );
372        self.metrics
373            .record_model_usage(&self.cfg.config.model.primary, &accumulated_usage);
374
375        let cost = self
376            .cfg
377            .budget_tracker
378            .record(&self.cfg.config.model.primary, &accumulated_usage);
379        self.metrics.add_cost(cost);
380
381        if let Some(ref tenant_budget) = self.cfg.tenant_budget {
382            tenant_budget.record(&self.cfg.config.model.primary, &accumulated_usage);
383        }
384
385        self.cfg
386            .tool_state
387            .with_session_mut(|session| {
388                session.update_usage(&accumulated_usage);
389
390                let mut content = Vec::new();
391                if !self.final_text.is_empty() {
392                    content.push(ContentBlock::Text {
393                        text: self.final_text.clone(),
394                        citations: None,
395                        cache_control: None,
396                    });
397                }
398                for tool_use in &self.pending_tool_uses {
399                    content.push(ContentBlock::ToolUse(tool_use.clone()));
400                }
401                if !content.is_empty() {
402                    session.add_assistant_message(content, Some(accumulated_usage));
403                }
404            })
405            .await;
406
407        if self.pending_tool_uses.is_empty() {
408            self.phase = Phase::Done;
409            self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
410
411            let messages = self
412                .cfg
413                .tool_state
414                .with_session(|session| session.to_api_messages())
415                .await;
416
417            return Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
418                text: self.final_text.clone(),
419                usage: self.total_usage,
420                tool_calls: self.metrics.tool_calls,
421                iterations: self.metrics.iterations,
422                stop_reason: StopReason::EndTurn,
423                state: AgentState::Completed,
424                metrics: self.metrics.clone(),
425                session_id: self.cfg.session_id.to_string(),
426                structured_output: None,
427                messages,
428                uuid: uuid::Uuid::new_v4().to_string(),
429            }))));
430        }
431
432        self.phase = Phase::ProcessingTools { tool_index: 0 };
433        None
434    }
435
436    async fn do_process_tool(&mut self, tool_index: usize) -> Option<crate::Result<AgentEvent>> {
437        if tool_index >= self.pending_tool_uses.len() {
438            if !self.pending_tool_results.is_empty() {
439                self.finalize_tool_results().await;
440            }
441            self.final_text.clear();
442            self.pending_tool_uses.clear();
443            self.phase = Phase::StartRequest;
444            return None;
445        }
446
447        let tool_use = self.pending_tool_uses[tool_index].clone();
448        self.execute_tool(tool_use, tool_index).await
449    }
450
451    async fn execute_tool(
452        &mut self,
453        tool_use: ToolUseBlock,
454        tool_index: usize,
455    ) -> Option<crate::Result<AgentEvent>> {
456        let pre_input = HookInput::pre_tool_use(
457            &*self.cfg.session_id,
458            &tool_use.name,
459            tool_use.input.clone(),
460        );
461        let pre_output = match self
462            .cfg
463            .hooks
464            .execute(HookEvent::PreToolUse, pre_input, &self.cfg.hook_context)
465            .await
466        {
467            Ok(output) => output,
468            Err(e) => {
469                warn!(tool = %tool_use.name, error = %e, "PreToolUse hook failed");
470                crate::hooks::HookOutput::allow()
471            }
472        };
473
474        if !pre_output.continue_execution {
475            let reason = pre_output
476                .stop_reason
477                .clone()
478                .unwrap_or_else(|| "Blocked by hook".into());
479            debug!(tool = %tool_use.name, "Tool blocked by hook");
480
481            self.pending_tool_results
482                .push(ToolResultBlock::error(&tool_use.id, reason.clone()));
483            self.metrics.record_permission_denial(
484                PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
485                    .with_reason(reason.clone()),
486            );
487            self.phase = Phase::ProcessingTools {
488                tool_index: tool_index + 1,
489            };
490
491            return Some(Ok(AgentEvent::ToolBlocked {
492                id: tool_use.id,
493                name: tool_use.name,
494                reason,
495            }));
496        }
497
498        let actual_input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
499
500        let start = Instant::now();
501        let result = self
502            .cfg
503            .tools
504            .execute(&tool_use.name, actual_input.clone())
505            .await;
506        let duration_ms = start.elapsed().as_millis() as u64;
507
508        let (output, is_error) = match &result.output {
509            crate::types::ToolOutput::Success(s) => (s.clone(), false),
510            crate::types::ToolOutput::SuccessBlocks(blocks) => {
511                let text = blocks
512                    .iter()
513                    .filter_map(|b| match b {
514                        crate::types::ToolOutputBlock::Text { text } => Some(text.as_str()),
515                        _ => None,
516                    })
517                    .collect::<Vec<_>>()
518                    .join("\n");
519                (text, false)
520            }
521            crate::types::ToolOutput::Error(e) => (e.to_string(), true),
522            crate::types::ToolOutput::Empty => (String::new(), false),
523        };
524
525        self.metrics
526            .record_tool(&tool_use.id, &tool_use.name, duration_ms, is_error);
527
528        if let Some(ref inner_usage) = result.inner_usage {
529            self.cfg
530                .tool_state
531                .with_session_mut(|session| {
532                    session.update_usage(inner_usage);
533                })
534                .await;
535            self.total_usage.input_tokens += inner_usage.input_tokens;
536            self.total_usage.output_tokens += inner_usage.output_tokens;
537            self.metrics
538                .add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
539            let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
540            self.metrics.record_model_usage(inner_model, inner_usage);
541            let inner_cost = self.cfg.budget_tracker.record(inner_model, inner_usage);
542            self.metrics.add_cost(inner_cost);
543        }
544
545        if is_error {
546            let failure_input = HookInput::post_tool_use_failure(
547                &*self.cfg.session_id,
548                &tool_use.name,
549                result.error_message(),
550            );
551            if let Err(e) = self
552                .cfg
553                .hooks
554                .execute(
555                    HookEvent::PostToolUseFailure,
556                    failure_input,
557                    &self.cfg.hook_context,
558                )
559                .await
560            {
561                warn!(tool = %tool_use.name, error = %e, "PostToolUseFailure hook failed");
562            }
563        } else {
564            let post_input = HookInput::post_tool_use(
565                &*self.cfg.session_id,
566                &tool_use.name,
567                result.output.clone(),
568            );
569            if let Err(e) = self
570                .cfg
571                .hooks
572                .execute(HookEvent::PostToolUse, post_input, &self.cfg.hook_context)
573                .await
574            {
575                warn!(tool = %tool_use.name, error = %e, "PostToolUse hook failed");
576            }
577        }
578
579        if let Some(file_path) = extract_file_path(&tool_use.name, &actual_input)
580            && let Some(ref orchestrator) = self.cfg.orchestrator
581        {
582            let orch = orchestrator.read().await;
583            let path = Path::new(&file_path);
584            let rules = orch.rules_engine().find_matching(path);
585            if !rules.is_empty() {
586                let dynamic_ctx = orch.build_dynamic_context(Some(path)).await;
587                if !dynamic_ctx.is_empty() {
588                    self.dynamic_rules = dynamic_ctx;
589                }
590            }
591        }
592
593        self.pending_tool_results
594            .push(ToolResultBlock::from_tool_result(&tool_use.id, result));
595        self.phase = Phase::ProcessingTools {
596            tool_index: tool_index + 1,
597        };
598
599        Some(Ok(AgentEvent::ToolComplete {
600            id: tool_use.id,
601            name: tool_use.name,
602            output,
603            is_error,
604            duration_ms,
605        }))
606    }
607
608    async fn finalize_tool_results(&mut self) {
609        let results = std::mem::take(&mut self.pending_tool_results);
610        let max_tokens = context_window::for_model(&self.cfg.config.model.primary);
611
612        self.cfg
613            .tool_state
614            .with_session_mut(|session| {
615                session.add_tool_results(results);
616            })
617            .await;
618
619        let should_compact = self
620            .cfg
621            .tool_state
622            .with_session(|session| {
623                self.cfg.config.execution.auto_compact
624                    && session.should_compact(
625                        max_tokens,
626                        self.cfg.config.execution.compact_threshold,
627                        self.cfg.config.execution.compact_keep_messages,
628                    )
629            })
630            .await;
631
632        if should_compact {
633            let compact_result = self
634                .cfg
635                .tool_state
636                .compact(
637                    &self.cfg.client,
638                    self.cfg.config.execution.compact_keep_messages,
639                )
640                .await;
641
642            if let Ok(CompactResult::Compacted { .. }) = compact_result {
643                self.metrics.record_compaction();
644                let state_sections = collect_compaction_state(&self.cfg.tools).await;
645                if !state_sections.is_empty() {
646                    self.cfg
647                        .tool_state
648                        .with_session_mut(|session| {
649                            session.add_user_message(format!(
650                                "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
651                                state_sections.join("\n\n")
652                            ));
653                        })
654                        .await;
655                }
656            }
657        }
658    }
659}
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664
665    #[test]
666    fn test_phase_transitions() {
667        assert!(matches!(Phase::StartRequest, Phase::StartRequest));
668        assert!(matches!(Phase::Done, Phase::Done));
669    }
670
671    #[test]
672    fn test_stream_poll_result_variants() {
673        let event = StreamPollResult::Event(Ok(AgentEvent::Text("test".into())));
674        assert!(matches!(event, StreamPollResult::Event(_)));
675
676        let cont = StreamPollResult::Continue;
677        assert!(matches!(cont, StreamPollResult::Continue));
678
679        let ended = StreamPollResult::StreamEnded;
680        assert!(matches!(ended, StreamPollResult::StreamEnded));
681    }
682}