claude_agent/agent/
streaming.rs

1//! Agent streaming execution with session-based context management.
2
3use std::path::Path;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Instant;
7
8use futures::{Stream, StreamExt, stream};
9use tokio::sync::RwLock;
10use tracing::{debug, warn};
11
12use super::common::BudgetContext;
13use super::events::{AgentEvent, AgentResult};
14use super::execution::extract_file_path;
15use super::executor::Agent;
16use super::request::RequestBuilder;
17use super::state_formatter::collect_compaction_state;
18use super::{AgentConfig, AgentMetrics, AgentState};
19use crate::budget::{BudgetTracker, TenantBudget};
20use crate::client::{RecoverableStream, StreamItem};
21use crate::context::PromptOrchestrator;
22use crate::hooks::{HookContext, HookEvent, HookInput, HookManager};
23use crate::session::ToolState;
24use crate::types::{
25    CompactResult, ContentBlock, PermissionDenial, StopReason, StreamEvent, ToolResultBlock,
26    ToolUseBlock, Usage, context_window,
27};
28use crate::{Client, ToolRegistry};
29
30type BoxedByteStream =
31    Pin<Box<dyn Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send>>;
32
33impl Agent {
34    pub async fn execute_stream(
35        &self,
36        prompt: &str,
37    ) -> crate::Result<impl Stream<Item = crate::Result<AgentEvent>> + Send> {
38        let timeout = self
39            .config
40            .execution
41            .timeout
42            .unwrap_or(std::time::Duration::from_secs(600));
43
44        if self.state.is_executing() {
45            self.state
46                .enqueue(prompt)
47                .await
48                .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
49        } else {
50            self.state
51                .with_session_mut(|session| {
52                    session.add_user_message(prompt);
53                })
54                .await;
55        }
56
57        let state = StreamState::new(
58            StreamStateConfig {
59                tool_state: self.state.clone(),
60                client: Arc::clone(&self.client),
61                config: Arc::clone(&self.config),
62                tools: Arc::clone(&self.tools),
63                hooks: Arc::clone(&self.hooks),
64                hook_context: self.hook_context(),
65                request_builder: RequestBuilder::new(&self.config, Arc::clone(&self.tools)),
66                orchestrator: self.orchestrator.clone(),
67                session_id: Arc::clone(&self.session_id),
68                budget_tracker: Arc::clone(&self.budget_tracker),
69                tenant_budget: self.tenant_budget.clone(),
70            },
71            timeout,
72        );
73
74        Ok(stream::unfold(state, |mut state| async move {
75            state.next_event().await.map(|event| (event, state))
76        }))
77    }
78}
79
80struct StreamStateConfig {
81    tool_state: ToolState,
82    client: Arc<Client>,
83    config: Arc<AgentConfig>,
84    tools: Arc<ToolRegistry>,
85    hooks: Arc<HookManager>,
86    hook_context: HookContext,
87    request_builder: RequestBuilder,
88    orchestrator: Option<Arc<RwLock<PromptOrchestrator>>>,
89    session_id: Arc<str>,
90    budget_tracker: Arc<BudgetTracker>,
91    tenant_budget: Option<Arc<TenantBudget>>,
92}
93
94enum StreamPollResult {
95    Event(crate::Result<AgentEvent>),
96    Continue,
97    StreamEnded,
98}
99
100enum Phase {
101    StartRequest,
102    Streaming(Box<StreamingPhase>),
103    StreamEnded { accumulated_usage: Usage },
104    ProcessingTools { tool_index: usize },
105    Done,
106}
107
108struct StreamingPhase {
109    stream: RecoverableStream<BoxedByteStream>,
110    accumulated_usage: Usage,
111}
112
113struct StreamState {
114    cfg: StreamStateConfig,
115    timeout: std::time::Duration,
116    chunk_timeout: std::time::Duration,
117    dynamic_rules: String,
118    metrics: AgentMetrics,
119    start_time: Instant,
120    last_chunk_time: Instant,
121    pending_tool_results: Vec<ToolResultBlock>,
122    pending_tool_uses: Vec<ToolUseBlock>,
123    final_text: String,
124    total_usage: Usage,
125    phase: Phase,
126}
127
128impl StreamState {
129    fn new(cfg: StreamStateConfig, timeout: std::time::Duration) -> Self {
130        let chunk_timeout = cfg.config.execution.chunk_timeout;
131        let now = Instant::now();
132        Self {
133            cfg,
134            timeout,
135            chunk_timeout,
136            dynamic_rules: String::new(),
137            metrics: AgentMetrics::default(),
138            start_time: now,
139            last_chunk_time: now,
140            pending_tool_results: Vec::new(),
141            pending_tool_uses: Vec::new(),
142            final_text: String::new(),
143            total_usage: Usage::default(),
144            phase: Phase::StartRequest,
145        }
146    }
147
148    async fn next_event(&mut self) -> Option<crate::Result<AgentEvent>> {
149        loop {
150            if matches!(self.phase, Phase::Done) {
151                return None;
152            }
153
154            if self.start_time.elapsed() > self.timeout {
155                self.phase = Phase::Done;
156                return Some(Err(crate::Error::Timeout(self.timeout)));
157            }
158
159            if let Some(event) = self.check_budget_exceeded() {
160                return Some(event);
161            }
162
163            match std::mem::replace(&mut self.phase, Phase::Done) {
164                Phase::StartRequest => {
165                    if let Some(result) = self.do_start_request().await {
166                        return Some(result);
167                    }
168                }
169                Phase::Streaming(mut streaming) => {
170                    match self
171                        .do_poll_stream(&mut streaming.stream, &mut streaming.accumulated_usage)
172                        .await
173                    {
174                        StreamPollResult::Event(event) => {
175                            self.phase = Phase::Streaming(streaming);
176                            return Some(event);
177                        }
178                        StreamPollResult::Continue => {
179                            self.phase = Phase::Streaming(streaming);
180                        }
181                        StreamPollResult::StreamEnded => {
182                            self.phase = Phase::StreamEnded {
183                                accumulated_usage: streaming.accumulated_usage,
184                            };
185                        }
186                    }
187                }
188                Phase::StreamEnded { accumulated_usage } => {
189                    if let Some(event) = self.do_handle_stream_end(accumulated_usage).await {
190                        return Some(event);
191                    }
192                }
193                Phase::ProcessingTools { tool_index } => {
194                    if let Some(result) = self.do_process_tool(tool_index).await {
195                        return Some(result);
196                    }
197                }
198                Phase::Done => return None,
199            }
200        }
201    }
202
203    fn check_budget_exceeded(&mut self) -> Option<crate::Result<AgentEvent>> {
204        let result = BudgetContext {
205            tracker: &self.cfg.budget_tracker,
206            tenant: self.cfg.tenant_budget.as_deref(),
207            config: &self.cfg.config.budget,
208        }
209        .check();
210
211        if let Err(e) = result {
212            self.phase = Phase::Done;
213            return Some(Err(e));
214        }
215
216        None
217    }
218
219    async fn do_start_request(&mut self) -> Option<crate::Result<AgentEvent>> {
220        self.metrics.iterations += 1;
221        if self.metrics.iterations > self.cfg.config.execution.max_iterations {
222            self.phase = Phase::Done;
223            self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
224
225            let messages = self
226                .cfg
227                .tool_state
228                .with_session(|session| session.to_api_messages())
229                .await;
230
231            return Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
232                text: self.final_text.clone(),
233                usage: self.total_usage,
234                tool_calls: self.metrics.tool_calls,
235                iterations: self.metrics.iterations - 1,
236                stop_reason: StopReason::MaxTokens,
237                state: AgentState::Completed,
238                metrics: self.metrics.clone(),
239                session_id: self.cfg.session_id.to_string(),
240                structured_output: None,
241                messages,
242                uuid: uuid::Uuid::new_v4().to_string(),
243            }))));
244        }
245
246        let messages = self
247            .cfg
248            .tool_state
249            .with_session(|session| {
250                session.to_api_messages_with_cache(self.cfg.config.cache.message_ttl_option())
251            })
252            .await;
253
254        let stream_request = self
255            .cfg
256            .request_builder
257            .build(messages, &self.dynamic_rules)
258            .with_stream();
259
260        let response = match self
261            .cfg
262            .client
263            .send_stream_with_auth_retry(stream_request)
264            .await
265        {
266            Ok(r) => r,
267            Err(e) => {
268                self.phase = Phase::Done;
269                return Some(Err(e));
270            }
271        };
272
273        self.metrics.record_api_call();
274
275        let boxed_stream: BoxedByteStream = Box::pin(response.bytes_stream());
276        self.phase = Phase::Streaming(Box::new(StreamingPhase {
277            stream: RecoverableStream::new(boxed_stream),
278            accumulated_usage: Usage::default(),
279        }));
280
281        None
282    }
283
284    async fn do_poll_stream(
285        &mut self,
286        stream: &mut RecoverableStream<BoxedByteStream>,
287        accumulated_usage: &mut Usage,
288    ) -> StreamPollResult {
289        let chunk_result = tokio::time::timeout(self.chunk_timeout, stream.next()).await;
290
291        match chunk_result {
292            Ok(Some(Ok(item))) => {
293                self.last_chunk_time = Instant::now();
294                self.handle_stream_item(item, accumulated_usage)
295            }
296            Ok(Some(Err(e))) => {
297                self.phase = Phase::Done;
298                StreamPollResult::Event(Err(e))
299            }
300            Ok(None) => StreamPollResult::StreamEnded,
301            Err(_) => {
302                self.phase = Phase::Done;
303                StreamPollResult::Event(Err(crate::Error::Stream(format!(
304                    "Chunk timeout after {:?} (no data received)",
305                    self.chunk_timeout
306                ))))
307            }
308        }
309    }
310
311    fn handle_stream_item(
312        &mut self,
313        item: StreamItem,
314        accumulated_usage: &mut Usage,
315    ) -> StreamPollResult {
316        match item {
317            StreamItem::Text(text) => {
318                self.final_text.push_str(&text);
319                StreamPollResult::Event(Ok(AgentEvent::Text(text)))
320            }
321            StreamItem::Thinking(thinking) => {
322                StreamPollResult::Event(Ok(AgentEvent::Thinking(thinking)))
323            }
324            StreamItem::Citation(_) => StreamPollResult::Continue,
325            StreamItem::ToolUseComplete(tool_use) => {
326                self.pending_tool_uses.push(tool_use);
327                StreamPollResult::Continue
328            }
329            StreamItem::Event(event) => self.handle_stream_event(event, accumulated_usage),
330        }
331    }
332
333    fn handle_stream_event(
334        &mut self,
335        event: StreamEvent,
336        accumulated_usage: &mut Usage,
337    ) -> StreamPollResult {
338        match event {
339            StreamEvent::MessageStart { message } => {
340                accumulated_usage.input_tokens = message.usage.input_tokens;
341                accumulated_usage.output_tokens = message.usage.output_tokens;
342                accumulated_usage.cache_creation_input_tokens =
343                    message.usage.cache_creation_input_tokens;
344                accumulated_usage.cache_read_input_tokens = message.usage.cache_read_input_tokens;
345                StreamPollResult::Continue
346            }
347            StreamEvent::ContentBlockStart { .. } => StreamPollResult::Continue,
348            StreamEvent::ContentBlockDelta { .. } => StreamPollResult::Continue,
349            StreamEvent::ContentBlockStop { .. } => StreamPollResult::Continue,
350            StreamEvent::MessageDelta { usage, .. } => {
351                accumulated_usage.output_tokens = usage.output_tokens;
352                StreamPollResult::Continue
353            }
354            StreamEvent::MessageStop => StreamPollResult::StreamEnded,
355            StreamEvent::Ping => StreamPollResult::Continue,
356            StreamEvent::Error { error } => {
357                self.phase = Phase::Done;
358                StreamPollResult::Event(Err(crate::Error::Stream(error.message)))
359            }
360        }
361    }
362
363    async fn do_handle_stream_end(
364        &mut self,
365        accumulated_usage: Usage,
366    ) -> Option<crate::Result<AgentEvent>> {
367        self.total_usage.input_tokens += accumulated_usage.input_tokens;
368        self.total_usage.output_tokens += accumulated_usage.output_tokens;
369        self.metrics.add_usage(
370            accumulated_usage.input_tokens,
371            accumulated_usage.output_tokens,
372        );
373        self.metrics
374            .record_model_usage(&self.cfg.config.model.primary, &accumulated_usage);
375
376        let cost = self
377            .cfg
378            .budget_tracker
379            .record(&self.cfg.config.model.primary, &accumulated_usage);
380        self.metrics.add_cost(cost);
381
382        if let Some(ref tenant_budget) = self.cfg.tenant_budget {
383            tenant_budget.record(&self.cfg.config.model.primary, &accumulated_usage);
384        }
385
386        self.cfg
387            .tool_state
388            .with_session_mut(|session| {
389                session.update_usage(&accumulated_usage);
390
391                let mut content = Vec::new();
392                if !self.final_text.is_empty() {
393                    content.push(ContentBlock::Text {
394                        text: self.final_text.clone(),
395                        citations: None,
396                        cache_control: None,
397                    });
398                }
399                for tool_use in &self.pending_tool_uses {
400                    content.push(ContentBlock::ToolUse(tool_use.clone()));
401                }
402                if !content.is_empty() {
403                    session.add_assistant_message(content, Some(accumulated_usage));
404                }
405            })
406            .await;
407
408        if self.pending_tool_uses.is_empty() {
409            self.phase = Phase::Done;
410            self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
411
412            let messages = self
413                .cfg
414                .tool_state
415                .with_session(|session| session.to_api_messages())
416                .await;
417
418            return Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
419                text: self.final_text.clone(),
420                usage: self.total_usage,
421                tool_calls: self.metrics.tool_calls,
422                iterations: self.metrics.iterations,
423                stop_reason: StopReason::EndTurn,
424                state: AgentState::Completed,
425                metrics: self.metrics.clone(),
426                session_id: self.cfg.session_id.to_string(),
427                structured_output: None,
428                messages,
429                uuid: uuid::Uuid::new_v4().to_string(),
430            }))));
431        }
432
433        self.phase = Phase::ProcessingTools { tool_index: 0 };
434        None
435    }
436
437    async fn do_process_tool(&mut self, tool_index: usize) -> Option<crate::Result<AgentEvent>> {
438        if tool_index >= self.pending_tool_uses.len() {
439            if !self.pending_tool_results.is_empty() {
440                self.finalize_tool_results().await;
441            }
442            self.final_text.clear();
443            self.pending_tool_uses.clear();
444            self.phase = Phase::StartRequest;
445            return None;
446        }
447
448        let tool_use = self.pending_tool_uses[tool_index].clone();
449        self.execute_tool(tool_use, tool_index).await
450    }
451
452    async fn execute_tool(
453        &mut self,
454        tool_use: ToolUseBlock,
455        tool_index: usize,
456    ) -> Option<crate::Result<AgentEvent>> {
457        let pre_input = HookInput::pre_tool_use(
458            &*self.cfg.session_id,
459            &tool_use.name,
460            tool_use.input.clone(),
461        );
462        let pre_output = match self
463            .cfg
464            .hooks
465            .execute(HookEvent::PreToolUse, pre_input, &self.cfg.hook_context)
466            .await
467        {
468            Ok(output) => output,
469            Err(e) => {
470                warn!(tool = %tool_use.name, error = %e, "PreToolUse hook failed");
471                crate::hooks::HookOutput::allow()
472            }
473        };
474
475        if !pre_output.continue_execution {
476            let reason = pre_output
477                .stop_reason
478                .clone()
479                .unwrap_or_else(|| "Blocked by hook".into());
480            debug!(tool = %tool_use.name, "Tool blocked by hook");
481
482            self.pending_tool_results
483                .push(ToolResultBlock::error(&tool_use.id, reason.clone()));
484            self.metrics.record_permission_denial(
485                PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
486                    .with_reason(reason.clone()),
487            );
488            self.phase = Phase::ProcessingTools {
489                tool_index: tool_index + 1,
490            };
491
492            return Some(Ok(AgentEvent::ToolBlocked {
493                id: tool_use.id,
494                name: tool_use.name,
495                reason,
496            }));
497        }
498
499        let actual_input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
500
501        let start = Instant::now();
502        let result = self
503            .cfg
504            .tools
505            .execute(&tool_use.name, actual_input.clone())
506            .await;
507        let duration_ms = start.elapsed().as_millis() as u64;
508
509        let (output, is_error) = match &result.output {
510            crate::types::ToolOutput::Success(s) => (s.clone(), false),
511            crate::types::ToolOutput::SuccessBlocks(blocks) => {
512                let text = blocks
513                    .iter()
514                    .filter_map(|b| match b {
515                        crate::types::ToolOutputBlock::Text { text } => Some(text.as_str()),
516                        _ => None,
517                    })
518                    .collect::<Vec<_>>()
519                    .join("\n");
520                (text, false)
521            }
522            crate::types::ToolOutput::Error(e) => (e.to_string(), true),
523            crate::types::ToolOutput::Empty => (String::new(), false),
524        };
525
526        self.metrics
527            .record_tool(&tool_use.id, &tool_use.name, duration_ms, is_error);
528
529        if let Some(ref inner_usage) = result.inner_usage {
530            self.cfg
531                .tool_state
532                .with_session_mut(|session| {
533                    session.update_usage(inner_usage);
534                })
535                .await;
536            self.total_usage.input_tokens += inner_usage.input_tokens;
537            self.total_usage.output_tokens += inner_usage.output_tokens;
538            self.metrics
539                .add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
540            let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
541            self.metrics.record_model_usage(inner_model, inner_usage);
542            let inner_cost = self.cfg.budget_tracker.record(inner_model, inner_usage);
543            self.metrics.add_cost(inner_cost);
544        }
545
546        if is_error {
547            let failure_input = HookInput::post_tool_use_failure(
548                &*self.cfg.session_id,
549                &tool_use.name,
550                result.error_message(),
551            );
552            if let Err(e) = self
553                .cfg
554                .hooks
555                .execute(
556                    HookEvent::PostToolUseFailure,
557                    failure_input,
558                    &self.cfg.hook_context,
559                )
560                .await
561            {
562                warn!(tool = %tool_use.name, error = %e, "PostToolUseFailure hook failed");
563            }
564        } else {
565            let post_input = HookInput::post_tool_use(
566                &*self.cfg.session_id,
567                &tool_use.name,
568                result.output.clone(),
569            );
570            if let Err(e) = self
571                .cfg
572                .hooks
573                .execute(HookEvent::PostToolUse, post_input, &self.cfg.hook_context)
574                .await
575            {
576                warn!(tool = %tool_use.name, error = %e, "PostToolUse hook failed");
577            }
578        }
579
580        if let Some(file_path) = extract_file_path(&tool_use.name, &actual_input)
581            && let Some(ref orchestrator) = self.cfg.orchestrator
582        {
583            let orch = orchestrator.read().await;
584            let path = Path::new(&file_path);
585            let rules = orch.rules_engine().find_matching(path);
586            if !rules.is_empty() {
587                let dynamic_ctx = orch.build_dynamic_context(Some(path)).await;
588                if !dynamic_ctx.is_empty() {
589                    self.dynamic_rules = dynamic_ctx;
590                }
591            }
592        }
593
594        self.pending_tool_results
595            .push(ToolResultBlock::from_tool_result(&tool_use.id, result));
596        self.phase = Phase::ProcessingTools {
597            tool_index: tool_index + 1,
598        };
599
600        Some(Ok(AgentEvent::ToolComplete {
601            id: tool_use.id,
602            name: tool_use.name,
603            output,
604            is_error,
605            duration_ms,
606        }))
607    }
608
609    async fn finalize_tool_results(&mut self) {
610        let results = std::mem::take(&mut self.pending_tool_results);
611        let max_tokens = context_window::for_model(&self.cfg.config.model.primary);
612
613        self.cfg
614            .tool_state
615            .with_session_mut(|session| {
616                session.add_tool_results(results);
617            })
618            .await;
619
620        let should_compact = self
621            .cfg
622            .tool_state
623            .with_session(|session| {
624                self.cfg.config.execution.auto_compact
625                    && session.should_compact(
626                        max_tokens,
627                        self.cfg.config.execution.compact_threshold,
628                        self.cfg.config.execution.compact_keep_messages,
629                    )
630            })
631            .await;
632
633        if should_compact {
634            let compact_result = self
635                .cfg
636                .tool_state
637                .compact(
638                    &self.cfg.client,
639                    self.cfg.config.execution.compact_keep_messages,
640                )
641                .await;
642
643            if let Ok(CompactResult::Compacted { .. }) = compact_result {
644                self.metrics.record_compaction();
645                let state_sections = collect_compaction_state(&self.cfg.tools).await;
646                if !state_sections.is_empty() {
647                    self.cfg
648                        .tool_state
649                        .with_session_mut(|session| {
650                            session.add_user_message(format!(
651                                "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
652                                state_sections.join("\n\n")
653                            ));
654                        })
655                        .await;
656                }
657            }
658        }
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    fn test_phase_transitions() {
668        assert!(matches!(Phase::StartRequest, Phase::StartRequest));
669        assert!(matches!(Phase::Done, Phase::Done));
670    }
671
672    #[test]
673    fn test_stream_poll_result_variants() {
674        let event = StreamPollResult::Event(Ok(AgentEvent::Text("test".into())));
675        assert!(matches!(event, StreamPollResult::Event(_)));
676
677        let cont = StreamPollResult::Continue;
678        assert!(matches!(cont, StreamPollResult::Continue));
679
680        let ended = StreamPollResult::StreamEnded;
681        assert!(matches!(ended, StreamPollResult::StreamEnded));
682    }
683}