1use std::path::Path;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Instant;
7
8use futures::{Stream, StreamExt, stream};
9use tokio::sync::RwLock;
10use tracing::{debug, warn};
11
12use super::common::BudgetContext;
13use super::events::{AgentEvent, AgentResult};
14use super::execution::extract_file_path;
15use super::executor::Agent;
16use super::request::RequestBuilder;
17use super::state_formatter::collect_compaction_state;
18use super::{AgentConfig, AgentMetrics, AgentState};
19use crate::budget::{BudgetTracker, TenantBudget};
20use crate::client::{RecoverableStream, StreamItem};
21use crate::context::PromptOrchestrator;
22use crate::hooks::{HookContext, HookEvent, HookInput, HookManager};
23use crate::session::ToolState;
24use crate::types::{
25 CompactResult, ContentBlock, PermissionDenial, StopReason, StreamEvent, ToolResultBlock,
26 ToolUseBlock, Usage, context_window,
27};
28use crate::{Client, ToolRegistry};
29
30type BoxedByteStream =
31 Pin<Box<dyn Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send>>;
32
33impl Agent {
34 pub async fn execute_stream(
35 &self,
36 prompt: &str,
37 ) -> crate::Result<impl Stream<Item = crate::Result<AgentEvent>> + Send> {
38 let timeout = self
39 .config
40 .execution
41 .timeout
42 .unwrap_or(std::time::Duration::from_secs(600));
43
44 if self.state.is_executing() {
45 self.state
46 .enqueue(prompt)
47 .await
48 .map_err(|e| crate::Error::Session(format!("Queue full: {}", e)))?;
49 } else {
50 self.state
51 .with_session_mut(|session| {
52 session.add_user_message(prompt);
53 })
54 .await;
55 }
56
57 let state = StreamState::new(
58 StreamStateConfig {
59 tool_state: self.state.clone(),
60 client: Arc::clone(&self.client),
61 config: Arc::clone(&self.config),
62 tools: Arc::clone(&self.tools),
63 hooks: Arc::clone(&self.hooks),
64 hook_context: self.hook_context(),
65 request_builder: RequestBuilder::new(&self.config, Arc::clone(&self.tools)),
66 orchestrator: self.orchestrator.clone(),
67 session_id: Arc::clone(&self.session_id),
68 budget_tracker: Arc::clone(&self.budget_tracker),
69 tenant_budget: self.tenant_budget.clone(),
70 },
71 timeout,
72 );
73
74 Ok(stream::unfold(state, |mut state| async move {
75 state.next_event().await.map(|event| (event, state))
76 }))
77 }
78}
79
80struct StreamStateConfig {
81 tool_state: ToolState,
82 client: Arc<Client>,
83 config: Arc<AgentConfig>,
84 tools: Arc<ToolRegistry>,
85 hooks: Arc<HookManager>,
86 hook_context: HookContext,
87 request_builder: RequestBuilder,
88 orchestrator: Option<Arc<RwLock<PromptOrchestrator>>>,
89 session_id: Arc<str>,
90 budget_tracker: Arc<BudgetTracker>,
91 tenant_budget: Option<Arc<TenantBudget>>,
92}
93
94enum StreamPollResult {
95 Event(crate::Result<AgentEvent>),
96 Continue,
97 StreamEnded,
98}
99
100enum Phase {
101 StartRequest,
102 Streaming(Box<StreamingPhase>),
103 StreamEnded { accumulated_usage: Usage },
104 ProcessingTools { tool_index: usize },
105 Done,
106}
107
108struct StreamingPhase {
109 stream: RecoverableStream<BoxedByteStream>,
110 accumulated_usage: Usage,
111}
112
113struct StreamState {
114 cfg: StreamStateConfig,
115 timeout: std::time::Duration,
116 chunk_timeout: std::time::Duration,
117 dynamic_rules: String,
118 metrics: AgentMetrics,
119 start_time: Instant,
120 last_chunk_time: Instant,
121 pending_tool_results: Vec<ToolResultBlock>,
122 pending_tool_uses: Vec<ToolUseBlock>,
123 final_text: String,
124 total_usage: Usage,
125 phase: Phase,
126}
127
128impl StreamState {
129 fn new(cfg: StreamStateConfig, timeout: std::time::Duration) -> Self {
130 let chunk_timeout = cfg.config.execution.chunk_timeout;
131 let now = Instant::now();
132 Self {
133 cfg,
134 timeout,
135 chunk_timeout,
136 dynamic_rules: String::new(),
137 metrics: AgentMetrics::default(),
138 start_time: now,
139 last_chunk_time: now,
140 pending_tool_results: Vec::new(),
141 pending_tool_uses: Vec::new(),
142 final_text: String::new(),
143 total_usage: Usage::default(),
144 phase: Phase::StartRequest,
145 }
146 }
147
148 fn extract_structured_output(&self, text: &str) -> Option<serde_json::Value> {
157 self.cfg.config.prompt.output_schema.as_ref()?;
159
160 serde_json::from_str(text).ok()
162 }
163
164 async fn next_event(&mut self) -> Option<crate::Result<AgentEvent>> {
165 loop {
166 if matches!(self.phase, Phase::Done) {
167 return None;
168 }
169
170 if self.start_time.elapsed() > self.timeout {
171 self.phase = Phase::Done;
172 return Some(Err(crate::Error::Timeout(self.timeout)));
173 }
174
175 if let Some(event) = self.check_budget_exceeded() {
176 return Some(event);
177 }
178
179 match std::mem::replace(&mut self.phase, Phase::Done) {
180 Phase::StartRequest => {
181 if let Some(result) = self.do_start_request().await {
182 return Some(result);
183 }
184 }
185 Phase::Streaming(mut streaming) => {
186 match self
187 .do_poll_stream(&mut streaming.stream, &mut streaming.accumulated_usage)
188 .await
189 {
190 StreamPollResult::Event(event) => {
191 self.phase = Phase::Streaming(streaming);
192 return Some(event);
193 }
194 StreamPollResult::Continue => {
195 self.phase = Phase::Streaming(streaming);
196 }
197 StreamPollResult::StreamEnded => {
198 self.phase = Phase::StreamEnded {
199 accumulated_usage: streaming.accumulated_usage,
200 };
201 }
202 }
203 }
204 Phase::StreamEnded { accumulated_usage } => {
205 if let Some(event) = self.do_handle_stream_end(accumulated_usage).await {
206 return Some(event);
207 }
208 }
209 Phase::ProcessingTools { tool_index } => {
210 if let Some(result) = self.do_process_tool(tool_index).await {
211 return Some(result);
212 }
213 }
214 Phase::Done => return None,
215 }
216 }
217 }
218
219 fn check_budget_exceeded(&mut self) -> Option<crate::Result<AgentEvent>> {
220 let result = BudgetContext {
221 tracker: &self.cfg.budget_tracker,
222 tenant: self.cfg.tenant_budget.as_deref(),
223 config: &self.cfg.config.budget,
224 }
225 .check();
226
227 if let Err(e) = result {
228 self.phase = Phase::Done;
229 return Some(Err(e));
230 }
231
232 None
233 }
234
235 async fn do_start_request(&mut self) -> Option<crate::Result<AgentEvent>> {
236 self.metrics.iterations += 1;
237 if self.metrics.iterations > self.cfg.config.execution.max_iterations {
238 self.phase = Phase::Done;
239 self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
240
241 let messages = self
242 .cfg
243 .tool_state
244 .with_session(|session| session.to_api_messages())
245 .await;
246
247 let structured_output = self.extract_structured_output(&self.final_text);
248 return Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
249 text: self.final_text.clone(),
250 usage: self.total_usage,
251 tool_calls: self.metrics.tool_calls,
252 iterations: self.metrics.iterations - 1,
253 stop_reason: StopReason::MaxTokens,
254 state: AgentState::Completed,
255 metrics: self.metrics.clone(),
256 session_id: self.cfg.session_id.to_string(),
257 structured_output,
258 messages,
259 uuid: uuid::Uuid::new_v4().to_string(),
260 }))));
261 }
262
263 let messages = self
264 .cfg
265 .tool_state
266 .with_session(|session| {
267 session.to_api_messages_with_cache(self.cfg.config.cache.message_ttl_option())
268 })
269 .await;
270
271 let stream_request = self
272 .cfg
273 .request_builder
274 .build(messages, &self.dynamic_rules)
275 .with_stream();
276
277 let response = match self
278 .cfg
279 .client
280 .send_stream_with_auth_retry(stream_request)
281 .await
282 {
283 Ok(r) => r,
284 Err(e) => {
285 self.phase = Phase::Done;
286 return Some(Err(e));
287 }
288 };
289
290 self.metrics.record_api_call();
291
292 let boxed_stream: BoxedByteStream = Box::pin(response.bytes_stream());
293 self.phase = Phase::Streaming(Box::new(StreamingPhase {
294 stream: RecoverableStream::new(boxed_stream),
295 accumulated_usage: Usage::default(),
296 }));
297
298 None
299 }
300
301 async fn do_poll_stream(
302 &mut self,
303 stream: &mut RecoverableStream<BoxedByteStream>,
304 accumulated_usage: &mut Usage,
305 ) -> StreamPollResult {
306 let chunk_result = tokio::time::timeout(self.chunk_timeout, stream.next()).await;
307
308 match chunk_result {
309 Ok(Some(Ok(item))) => {
310 self.last_chunk_time = Instant::now();
311 self.handle_stream_item(item, accumulated_usage)
312 }
313 Ok(Some(Err(e))) => {
314 self.phase = Phase::Done;
315 StreamPollResult::Event(Err(e))
316 }
317 Ok(None) => StreamPollResult::StreamEnded,
318 Err(_) => {
319 self.phase = Phase::Done;
320 StreamPollResult::Event(Err(crate::Error::Stream(format!(
321 "Chunk timeout after {:?} (no data received)",
322 self.chunk_timeout
323 ))))
324 }
325 }
326 }
327
328 fn handle_stream_item(
329 &mut self,
330 item: StreamItem,
331 accumulated_usage: &mut Usage,
332 ) -> StreamPollResult {
333 match item {
334 StreamItem::Text(text) => {
335 self.final_text.push_str(&text);
336 StreamPollResult::Event(Ok(AgentEvent::Text(text)))
337 }
338 StreamItem::Thinking(thinking) => {
339 StreamPollResult::Event(Ok(AgentEvent::Thinking(thinking)))
340 }
341 StreamItem::Citation(_) => StreamPollResult::Continue,
342 StreamItem::ToolUseComplete(tool_use) => {
343 self.pending_tool_uses.push(tool_use);
344 StreamPollResult::Continue
345 }
346 StreamItem::Event(event) => self.handle_stream_event(event, accumulated_usage),
347 }
348 }
349
350 fn handle_stream_event(
351 &mut self,
352 event: StreamEvent,
353 accumulated_usage: &mut Usage,
354 ) -> StreamPollResult {
355 match event {
356 StreamEvent::MessageStart { message } => {
357 accumulated_usage.input_tokens = message.usage.input_tokens;
358 accumulated_usage.output_tokens = message.usage.output_tokens;
359 accumulated_usage.cache_creation_input_tokens =
360 message.usage.cache_creation_input_tokens;
361 accumulated_usage.cache_read_input_tokens = message.usage.cache_read_input_tokens;
362 StreamPollResult::Continue
363 }
364 StreamEvent::ContentBlockStart { .. } => StreamPollResult::Continue,
365 StreamEvent::ContentBlockDelta { .. } => StreamPollResult::Continue,
366 StreamEvent::ContentBlockStop { .. } => StreamPollResult::Continue,
367 StreamEvent::MessageDelta { usage, .. } => {
368 accumulated_usage.output_tokens = usage.output_tokens;
369 StreamPollResult::Continue
370 }
371 StreamEvent::MessageStop => StreamPollResult::StreamEnded,
372 StreamEvent::Ping => StreamPollResult::Continue,
373 StreamEvent::Error { error } => {
374 self.phase = Phase::Done;
375 StreamPollResult::Event(Err(crate::Error::Stream(error.message)))
376 }
377 }
378 }
379
380 async fn do_handle_stream_end(
381 &mut self,
382 accumulated_usage: Usage,
383 ) -> Option<crate::Result<AgentEvent>> {
384 self.total_usage.input_tokens += accumulated_usage.input_tokens;
385 self.total_usage.output_tokens += accumulated_usage.output_tokens;
386 self.metrics.add_usage(
387 accumulated_usage.input_tokens,
388 accumulated_usage.output_tokens,
389 );
390 self.metrics
391 .record_model_usage(&self.cfg.config.model.primary, &accumulated_usage);
392
393 let cost = self
394 .cfg
395 .budget_tracker
396 .record(&self.cfg.config.model.primary, &accumulated_usage);
397 self.metrics.add_cost(cost);
398
399 if let Some(ref tenant_budget) = self.cfg.tenant_budget {
400 tenant_budget.record(&self.cfg.config.model.primary, &accumulated_usage);
401 }
402
403 self.cfg
404 .tool_state
405 .with_session_mut(|session| {
406 session.update_usage(&accumulated_usage);
407
408 let mut content = Vec::new();
409 if !self.final_text.is_empty() {
410 content.push(ContentBlock::Text {
411 text: self.final_text.clone(),
412 citations: None,
413 cache_control: None,
414 });
415 }
416 for tool_use in &self.pending_tool_uses {
417 content.push(ContentBlock::ToolUse(tool_use.clone()));
418 }
419 if !content.is_empty() {
420 session.add_assistant_message(content, Some(accumulated_usage));
421 }
422 })
423 .await;
424
425 if self.pending_tool_uses.is_empty() {
426 self.phase = Phase::Done;
427 self.metrics.execution_time_ms = self.start_time.elapsed().as_millis() as u64;
428
429 let messages = self
430 .cfg
431 .tool_state
432 .with_session(|session| session.to_api_messages())
433 .await;
434
435 let structured_output = self.extract_structured_output(&self.final_text);
436 return Some(Ok(AgentEvent::Complete(Box::new(AgentResult {
437 text: self.final_text.clone(),
438 usage: self.total_usage,
439 tool_calls: self.metrics.tool_calls,
440 iterations: self.metrics.iterations,
441 stop_reason: StopReason::EndTurn,
442 state: AgentState::Completed,
443 metrics: self.metrics.clone(),
444 session_id: self.cfg.session_id.to_string(),
445 structured_output,
446 messages,
447 uuid: uuid::Uuid::new_v4().to_string(),
448 }))));
449 }
450
451 self.phase = Phase::ProcessingTools { tool_index: 0 };
452 None
453 }
454
455 async fn do_process_tool(&mut self, tool_index: usize) -> Option<crate::Result<AgentEvent>> {
456 if tool_index >= self.pending_tool_uses.len() {
457 if !self.pending_tool_results.is_empty() {
458 self.finalize_tool_results().await;
459 }
460 self.final_text.clear();
461 self.pending_tool_uses.clear();
462 self.phase = Phase::StartRequest;
463 return None;
464 }
465
466 let tool_use = self.pending_tool_uses[tool_index].clone();
467 self.execute_tool(tool_use, tool_index).await
468 }
469
470 async fn execute_tool(
471 &mut self,
472 tool_use: ToolUseBlock,
473 tool_index: usize,
474 ) -> Option<crate::Result<AgentEvent>> {
475 let pre_input = HookInput::pre_tool_use(
476 &*self.cfg.session_id,
477 &tool_use.name,
478 tool_use.input.clone(),
479 );
480 let pre_output = match self
481 .cfg
482 .hooks
483 .execute(HookEvent::PreToolUse, pre_input, &self.cfg.hook_context)
484 .await
485 {
486 Ok(output) => output,
487 Err(e) => {
488 warn!(tool = %tool_use.name, error = %e, "PreToolUse hook failed");
489 crate::hooks::HookOutput::allow()
490 }
491 };
492
493 if !pre_output.continue_execution {
494 let reason = pre_output
495 .stop_reason
496 .clone()
497 .unwrap_or_else(|| "Blocked by hook".into());
498 debug!(tool = %tool_use.name, "Tool blocked by hook");
499
500 self.pending_tool_results
501 .push(ToolResultBlock::error(&tool_use.id, reason.clone()));
502 self.metrics.record_permission_denial(
503 PermissionDenial::new(&tool_use.name, &tool_use.id, tool_use.input.clone())
504 .with_reason(reason.clone()),
505 );
506 self.phase = Phase::ProcessingTools {
507 tool_index: tool_index + 1,
508 };
509
510 return Some(Ok(AgentEvent::ToolBlocked {
511 id: tool_use.id,
512 name: tool_use.name,
513 reason,
514 }));
515 }
516
517 let actual_input = pre_output.updated_input.unwrap_or(tool_use.input.clone());
518
519 let start = Instant::now();
520 let result = self
521 .cfg
522 .tools
523 .execute(&tool_use.name, actual_input.clone())
524 .await;
525 let duration_ms = start.elapsed().as_millis() as u64;
526
527 let (output, is_error) = match &result.output {
528 crate::types::ToolOutput::Success(s) => (s.clone(), false),
529 crate::types::ToolOutput::SuccessBlocks(blocks) => {
530 let text = blocks
531 .iter()
532 .filter_map(|b| match b {
533 crate::types::ToolOutputBlock::Text { text } => Some(text.as_str()),
534 _ => None,
535 })
536 .collect::<Vec<_>>()
537 .join("\n");
538 (text, false)
539 }
540 crate::types::ToolOutput::Error(e) => (e.to_string(), true),
541 crate::types::ToolOutput::Empty => (String::new(), false),
542 };
543
544 self.metrics
545 .record_tool(&tool_use.id, &tool_use.name, duration_ms, is_error);
546
547 if let Some(ref inner_usage) = result.inner_usage {
548 self.cfg
549 .tool_state
550 .with_session_mut(|session| {
551 session.update_usage(inner_usage);
552 })
553 .await;
554 self.total_usage.input_tokens += inner_usage.input_tokens;
555 self.total_usage.output_tokens += inner_usage.output_tokens;
556 self.metrics
557 .add_usage(inner_usage.input_tokens, inner_usage.output_tokens);
558 let inner_model = result.inner_model.as_deref().unwrap_or("claude-haiku-4-5");
559 self.metrics.record_model_usage(inner_model, inner_usage);
560 let inner_cost = self.cfg.budget_tracker.record(inner_model, inner_usage);
561 self.metrics.add_cost(inner_cost);
562 }
563
564 if is_error {
565 let failure_input = HookInput::post_tool_use_failure(
566 &*self.cfg.session_id,
567 &tool_use.name,
568 result.error_message(),
569 );
570 if let Err(e) = self
571 .cfg
572 .hooks
573 .execute(
574 HookEvent::PostToolUseFailure,
575 failure_input,
576 &self.cfg.hook_context,
577 )
578 .await
579 {
580 warn!(tool = %tool_use.name, error = %e, "PostToolUseFailure hook failed");
581 }
582 } else {
583 let post_input = HookInput::post_tool_use(
584 &*self.cfg.session_id,
585 &tool_use.name,
586 result.output.clone(),
587 );
588 if let Err(e) = self
589 .cfg
590 .hooks
591 .execute(HookEvent::PostToolUse, post_input, &self.cfg.hook_context)
592 .await
593 {
594 warn!(tool = %tool_use.name, error = %e, "PostToolUse hook failed");
595 }
596 }
597
598 if let Some(file_path) = extract_file_path(&tool_use.name, &actual_input)
599 && let Some(ref orchestrator) = self.cfg.orchestrator
600 {
601 let orch = orchestrator.read().await;
602 let path = Path::new(&file_path);
603 if orch.has_matching_rules(path).await {
604 let dynamic_ctx = orch.build_dynamic_context(Some(path)).await;
605 if !dynamic_ctx.is_empty() {
606 self.dynamic_rules = dynamic_ctx;
607 }
608 }
609 }
610
611 self.pending_tool_results
612 .push(ToolResultBlock::from_tool_result(&tool_use.id, &result));
613 self.phase = Phase::ProcessingTools {
614 tool_index: tool_index + 1,
615 };
616
617 Some(Ok(AgentEvent::ToolComplete {
618 id: tool_use.id,
619 name: tool_use.name,
620 output,
621 is_error,
622 duration_ms,
623 }))
624 }
625
626 async fn finalize_tool_results(&mut self) {
627 let results = std::mem::take(&mut self.pending_tool_results);
628 let max_tokens = context_window::for_model(&self.cfg.config.model.primary);
629
630 self.cfg
631 .tool_state
632 .with_session_mut(|session| {
633 session.add_tool_results(results);
634 })
635 .await;
636
637 let should_compact = self
638 .cfg
639 .tool_state
640 .with_session(|session| {
641 self.cfg.config.execution.auto_compact
642 && session.should_compact(
643 max_tokens,
644 self.cfg.config.execution.compact_threshold,
645 self.cfg.config.execution.compact_keep_messages,
646 )
647 })
648 .await;
649
650 if should_compact {
651 let compact_result = self
652 .cfg
653 .tool_state
654 .compact(
655 &self.cfg.client,
656 self.cfg.config.execution.compact_keep_messages,
657 )
658 .await;
659
660 if let Ok(CompactResult::Compacted { .. }) = compact_result {
661 self.metrics.record_compaction();
662 let state_sections = collect_compaction_state(&self.cfg.tools).await;
663 if !state_sections.is_empty() {
664 self.cfg
665 .tool_state
666 .with_session_mut(|session| {
667 session.add_user_message(format!(
668 "<system-reminder>\n# State preserved after compaction\n\n{}\n</system-reminder>",
669 state_sections.join("\n\n")
670 ));
671 })
672 .await;
673 }
674 }
675 }
676 }
677}
678
679#[cfg(test)]
680mod tests {
681 use super::*;
682
683 #[test]
684 fn test_phase_transitions() {
685 assert!(matches!(Phase::StartRequest, Phase::StartRequest));
686 assert!(matches!(Phase::Done, Phase::Done));
687 }
688
689 #[test]
690 fn test_stream_poll_result_variants() {
691 let event = StreamPollResult::Event(Ok(AgentEvent::Text("test".into())));
692 assert!(matches!(event, StreamPollResult::Event(_)));
693
694 let cont = StreamPollResult::Continue;
695 assert!(matches!(cont, StreamPollResult::Continue));
696
697 let ended = StreamPollResult::StreamEnded;
698 assert!(matches!(ended, StreamPollResult::StreamEnded));
699 }
700}