claude_agent/agent/
streaming.rs

1//! Agent streaming execution with session-based context management.
2
3use std::path::Path;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Instant;
7
8use futures::{Stream, StreamExt, stream};
9use tokio::sync::RwLock;
10use tracing::{debug, warn};
11
12use super::common::BudgetContext;
13use super::events::{AgentEvent, AgentResult};
14use super::execution::extract_file_path;
15use super::executor::Agent;
16use super::request::RequestBuilder;
17use super::state_formatter::collect_compaction_state;
18use super::{AgentConfig, AgentMetrics, AgentState};
19use crate::budget::{BudgetTracker, TenantBudget};
20use crate::client::{RecoverableStream, StreamItem};
21use crate::context::PromptOrchestrator;
22use crate::hooks::{HookContext, HookEvent, HookInput, HookManager};
23use crate::session::ToolState;
24use crate::types::{
25    CompactResult, ContentBlock, PermissionDenial, StopReason, StreamEvent, ToolResultBlock,
26    ToolUseBlock, Usage, context_window,
27};
28use crate::{Client, ToolRegistry};
29
30type BoxedByteStream =
31    Pin<Box<dyn Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send>>;
32
33impl Agent {
34    pub async fn execute_stream(
35        &self,
36        prompt: &str,
37    ) -> crate::Result<impl Stream<Item = crate::Result<AgentEvent>> + Send> {
38        let timeout = self
39            .config
40            .execution
41            .timeout
42            .unwrap_or(std::time::Duration::from_secs(600));
43
44        if self.state.is_executing() {
45            self.state
46                .enqueue(prompt)
47                .await
48                .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
49        } else {
50            self.state
51                .with_session_mut(|session| {
52                    session.add_user_message(prompt);
53                })
54                .await;
55        }
56
57        let state = StreamState::new(
58            StreamStateConfig {
59                tool_state: self.state.clone(),
60                client: Arc::clone(&self.client),
61                config: Arc::clone(&self.config),
62                tools: Arc::clone(&self.tools),
63                hooks: Arc::clone(&self.hooks),
64                hook_context: self.hook_context(),
65                request_builder: RequestBuilder::new(&self.config, Arc::clone(&self.tools)),
66                orchestrator: self.orchestrator.clone(),
67                session_id: Arc::clone(&self.session_id),
68                budget_tracker: Arc::clone(&self.budget_tracker),
69                tenant_budget: self.tenant_budget.clone(),
70            },
71            timeout,
72        );
73
74        Ok(stream::unfold(state, |mut state| async move {
75            state.next_event().await.map(|event| (event, state))
76        }))
77    }
78}
79
80struct StreamStateConfig {
81    tool_state: ToolState,
82    client: Arc<Client>,
83    config: Arc<AgentConfig>,
84    tools: Arc<ToolRegistry>,
85    hooks: Arc<HookManager>,
86    hook_context: HookContext,
87    request_builder: RequestBuilder,
88    orchestrator: Option<Arc<RwLock<PromptOrchestrator>>>,
89    session_id: Arc<str>,
90    budget_tracker: Arc<BudgetTracker>,
91    tenant_budget: Option<Arc<TenantBudget>>,
92}
93
94enum StreamPollResult {
95    Event(crate::Result<AgentEvent>),
96    Continue,
97    StreamEnded,
98}
99
100enum Phase {
101    StartRequest,
102    Streaming(Box<StreamingPhase>),
103    StreamEnded { accumulated_usage: Usage },
104    ProcessingTools { tool_index: usize },
105    Done,
106}
107
108struct StreamingPhase {
109    stream: RecoverableStream<BoxedByteStream>,
110    accumulated_usage: Usage,
111}
112
113struct StreamState {
114    cfg: StreamStateConfig,
115    timeout: std::time::Duration,
116    chunk_timeout: std::time::Duration,
117    dynamic_rules: String,
118    metrics: AgentMetrics,
119    start_time: Instant,
120    last_chunk_time: Instant,
121    pending_tool_results: Vec<ToolResultBlock>,
122    pending_tool_uses: Vec<ToolUseBlock>,
123    final_text: String,
124    total_usage: Usage,
125    phase: Phase,
126}
127
128impl StreamState {
129    fn new(cfg: StreamStateConfig, timeout: std::time::Duration) -> Self {
130        let chunk_timeout = cfg.config.execution.chunk_timeout;
131        let now = Instant::now();
132        Self {
133            cfg,
134            timeout,
135            chunk_timeout,
136            dynamic_rules: String::new(),
137            metrics: AgentMetrics::default(),
138            start_time: now,
139            last_chunk_time: now,
140            pending_tool_results: Vec::new(),
141            pending_tool_uses: Vec::new(),
142            final_text: String::new(),
143            total_usage: Usage::default(),
144            phase: Phase::StartRequest,
145        }
146    }
147
148    /// Extract structured output from text if output_schema is configured.
149    ///
150    /// When structured outputs are enabled via `output_format`, the API returns
151    /// pure JSON in `response.content[0].text`. This method parses that JSON.
152    ///
153    /// Returns `None` if:
154    /// - `output_schema` is not configured
155    /// - The response is not valid JSON (e.g., `stop_reason: "refusal"` or `"max_tokens"`)
156    fn extract_structured_output(&self, text: &str) -> Option<serde_json::Value> {
157        // Only extract if output_schema is configured
158        self.cfg.config.prompt.output_schema.as_ref()?;
159
160        // With structured outputs enabled, API returns pure JSON directly
161        serde_json::from_str(text).ok()
162    }
163
164    async fn next_event(&mut self) -> Option<crate::Result<AgentEvent>> {
165        loop {
166            if matches!(self.phase, Phase::Done) {
167                return None;
168            }
169
170            if self.start_time.elapsed() > self.timeout {
171                self.phase = Phase::Done;
172                return Some(Err(crate::Error::Timeout(self.timeout)));
173            }
174
175            if let Some(event) = self.check_budget_exceeded() {
176                return Some(event);
177            }
178
179            match std::mem::replace(&mut self.phase, Phase::Done) {
180                Phase::StartRequest => {
181                    if let Some(result) = self.do_start_request().await {
182                        return Some(result);
183                    }
184                }
185                Phase::Streaming(mut streaming) => {
186                    match self
187                        .do_poll_stream(&mut streaming.stream, &mut streaming.accumulated_usage)
188                        .await
189                    {
190                        StreamPollResult::Event(event) => {
191                            self.phase = Phase::Streaming(streaming);
192                            return Some(event);
193                        }
194                        StreamPollResult::Continue => {
195                            self.phase = Phase::Streaming(streaming);
196                        }
197                        StreamPollResult::StreamEnded => {
198                            self.phase = Phase::StreamEnded {
199                                accumulated_usage: streaming.accumulated_usage,
200                            };
201                        }
202                    }
203                }
204                Phase::StreamEnded { accumulated_usage } => {
205                    if let Some(event) = self.do_handle_stream_end(accumulated_usage).await {
206                        return Some(event);
207                    }
208                }
209                Phase::ProcessingTools { tool_index } => {
210                    if let Some(result) = self.do_process_tool(tool_index).await {
211                        return Some(result);
212                    }
213                }
214                Phase::Done => return None,
215            }
216        }
217    }
218
219    fn check_budget_exceeded(&mut self) -> Option<crate::Result<AgentEvent>> {
220        let result = BudgetContext {
221            tracker: &self.cfg.budget_tracker,
222            tenant: self.cfg.tenant_budget.as_deref(),
223            config: &self.cfg.config.budget,
224        }
225        .check();
226
227        if let Err(e) = result {
228            self.phase = Phase::Done;
229            return Some(Err(e));
230        }
231
232        None
233    }
234
235    async fn do_start_request(&mut self) -> Option<crate::Result<AgentEvent>> {
236        self.metrics.iterations += 1;
237        if self.metrics.iterations > self.cfg.config.execution.max_iterations {
238            self.phase = Phase::Done;
239            self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
240
241            let messages = self
242                .cfg
243                .tool_state
244                .with_session(|session| session.to_api_messages())
245                .await;
246
247            let structured_output = self.extract_structured_output(&self.final_text);
248            return Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
249                text: self.final_text.clone(),
250                usage: self.total_usage,
251                tool_calls: self.metrics.tool_calls,
252                iterations: self.metrics.iterations - 1,
253                stop_reason: StopReason::MaxTokens,
254                state: AgentState::Completed,
255                metrics: self.metrics.clone(),
256                session_id: self.cfg.session_id.to_string(),
257                structured_output,
258                messages,
259                uuid: uuid::Uuid::new_v4().to_string(),
260            }))));
261        }
262
263        let messages = self
264            .cfg
265            .tool_state
266            .with_session(|session| {
267                session.to_api_messages_with_cache(self.cfg.config.cache.message_ttl_option())
268            })
269            .await;
270
271        let stream_request = self
272            .cfg
273            .request_builder
274            .build(messages, &self.dynamic_rules)
275            .with_stream();
276
277        let response = match self
278            .cfg
279            .client
280            .send_stream_with_auth_retry(stream_request)
281            .await
282        {
283            Ok(r) => r,
284            Err(e) => {
285                self.phase = Phase::Done;
286                return Some(Err(e));
287            }
288        };
289
290        self.metrics.record_api_call();
291
292        let boxed_stream: BoxedByteStream = Box::pin(response.bytes_stream());
293        self.phase = Phase::Streaming(Box::new(StreamingPhase {
294            stream: RecoverableStream::new(boxed_stream),
295            accumulated_usage: Usage::default(),
296        }));
297
298        None
299    }
300
301    async fn do_poll_stream(
302        &mut self,
303        stream: &mut RecoverableStream<BoxedByteStream>,
304        accumulated_usage: &mut Usage,
305    ) -> StreamPollResult {
306        let chunk_result = tokio::time::timeout(self.chunk_timeout, stream.next()).await;
307
308        match chunk_result {
309            Ok(Some(Ok(item))) => {
310                self.last_chunk_time = Instant::now();
311                self.handle_stream_item(item, accumulated_usage)
312            }
313            Ok(Some(Err(e))) => {
314                self.phase = Phase::Done;
315                StreamPollResult::Event(Err(e))
316            }
317            Ok(None) => StreamPollResult::StreamEnded,
318            Err(_) => {
319                self.phase = Phase::Done;
320                StreamPollResult::Event(Err(crate::Error::Stream(format!(
321                    "Chunk timeout after {:?} (no data received)",
322                    self.chunk_timeout
323                ))))
324            }
325        }
326    }
327
328    fn handle_stream_item(
329        &mut self,
330        item: StreamItem,
331        accumulated_usage: &mut Usage,
332    ) -> StreamPollResult {
333        match item {
334            StreamItem::Text(text) => {
335                self.final_text.push_str(&text);
336                StreamPollResult::Event(Ok(AgentEvent::Text(text)))
337            }
338            StreamItem::Thinking(thinking) => {
339                StreamPollResult::Event(Ok(AgentEvent::Thinking(thinking)))
340            }
341            StreamItem::Citation(_) => StreamPollResult::Continue,
342            StreamItem::ToolUseComplete(tool_use) => {
343                self.pending_tool_uses.push(tool_use);
344                StreamPollResult::Continue
345            }
346            StreamItem::Event(event) => self.handle_stream_event(event, accumulated_usage),
347        }
348    }
349
350    fn handle_stream_event(
351        &mut self,
352        event: StreamEvent,
353        accumulated_usage: &mut Usage,
354    ) -> StreamPollResult {
355        match event {
356            StreamEvent::MessageStart { message } => {
357                accumulated_usage.input_tokens = message.usage.input_tokens;
358                accumulated_usage.output_tokens = message.usage.output_tokens;
359                accumulated_usage.cache_creation_input_tokens =
360                    message.usage.cache_creation_input_tokens;
361                accumulated_usage.cache_read_input_tokens = message.usage.cache_read_input_tokens;
362                StreamPollResult::Continue
363            }
364            StreamEvent::ContentBlockStart { .. } => StreamPollResult::Continue,
365            StreamEvent::ContentBlockDelta { .. } => StreamPollResult::Continue,
366            StreamEvent::ContentBlockStop { .. } => StreamPollResult::Continue,
367            StreamEvent::MessageDelta { usage, .. } => {
368                accumulated_usage.output_tokens = usage.output_tokens;
369                StreamPollResult::Continue
370            }
371            StreamEvent::MessageStop => StreamPollResult::StreamEnded,
372            StreamEvent::Ping => StreamPollResult::Continue,
373            StreamEvent::Error { error } => {
374                self.phase = Phase::Done;
375                StreamPollResult::Event(Err(crate::Error::Stream(error.message)))
376            }
377        }
378    }
379
380    async fn do_handle_stream_end(
381        &mut self,
382        accumulated_usage: Usage,
383    ) -> Option<crate::Result<AgentEvent>> {
384        self.total_usage.input_tokens += accumulated_usage.input_tokens;
385        self.total_usage.output_tokens += accumulated_usage.output_tokens;
386        self.metrics.add_usage(
387            accumulated_usage.input_tokens,
388            accumulated_usage.output_tokens,
389        );
390        self.metrics
391            .record_model_usage(&self.cfg.config.model.primary, &accumulated_usage);
392
393        let cost = self
394            .cfg
395            .budget_tracker
396            .record(&self.cfg.config.model.primary, &accumulated_usage);
397        self.metrics.add_cost(cost);
398
399        if let Some(ref tenant_budget) = self.cfg.tenant_budget {
400            tenant_budget.record(&self.cfg.config.model.primary, &accumulated_usage);
401        }
402
403        self.cfg
404            .tool_state
405            .with_session_mut(|session| {
406                session.update_usage(&accumulated_usage);
407
408                let mut content = Vec::new();
409                if !self.final_text.is_empty() {
410                    content.push(ContentBlock::Text {
411                        text: self.final_text.clone(),
412                        citations: None,
413                        cache_control: None,
414                    });
415                }
416                for tool_use in &self.pending_tool_uses {
417                    content.push(ContentBlock::ToolUse(tool_use.clone()));
418                }
419                if !content.is_empty() {
420                    session.add_assistant_message(content, Some(accumulated_usage));
421                }
422            })
423            .await;
424
425        if self.pending_tool_uses.is_empty() {
426            self.phase = Phase::Done;
427            self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
428
429            let messages = self
430                .cfg
431                .tool_state
432                .with_session(|session| session.to_api_messages())
433                .await;
434
435            let structured_output = self.extract_structured_output(&self.final_text);
436            return Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
437                text: self.final_text.clone(),
438                usage: self.total_usage,
439                tool_calls: self.metrics.tool_calls,
440                iterations: self.metrics.iterations,
441                stop_reason: StopReason::EndTurn,
442                state: AgentState::Completed,
443                metrics: self.metrics.clone(),
444                session_id: self.cfg.session_id.to_string(),
445                structured_output,
446                messages,
447                uuid: uuid::Uuid::new_v4().to_string(),
448            }))));
449        }
450
451        self.phase = Phase::ProcessingTools { tool_index: 0 };
452        None
453    }
454
455    async fn do_process_tool(&mut self, tool_index: usize) -> Option<crate::Result<AgentEvent>> {
456        if tool_index >= self.pending_tool_uses.len() {
457            if !self.pending_tool_results.is_empty() {
458                self.finalize_tool_results().await;
459            }
460            self.final_text.clear();
461            self.pending_tool_uses.clear();
462            self.phase = Phase::StartRequest;
463            return None;
464        }
465
466        let tool_use = self.pending_tool_uses[tool_index].clone();
467        self.execute_tool(tool_use, tool_index).await
468    }
469
470    async fn execute_tool(
471        &mut self,
472        tool_use: ToolUseBlock,
473        tool_index: usize,
474    ) -> Option<crate::Result<AgentEvent>> {
475        let pre_input = HookInput::pre_tool_use(
476            &*self.cfg.session_id,
477            &tool_use.name,
478            tool_use.input.clone(),
479        );
480        let pre_output = match self
481            .cfg
482            .hooks
483            .execute(HookEvent::PreToolUse, pre_input, &self.cfg.hook_context)
484            .await
485        {
486            Ok(output) => output,
487            Err(e) => {
488                warn!(tool = %tool_use.name, error = %e, "PreToolUse hook failed");
489                crate::hooks::HookOutput::allow()
490            }
491        };
492
493        if !pre_output.continue_execution {
494            let reason = pre_output
495                .stop_reason
496                .clone()
497                .unwrap_or_else(|| "Blocked by hook".into());
498            debug!(tool = %tool_use.name, "Tool blocked by hook");
499
500            self.pending_tool_results
501                .push(ToolResultBlock::error(&tool_use.id, reason.clone()));
502            self.metrics.record_permission_denial(
503                PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
504                    .with_reason(reason.clone()),
505            );
506            self.phase = Phase::ProcessingTools {
507                tool_index: tool_index + 1,
508            };
509
510            return Some(Ok(AgentEvent::ToolBlocked {
511                id: tool_use.id,
512                name: tool_use.name,
513                reason,
514            }));
515        }
516
517        let actual_input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
518
519        let start = Instant::now();
520        let result = self
521            .cfg
522            .tools
523            .execute(&tool_use.name, actual_input.clone())
524            .await;
525        let duration_ms = start.elapsed().as_millis() as u64;
526
527        let (output, is_error) = match &result.output {
528            crate::types::ToolOutput::Success(s) => (s.clone(), false),
529            crate::types::ToolOutput::SuccessBlocks(blocks) => {
530                let text = blocks
531                    .iter()
532                    .filter_map(|b| match b {
533                        crate::types::ToolOutputBlock::Text { text } => Some(text.as_str()),
534                        _ => None,
535                    })
536                    .collect::<Vec<_>>()
537                    .join("\n");
538                (text, false)
539            }
540            crate::types::ToolOutput::Error(e) => (e.to_string(), true),
541            crate::types::ToolOutput::Empty => (String::new(), false),
542        };
543
544        self.metrics
545            .record_tool(&tool_use.id, &tool_use.name, duration_ms, is_error);
546
547        if let Some(ref inner_usage) = result.inner_usage {
548            self.cfg
549                .tool_state
550                .with_session_mut(|session| {
551                    session.update_usage(inner_usage);
552                })
553                .await;
554            self.total_usage.input_tokens += inner_usage.input_tokens;
555            self.total_usage.output_tokens += inner_usage.output_tokens;
556            self.metrics
557                .add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
558            let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
559            self.metrics.record_model_usage(inner_model, inner_usage);
560            let inner_cost = self.cfg.budget_tracker.record(inner_model, inner_usage);
561            self.metrics.add_cost(inner_cost);
562        }
563
564        if is_error {
565            let failure_input = HookInput::post_tool_use_failure(
566                &*self.cfg.session_id,
567                &tool_use.name,
568                result.error_message(),
569            );
570            if let Err(e) = self
571                .cfg
572                .hooks
573                .execute(
574                    HookEvent::PostToolUseFailure,
575                    failure_input,
576                    &self.cfg.hook_context,
577                )
578                .await
579            {
580                warn!(tool = %tool_use.name, error = %e, "PostToolUseFailure hook failed");
581            }
582        } else {
583            let post_input = HookInput::post_tool_use(
584                &*self.cfg.session_id,
585                &tool_use.name,
586                result.output.clone(),
587            );
588            if let Err(e) = self
589                .cfg
590                .hooks
591                .execute(HookEvent::PostToolUse, post_input, &self.cfg.hook_context)
592                .await
593            {
594                warn!(tool = %tool_use.name, error = %e, "PostToolUse hook failed");
595            }
596        }
597
598        if let Some(file_path) = extract_file_path(&tool_use.name, &actual_input)
599            && let Some(ref orchestrator) = self.cfg.orchestrator
600        {
601            let orch = orchestrator.read().await;
602            let path = Path::new(&file_path);
603            if orch.has_matching_rules(path).await {
604                let dynamic_ctx = orch.build_dynamic_context(Some(path)).await;
605                if !dynamic_ctx.is_empty() {
606                    self.dynamic_rules = dynamic_ctx;
607                }
608            }
609        }
610
611        self.pending_tool_results
612            .push(ToolResultBlock::from_tool_result(&tool_use.id, &result));
613        self.phase = Phase::ProcessingTools {
614            tool_index: tool_index + 1,
615        };
616
617        Some(Ok(AgentEvent::ToolComplete {
618            id: tool_use.id,
619            name: tool_use.name,
620            output,
621            is_error,
622            duration_ms,
623        }))
624    }
625
626    async fn finalize_tool_results(&mut self) {
627        let results = std::mem::take(&mut self.pending_tool_results);
628        let max_tokens = context_window::for_model(&self.cfg.config.model.primary);
629
630        self.cfg
631            .tool_state
632            .with_session_mut(|session| {
633                session.add_tool_results(results);
634            })
635            .await;
636
637        let should_compact = self
638            .cfg
639            .tool_state
640            .with_session(|session| {
641                self.cfg.config.execution.auto_compact
642                    && session.should_compact(
643                        max_tokens,
644                        self.cfg.config.execution.compact_threshold,
645                        self.cfg.config.execution.compact_keep_messages,
646                    )
647            })
648            .await;
649
650        if should_compact {
651            let compact_result = self
652                .cfg
653                .tool_state
654                .compact(
655                    &self.cfg.client,
656                    self.cfg.config.execution.compact_keep_messages,
657                )
658                .await;
659
660            if let Ok(CompactResult::Compacted { .. }) = compact_result {
661                self.metrics.record_compaction();
662                let state_sections = collect_compaction_state(&self.cfg.tools).await;
663                if !state_sections.is_empty() {
664                    self.cfg
665                        .tool_state
666                        .with_session_mut(|session| {
667                            session.add_user_message(format!(
668                                "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
669                                state_sections.join("\n\n")
670                            ));
671                        })
672                        .await;
673                }
674            }
675        }
676    }
677}
678
679#[cfg(test)]
680mod tests {
681    use super::*;
682
683    #[test]
684    fn test_phase_transitions() {
685        assert!(matches!(Phase::StartRequest, Phase::StartRequest));
686        assert!(matches!(Phase::Done, Phase::Done));
687    }
688
689    #[test]
690    fn test_stream_poll_result_variants() {
691        let event = StreamPollResult::Event(Ok(AgentEvent::Text("test".into())));
692        assert!(matches!(event, StreamPollResult::Event(_)));
693
694        let cont = StreamPollResult::Continue;
695        assert!(matches!(cont, StreamPollResult::Continue));
696
697        let ended = StreamPollResult::StreamEnded;
698        assert!(matches!(ended, StreamPollResult::StreamEnded));
699    }
700}