Skip to main content

claude_agent/agent/
streaming.rs

1//! Agent streaming execution with session-based context management.
2
3use std::pin::Pin;
4use std::sync::Arc;
5use std::time::Instant;
6
7use futures::{Stream, StreamExt, stream};
8use tokio::sync::RwLock;
9use tracing::{debug, warn};
10
11use super::common::{
12    BudgetContext, accumulate_inner_usage, accumulate_response_usage, handle_compaction,
13    run_post_tool_hooks, run_stop_hooks, try_activate_dynamic_rules,
14};
15use super::events::{AgentEvent, AgentResult};
16use super::executor::Agent;
17use super::request::RequestBuilder;
18use super::{AgentConfig, AgentMetrics};
19use crate::budget::{BudgetTracker, TenantBudget};
20use crate::client::{RecoverableStream, StreamItem};
21use crate::context::PromptOrchestrator;
22use crate::hooks::{HookContext, HookEvent, HookInput, HookManager};
23use crate::session::ToolState;
24use crate::types::{
25    ContentBlock, PermissionDenial, StopReason, StreamEvent, ToolResultBlock, ToolUseBlock, Usage,
26    context_window,
27};
28use crate::{Client, ToolRegistry};
29
30type BoxedByteStream =
31    Pin<Box<dyn Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send>>;
32
33impl Agent {
34    pub async fn execute_stream(
35        &self,
36        prompt: &str,
37    ) -> crate::Result<impl Stream<Item = crate::Result<AgentEvent>> + Send> {
38        let timeout = self
39            .config
40            .execution
41            .timeout
42            .unwrap_or(std::time::Duration::from_secs(600));
43
44        if self.state.is_executing() {
45            self.state
46                .enqueue(prompt)
47                .await
48                .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
49        }
50        let state = StreamState::new(
51            StreamStateConfig {
52                tool_state: self.state.clone(),
53                client: Arc::clone(&self.client),
54                config: Arc::clone(&self.config),
55                tools: Arc::clone(&self.tools),
56                hooks: Arc::clone(&self.hooks),
57                hook_context: self.hook_context(),
58                request_builder: RequestBuilder::new(&self.config, Arc::clone(&self.tools)),
59                orchestrator: self.orchestrator.clone(),
60                session_id: Arc::clone(&self.session_id),
61                budget_tracker: Arc::clone(&self.budget_tracker),
62                tenant_budget: self.tenant_budget.clone(),
63            },
64            timeout,
65            prompt.to_string(),
66        );
67
68        Ok(stream::unfold(state, |mut state| async move {
69            state.next_event().await.map(|event| (event, state))
70        }))
71    }
72}
73
74struct StreamStateConfig {
75    tool_state: ToolState,
76    client: Arc<Client>,
77    config: Arc<AgentConfig>,
78    tools: Arc<ToolRegistry>,
79    hooks: Arc<HookManager>,
80    hook_context: HookContext,
81    request_builder: RequestBuilder,
82    orchestrator: Option<Arc<RwLock<PromptOrchestrator>>>,
83    session_id: Arc<str>,
84    budget_tracker: Arc<BudgetTracker>,
85    tenant_budget: Option<Arc<TenantBudget>>,
86}
87
88enum StreamPollResult {
89    Event(crate::Result<AgentEvent>),
90    Continue,
91    StreamEnded,
92}
93
94enum Phase {
95    StartRequest,
96    Streaming(Box<StreamingPhase>),
97    StreamEnded { accumulated_usage: Usage },
98    ProcessingTools { tool_index: usize },
99    Done,
100}
101
102struct StreamingPhase {
103    stream: RecoverableStream<BoxedByteStream>,
104    accumulated_usage: Usage,
105}
106
107struct StreamState {
108    cfg: StreamStateConfig,
109    timeout: std::time::Duration,
110    chunk_timeout: std::time::Duration,
111    dynamic_rules: String,
112    metrics: AgentMetrics,
113    start_time: Instant,
114    last_chunk_time: Instant,
115    pending_tool_results: Vec<ToolResultBlock>,
116    pending_tool_uses: Vec<ToolUseBlock>,
117    final_text: String,
118    total_usage: Usage,
119    phase: Phase,
120    session_started: bool,
121    prompt_submitted: bool,
122    initial_prompt: Option<String>,
123}
124
125impl StreamState {
126    fn new(cfg: StreamStateConfig, timeout: std::time::Duration, prompt: String) -> Self {
127        let chunk_timeout = cfg.config.execution.chunk_timeout;
128        let now = Instant::now();
129        Self {
130            cfg,
131            timeout,
132            chunk_timeout,
133            dynamic_rules: String::new(),
134            metrics: AgentMetrics::default(),
135            start_time: now,
136            last_chunk_time: now,
137            pending_tool_results: Vec::new(),
138            pending_tool_uses: Vec::new(),
139            final_text: String::new(),
140            total_usage: Usage::default(),
141            phase: Phase::StartRequest,
142            session_started: false,
143            prompt_submitted: false,
144            initial_prompt: Some(prompt),
145        }
146    }
147
148    fn extract_structured_output(&self, text: &str) -> Option<serde_json::Value> {
149        super::common::extract_structured_output(
150            self.cfg.config.prompt.output_schema.as_ref(),
151            text,
152        )
153    }
154
155    fn build_result(
156        &self,
157        iterations: usize,
158        stop_reason: StopReason,
159        messages: Vec<crate::types::Message>,
160    ) -> AgentResult {
161        let structured_output = self.extract_structured_output(&self.final_text);
162        AgentResult::new(
163            self.final_text.clone(),
164            self.total_usage,
165            iterations,
166            stop_reason,
167            self.metrics.clone(),
168            self.cfg.session_id.to_string(),
169            structured_output,
170            messages,
171        )
172    }
173
174    async fn next_event(&mut self) -> Option<crate::Result<AgentEvent>> {
175        loop {
176            if matches!(self.phase, Phase::Done) {
177                return None;
178            }
179
180            if self.start_time.elapsed() > self.timeout {
181                self.phase = Phase::Done;
182                return Some(Err(crate::Error::Timeout(self.timeout)));
183            }
184
185            if let Some(event) = self.check_budget_exceeded() {
186                return Some(event);
187            }
188
189            match std::mem::replace(&mut self.phase, Phase::Done) {
190                Phase::StartRequest => {
191                    if let Some(result) = self.do_start_request().await {
192                        return Some(result);
193                    }
194                }
195                Phase::Streaming(mut streaming) => {
196                    match self
197                        .do_poll_stream(&mut streaming.stream, &mut streaming.accumulated_usage)
198                        .await
199                    {
200                        StreamPollResult::Event(event) => {
201                            self.phase = Phase::Streaming(streaming);
202                            return Some(event);
203                        }
204                        StreamPollResult::Continue => {
205                            self.phase = Phase::Streaming(streaming);
206                        }
207                        StreamPollResult::StreamEnded => {
208                            self.phase = Phase::StreamEnded {
209                                accumulated_usage: streaming.accumulated_usage,
210                            };
211                        }
212                    }
213                }
214                Phase::StreamEnded { accumulated_usage } => {
215                    if let Some(event) = self.do_handle_stream_end(accumulated_usage).await {
216                        return Some(event);
217                    }
218                }
219                Phase::ProcessingTools { tool_index } => {
220                    if let Some(result) = self.do_process_tool(tool_index).await {
221                        return Some(result);
222                    }
223                }
224                Phase::Done => return None,
225            }
226        }
227    }
228
229    fn check_budget_exceeded(&mut self) -> Option<crate::Result<AgentEvent>> {
230        let result = BudgetContext {
231            tracker: &self.cfg.budget_tracker,
232            tenant: self.cfg.tenant_budget.as_deref(),
233            config: &self.cfg.config.budget,
234        }
235        .check();
236
237        if let Err(e) = result {
238            self.phase = Phase::Done;
239            return Some(Err(e));
240        }
241
242        None
243    }
244
245    async fn do_start_request(&mut self) -> Option<crate::Result<AgentEvent>> {
246        if !self.session_started {
247            self.session_started = true;
248            let session_start_input = HookInput::session_start(&*self.cfg.session_id);
249            if let Err(e) = self
250                .cfg
251                .hooks
252                .execute(
253                    HookEvent::SessionStart,
254                    session_start_input,
255                    &self.cfg.hook_context,
256                )
257                .await
258            {
259                warn!(error = %e, "SessionStart hook failed");
260            }
261        }
262
263        if !self.prompt_submitted {
264            if let Some(prompt) = self.initial_prompt.take() {
265                let prompt_input = HookInput::user_prompt_submit(&*self.cfg.session_id, &prompt);
266                let prompt_output = match self
267                    .cfg
268                    .hooks
269                    .execute(
270                        HookEvent::UserPromptSubmit,
271                        prompt_input,
272                        &self.cfg.hook_context,
273                    )
274                    .await
275                {
276                    Ok(output) => output,
277                    Err(e) => {
278                        self.phase = Phase::Done;
279                        return Some(Err(e));
280                    }
281                };
282
283                if !prompt_output.continue_execution {
284                    self.phase = Phase::Done;
285                    return Some(Err(crate::Error::Permission(
286                        prompt_output
287                            .stop_reason
288                            .unwrap_or_else(|| "Blocked by hook".into()),
289                    )));
290                }
291
292                self.cfg
293                    .tool_state
294                    .with_session_mut(|session| {
295                        session.add_user_message(&prompt);
296                    })
297                    .await;
298            }
299            self.prompt_submitted = true;
300        }
301
302        self.metrics.iterations += 1;
303        if self.metrics.iterations > self.cfg.config.execution.max_iterations {
304            self.phase = Phase::Done;
305            self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
306
307            run_stop_hooks(
308                &self.cfg.hooks,
309                &self.cfg.hook_context,
310                &self.cfg.session_id,
311            )
312            .await;
313
314            let messages = self
315                .cfg
316                .tool_state
317                .with_session(|session| session.to_api_messages())
318                .await;
319            let result =
320                self.build_result(self.metrics.iterations - 1, StopReason::MaxTokens, messages);
321            return Some(Ok(AgentEvent::Complete(Box::new(result))));
322        }
323
324        let budget_ctx = BudgetContext {
325            tracker: &self.cfg.budget_tracker,
326            tenant: self.cfg.tenant_budget.as_deref(),
327            config: &self.cfg.config.budget,
328        };
329        if let Some(fallback) = budget_ctx.fallback_model() {
330            self.cfg.request_builder.set_model(fallback);
331        }
332
333        let messages = self
334            .cfg
335            .tool_state
336            .with_session(|session| {
337                session.to_api_messages_with_cache(self.cfg.config.cache.message_ttl_option())
338            })
339            .await;
340
341        let stream_request = self
342            .cfg
343            .request_builder
344            .build(messages, &self.dynamic_rules)
345            .stream();
346
347        let response = match self
348            .cfg
349            .client
350            .send_stream_with_auth_retry(stream_request)
351            .await
352        {
353            Ok(r) => r,
354            Err(e) => {
355                self.phase = Phase::Done;
356                return Some(Err(e));
357            }
358        };
359
360        self.metrics.record_api_call();
361
362        let boxed_stream: BoxedByteStream = Box::pin(response.bytes_stream());
363        self.phase = Phase::Streaming(Box::new(StreamingPhase {
364            stream: RecoverableStream::new(boxed_stream),
365            accumulated_usage: Usage::default(),
366        }));
367
368        None
369    }
370
371    async fn do_poll_stream(
372        &mut self,
373        stream: &mut RecoverableStream<BoxedByteStream>,
374        accumulated_usage: &mut Usage,
375    ) -> StreamPollResult {
376        let chunk_result = tokio::time::timeout(self.chunk_timeout, stream.next()).await;
377
378        match chunk_result {
379            Ok(Some(Ok(item))) => {
380                self.last_chunk_time = Instant::now();
381                self.handle_stream_item(item, accumulated_usage)
382            }
383            Ok(Some(Err(e))) => {
384                self.phase = Phase::Done;
385                StreamPollResult::Event(Err(e))
386            }
387            Ok(None) => StreamPollResult::StreamEnded,
388            Err(_) => {
389                self.phase = Phase::Done;
390                StreamPollResult::Event(Err(crate::Error::Stream(format!(
391                    "Chunk timeout after {:?} (no data received)",
392                    self.chunk_timeout
393                ))))
394            }
395        }
396    }
397
398    fn handle_stream_item(
399        &mut self,
400        item: StreamItem,
401        accumulated_usage: &mut Usage,
402    ) -> StreamPollResult {
403        match item {
404            StreamItem::Text(text) => {
405                self.final_text.push_str(&text);
406                StreamPollResult::Event(Ok(AgentEvent::Text(text)))
407            }
408            StreamItem::Thinking(thinking) => {
409                StreamPollResult::Event(Ok(AgentEvent::Thinking(thinking)))
410            }
411            StreamItem::Citation(_) => StreamPollResult::Continue,
412            StreamItem::ToolUseComplete(tool_use) => {
413                self.pending_tool_uses.push(tool_use);
414                StreamPollResult::Continue
415            }
416            StreamItem::Event(event) => self.handle_stream_event(event, accumulated_usage),
417        }
418    }
419
420    fn handle_stream_event(
421        &mut self,
422        event: StreamEvent,
423        accumulated_usage: &mut Usage,
424    ) -> StreamPollResult {
425        match event {
426            StreamEvent::MessageStart { message } => {
427                accumulated_usage.input_tokens = message.usage.input_tokens;
428                accumulated_usage.output_tokens = message.usage.output_tokens;
429                accumulated_usage.cache_creation_input_tokens =
430                    message.usage.cache_creation_input_tokens;
431                accumulated_usage.cache_read_input_tokens = message.usage.cache_read_input_tokens;
432                StreamPollResult::Continue
433            }
434            StreamEvent::ContentBlockStart { .. } => StreamPollResult::Continue,
435            StreamEvent::ContentBlockDelta { .. } => StreamPollResult::Continue,
436            StreamEvent::ContentBlockStop { .. } => StreamPollResult::Continue,
437            StreamEvent::MessageDelta { usage, .. } => {
438                accumulated_usage.output_tokens = usage.output_tokens;
439                StreamPollResult::Continue
440            }
441            StreamEvent::MessageStop => StreamPollResult::StreamEnded,
442            StreamEvent::Ping => StreamPollResult::Continue,
443            StreamEvent::Error { error } => {
444                self.phase = Phase::Done;
445                StreamPollResult::Event(Err(crate::Error::Stream(error.message)))
446            }
447        }
448    }
449
450    async fn do_handle_stream_end(
451        &mut self,
452        accumulated_usage: Usage,
453    ) -> Option<crate::Result<AgentEvent>> {
454        self.cfg
455            .tool_state
456            .with_session_mut(|session| {
457                session.update_usage(&accumulated_usage);
458            })
459            .await;
460
461        accumulate_response_usage(
462            &mut self.total_usage,
463            &mut self.metrics,
464            &self.cfg.budget_tracker,
465            self.cfg.tenant_budget.as_deref(),
466            &self.cfg.config.model.primary,
467            &accumulated_usage,
468        );
469
470        self.cfg
471            .tool_state
472            .with_session_mut(|session| {
473                let text_count = if self.final_text.is_empty() { 0 } else { 1 };
474                let mut content = Vec::with_capacity(text_count + self.pending_tool_uses.len());
475                if !self.final_text.is_empty() {
476                    content.push(ContentBlock::Text {
477                        text: self.final_text.clone(),
478                        citations: None,
479                        cache_control: None,
480                    });
481                }
482                for tool_use in &self.pending_tool_uses {
483                    content.push(ContentBlock::ToolUse(tool_use.clone()));
484                }
485                if !content.is_empty() {
486                    session.add_assistant_message(content, Some(accumulated_usage));
487                }
488            })
489            .await;
490
491        if self.pending_tool_uses.is_empty() {
492            self.phase = Phase::Done;
493            self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
494
495            run_stop_hooks(
496                &self.cfg.hooks,
497                &self.cfg.hook_context,
498                &self.cfg.session_id,
499            )
500            .await;
501
502            let messages = self
503                .cfg
504                .tool_state
505                .with_session(|session| session.to_api_messages())
506                .await;
507            let result = self.build_result(self.metrics.iterations, StopReason::EndTurn, messages);
508            return Some(Ok(AgentEvent::Complete(Box::new(result))));
509        }
510
511        self.phase = Phase::ProcessingTools { tool_index: 0 };
512        None
513    }
514
515    async fn do_process_tool(&mut self, tool_index: usize) -> Option<crate::Result<AgentEvent>> {
516        if tool_index >= self.pending_tool_uses.len() {
517            if !self.pending_tool_results.is_empty() {
518                self.finalize_tool_results().await;
519            }
520            self.final_text.clear();
521            self.pending_tool_uses.clear();
522            self.phase = Phase::StartRequest;
523            return None;
524        }
525
526        let tool_use = self.pending_tool_uses[tool_index].clone();
527        self.execute_tool(tool_use, tool_index).await
528    }
529
530    async fn execute_tool(
531        &mut self,
532        tool_use: ToolUseBlock,
533        tool_index: usize,
534    ) -> Option<crate::Result<AgentEvent>> {
535        let pre_input = HookInput::pre_tool_use(
536            &*self.cfg.session_id,
537            &tool_use.name,
538            tool_use.input.clone(),
539        );
540        let pre_output = match self
541            .cfg
542            .hooks
543            .execute(HookEvent::PreToolUse, pre_input, &self.cfg.hook_context)
544            .await
545        {
546            Ok(output) => output,
547            Err(e) => {
548                self.phase = Phase::Done;
549                return Some(Err(e));
550            }
551        };
552
553        if !pre_output.continue_execution {
554            let reason = pre_output
555                .stop_reason
556                .clone()
557                .unwrap_or_else(|| "Blocked by hook".into());
558            debug!(tool = %tool_use.name, "Tool blocked by hook");
559
560            self.pending_tool_results
561                .push(ToolResultBlock::error(&tool_use.id, reason.clone()));
562            self.metrics.record_permission_denial(
563                PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
564                    .reason(reason.clone()),
565            );
566            self.phase = Phase::ProcessingTools {
567                tool_index: tool_index + 1,
568            };
569
570            return Some(Ok(AgentEvent::ToolBlocked {
571                id: tool_use.id,
572                name: tool_use.name,
573                reason,
574            }));
575        }
576
577        let actual_input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
578
579        let start = Instant::now();
580        let result = self
581            .cfg
582            .tools
583            .execute(&tool_use.name, actual_input.clone())
584            .await;
585        let duration_ms = start.elapsed().as_millis() as u64;
586
587        let (output, is_error) = match &result.output {
588            crate::types::ToolOutput::Success(s) => (s.clone(), false),
589            crate::types::ToolOutput::SuccessBlocks(blocks) => {
590                let text = blocks
591                    .iter()
592                    .filter_map(|b| match b {
593                        crate::types::ToolOutputBlock::Text { text } => Some(text.as_str()),
594                        _ => None,
595                    })
596                    .collect::<Vec<_>>()
597                    .join("\n");
598                (text, false)
599            }
600            crate::types::ToolOutput::Error(e) => (e.to_string(), true),
601            crate::types::ToolOutput::Empty => (String::new(), false),
602        };
603
604        self.metrics
605            .record_tool(&tool_use.id, &tool_use.name, duration_ms, is_error);
606
607        accumulate_inner_usage(
608            &self.cfg.tool_state,
609            &mut self.total_usage,
610            &mut self.metrics,
611            &self.cfg.budget_tracker,
612            &result,
613            &tool_use.name,
614        )
615        .await;
616
617        run_post_tool_hooks(
618            &self.cfg.hooks,
619            &self.cfg.hook_context,
620            &self.cfg.session_id,
621            &tool_use.name,
622            is_error,
623            &result,
624        )
625        .await;
626
627        try_activate_dynamic_rules(
628            &tool_use.name,
629            &actual_input,
630            &self.cfg.orchestrator,
631            &mut self.dynamic_rules,
632        )
633        .await;
634
635        self.pending_tool_results
636            .push(ToolResultBlock::from_tool_result(&tool_use.id, &result));
637        self.phase = Phase::ProcessingTools {
638            tool_index: tool_index + 1,
639        };
640
641        Some(Ok(AgentEvent::ToolComplete {
642            id: tool_use.id,
643            name: tool_use.name,
644            output,
645            is_error,
646            duration_ms,
647        }))
648    }
649
650    async fn finalize_tool_results(&mut self) {
651        let results = std::mem::take(&mut self.pending_tool_results);
652        let max_tokens = context_window::for_model(&self.cfg.config.model.primary);
653
654        self.cfg
655            .tool_state
656            .with_session_mut(|session| {
657                session.add_tool_results(results);
658            })
659            .await;
660
661        handle_compaction(
662            &self.cfg.tool_state,
663            &self.cfg.client,
664            &self.cfg.tools,
665            &self.cfg.hooks,
666            &self.cfg.hook_context,
667            &self.cfg.session_id,
668            &self.cfg.config.execution,
669            max_tokens,
670            &mut self.metrics,
671        )
672        .await;
673    }
674}
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679
680    #[test]
681    fn test_phase_transitions() {
682        assert!(matches!(Phase::StartRequest, Phase::StartRequest));
683        assert!(matches!(Phase::Done, Phase::Done));
684    }
685
686    #[test]
687    fn test_stream_poll_result_variants() {
688        let event = StreamPollResult::Event(Ok(AgentEvent::Text("test".into())));
689        assert!(matches!(event, StreamPollResult::Event(_)));
690
691        let cont = StreamPollResult::Continue;
692        assert!(matches!(cont, StreamPollResult::Continue));
693
694        let ended = StreamPollResult::StreamEnded;
695        assert!(matches!(ended, StreamPollResult::StreamEnded));
696    }
697}