1use std::path::Path;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Instant;
7
8use futures::{Stream, StreamExt, stream};
9use tokio::sync::RwLock;
10use tracing::{debug, warn};
11
12use super::common::BudgetContext;
13use super::events::{AgentEvent, AgentResult};
14use super::execution::extract_file_path;
15use super::executor::Agent;
16use super::request::RequestBuilder;
17use super::state_formatter::collect_compaction_state;
18use super::{AgentConfig, AgentMetrics, AgentState};
19use crate::budget::{BudgetTracker, TenantBudget};
20use crate::client::{RecoverableStream, StreamItem};
21use crate::context::PromptOrchestrator;
22use crate::hooks::{HookContext, HookEvent, HookInput, HookManager};
23use crate::session::ToolState;
24use crate::types::{
25 CompactResult, ContentBlock, PermissionDenial, StopReason, StreamEvent, ToolResultBlock,
26 ToolUseBlock, Usage, context_window,
27};
28use crate::{Client, ToolRegistry};
29
30type BoxedByteStream =
31 Pin<Box<dyn Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send>>;
32
33impl Agent {
34 pub async fn execute_stream(
35 &self,
36 prompt: &str,
37 ) -> crate::Result<impl Stream<Item = crate::Result<AgentEvent>> + Send> {
38 let timeout = self
39 .config
40 .execution
41 .timeout
42 .unwrap_or(std::time::Duration::from_secs(600));
43
44 if self.state.is_executing() {
45 self.state
46 .enqueue(prompt)
47 .await
48 .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
49 } else {
50 self.state
51 .with_session_mut(|session| {
52 session.add_user_message(prompt);
53 })
54 .await;
55 }
56
57 let state = StreamState::new(
58 StreamStateConfig {
59 tool_state: self.state.clone(),
60 client: Arc::clone(&self.client),
61 config: Arc::clone(&self.config),
62 tools: Arc::clone(&self.tools),
63 hooks: Arc::clone(&self.hooks),
64 hook_context: self.hook_context(),
65 request_builder: RequestBuilder::new(&self.config, Arc::clone(&self.tools)),
66 orchestrator: self.orchestrator.clone(),
67 session_id: Arc::clone(&self.session_id),
68 budget_tracker: Arc::clone(&self.budget_tracker),
69 tenant_budget: self.tenant_budget.clone(),
70 },
71 timeout,
72 );
73
74 Ok(stream::unfold(state, |mut state| async move {
75 state.next_event().await.map(|event| (event, state))
76 }))
77 }
78}
79
80struct StreamStateConfig {
81 tool_state: ToolState,
82 client: Arc<Client>,
83 config: Arc<AgentConfig>,
84 tools: Arc<ToolRegistry>,
85 hooks: Arc<HookManager>,
86 hook_context: HookContext,
87 request_builder: RequestBuilder,
88 orchestrator: Option<Arc<RwLock<PromptOrchestrator>>>,
89 session_id: Arc<str>,
90 budget_tracker: Arc<BudgetTracker>,
91 tenant_budget: Option<Arc<TenantBudget>>,
92}
93
94enum StreamPollResult {
95 Event(crate::Result<AgentEvent>),
96 Continue,
97 StreamEnded,
98}
99
100enum Phase {
101 StartRequest,
102 Streaming(Box<StreamingPhase>),
103 StreamEnded { accumulated_usage: Usage },
104 ProcessingTools { tool_index: usize },
105 Done,
106}
107
108struct StreamingPhase {
109 stream: RecoverableStream<BoxedByteStream>,
110 accumulated_usage: Usage,
111}
112
113struct StreamState {
114 cfg: StreamStateConfig,
115 timeout: std::time::Duration,
116 chunk_timeout: std::time::Duration,
117 dynamic_rules: String,
118 metrics: AgentMetrics,
119 start_time: Instant,
120 last_chunk_time: Instant,
121 pending_tool_results: Vec<ToolResultBlock>,
122 pending_tool_uses: Vec<ToolUseBlock>,
123 final_text: String,
124 total_usage: Usage,
125 phase: Phase,
126}
127
128impl StreamState {
129 fn new(cfg: StreamStateConfig, timeout: std::time::Duration) -> Self {
130 let chunk_timeout = cfg.config.execution.chunk_timeout;
131 let now = Instant::now();
132 Self {
133 cfg,
134 timeout,
135 chunk_timeout,
136 dynamic_rules: String::new(),
137 metrics: AgentMetrics::default(),
138 start_time: now,
139 last_chunk_time: now,
140 pending_tool_results: Vec::new(),
141 pending_tool_uses: Vec::new(),
142 final_text: String::new(),
143 total_usage: Usage::default(),
144 phase: Phase::StartRequest,
145 }
146 }
147
148 async fn next_event(&mut self) -> Option<crate::Result<AgentEvent>> {
149 loop {
150 if matches!(self.phase, Phase::Done) {
151 return None;
152 }
153
154 if self.start_time.elapsed() > self.timeout {
155 self.phase = Phase::Done;
156 return Some(Err(crate::Error::Timeout(self.timeout)));
157 }
158
159 if let Some(event) = self.check_budget_exceeded() {
160 return Some(event);
161 }
162
163 match std::mem::replace(&mut self.phase, Phase::Done) {
164 Phase::StartRequest => {
165 if let Some(result) = self.do_start_request().await {
166 return Some(result);
167 }
168 }
169 Phase::Streaming(mut streaming) => {
170 match self
171 .do_poll_stream(&mut streaming.stream, &mut streaming.accumulated_usage)
172 .await
173 {
174 StreamPollResult::Event(event) => {
175 self.phase = Phase::Streaming(streaming);
176 return Some(event);
177 }
178 StreamPollResult::Continue => {
179 self.phase = Phase::Streaming(streaming);
180 }
181 StreamPollResult::StreamEnded => {
182 self.phase = Phase::StreamEnded {
183 accumulated_usage: streaming.accumulated_usage,
184 };
185 }
186 }
187 }
188 Phase::StreamEnded { accumulated_usage } => {
189 if let Some(event) = self.do_handle_stream_end(accumulated_usage).await {
190 return Some(event);
191 }
192 }
193 Phase::ProcessingTools { tool_index } => {
194 if let Some(result) = self.do_process_tool(tool_index).await {
195 return Some(result);
196 }
197 }
198 Phase::Done => return None,
199 }
200 }
201 }
202
203 fn check_budget_exceeded(&mut self) -> Option<crate::Result<AgentEvent>> {
204 let result = BudgetContext {
205 tracker: &self.cfg.budget_tracker,
206 tenant: self.cfg.tenant_budget.as_deref(),
207 config: &self.cfg.config.budget,
208 }
209 .check();
210
211 if let Err(e) = result {
212 self.phase = Phase::Done;
213 return Some(Err(e));
214 }
215
216 None
217 }
218
219 async fn do_start_request(&mut self) -> Option<crate::Result<AgentEvent>> {
220 self.metrics.iterations += 1;
221 if self.metrics.iterations > self.cfg.config.execution.max_iterations {
222 self.phase = Phase::Done;
223 self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
224
225 let messages = self
226 .cfg
227 .tool_state
228 .with_session(|session| session.to_api_messages())
229 .await;
230
231 return Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
232 text: self.final_text.clone(),
233 usage: self.total_usage,
234 tool_calls: self.metrics.tool_calls,
235 iterations: self.metrics.iterations - 1,
236 stop_reason: StopReason::MaxTokens,
237 state: AgentState::Completed,
238 metrics: self.metrics.clone(),
239 session_id: self.cfg.session_id.to_string(),
240 structured_output: None,
241 messages,
242 uuid: uuid::Uuid::new_v4().to_string(),
243 }))));
244 }
245
246 let messages = self
247 .cfg
248 .tool_state
249 .with_session(|session| {
250 session.to_api_messages_with_cache(self.cfg.config.cache.message_ttl_option())
251 })
252 .await;
253
254 let stream_request = self
255 .cfg
256 .request_builder
257 .build(messages, &self.dynamic_rules)
258 .with_stream();
259
260 let response = match self
261 .cfg
262 .client
263 .send_stream_with_auth_retry(stream_request)
264 .await
265 {
266 Ok(r) => r,
267 Err(e) => {
268 self.phase = Phase::Done;
269 return Some(Err(e));
270 }
271 };
272
273 self.metrics.record_api_call();
274
275 let boxed_stream: BoxedByteStream = Box::pin(response.bytes_stream());
276 self.phase = Phase::Streaming(Box::new(StreamingPhase {
277 stream: RecoverableStream::new(boxed_stream),
278 accumulated_usage: Usage::default(),
279 }));
280
281 None
282 }
283
284 async fn do_poll_stream(
285 &mut self,
286 stream: &mut RecoverableStream<BoxedByteStream>,
287 accumulated_usage: &mut Usage,
288 ) -> StreamPollResult {
289 let chunk_result = tokio::time::timeout(self.chunk_timeout, stream.next()).await;
290
291 match chunk_result {
292 Ok(Some(Ok(item))) => {
293 self.last_chunk_time = Instant::now();
294 self.handle_stream_item(item, accumulated_usage)
295 }
296 Ok(Some(Err(e))) => {
297 self.phase = Phase::Done;
298 StreamPollResult::Event(Err(e))
299 }
300 Ok(None) => StreamPollResult::StreamEnded,
301 Err(_) => {
302 self.phase = Phase::Done;
303 StreamPollResult::Event(Err(crate::Error::Stream(format!(
304 "Chunk timeout after {:?} (no data received)",
305 self.chunk_timeout
306 ))))
307 }
308 }
309 }
310
311 fn handle_stream_item(
312 &mut self,
313 item: StreamItem,
314 accumulated_usage: &mut Usage,
315 ) -> StreamPollResult {
316 match item {
317 StreamItem::Text(text) => {
318 self.final_text.push_str(&text);
319 StreamPollResult::Event(Ok(AgentEvent::Text(text)))
320 }
321 StreamItem::Thinking(thinking) => {
322 StreamPollResult::Event(Ok(AgentEvent::Thinking(thinking)))
323 }
324 StreamItem::Citation(_) => StreamPollResult::Continue,
325 StreamItem::ToolUseComplete(tool_use) => {
326 self.pending_tool_uses.push(tool_use);
327 StreamPollResult::Continue
328 }
329 StreamItem::Event(event) => self.handle_stream_event(event, accumulated_usage),
330 }
331 }
332
333 fn handle_stream_event(
334 &mut self,
335 event: StreamEvent,
336 accumulated_usage: &mut Usage,
337 ) -> StreamPollResult {
338 match event {
339 StreamEvent::MessageStart { message } => {
340 accumulated_usage.input_tokens = message.usage.input_tokens;
341 accumulated_usage.output_tokens = message.usage.output_tokens;
342 accumulated_usage.cache_creation_input_tokens =
343 message.usage.cache_creation_input_tokens;
344 accumulated_usage.cache_read_input_tokens = message.usage.cache_read_input_tokens;
345 StreamPollResult::Continue
346 }
347 StreamEvent::ContentBlockStart { .. } => StreamPollResult::Continue,
348 StreamEvent::ContentBlockDelta { .. } => StreamPollResult::Continue,
349 StreamEvent::ContentBlockStop { .. } => StreamPollResult::Continue,
350 StreamEvent::MessageDelta { usage, .. } => {
351 accumulated_usage.output_tokens = usage.output_tokens;
352 StreamPollResult::Continue
353 }
354 StreamEvent::MessageStop => StreamPollResult::StreamEnded,
355 StreamEvent::Ping => StreamPollResult::Continue,
356 StreamEvent::Error { error } => {
357 self.phase = Phase::Done;
358 StreamPollResult::Event(Err(crate::Error::Stream(error.message)))
359 }
360 }
361 }
362
363 async fn do_handle_stream_end(
364 &mut self,
365 accumulated_usage: Usage,
366 ) -> Option<crate::Result<AgentEvent>> {
367 self.total_usage.input_tokens += accumulated_usage.input_tokens;
368 self.total_usage.output_tokens += accumulated_usage.output_tokens;
369 self.metrics.add_usage(
370 accumulated_usage.input_tokens,
371 accumulated_usage.output_tokens,
372 );
373 self.metrics
374 .record_model_usage(&self.cfg.config.model.primary, &accumulated_usage);
375
376 let cost = self
377 .cfg
378 .budget_tracker
379 .record(&self.cfg.config.model.primary, &accumulated_usage);
380 self.metrics.add_cost(cost);
381
382 if let Some(ref tenant_budget) = self.cfg.tenant_budget {
383 tenant_budget.record(&self.cfg.config.model.primary, &accumulated_usage);
384 }
385
386 self.cfg
387 .tool_state
388 .with_session_mut(|session| {
389 session.update_usage(&accumulated_usage);
390
391 let mut content = Vec::new();
392 if !self.final_text.is_empty() {
393 content.push(ContentBlock::Text {
394 text: self.final_text.clone(),
395 citations: None,
396 cache_control: None,
397 });
398 }
399 for tool_use in &self.pending_tool_uses {
400 content.push(ContentBlock::ToolUse(tool_use.clone()));
401 }
402 if !content.is_empty() {
403 session.add_assistant_message(content, Some(accumulated_usage));
404 }
405 })
406 .await;
407
408 if self.pending_tool_uses.is_empty() {
409 self.phase = Phase::Done;
410 self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
411
412 let messages = self
413 .cfg
414 .tool_state
415 .with_session(|session| session.to_api_messages())
416 .await;
417
418 return Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
419 text: self.final_text.clone(),
420 usage: self.total_usage,
421 tool_calls: self.metrics.tool_calls,
422 iterations: self.metrics.iterations,
423 stop_reason: StopReason::EndTurn,
424 state: AgentState::Completed,
425 metrics: self.metrics.clone(),
426 session_id: self.cfg.session_id.to_string(),
427 structured_output: None,
428 messages,
429 uuid: uuid::Uuid::new_v4().to_string(),
430 }))));
431 }
432
433 self.phase = Phase::ProcessingTools { tool_index: 0 };
434 None
435 }
436
437 async fn do_process_tool(&mut self, tool_index: usize) -> Option<crate::Result<AgentEvent>> {
438 if tool_index >= self.pending_tool_uses.len() {
439 if !self.pending_tool_results.is_empty() {
440 self.finalize_tool_results().await;
441 }
442 self.final_text.clear();
443 self.pending_tool_uses.clear();
444 self.phase = Phase::StartRequest;
445 return None;
446 }
447
448 let tool_use = self.pending_tool_uses[tool_index].clone();
449 self.execute_tool(tool_use, tool_index).await
450 }
451
452 async fn execute_tool(
453 &mut self,
454 tool_use: ToolUseBlock,
455 tool_index: usize,
456 ) -> Option<crate::Result<AgentEvent>> {
457 let pre_input = HookInput::pre_tool_use(
458 &*self.cfg.session_id,
459 &tool_use.name,
460 tool_use.input.clone(),
461 );
462 let pre_output = match self
463 .cfg
464 .hooks
465 .execute(HookEvent::PreToolUse, pre_input, &self.cfg.hook_context)
466 .await
467 {
468 Ok(output) => output,
469 Err(e) => {
470 warn!(tool = %tool_use.name, error = %e, "PreToolUse hook failed");
471 crate::hooks::HookOutput::allow()
472 }
473 };
474
475 if !pre_output.continue_execution {
476 let reason = pre_output
477 .stop_reason
478 .clone()
479 .unwrap_or_else(|| "Blocked by hook".into());
480 debug!(tool = %tool_use.name, "Tool blocked by hook");
481
482 self.pending_tool_results
483 .push(ToolResultBlock::error(&tool_use.id, reason.clone()));
484 self.metrics.record_permission_denial(
485 PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
486 .with_reason(reason.clone()),
487 );
488 self.phase = Phase::ProcessingTools {
489 tool_index: tool_index + 1,
490 };
491
492 return Some(Ok(AgentEvent::ToolBlocked {
493 id: tool_use.id,
494 name: tool_use.name,
495 reason,
496 }));
497 }
498
499 let actual_input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
500
501 let start = Instant::now();
502 let result = self
503 .cfg
504 .tools
505 .execute(&tool_use.name, actual_input.clone())
506 .await;
507 let duration_ms = start.elapsed().as_millis() as u64;
508
509 let (output, is_error) = match &result.output {
510 crate::types::ToolOutput::Success(s) => (s.clone(), false),
511 crate::types::ToolOutput::SuccessBlocks(blocks) => {
512 let text = blocks
513 .iter()
514 .filter_map(|b| match b {
515 crate::types::ToolOutputBlock::Text { text } => Some(text.as_str()),
516 _ => None,
517 })
518 .collect::<Vec<_>>()
519 .join("\n");
520 (text, false)
521 }
522 crate::types::ToolOutput::Error(e) => (e.to_string(), true),
523 crate::types::ToolOutput::Empty => (String::new(), false),
524 };
525
526 self.metrics
527 .record_tool(&tool_use.id, &tool_use.name, duration_ms, is_error);
528
529 if let Some(ref inner_usage) = result.inner_usage {
530 self.cfg
531 .tool_state
532 .with_session_mut(|session| {
533 session.update_usage(inner_usage);
534 })
535 .await;
536 self.total_usage.input_tokens += inner_usage.input_tokens;
537 self.total_usage.output_tokens += inner_usage.output_tokens;
538 self.metrics
539 .add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
540 let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
541 self.metrics.record_model_usage(inner_model, inner_usage);
542 let inner_cost = self.cfg.budget_tracker.record(inner_model, inner_usage);
543 self.metrics.add_cost(inner_cost);
544 }
545
546 if is_error {
547 let failure_input = HookInput::post_tool_use_failure(
548 &*self.cfg.session_id,
549 &tool_use.name,
550 result.error_message(),
551 );
552 if let Err(e) = self
553 .cfg
554 .hooks
555 .execute(
556 HookEvent::PostToolUseFailure,
557 failure_input,
558 &self.cfg.hook_context,
559 )
560 .await
561 {
562 warn!(tool = %tool_use.name, error = %e, "PostToolUseFailure hook failed");
563 }
564 } else {
565 let post_input = HookInput::post_tool_use(
566 &*self.cfg.session_id,
567 &tool_use.name,
568 result.output.clone(),
569 );
570 if let Err(e) = self
571 .cfg
572 .hooks
573 .execute(HookEvent::PostToolUse, post_input, &self.cfg.hook_context)
574 .await
575 {
576 warn!(tool = %tool_use.name, error = %e, "PostToolUse hook failed");
577 }
578 }
579
580 if let Some(file_path) = extract_file_path(&tool_use.name, &actual_input)
581 && let Some(ref orchestrator) = self.cfg.orchestrator
582 {
583 let orch = orchestrator.read().await;
584 let path = Path::new(&file_path);
585 let rules = orch.rules_engine().find_matching(path);
586 if !rules.is_empty() {
587 let dynamic_ctx = orch.build_dynamic_context(Some(path)).await;
588 if !dynamic_ctx.is_empty() {
589 self.dynamic_rules = dynamic_ctx;
590 }
591 }
592 }
593
594 self.pending_tool_results
595 .push(ToolResultBlock::from_tool_result(&tool_use.id, result));
596 self.phase = Phase::ProcessingTools {
597 tool_index: tool_index + 1,
598 };
599
600 Some(Ok(AgentEvent::ToolComplete {
601 id: tool_use.id,
602 name: tool_use.name,
603 output,
604 is_error,
605 duration_ms,
606 }))
607 }
608
609 async fn finalize_tool_results(&mut self) {
610 let results = std::mem::take(&mut self.pending_tool_results);
611 let max_tokens = context_window::for_model(&self.cfg.config.model.primary);
612
613 self.cfg
614 .tool_state
615 .with_session_mut(|session| {
616 session.add_tool_results(results);
617 })
618 .await;
619
620 let should_compact = self
621 .cfg
622 .tool_state
623 .with_session(|session| {
624 self.cfg.config.execution.auto_compact
625 && session.should_compact(
626 max_tokens,
627 self.cfg.config.execution.compact_threshold,
628 self.cfg.config.execution.compact_keep_messages,
629 )
630 })
631 .await;
632
633 if should_compact {
634 let compact_result = self
635 .cfg
636 .tool_state
637 .compact(
638 &self.cfg.client,
639 self.cfg.config.execution.compact_keep_messages,
640 )
641 .await;
642
643 if let Ok(CompactResult::Compacted { .. }) = compact_result {
644 self.metrics.record_compaction();
645 let state_sections = collect_compaction_state(&self.cfg.tools).await;
646 if !state_sections.is_empty() {
647 self.cfg
648 .tool_state
649 .with_session_mut(|session| {
650 session.add_user_message(format!(
651 "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
652 state_sections.join("\n\n")
653 ));
654 })
655 .await;
656 }
657 }
658 }
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665
666 #[test]
667 fn test_phase_transitions() {
668 assert!(matches!(Phase::StartRequest, Phase::StartRequest));
669 assert!(matches!(Phase::Done, Phase::Done));
670 }
671
672 #[test]
673 fn test_stream_poll_result_variants() {
674 let event = StreamPollResult::Event(Ok(AgentEvent::Text("test".into())));
675 assert!(matches!(event, StreamPollResult::Event(_)));
676
677 let cont = StreamPollResult::Continue;
678 assert!(matches!(cont, StreamPollResult::Continue));
679
680 let ended = StreamPollResult::StreamEnded;
681 assert!(matches!(ended, StreamPollResult::StreamEnded));
682 }
683}