1use std::pin::Pin;
4use std::sync::Arc;
5use std::time::Instant;
6
7use futures::{Stream, StreamExt, stream};
8use tokio::sync::RwLock;
9use tracing::{debug, warn};
10
11use super::common::{
12 BudgetContext, accumulate_inner_usage, accumulate_response_usage, handle_compaction,
13 run_post_tool_hooks, run_stop_hooks, try_activate_dynamic_rules,
14};
15use super::events::{AgentEvent, AgentResult};
16use super::executor::Agent;
17use super::request::RequestBuilder;
18use super::{AgentConfig, AgentMetrics};
19use crate::budget::{BudgetTracker, TenantBudget};
20use crate::client::{RecoverableStream, StreamItem};
21use crate::context::PromptOrchestrator;
22use crate::hooks::{HookContext, HookEvent, HookInput, HookManager};
23use crate::session::ToolState;
24use crate::types::{
25 ContentBlock, PermissionDenial, StopReason, StreamEvent, ToolResultBlock, ToolUseBlock, Usage,
26 context_window,
27};
28use crate::{Client, ToolRegistry};
29
30type BoxedByteStream =
31 Pin<Box<dyn Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send>>;
32
33impl Agent {
34 pub async fn execute_stream(
35 &self,
36 prompt: &str,
37 ) -> crate::Result<impl Stream<Item = crate::Result<AgentEvent>> + Send> {
38 let timeout = self
39 .config
40 .execution
41 .timeout
42 .unwrap_or(std::time::Duration::from_secs(600));
43
44 if self.state.is_executing() {
45 self.state
46 .enqueue(prompt)
47 .await
48 .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
49 }
50 let state = StreamState::new(
51 StreamStateConfig {
52 tool_state: self.state.clone(),
53 client: Arc::clone(&self.client),
54 config: Arc::clone(&self.config),
55 tools: Arc::clone(&self.tools),
56 hooks: Arc::clone(&self.hooks),
57 hook_context: self.hook_context(),
58 request_builder: RequestBuilder::new(&self.config, Arc::clone(&self.tools)),
59 orchestrator: self.orchestrator.clone(),
60 session_id: Arc::clone(&self.session_id),
61 budget_tracker: Arc::clone(&self.budget_tracker),
62 tenant_budget: self.tenant_budget.clone(),
63 },
64 timeout,
65 prompt.to_string(),
66 );
67
68 Ok(stream::unfold(state, |mut state| async move {
69 state.next_event().await.map(|event| (event, state))
70 }))
71 }
72}
73
74struct StreamStateConfig {
75 tool_state: ToolState,
76 client: Arc<Client>,
77 config: Arc<AgentConfig>,
78 tools: Arc<ToolRegistry>,
79 hooks: Arc<HookManager>,
80 hook_context: HookContext,
81 request_builder: RequestBuilder,
82 orchestrator: Option<Arc<RwLock<PromptOrchestrator>>>,
83 session_id: Arc<str>,
84 budget_tracker: Arc<BudgetTracker>,
85 tenant_budget: Option<Arc<TenantBudget>>,
86}
87
88enum StreamPollResult {
89 Event(crate::Result<AgentEvent>),
90 Continue,
91 StreamEnded,
92}
93
94enum Phase {
95 StartRequest,
96 Streaming(Box<StreamingPhase>),
97 StreamEnded { accumulated_usage: Usage },
98 ProcessingTools { tool_index: usize },
99 Done,
100}
101
102struct StreamingPhase {
103 stream: RecoverableStream<BoxedByteStream>,
104 accumulated_usage: Usage,
105}
106
107struct StreamState {
108 cfg: StreamStateConfig,
109 timeout: std::time::Duration,
110 chunk_timeout: std::time::Duration,
111 dynamic_rules: String,
112 metrics: AgentMetrics,
113 start_time: Instant,
114 last_chunk_time: Instant,
115 pending_tool_results: Vec<ToolResultBlock>,
116 pending_tool_uses: Vec<ToolUseBlock>,
117 final_text: String,
118 total_usage: Usage,
119 phase: Phase,
120 session_started: bool,
121 prompt_submitted: bool,
122 initial_prompt: Option<String>,
123}
124
125impl StreamState {
126 fn new(cfg: StreamStateConfig, timeout: std::time::Duration, prompt: String) -> Self {
127 let chunk_timeout = cfg.config.execution.chunk_timeout;
128 let now = Instant::now();
129 Self {
130 cfg,
131 timeout,
132 chunk_timeout,
133 dynamic_rules: String::new(),
134 metrics: AgentMetrics::default(),
135 start_time: now,
136 last_chunk_time: now,
137 pending_tool_results: Vec::new(),
138 pending_tool_uses: Vec::new(),
139 final_text: String::new(),
140 total_usage: Usage::default(),
141 phase: Phase::StartRequest,
142 session_started: false,
143 prompt_submitted: false,
144 initial_prompt: Some(prompt),
145 }
146 }
147
148 fn extract_structured_output(&self, text: &str) -> Option<serde_json::Value> {
149 super::common::extract_structured_output(
150 self.cfg.config.prompt.output_schema.as_ref(),
151 text,
152 )
153 }
154
155 fn build_result(
156 &self,
157 iterations: usize,
158 stop_reason: StopReason,
159 messages: Vec<crate::types::Message>,
160 ) -> AgentResult {
161 let structured_output = self.extract_structured_output(&self.final_text);
162 AgentResult::new(
163 self.final_text.clone(),
164 self.total_usage,
165 iterations,
166 stop_reason,
167 self.metrics.clone(),
168 self.cfg.session_id.to_string(),
169 structured_output,
170 messages,
171 )
172 }
173
174 async fn next_event(&mut self) -> Option<crate::Result<AgentEvent>> {
175 loop {
176 if matches!(self.phase, Phase::Done) {
177 return None;
178 }
179
180 if self.start_time.elapsed() > self.timeout {
181 self.phase = Phase::Done;
182 return Some(Err(crate::Error::Timeout(self.timeout)));
183 }
184
185 if let Some(event) = self.check_budget_exceeded() {
186 return Some(event);
187 }
188
189 match std::mem::replace(&mut self.phase, Phase::Done) {
190 Phase::StartRequest => {
191 if let Some(result) = self.do_start_request().await {
192 return Some(result);
193 }
194 }
195 Phase::Streaming(mut streaming) => {
196 match self
197 .do_poll_stream(&mut streaming.stream, &mut streaming.accumulated_usage)
198 .await
199 {
200 StreamPollResult::Event(event) => {
201 self.phase = Phase::Streaming(streaming);
202 return Some(event);
203 }
204 StreamPollResult::Continue => {
205 self.phase = Phase::Streaming(streaming);
206 }
207 StreamPollResult::StreamEnded => {
208 self.phase = Phase::StreamEnded {
209 accumulated_usage: streaming.accumulated_usage,
210 };
211 }
212 }
213 }
214 Phase::StreamEnded { accumulated_usage } => {
215 if let Some(event) = self.do_handle_stream_end(accumulated_usage).await {
216 return Some(event);
217 }
218 }
219 Phase::ProcessingTools { tool_index } => {
220 if let Some(result) = self.do_process_tool(tool_index).await {
221 return Some(result);
222 }
223 }
224 Phase::Done => return None,
225 }
226 }
227 }
228
229 fn check_budget_exceeded(&mut self) -> Option<crate::Result<AgentEvent>> {
230 let result = BudgetContext {
231 tracker: &self.cfg.budget_tracker,
232 tenant: self.cfg.tenant_budget.as_deref(),
233 config: &self.cfg.config.budget,
234 }
235 .check();
236
237 if let Err(e) = result {
238 self.phase = Phase::Done;
239 return Some(Err(e));
240 }
241
242 None
243 }
244
245 async fn do_start_request(&mut self) -> Option<crate::Result<AgentEvent>> {
246 if !self.session_started {
247 self.session_started = true;
248 let session_start_input = HookInput::session_start(&*self.cfg.session_id);
249 if let Err(e) = self
250 .cfg
251 .hooks
252 .execute(
253 HookEvent::SessionStart,
254 session_start_input,
255 &self.cfg.hook_context,
256 )
257 .await
258 {
259 warn!(error = %e, "SessionStart hook failed");
260 }
261 }
262
263 if !self.prompt_submitted {
264 if let Some(prompt) = self.initial_prompt.take() {
265 let prompt_input = HookInput::user_prompt_submit(&*self.cfg.session_id, &prompt);
266 let prompt_output = match self
267 .cfg
268 .hooks
269 .execute(
270 HookEvent::UserPromptSubmit,
271 prompt_input,
272 &self.cfg.hook_context,
273 )
274 .await
275 {
276 Ok(output) => output,
277 Err(e) => {
278 self.phase = Phase::Done;
279 return Some(Err(e));
280 }
281 };
282
283 if !prompt_output.continue_execution {
284 self.phase = Phase::Done;
285 return Some(Err(crate::Error::Permission(
286 prompt_output
287 .stop_reason
288 .unwrap_or_else(|| "Blocked by hook".into()),
289 )));
290 }
291
292 self.cfg
293 .tool_state
294 .with_session_mut(|session| {
295 session.add_user_message(&prompt);
296 })
297 .await;
298 }
299 self.prompt_submitted = true;
300 }
301
302 self.metrics.iterations += 1;
303 if self.metrics.iterations > self.cfg.config.execution.max_iterations {
304 self.phase = Phase::Done;
305 self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
306
307 run_stop_hooks(
308 &self.cfg.hooks,
309 &self.cfg.hook_context,
310 &self.cfg.session_id,
311 )
312 .await;
313
314 let messages = self
315 .cfg
316 .tool_state
317 .with_session(|session| session.to_api_messages())
318 .await;
319 let result =
320 self.build_result(self.metrics.iterations - 1, StopReason::MaxTokens, messages);
321 return Some(Ok(AgentEvent::Complete(Box::new(result))));
322 }
323
324 let budget_ctx = BudgetContext {
325 tracker: &self.cfg.budget_tracker,
326 tenant: self.cfg.tenant_budget.as_deref(),
327 config: &self.cfg.config.budget,
328 };
329 if let Some(fallback) = budget_ctx.fallback_model() {
330 self.cfg.request_builder.set_model(fallback);
331 }
332
333 let messages = self
334 .cfg
335 .tool_state
336 .with_session(|session| {
337 session.to_api_messages_with_cache(self.cfg.config.cache.message_ttl_option())
338 })
339 .await;
340
341 let stream_request = self
342 .cfg
343 .request_builder
344 .build(messages, &self.dynamic_rules)
345 .stream();
346
347 let response = match self
348 .cfg
349 .client
350 .send_stream_with_auth_retry(stream_request)
351 .await
352 {
353 Ok(r) => r,
354 Err(e) => {
355 self.phase = Phase::Done;
356 return Some(Err(e));
357 }
358 };
359
360 self.metrics.record_api_call();
361
362 let boxed_stream: BoxedByteStream = Box::pin(response.bytes_stream());
363 self.phase = Phase::Streaming(Box::new(StreamingPhase {
364 stream: RecoverableStream::new(boxed_stream),
365 accumulated_usage: Usage::default(),
366 }));
367
368 None
369 }
370
371 async fn do_poll_stream(
372 &mut self,
373 stream: &mut RecoverableStream<BoxedByteStream>,
374 accumulated_usage: &mut Usage,
375 ) -> StreamPollResult {
376 let chunk_result = tokio::time::timeout(self.chunk_timeout, stream.next()).await;
377
378 match chunk_result {
379 Ok(Some(Ok(item))) => {
380 self.last_chunk_time = Instant::now();
381 self.handle_stream_item(item, accumulated_usage)
382 }
383 Ok(Some(Err(e))) => {
384 self.phase = Phase::Done;
385 StreamPollResult::Event(Err(e))
386 }
387 Ok(None) => StreamPollResult::StreamEnded,
388 Err(_) => {
389 self.phase = Phase::Done;
390 StreamPollResult::Event(Err(crate::Error::Stream(format!(
391 "Chunk timeout after {:?} (no data received)",
392 self.chunk_timeout
393 ))))
394 }
395 }
396 }
397
398 fn handle_stream_item(
399 &mut self,
400 item: StreamItem,
401 accumulated_usage: &mut Usage,
402 ) -> StreamPollResult {
403 match item {
404 StreamItem::Text(text) => {
405 self.final_text.push_str(&text);
406 StreamPollResult::Event(Ok(AgentEvent::Text(text)))
407 }
408 StreamItem::Thinking(thinking) => {
409 StreamPollResult::Event(Ok(AgentEvent::Thinking(thinking)))
410 }
411 StreamItem::Citation(_) => StreamPollResult::Continue,
412 StreamItem::ToolUseComplete(tool_use) => {
413 self.pending_tool_uses.push(tool_use);
414 StreamPollResult::Continue
415 }
416 StreamItem::Event(event) => self.handle_stream_event(event, accumulated_usage),
417 }
418 }
419
420 fn handle_stream_event(
421 &mut self,
422 event: StreamEvent,
423 accumulated_usage: &mut Usage,
424 ) -> StreamPollResult {
425 match event {
426 StreamEvent::MessageStart { message } => {
427 accumulated_usage.input_tokens = message.usage.input_tokens;
428 accumulated_usage.output_tokens = message.usage.output_tokens;
429 accumulated_usage.cache_creation_input_tokens =
430 message.usage.cache_creation_input_tokens;
431 accumulated_usage.cache_read_input_tokens = message.usage.cache_read_input_tokens;
432 StreamPollResult::Continue
433 }
434 StreamEvent::ContentBlockStart { .. } => StreamPollResult::Continue,
435 StreamEvent::ContentBlockDelta { .. } => StreamPollResult::Continue,
436 StreamEvent::ContentBlockStop { .. } => StreamPollResult::Continue,
437 StreamEvent::MessageDelta { usage, .. } => {
438 accumulated_usage.output_tokens = usage.output_tokens;
439 StreamPollResult::Continue
440 }
441 StreamEvent::MessageStop => StreamPollResult::StreamEnded,
442 StreamEvent::Ping => StreamPollResult::Continue,
443 StreamEvent::Error { error } => {
444 self.phase = Phase::Done;
445 StreamPollResult::Event(Err(crate::Error::Stream(error.message)))
446 }
447 }
448 }
449
450 async fn do_handle_stream_end(
451 &mut self,
452 accumulated_usage: Usage,
453 ) -> Option<crate::Result<AgentEvent>> {
454 self.cfg
455 .tool_state
456 .with_session_mut(|session| {
457 session.update_usage(&accumulated_usage);
458 })
459 .await;
460
461 accumulate_response_usage(
462 &mut self.total_usage,
463 &mut self.metrics,
464 &self.cfg.budget_tracker,
465 self.cfg.tenant_budget.as_deref(),
466 &self.cfg.config.model.primary,
467 &accumulated_usage,
468 );
469
470 self.cfg
471 .tool_state
472 .with_session_mut(|session| {
473 let text_count = if self.final_text.is_empty() { 0 } else { 1 };
474 let mut content = Vec::with_capacity(text_count + self.pending_tool_uses.len());
475 if !self.final_text.is_empty() {
476 content.push(ContentBlock::Text {
477 text: self.final_text.clone(),
478 citations: None,
479 cache_control: None,
480 });
481 }
482 for tool_use in &self.pending_tool_uses {
483 content.push(ContentBlock::ToolUse(tool_use.clone()));
484 }
485 if !content.is_empty() {
486 session.add_assistant_message(content, Some(accumulated_usage));
487 }
488 })
489 .await;
490
491 if self.pending_tool_uses.is_empty() {
492 self.phase = Phase::Done;
493 self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
494
495 run_stop_hooks(
496 &self.cfg.hooks,
497 &self.cfg.hook_context,
498 &self.cfg.session_id,
499 )
500 .await;
501
502 let messages = self
503 .cfg
504 .tool_state
505 .with_session(|session| session.to_api_messages())
506 .await;
507 let result = self.build_result(self.metrics.iterations, StopReason::EndTurn, messages);
508 return Some(Ok(AgentEvent::Complete(Box::new(result))));
509 }
510
511 self.phase = Phase::ProcessingTools { tool_index: 0 };
512 None
513 }
514
515 async fn do_process_tool(&mut self, tool_index: usize) -> Option<crate::Result<AgentEvent>> {
516 if tool_index >= self.pending_tool_uses.len() {
517 if !self.pending_tool_results.is_empty() {
518 self.finalize_tool_results().await;
519 }
520 self.final_text.clear();
521 self.pending_tool_uses.clear();
522 self.phase = Phase::StartRequest;
523 return None;
524 }
525
526 let tool_use = self.pending_tool_uses[tool_index].clone();
527 self.execute_tool(tool_use, tool_index).await
528 }
529
530 async fn execute_tool(
531 &mut self,
532 tool_use: ToolUseBlock,
533 tool_index: usize,
534 ) -> Option<crate::Result<AgentEvent>> {
535 let pre_input = HookInput::pre_tool_use(
536 &*self.cfg.session_id,
537 &tool_use.name,
538 tool_use.input.clone(),
539 );
540 let pre_output = match self
541 .cfg
542 .hooks
543 .execute(HookEvent::PreToolUse, pre_input, &self.cfg.hook_context)
544 .await
545 {
546 Ok(output) => output,
547 Err(e) => {
548 self.phase = Phase::Done;
549 return Some(Err(e));
550 }
551 };
552
553 if !pre_output.continue_execution {
554 let reason = pre_output
555 .stop_reason
556 .clone()
557 .unwrap_or_else(|| "Blocked by hook".into());
558 debug!(tool = %tool_use.name, "Tool blocked by hook");
559
560 self.pending_tool_results
561 .push(ToolResultBlock::error(&tool_use.id, reason.clone()));
562 self.metrics.record_permission_denial(
563 PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
564 .reason(reason.clone()),
565 );
566 self.phase = Phase::ProcessingTools {
567 tool_index: tool_index + 1,
568 };
569
570 return Some(Ok(AgentEvent::ToolBlocked {
571 id: tool_use.id,
572 name: tool_use.name,
573 reason,
574 }));
575 }
576
577 let actual_input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
578
579 let start = Instant::now();
580 let result = self
581 .cfg
582 .tools
583 .execute(&tool_use.name, actual_input.clone())
584 .await;
585 let duration_ms = start.elapsed().as_millis() as u64;
586
587 let (output, is_error) = match &result.output {
588 crate::types::ToolOutput::Success(s) => (s.clone(), false),
589 crate::types::ToolOutput::SuccessBlocks(blocks) => {
590 let text = blocks
591 .iter()
592 .filter_map(|b| match b {
593 crate::types::ToolOutputBlock::Text { text } => Some(text.as_str()),
594 _ => None,
595 })
596 .collect::<Vec<_>>()
597 .join("\n");
598 (text, false)
599 }
600 crate::types::ToolOutput::Error(e) => (e.to_string(), true),
601 crate::types::ToolOutput::Empty => (String::new(), false),
602 };
603
604 self.metrics
605 .record_tool(&tool_use.id, &tool_use.name, duration_ms, is_error);
606
607 accumulate_inner_usage(
608 &self.cfg.tool_state,
609 &mut self.total_usage,
610 &mut self.metrics,
611 &self.cfg.budget_tracker,
612 &result,
613 &tool_use.name,
614 )
615 .await;
616
617 run_post_tool_hooks(
618 &self.cfg.hooks,
619 &self.cfg.hook_context,
620 &self.cfg.session_id,
621 &tool_use.name,
622 is_error,
623 &result,
624 )
625 .await;
626
627 try_activate_dynamic_rules(
628 &tool_use.name,
629 &actual_input,
630 &self.cfg.orchestrator,
631 &mut self.dynamic_rules,
632 )
633 .await;
634
635 self.pending_tool_results
636 .push(ToolResultBlock::from_tool_result(&tool_use.id, &result));
637 self.phase = Phase::ProcessingTools {
638 tool_index: tool_index + 1,
639 };
640
641 Some(Ok(AgentEvent::ToolComplete {
642 id: tool_use.id,
643 name: tool_use.name,
644 output,
645 is_error,
646 duration_ms,
647 }))
648 }
649
650 async fn finalize_tool_results(&mut self) {
651 let results = std::mem::take(&mut self.pending_tool_results);
652 let max_tokens = context_window::for_model(&self.cfg.config.model.primary);
653
654 self.cfg
655 .tool_state
656 .with_session_mut(|session| {
657 session.add_tool_results(results);
658 })
659 .await;
660
661 handle_compaction(
662 &self.cfg.tool_state,
663 &self.cfg.client,
664 &self.cfg.tools,
665 &self.cfg.hooks,
666 &self.cfg.hook_context,
667 &self.cfg.session_id,
668 &self.cfg.config.execution,
669 max_tokens,
670 &mut self.metrics,
671 )
672 .await;
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679
680 #[test]
681 fn test_phase_transitions() {
682 assert!(matches!(Phase::StartRequest, Phase::StartRequest));
683 assert!(matches!(Phase::Done, Phase::Done));
684 }
685
686 #[test]
687 fn test_stream_poll_result_variants() {
688 let event = StreamPollResult::Event(Ok(AgentEvent::Text("test".into())));
689 assert!(matches!(event, StreamPollResult::Event(_)));
690
691 let cont = StreamPollResult::Continue;
692 assert!(matches!(cont, StreamPollResult::Continue));
693
694 let ended = StreamPollResult::StreamEnded;
695 assert!(matches!(ended, StreamPollResult::StreamEnded));
696 }
697}