1#![allow(dead_code)]
4
5use crate::error::AgentError;
6use crate::types::TokenUsage;
7use crate::types::*;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11
12pub const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 90_000;
16pub const DEFAULT_STREAM_IDLE_WARNING_MS: u64 = 45_000;
18pub const STALL_THRESHOLD_MS: u64 = 30_000;
20
21#[derive(Debug, Clone)]
26pub struct StreamingResult {
27 pub content: String,
29 pub tool_calls: Vec<serde_json::Value>,
31 pub usage: TokenUsage,
33 pub api_error: Option<String>,
35 pub ttft_ms: Option<u64>,
37 pub stop_reason: Option<String>,
39 pub cost: f64,
41 pub message_started: bool,
43 pub content_blocks_started: u32,
45 pub content_blocks_completed: u32,
47 pub any_tool_use_completed: bool,
49 pub research: Option<serde_json::Value>,
51}
52
53impl Default for StreamingResult {
54 fn default() -> Self {
55 Self {
56 content: String::new(),
57 tool_calls: Vec::new(),
58 usage: TokenUsage::default(),
59 api_error: None,
60 ttft_ms: None,
61 stop_reason: None,
62 cost: 0.0,
63 message_started: false,
64 content_blocks_started: 0,
65 content_blocks_completed: 0,
66 any_tool_use_completed: false,
67 research: None,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Default)]
76pub struct StallStats {
77 pub stall_count: u64,
79 pub total_stall_time_ms: u64,
81 pub stall_durations: Vec<u64>,
83}
84
85pub struct StreamWatchdog {
90 pub enabled: bool,
92 pub idle_timeout_ms: u64,
94 pub warning_threshold_ms: u64,
96 pub aborted: bool,
98 pub watchdog_fired_at: Option<u128>,
100}
101
102impl StreamWatchdog {
103 pub fn new(enabled: bool, idle_timeout_ms: u64) -> Self {
104 Self {
105 enabled,
106 idle_timeout_ms,
107 warning_threshold_ms: idle_timeout_ms / 2,
108 aborted: false,
109 watchdog_fired_at: None,
110 }
111 }
112
113 pub fn from_env() -> Self {
114 let enabled = std::env::var(crate::constants::env::ai_code::ENABLE_STREAM_WATCHDOG)
115 .map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes" | "on"))
116 .unwrap_or(false);
117
118 let timeout_ms = std::env::var(crate::constants::env::ai_code::STREAM_IDLE_TIMEOUT_MS)
119 .ok()
120 .and_then(|s| s.parse::<u64>().ok())
121 .unwrap_or(DEFAULT_STREAM_IDLE_TIMEOUT_MS);
122
123 Self::new(enabled, timeout_ms)
124 }
125
126 pub fn is_aborted(&self) -> bool {
128 self.aborted
129 }
130
131 pub fn watchdog_fired_at(&self) -> Option<u128> {
133 self.watchdog_fired_at
134 }
135
136 pub fn fire(&mut self) -> String {
139 self.aborted = true;
140 self.watchdog_fired_at = Some(
141 std::time::SystemTime::now()
142 .duration_since(std::time::UNIX_EPOCH)
143 .unwrap_or_default()
144 .as_millis(),
145 );
146 format!(
147 "Stream idle timeout - no chunks received for {}ms",
148 self.idle_timeout_ms
149 )
150 }
151
152 pub fn warning_message(&self) -> String {
154 format!(
155 "Streaming idle warning: no chunks received for {}ms",
156 self.warning_threshold_ms
157 )
158 }
159}
160
161pub fn is_nonstreaming_fallback_disabled() -> bool {
168 if std::env::var(crate::constants::env::ai_code::DISABLE_NONSTREAMING_FALLBACK)
170 .map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes" | "on"))
171 .unwrap_or(false)
172 {
173 return true;
174 }
175
176 if let Ok(value) = std::env::var("AI_CODE_TENGU_DISABLE_STREAMING_FALLBACK") {
178 if matches!(value.to_lowercase().as_str(), "1" | "true" | "yes" | "on") {
179 return true;
180 }
181 }
182
183 false
184}
185
186pub fn get_nonstreaming_fallback_timeout_ms() -> u64 {
191 if let Ok(ms) = std::env::var(crate::constants::env::ai_code::API_TIMEOUT_MS) {
193 if let Ok(val) = ms.parse::<u64>() {
194 return val;
195 }
196 }
197
198 if std::env::var("AI_CODE_REMOTE").is_ok() {
200 120_000
201 } else {
202 300_000
203 }
204}
205
206pub fn cleanup_stream(abort_handle: &Option<Arc<AtomicBool>>) {
211 if let Some(handle) = abort_handle {
212 handle.store(true, Ordering::SeqCst);
213 }
214}
215
216pub fn release_stream_resources(
217 abort_handle: &Option<Arc<AtomicBool>>,
218 _stream_response: &Option<reqwest::Response>,
219) {
220 cleanup_stream(abort_handle);
221 if let Some(response) = _stream_response {
225 let _ = response.error_for_status_ref();
227 }
228}
229
230pub fn validate_stream_completion(result: &StreamingResult) -> Result<(), AgentError> {
237 if !result.message_started {
238 return Err(AgentError::StreamEndedWithoutEvents);
239 }
240
241 if result.content_blocks_started > 0
244 && result.content_blocks_completed == 0
245 && result.stop_reason.is_none()
246 {
247 return Err(AgentError::StreamEndedWithoutEvents);
248 }
249
250 Ok(())
251}
252
253pub fn is_404_stream_creation_error(error: &AgentError) -> bool {
259 let error_str = error.to_string();
260 error_str.contains("404")
261 && (error_str.contains("Not Found") || error_str.contains("streaming"))
262}
263
264#[derive(Debug, Clone)]
269pub struct FallbackTriggeredError {
270 pub original_model: String,
271 pub fallback_model: String,
272}
273
274impl std::fmt::Display for FallbackTriggeredError {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 write!(
277 f,
278 "Model fallback triggered: {} -> {}",
279 self.original_model, self.fallback_model
280 )
281 }
282}
283
284impl std::error::Error for FallbackTriggeredError {}
285
286pub fn is_fallback_triggered_error(error: &AgentError) -> bool {
288 let msg = error.to_string();
289 msg.contains("Model fallback triggered")
290}
291
292pub fn extract_fallback_error(error: &AgentError) -> Option<(String, String)> {
294 let msg = error.to_string();
295 const PREFIX: &str = "Model fallback triggered: ";
296 if msg.contains(PREFIX) {
297 if let Some(remainder) = msg.strip_prefix(PREFIX) {
299 if let Some(arrow_pos) = remainder.find(" -> ") {
300 let original = remainder[..arrow_pos].trim().to_string();
301 let fallback = remainder[arrow_pos + 4..].trim().to_string();
302 return Some((original, fallback));
303 }
304 }
305 }
306 None
307}
308
309pub const MAX_529_RETRIES: u32 = 3;
313
314pub fn is_529_error(error: &AgentError) -> bool {
317 let msg = error.to_string();
318 let lower = msg.to_lowercase();
319 lower.contains("529")
320 || lower.contains("overloaded")
321 || lower.contains(r#""type":"overloaded_error""#)
322}
323
324pub fn is_stale_connection_error(error: &AgentError) -> bool {
327 let msg = error.to_string();
328 let lower = msg.to_lowercase();
329 lower.contains("econnreset") || lower.contains("epipe") || lower.contains("connection reset")
330}
331
332pub fn is_auth_error(error: &AgentError) -> bool {
335 match error {
336 AgentError::Auth(_) => true,
337 AgentError::Api(msg) => {
338 let s = msg.to_lowercase();
339 s.contains("401") || s.contains("unauthorized") || s.contains("api key")
340 }
341 AgentError::Http(http_err) => {
342 let status = http_err.status();
343 status == Some(reqwest::StatusCode::UNAUTHORIZED)
344 }
345 _ => false,
346 }
347}
348
349pub fn parse_max_tokens_context_overflow(error: &AgentError) -> Option<(u64, u64, u64)> {
355 let msg = error.to_string();
356 if !msg.contains("input length and `max_tokens` exceed context limit") {
357 return None;
358 }
359
360 let regex = regex::Regex::new(r"(\d+)\s*\+\s*(\d+)\s*>\s*(\d+)").ok()?;
362 let caps = regex.captures(&msg)?;
363 let input_tokens: u64 = caps.get(1)?.as_str().parse().ok()?;
364 let max_tokens: u64 = caps.get(2)?.as_str().parse().ok()?;
365 let context_limit: u64 = caps.get(3)?.as_str().parse().ok()?;
366
367 Some((input_tokens, max_tokens, context_limit))
368}
369
370pub const FLOOR_OUTPUT_TOKENS: u64 = 3000;
372
373pub fn is_429_only_error(error: &AgentError) -> bool {
381 let msg = error.to_string();
382 let lower = msg.to_lowercase();
383 (lower.contains("429") || lower.contains("rate_limit") || lower.contains("rate limit"))
385 && !lower.contains("529")
386}
387
388pub fn is_user_abort_error(error: &AgentError) -> bool {
391 matches!(error, AgentError::UserAborted)
392}
393
394pub fn is_api_timeout_error(error: &AgentError) -> bool {
396 matches!(error, AgentError::ApiConnectionTimeout(_))
397}
398
399pub fn calculate_streaming_cost(usage: &TokenUsage, model: &str) -> f64 {
404 use crate::services::model_cost::TokenUsage as ModelCostTokenUsage;
405
406 let model_usage = ModelCostTokenUsage {
408 input_tokens: usage.input_tokens as u32,
409 output_tokens: usage.output_tokens as u32,
410 prompt_cache_write_tokens: usage.cache_creation_input_tokens.unwrap_or(0) as u32,
411 prompt_cache_read_tokens: usage.cache_read_input_tokens.unwrap_or(0) as u32,
412 };
413
414 crate::services::model_cost::calculate_cost(model, &model_usage)
415}
416
417use futures_util::{FutureExt, StreamExt};
420use std::sync::Mutex;
421
422struct SharedExecutorInner {
425 tx: tokio::sync::mpsc::Sender<(
426 String,
427 serde_json::Value,
428 String,
429 tokio::sync::mpsc::Sender<crate::types::ToolResult>,
430 )>,
431}
432
433pub struct SharedExecutorFn {
436 inner: Arc<SharedExecutorInner>,
437}
438
439impl Clone for SharedExecutorFn {
440 fn clone(&self) -> Self {
441 Self {
442 inner: Arc::clone(&self.inner),
443 }
444 }
445}
446
447impl SharedExecutorFn {
448 pub fn new<F, Fut>(executor: F) -> (Self, tokio::task::JoinHandle<()>)
451 where
452 F: Fn(String, serde_json::Value, String) -> Fut + Send + Sync + 'static,
453 Fut: std::future::Future<Output = crate::types::ToolResult> + Send + 'static,
454 {
455 let (tx, mut rx) = tokio::sync::mpsc::channel(256);
456 let inner = Arc::new(SharedExecutorInner { tx });
457 let handle = tokio::spawn(async move {
458 while let Some((name, args, tool_call_id, resp_tx)) = rx.recv().await {
459 let result = executor(name, args, tool_call_id).await;
460 let _ = resp_tx.send(result).await;
461 }
462 });
463 (Self { inner }, handle)
464 }
465
466 pub async fn call(
467 &self,
468 name: String,
469 args: serde_json::Value,
470 tool_call_id: String,
471 ) -> crate::types::ToolResult {
472 let (resp_tx, mut resp_rx) = tokio::sync::mpsc::channel(1);
473 self.inner
474 .tx
475 .send((name, args, tool_call_id, resp_tx))
476 .await
477 .expect("dispatcher disconnected");
478 resp_rx.recv().await.expect("dispatcher dropped response")
479 }
480}
481
482#[derive(Debug, Clone, PartialEq)]
484pub enum ToolStatus {
485 Queued,
486 Executing,
487 Completed,
488 Yielded,
489}
490
491#[derive(Debug)]
493pub struct TrackedTool {
494 pub id: String,
496 pub block: serde_json::Value,
498 pub is_concurrency_safe: bool,
500 pub status: ToolStatus,
502 pub pending_progress: Vec<AgentEvent>,
504 pub has_errored: bool,
506 pub context_modifiers: Vec<fn(crate::types::ToolContext) -> crate::types::ToolContext>,
508}
509
510struct ExecutorState {
512 tools: Vec<TrackedTool>,
513 discarded: bool,
514 has_errored: bool,
515 errored_tool_description: String,
516 parent_abort: Arc<AtomicBool>,
517 max_concurrency: usize,
518}
519
520pub struct StreamingToolExecutor {
529 state: Arc<Mutex<ExecutorState>>,
530}
531
532impl StreamingToolExecutor {
533 pub fn new(parent_abort: Arc<AtomicBool>) -> Self {
534 Self {
535 state: Arc::new(Mutex::new(ExecutorState {
536 tools: Vec::new(),
537 discarded: false,
538 has_errored: false,
539 errored_tool_description: String::new(),
540 parent_abort,
541 max_concurrency: 4,
542 })),
543 }
544 }
545
546 fn clone_state(&self) -> Arc<Mutex<ExecutorState>> {
547 Arc::clone(&self.state)
548 }
549
550 pub fn discard(&self) {
553 self.state
554 .lock()
555 .expect("StreamingToolExecutor mutex poisoned")
556 .discarded = true;
557 }
558
559 pub fn add_tool(&self, tool_use_block: serde_json::Value, is_concurrency_safe: bool) {
561 let tool_id = tool_use_block
562 .get("id")
563 .and_then(|v| v.as_str())
564 .unwrap_or("")
565 .to_string();
566
567 let mut state = self
568 .state
569 .lock()
570 .expect("StreamingToolExecutor mutex poisoned");
571 state.tools.push(TrackedTool {
572 id: tool_id,
573 block: tool_use_block,
574 is_concurrency_safe,
575 status: ToolStatus::Queued,
576 pending_progress: Vec::new(),
577 has_errored: false,
578 context_modifiers: Vec::new(),
579 });
580 }
581
582 fn can_execute_tool(&self, is_concurrency_safe: bool) -> bool {
584 let state = self
585 .state
586 .lock()
587 .expect("StreamingToolExecutor mutex poisoned");
588 let executing_safe: Vec<bool> = state
589 .tools
590 .iter()
591 .filter(|t| t.status == ToolStatus::Executing)
592 .map(|t| t.is_concurrency_safe)
593 .collect();
594 drop(state);
595
596 executing_safe.is_empty() || (is_concurrency_safe && executing_safe.iter().all(|s| *s))
597 }
598
599 fn get_abort_reason_inner(&self) -> Option<&'static str> {
601 let state = self
602 .state
603 .lock()
604 .expect("StreamingToolExecutor mutex poisoned");
605 if state.discarded {
606 return Some("streaming_fallback");
607 }
608 if state.has_errored {
609 return Some("sibling_error");
610 }
611 if state.parent_abort.load(Ordering::SeqCst) {
612 return Some("user_interrupted");
613 }
614 None
615 }
616
617 fn executing_count(&self) -> usize {
619 let state = self
620 .state
621 .lock()
622 .expect("StreamingToolExecutor mutex poisoned");
623 state
624 .tools
625 .iter()
626 .filter(|t| t.status == ToolStatus::Executing)
627 .count()
628 }
629
630 pub fn has_unfinished_tools(&self) -> bool {
632 let state = self
633 .state
634 .lock()
635 .expect("StreamingToolExecutor mutex poisoned");
636 state.tools.iter().any(|t| t.status != ToolStatus::Yielded)
637 }
638
639 pub fn get_completed_results(&self) -> Vec<(String, serde_json::Value)> {
642 let mut state = self
643 .state
644 .lock()
645 .expect("StreamingToolExecutor mutex poisoned");
646 if state.discarded {
647 return Vec::new();
648 }
649
650 let mut results = Vec::new();
651
652 for tool in &mut state.tools {
653 tool.pending_progress.clear();
654
655 if tool.status == ToolStatus::Yielded {
656 continue;
657 }
658
659 if tool.status == ToolStatus::Completed {
660 tool.status = ToolStatus::Yielded;
661 results.push((tool.id.clone(), tool.block.clone()));
662 } else if tool.status == ToolStatus::Executing && !tool.is_concurrency_safe {
663 break;
664 }
665 }
666
667 results
668 }
669
670 pub fn mark_tool_errored(&self, tool_id: &str, _description: &str) {
672 let mut state = self
673 .state
674 .lock()
675 .expect("StreamingToolExecutor mutex poisoned");
676 state.has_errored = true;
677
678 if let Some(tool) = state.tools.iter_mut().find(|t| t.id == tool_id) {
679 tool.has_errored = true;
680 }
681 }
682
683 pub fn summary(&self) -> String {
685 let state = self
686 .state
687 .lock()
688 .expect("StreamingToolExecutor mutex poisoned");
689 let queued = state
690 .tools
691 .iter()
692 .filter(|t| t.status == ToolStatus::Queued)
693 .count();
694 let executing = state
695 .tools
696 .iter()
697 .filter(|t| t.status == ToolStatus::Executing)
698 .count();
699 let completed = state
700 .tools
701 .iter()
702 .filter(|t| t.status == ToolStatus::Completed)
703 .count();
704 let yielded = state
705 .tools
706 .iter()
707 .filter(|t| t.status == ToolStatus::Yielded)
708 .count();
709 let discarded = state.discarded;
710 drop(state);
711 format!(
712 "StreamingToolExecutor: queued={}, executing={}, completed={}, yielded={}, discarded={}",
713 queued, executing, completed, yielded, discarded
714 )
715 }
716
717 pub async fn execute_all(
721 &self,
722 executor_fn: SharedExecutorFn,
723 ) -> Vec<(String, Result<crate::types::ToolResult, crate::AgentError>)> {
724 let (can_run, max_concurrency) = {
726 let state = self
727 .state
728 .lock()
729 .expect("StreamingToolExecutor mutex poisoned");
730
731 let mut can_run: Vec<(String, serde_json::Value, serde_json::Value, bool)> = Vec::new();
732
733 for tool in &state.tools {
734 if tool.status != ToolStatus::Queued {
735 continue;
736 }
737 if tool.has_errored {
738 continue;
739 }
740
741 let block = tool.block.clone();
742 let tool_id = tool.id.clone();
743
744 let blocked = state
745 .tools
746 .iter()
747 .any(|t| t.status == ToolStatus::Executing && !t.is_concurrency_safe);
748 if blocked && !tool.is_concurrency_safe {
749 continue;
750 }
751
752 let executing_in_state = state
753 .tools
754 .iter()
755 .filter(|t| t.status == ToolStatus::Executing)
756 .count();
757 if executing_in_state >= state.max_concurrency {
758 continue;
759 }
760
761 let name = block
762 .get("name")
763 .and_then(|n| n.as_str())
764 .unwrap_or("")
765 .to_string();
766 let args = block
767 .get("arguments")
768 .cloned()
769 .unwrap_or(serde_json::Value::Null);
770 can_run.push((tool_id, block, args, tool.is_concurrency_safe));
771 }
772
773 let max_concurrency = state.max_concurrency;
774 drop(state);
775
776 {
778 let mut state = self
779 .state
780 .lock()
781 .expect("StreamingToolExecutor mutex poisoned");
782 for (tool_id, _, _, _) in &can_run {
783 if let Some(tool) = state.tools.iter_mut().find(|t| t.id == *tool_id) {
784 tool.status = ToolStatus::Executing;
785 }
786 }
787 }
788
789 (can_run, max_concurrency)
790 };
791
792 let mut results: Vec<(String, Result<crate::types::ToolResult, crate::AgentError>)> =
794 Vec::with_capacity(can_run.len());
795
796 let state_arc = self.clone_state();
797 let total = can_run.len();
798
799 for chunk_start in (0..total).step_by(max_concurrency) {
800 let chunk_end = (chunk_start + max_concurrency).min(total);
801 let mut handles = Vec::new();
802
803 for (tool_id, block, args, _is_safe) in &can_run[chunk_start..chunk_end] {
804 let name = block
805 .get("name")
806 .and_then(|n| n.as_str())
807 .unwrap_or("")
808 .to_string();
809 let tid = tool_id.clone();
810 let args = args.clone();
811 let exec = executor_fn.clone();
812 let state_arc = Arc::clone(&state_arc);
813
814 let handle = tokio::spawn(async move {
815 let tool_result = exec.call(name, args, tid.clone()).await;
816
817 {
819 let mut st = state_arc
820 .lock()
821 .expect("StreamingToolExecutor mutex poisoned");
822 if let Some(tool) = st.tools.iter_mut().find(|t| t.id == tid) {
823 tool.status = ToolStatus::Completed;
824 }
825 }
826
827 let result = Ok(tool_result);
828 if result
829 .as_ref()
830 .map(|r| r.is_error == Some(true))
831 .unwrap_or(false)
832 {
833 state_arc
834 .lock()
835 .expect("StreamingToolExecutor mutex poisoned")
836 .has_errored = true;
837 }
838
839 (tid, result)
840 });
841 handles.push(handle);
842 }
843
844 for handle in handles {
846 let (tool_id, result) = handle.await.unwrap_or_else(|e| {
847 (
848 "unknown".to_string(),
849 Err(crate::AgentError::Tool(format!("Task panicked: {}", e))),
850 )
851 });
852 results.push((tool_id, result));
853 }
854 }
855
856 results
857 }
858}
859
860#[cfg(test)]
861mod tests {
862 use super::*;
863
864 #[test]
865 fn test_streaming_result_defaults() {
866 let result = StreamingResult::default();
867 assert!(!result.message_started);
868 assert_eq!(result.content_blocks_started, 0);
869 assert_eq!(result.content_blocks_completed, 0);
870 assert!(!result.any_tool_use_completed);
871 assert!(result.ttft_ms.is_none());
872 assert!(result.stop_reason.is_none());
873 assert_eq!(result.cost, 0.0);
874 }
875
876 #[test]
877 fn test_stream_watchdog_defaults() {
878 let watchdog = StreamWatchdog::new(false, DEFAULT_STREAM_IDLE_TIMEOUT_MS);
879 assert!(!watchdog.is_aborted());
880 assert!(watchdog.watchdog_fired_at().is_none());
881 }
882
883 #[test]
884 fn test_stream_watchdog_fire() {
885 let mut watchdog = StreamWatchdog::new(true, 90_000);
886 assert!(!watchdog.is_aborted());
887
888 let reason = watchdog.fire();
889 assert!(watchdog.is_aborted());
890 assert!(watchdog.watchdog_fired_at().is_some());
891 assert!(reason.contains("idle timeout"));
892 }
893
894 #[test]
895 fn test_nonstreaming_fallback_disabled_default() {
896 assert!(!is_nonstreaming_fallback_disabled());
898 }
899
900 #[test]
901 fn test_stream_completion_validation_started_but_not_completed() {
902 let mut result = StreamingResult::default();
903 result.message_started = true;
904 result.content_blocks_started = 1;
905 assert!(validate_stream_completion(&result).is_err());
907 }
908
909 #[test]
910 fn test_stream_completion_validation_message_not_started() {
911 let result = StreamingResult::default();
912 assert!(validate_stream_completion(&result).is_err());
913 }
914
915 #[test]
916 fn test_stream_completion_validation_valid() {
917 let mut result = StreamingResult::default();
918 result.message_started = true;
919 result.content_blocks_started = 1;
920 result.content_blocks_completed = 1;
921 assert!(validate_stream_completion(&result).is_ok());
922 }
923
924 #[test]
925 fn test_stream_completion_validation_with_stop_reason() {
926 let mut result = StreamingResult::default();
927 result.message_started = true;
928 result.content_blocks_started = 1;
929 result.stop_reason = Some("end_turn".to_string());
930 assert!(validate_stream_completion(&result).is_ok());
931 }
932
933 #[test]
934 fn test_is_404_stream_creation_error() {
935 assert!(is_404_stream_creation_error(&AgentError::Api(
936 "Streaming API error 404: Not Found".to_string()
937 )));
938 assert!(is_404_stream_creation_error(&AgentError::Api(
939 "404 streaming endpoint not found".to_string()
940 )));
941 assert!(!is_404_stream_creation_error(&AgentError::Api(
942 "API error: 500".to_string()
943 )));
944 }
945
946 #[test]
947 fn test_is_user_abort_error() {
948 assert!(is_user_abort_error(&AgentError::UserAborted));
949 assert!(!is_user_abort_error(&AgentError::Api(
950 "timeout".to_string()
951 )));
952 }
953
954 #[test]
955 fn test_is_api_timeout_error() {
956 assert!(is_api_timeout_error(&AgentError::ApiConnectionTimeout(
957 "Request timed out".to_string()
958 )));
959 assert!(!is_api_timeout_error(&AgentError::Api("other".to_string())));
960 }
961
962 #[test]
963 fn test_streaming_tool_executor_add_and_summary() {
964 let abort = Arc::new(AtomicBool::new(false));
965 let executor = StreamingToolExecutor::new(abort);
966
967 executor.add_tool(
968 serde_json::json!({"id": "tool_1", "name": "Bash", "input": {"command": "ls"}}),
969 true,
970 );
971 executor.add_tool(
972 serde_json::json!({"id": "tool_2", "name": "Read", "input": {"file": "foo.txt"}}),
973 false,
974 );
975
976 let summary = executor.summary();
977 assert!(summary.contains("queued=2"));
978 assert!(executor.has_unfinished_tools());
979 }
980
981 #[test]
982 fn test_streaming_tool_executor_can_execute() {
983 let abort = Arc::new(AtomicBool::new(false));
984 let executor = StreamingToolExecutor::new(abort);
985
986 assert!(executor.can_execute_tool(true));
988 assert!(executor.can_execute_tool(false));
989
990 executor.add_tool(serde_json::json!({"id": "tool_1", "name": "Bash"}), true);
992 {
993 let mut state = executor.state.lock().expect("mutex poisoned");
994 state.tools[0].status = ToolStatus::Executing;
995 }
996
997 assert!(executor.can_execute_tool(true));
999 assert!(!executor.can_execute_tool(false));
1001 }
1002
1003 #[test]
1004 fn test_streaming_tool_executor_discard() {
1005 let abort = Arc::new(AtomicBool::new(false));
1006 let mut executor = StreamingToolExecutor::new(abort);
1007
1008 executor.add_tool(serde_json::json!({"id": "tool_1", "name": "Bash"}), true);
1009 executor.discard();
1010
1011 let results = executor.get_completed_results();
1012 assert!(results.is_empty());
1013 }
1014
1015 #[test]
1016 fn test_stall_stats_default() {
1017 let stats = StallStats::default();
1018 assert_eq!(stats.stall_count, 0);
1019 assert_eq!(stats.total_stall_time_ms, 0);
1020 }
1021
1022 #[test]
1023 fn test_release_stream_resources() {
1024 let abort = Arc::new(AtomicBool::new(false));
1025 release_stream_resources(&Some(abort.clone()), &None);
1026 assert!(abort.load(Ordering::SeqCst));
1027 }
1028}