1use std::path::Path;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Instant;
7
8use futures::{Stream, StreamExt, stream};
9use tokio::sync::RwLock;
10use tracing::{debug, warn};
11
12use super::common::BudgetContext;
13use super::events::{AgentEvent, AgentResult};
14use super::execution::extract_file_path;
15use super::executor::Agent;
16use super::request::RequestBuilder;
17use super::state_formatter::collect_compaction_state;
18use super::{AgentConfig, AgentMetrics, AgentState};
19use crate::budget::{BudgetTracker, TenantBudget};
20use crate::client::{RecoverableStream, StreamItem};
21use crate::context::PromptOrchestrator;
22use crate::hooks::{HookContext, HookEvent, HookInput, HookManager};
23use crate::session::ToolState;
24use crate::types::{
25 CompactResult, ContentBlock, PermissionDenial, StopReason, StreamEvent, ToolResultBlock,
26 ToolUseBlock, Usage, context_window,
27};
28use crate::{Client, ToolRegistry};
29
30type BoxedByteStream =
31 Pin<Box<dyn Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send>>;
32
33impl Agent {
34 pub async fn execute_stream(
35 &self,
36 prompt: &str,
37 ) -> crate::Result<impl Stream<Item = crate::Result<AgentEvent>> + Send> {
38 let timeout = self
39 .config
40 .execution
41 .timeout
42 .unwrap_or(std::time::Duration::from_secs(600));
43
44 if self.state.is_executing() {
45 self.state
46 .enqueue(prompt)
47 .await
48 .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
49 } else {
50 self.state
51 .with_session_mut(|session| {
52 session.add_user_message(prompt);
53 })
54 .await;
55 }
56
57 let state = StreamState::new(
58 StreamStateConfig {
59 tool_state: self.state.clone(),
60 client: Arc::clone(&self.client),
61 config: Arc::clone(&self.config),
62 tools: Arc::clone(&self.tools),
63 hooks: Arc::clone(&self.hooks),
64 hook_context: self.hook_context(),
65 request_builder: RequestBuilder::new(&self.config, Arc::clone(&self.tools)),
66 orchestrator: self.orchestrator.clone(),
67 session_id: Arc::clone(&self.session_id),
68 budget_tracker: Arc::clone(&self.budget_tracker),
69 tenant_budget: self.tenant_budget.clone(),
70 },
71 timeout,
72 );
73
74 Ok(stream::unfold(state, |mut state| async move {
75 state.next_event().await.map(|event| (event, state))
76 }))
77 }
78}
79
80struct StreamStateConfig {
81 tool_state: ToolState,
82 client: Arc<Client>,
83 config: Arc<AgentConfig>,
84 tools: Arc<ToolRegistry>,
85 hooks: Arc<HookManager>,
86 hook_context: HookContext,
87 request_builder: RequestBuilder,
88 orchestrator: Option<Arc<RwLock<PromptOrchestrator>>>,
89 session_id: Arc<str>,
90 budget_tracker: Arc<BudgetTracker>,
91 tenant_budget: Option<Arc<TenantBudget>>,
92}
93
94enum StreamPollResult {
95 Event(crate::Result<AgentEvent>),
96 Continue,
97 StreamEnded,
98}
99
100enum Phase {
101 StartRequest,
102 Streaming(Box<StreamingPhase>),
103 StreamEnded { accumulated_usage: Usage },
104 ProcessingTools { tool_index: usize },
105 Done,
106}
107
108struct StreamingPhase {
109 stream: RecoverableStream<BoxedByteStream>,
110 accumulated_usage: Usage,
111}
112
113struct StreamState {
114 cfg: StreamStateConfig,
115 timeout: std::time::Duration,
116 chunk_timeout: std::time::Duration,
117 dynamic_rules: String,
118 metrics: AgentMetrics,
119 start_time: Instant,
120 last_chunk_time: Instant,
121 pending_tool_results: Vec<ToolResultBlock>,
122 pending_tool_uses: Vec<ToolUseBlock>,
123 final_text: String,
124 total_usage: Usage,
125 phase: Phase,
126}
127
128impl StreamState {
129 fn new(cfg: StreamStateConfig, timeout: std::time::Duration) -> Self {
130 let chunk_timeout = cfg.config.execution.chunk_timeout;
131 let now = Instant::now();
132 Self {
133 cfg,
134 timeout,
135 chunk_timeout,
136 dynamic_rules: String::new(),
137 metrics: AgentMetrics::default(),
138 start_time: now,
139 last_chunk_time: now,
140 pending_tool_results: Vec::new(),
141 pending_tool_uses: Vec::new(),
142 final_text: String::new(),
143 total_usage: Usage::default(),
144 phase: Phase::StartRequest,
145 }
146 }
147
148 async fn next_event(&mut self) -> Option<crate::Result<AgentEvent>> {
149 loop {
150 if matches!(self.phase, Phase::Done) {
151 return None;
152 }
153
154 if self.start_time.elapsed() > self.timeout {
155 self.phase = Phase::Done;
156 return Some(Err(crate::Error::Timeout(self.timeout)));
157 }
158
159 if let Some(event) = self.check_budget_exceeded() {
160 return Some(event);
161 }
162
163 match std::mem::replace(&mut self.phase, Phase::Done) {
164 Phase::StartRequest => {
165 if let Some(result) = self.do_start_request().await {
166 return Some(result);
167 }
168 }
169 Phase::Streaming(mut streaming) => {
170 match self
171 .do_poll_stream(&mut streaming.stream, &mut streaming.accumulated_usage)
172 .await
173 {
174 StreamPollResult::Event(event) => {
175 self.phase = Phase::Streaming(streaming);
176 return Some(event);
177 }
178 StreamPollResult::Continue => {
179 self.phase = Phase::Streaming(streaming);
180 }
181 StreamPollResult::StreamEnded => {
182 self.phase = Phase::StreamEnded {
183 accumulated_usage: streaming.accumulated_usage,
184 };
185 }
186 }
187 }
188 Phase::StreamEnded { accumulated_usage } => {
189 if let Some(event) = self.do_handle_stream_end(accumulated_usage).await {
190 return Some(event);
191 }
192 }
193 Phase::ProcessingTools { tool_index } => {
194 if let Some(result) = self.do_process_tool(tool_index).await {
195 return Some(result);
196 }
197 }
198 Phase::Done => return None,
199 }
200 }
201 }
202
203 fn check_budget_exceeded(&mut self) -> Option<crate::Result<AgentEvent>> {
204 let result = BudgetContext {
205 tracker: &self.cfg.budget_tracker,
206 tenant: self.cfg.tenant_budget.as_deref(),
207 config: &self.cfg.config.budget,
208 }
209 .check();
210
211 if let Err(e) = result {
212 self.phase = Phase::Done;
213 return Some(Err(e));
214 }
215
216 None
217 }
218
219 async fn do_start_request(&mut self) -> Option<crate::Result<AgentEvent>> {
220 self.metrics.iterations += 1;
221 if self.metrics.iterations > self.cfg.config.execution.max_iterations {
222 self.phase = Phase::Done;
223 self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
224
225 let messages = self
226 .cfg
227 .tool_state
228 .with_session(|session| session.to_api_messages())
229 .await;
230
231 return Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
232 text: self.final_text.clone(),
233 usage: self.total_usage,
234 tool_calls: self.metrics.tool_calls,
235 iterations: self.metrics.iterations - 1,
236 stop_reason: StopReason::MaxTokens,
237 state: AgentState::Completed,
238 metrics: self.metrics.clone(),
239 session_id: self.cfg.session_id.to_string(),
240 structured_output: None,
241 messages,
242 uuid: uuid::Uuid::new_v4().to_string(),
243 }))));
244 }
245
246 let cache_messages = self.cfg.config.cache.enabled && self.cfg.config.cache.message_cache;
247 let messages = self
248 .cfg
249 .tool_state
250 .with_session(|session| session.to_api_messages_with_cache(cache_messages))
251 .await;
252
253 let stream_request = self
254 .cfg
255 .request_builder
256 .build(messages, &self.dynamic_rules)
257 .with_stream();
258
259 let response = match self
260 .cfg
261 .client
262 .send_stream_with_auth_retry(stream_request)
263 .await
264 {
265 Ok(r) => r,
266 Err(e) => {
267 self.phase = Phase::Done;
268 return Some(Err(e));
269 }
270 };
271
272 self.metrics.record_api_call();
273
274 let boxed_stream: BoxedByteStream = Box::pin(response.bytes_stream());
275 self.phase = Phase::Streaming(Box::new(StreamingPhase {
276 stream: RecoverableStream::new(boxed_stream),
277 accumulated_usage: Usage::default(),
278 }));
279
280 None
281 }
282
283 async fn do_poll_stream(
284 &mut self,
285 stream: &mut RecoverableStream<BoxedByteStream>,
286 accumulated_usage: &mut Usage,
287 ) -> StreamPollResult {
288 let chunk_result = tokio::time::timeout(self.chunk_timeout, stream.next()).await;
289
290 match chunk_result {
291 Ok(Some(Ok(item))) => {
292 self.last_chunk_time = Instant::now();
293 self.handle_stream_item(item, accumulated_usage)
294 }
295 Ok(Some(Err(e))) => {
296 self.phase = Phase::Done;
297 StreamPollResult::Event(Err(e))
298 }
299 Ok(None) => StreamPollResult::StreamEnded,
300 Err(_) => {
301 self.phase = Phase::Done;
302 StreamPollResult::Event(Err(crate::Error::Stream(format!(
303 "Chunk timeout after {:?} (no data received)",
304 self.chunk_timeout
305 ))))
306 }
307 }
308 }
309
310 fn handle_stream_item(
311 &mut self,
312 item: StreamItem,
313 accumulated_usage: &mut Usage,
314 ) -> StreamPollResult {
315 match item {
316 StreamItem::Text(text) => {
317 self.final_text.push_str(&text);
318 StreamPollResult::Event(Ok(AgentEvent::Text(text)))
319 }
320 StreamItem::Thinking(thinking) => {
321 StreamPollResult::Event(Ok(AgentEvent::Thinking(thinking)))
322 }
323 StreamItem::Citation(_) => StreamPollResult::Continue,
324 StreamItem::ToolUseComplete(tool_use) => {
325 self.pending_tool_uses.push(tool_use);
326 StreamPollResult::Continue
327 }
328 StreamItem::Event(event) => self.handle_stream_event(event, accumulated_usage),
329 }
330 }
331
332 fn handle_stream_event(
333 &mut self,
334 event: StreamEvent,
335 accumulated_usage: &mut Usage,
336 ) -> StreamPollResult {
337 match event {
338 StreamEvent::MessageStart { message } => {
339 accumulated_usage.input_tokens = message.usage.input_tokens;
340 accumulated_usage.output_tokens = message.usage.output_tokens;
341 accumulated_usage.cache_creation_input_tokens =
342 message.usage.cache_creation_input_tokens;
343 accumulated_usage.cache_read_input_tokens = message.usage.cache_read_input_tokens;
344 StreamPollResult::Continue
345 }
346 StreamEvent::ContentBlockStart { .. } => StreamPollResult::Continue,
347 StreamEvent::ContentBlockDelta { .. } => StreamPollResult::Continue,
348 StreamEvent::ContentBlockStop { .. } => StreamPollResult::Continue,
349 StreamEvent::MessageDelta { usage, .. } => {
350 accumulated_usage.output_tokens = usage.output_tokens;
351 StreamPollResult::Continue
352 }
353 StreamEvent::MessageStop => StreamPollResult::StreamEnded,
354 StreamEvent::Ping => StreamPollResult::Continue,
355 StreamEvent::Error { error } => {
356 self.phase = Phase::Done;
357 StreamPollResult::Event(Err(crate::Error::Stream(error.message)))
358 }
359 }
360 }
361
362 async fn do_handle_stream_end(
363 &mut self,
364 accumulated_usage: Usage,
365 ) -> Option<crate::Result<AgentEvent>> {
366 self.total_usage.input_tokens += accumulated_usage.input_tokens;
367 self.total_usage.output_tokens += accumulated_usage.output_tokens;
368 self.metrics.add_usage(
369 accumulated_usage.input_tokens,
370 accumulated_usage.output_tokens,
371 );
372 self.metrics
373 .record_model_usage(&self.cfg.config.model.primary, &accumulated_usage);
374
375 let cost = self
376 .cfg
377 .budget_tracker
378 .record(&self.cfg.config.model.primary, &accumulated_usage);
379 self.metrics.add_cost(cost);
380
381 if let Some(ref tenant_budget) = self.cfg.tenant_budget {
382 tenant_budget.record(&self.cfg.config.model.primary, &accumulated_usage);
383 }
384
385 self.cfg
386 .tool_state
387 .with_session_mut(|session| {
388 session.update_usage(&accumulated_usage);
389
390 let mut content = Vec::new();
391 if !self.final_text.is_empty() {
392 content.push(ContentBlock::Text {
393 text: self.final_text.clone(),
394 citations: None,
395 cache_control: None,
396 });
397 }
398 for tool_use in &self.pending_tool_uses {
399 content.push(ContentBlock::ToolUse(tool_use.clone()));
400 }
401 if !content.is_empty() {
402 session.add_assistant_message(content, Some(accumulated_usage));
403 }
404 })
405 .await;
406
407 if self.pending_tool_uses.is_empty() {
408 self.phase = Phase::Done;
409 self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
410
411 let messages = self
412 .cfg
413 .tool_state
414 .with_session(|session| session.to_api_messages())
415 .await;
416
417 return Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
418 text: self.final_text.clone(),
419 usage: self.total_usage,
420 tool_calls: self.metrics.tool_calls,
421 iterations: self.metrics.iterations,
422 stop_reason: StopReason::EndTurn,
423 state: AgentState::Completed,
424 metrics: self.metrics.clone(),
425 session_id: self.cfg.session_id.to_string(),
426 structured_output: None,
427 messages,
428 uuid: uuid::Uuid::new_v4().to_string(),
429 }))));
430 }
431
432 self.phase = Phase::ProcessingTools { tool_index: 0 };
433 None
434 }
435
436 async fn do_process_tool(&mut self, tool_index: usize) -> Option<crate::Result<AgentEvent>> {
437 if tool_index >= self.pending_tool_uses.len() {
438 if !self.pending_tool_results.is_empty() {
439 self.finalize_tool_results().await;
440 }
441 self.final_text.clear();
442 self.pending_tool_uses.clear();
443 self.phase = Phase::StartRequest;
444 return None;
445 }
446
447 let tool_use = self.pending_tool_uses[tool_index].clone();
448 self.execute_tool(tool_use, tool_index).await
449 }
450
451 async fn execute_tool(
452 &mut self,
453 tool_use: ToolUseBlock,
454 tool_index: usize,
455 ) -> Option<crate::Result<AgentEvent>> {
456 let pre_input = HookInput::pre_tool_use(
457 &*self.cfg.session_id,
458 &tool_use.name,
459 tool_use.input.clone(),
460 );
461 let pre_output = match self
462 .cfg
463 .hooks
464 .execute(HookEvent::PreToolUse, pre_input, &self.cfg.hook_context)
465 .await
466 {
467 Ok(output) => output,
468 Err(e) => {
469 warn!(tool = %tool_use.name, error = %e, "PreToolUse hook failed");
470 crate::hooks::HookOutput::allow()
471 }
472 };
473
474 if !pre_output.continue_execution {
475 let reason = pre_output
476 .stop_reason
477 .clone()
478 .unwrap_or_else(|| "Blocked by hook".into());
479 debug!(tool = %tool_use.name, "Tool blocked by hook");
480
481 self.pending_tool_results
482 .push(ToolResultBlock::error(&tool_use.id, reason.clone()));
483 self.metrics.record_permission_denial(
484 PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
485 .with_reason(reason.clone()),
486 );
487 self.phase = Phase::ProcessingTools {
488 tool_index: tool_index + 1,
489 };
490
491 return Some(Ok(AgentEvent::ToolBlocked {
492 id: tool_use.id,
493 name: tool_use.name,
494 reason,
495 }));
496 }
497
498 let actual_input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
499
500 let start = Instant::now();
501 let result = self
502 .cfg
503 .tools
504 .execute(&tool_use.name, actual_input.clone())
505 .await;
506 let duration_ms = start.elapsed().as_millis() as u64;
507
508 let (output, is_error) = match &result.output {
509 crate::types::ToolOutput::Success(s) => (s.clone(), false),
510 crate::types::ToolOutput::SuccessBlocks(blocks) => {
511 let text = blocks
512 .iter()
513 .filter_map(|b| match b {
514 crate::types::ToolOutputBlock::Text { text } => Some(text.as_str()),
515 _ => None,
516 })
517 .collect::<Vec<_>>()
518 .join("\n");
519 (text, false)
520 }
521 crate::types::ToolOutput::Error(e) => (e.to_string(), true),
522 crate::types::ToolOutput::Empty => (String::new(), false),
523 };
524
525 self.metrics
526 .record_tool(&tool_use.id, &tool_use.name, duration_ms, is_error);
527
528 if let Some(ref inner_usage) = result.inner_usage {
529 self.cfg
530 .tool_state
531 .with_session_mut(|session| {
532 session.update_usage(inner_usage);
533 })
534 .await;
535 self.total_usage.input_tokens += inner_usage.input_tokens;
536 self.total_usage.output_tokens += inner_usage.output_tokens;
537 self.metrics
538 .add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
539 let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
540 self.metrics.record_model_usage(inner_model, inner_usage);
541 let inner_cost = self.cfg.budget_tracker.record(inner_model, inner_usage);
542 self.metrics.add_cost(inner_cost);
543 }
544
545 if is_error {
546 let failure_input = HookInput::post_tool_use_failure(
547 &*self.cfg.session_id,
548 &tool_use.name,
549 result.error_message(),
550 );
551 if let Err(e) = self
552 .cfg
553 .hooks
554 .execute(
555 HookEvent::PostToolUseFailure,
556 failure_input,
557 &self.cfg.hook_context,
558 )
559 .await
560 {
561 warn!(tool = %tool_use.name, error = %e, "PostToolUseFailure hook failed");
562 }
563 } else {
564 let post_input = HookInput::post_tool_use(
565 &*self.cfg.session_id,
566 &tool_use.name,
567 result.output.clone(),
568 );
569 if let Err(e) = self
570 .cfg
571 .hooks
572 .execute(HookEvent::PostToolUse, post_input, &self.cfg.hook_context)
573 .await
574 {
575 warn!(tool = %tool_use.name, error = %e, "PostToolUse hook failed");
576 }
577 }
578
579 if let Some(file_path) = extract_file_path(&tool_use.name, &actual_input)
580 && let Some(ref orchestrator) = self.cfg.orchestrator
581 {
582 let orch = orchestrator.read().await;
583 let path = Path::new(&file_path);
584 let rules = orch.rules_engine().find_matching(path);
585 if !rules.is_empty() {
586 let dynamic_ctx = orch.build_dynamic_context(Some(path)).await;
587 if !dynamic_ctx.is_empty() {
588 self.dynamic_rules = dynamic_ctx;
589 }
590 }
591 }
592
593 self.pending_tool_results
594 .push(ToolResultBlock::from_tool_result(&tool_use.id, result));
595 self.phase = Phase::ProcessingTools {
596 tool_index: tool_index + 1,
597 };
598
599 Some(Ok(AgentEvent::ToolComplete {
600 id: tool_use.id,
601 name: tool_use.name,
602 output,
603 is_error,
604 duration_ms,
605 }))
606 }
607
608 async fn finalize_tool_results(&mut self) {
609 let results = std::mem::take(&mut self.pending_tool_results);
610 let max_tokens = context_window::for_model(&self.cfg.config.model.primary);
611
612 self.cfg
613 .tool_state
614 .with_session_mut(|session| {
615 session.add_tool_results(results);
616 })
617 .await;
618
619 let should_compact = self
620 .cfg
621 .tool_state
622 .with_session(|session| {
623 self.cfg.config.execution.auto_compact
624 && session.should_compact(
625 max_tokens,
626 self.cfg.config.execution.compact_threshold,
627 self.cfg.config.execution.compact_keep_messages,
628 )
629 })
630 .await;
631
632 if should_compact {
633 let compact_result = self
634 .cfg
635 .tool_state
636 .compact(
637 &self.cfg.client,
638 self.cfg.config.execution.compact_keep_messages,
639 )
640 .await;
641
642 if let Ok(CompactResult::Compacted { .. }) = compact_result {
643 self.metrics.record_compaction();
644 let state_sections = collect_compaction_state(&self.cfg.tools).await;
645 if !state_sections.is_empty() {
646 self.cfg
647 .tool_state
648 .with_session_mut(|session| {
649 session.add_user_message(format!(
650 "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
651 state_sections.join("\n\n")
652 ));
653 })
654 .await;
655 }
656 }
657 }
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664
665 #[test]
666 fn test_phase_transitions() {
667 assert!(matches!(Phase::StartRequest, Phase::StartRequest));
668 assert!(matches!(Phase::Done, Phase::Done));
669 }
670
671 #[test]
672 fn test_stream_poll_result_variants() {
673 let event = StreamPollResult::Event(Ok(AgentEvent::Text("test".into())));
674 assert!(matches!(event, StreamPollResult::Event(_)));
675
676 let cont = StreamPollResult::Continue;
677 assert!(matches!(cont, StreamPollResult::Continue));
678
679 let ended = StreamPollResult::StreamEnded;
680 assert!(matches!(ended, StreamPollResult::StreamEnded));
681 }
682}