Skip to main content

ai_agent/services/
streaming.rs

1// Source: ~/claudecode/openclaudecode/src/services/api/claude.ts (streaming logic)
2// Source: ~/claudecode/openclaudecode/src/services/tools/StreamingToolExecutor.ts
3#![allow(dead_code)]
4
5use crate::error::AgentError;
6use crate::types::TokenUsage;
7use crate::types::*;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11
12// ─── Streaming Constants (matching TypeScript) ───
13
14/// Default streaming idle timeout in milliseconds (90 seconds)
15pub const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 90_000;
16/// Default streaming idle warning threshold (half of timeout, 45 seconds)
17pub const DEFAULT_STREAM_IDLE_WARNING_MS: u64 = 45_000;
18/// Stall detection threshold in milliseconds (30 seconds)
19pub const STALL_THRESHOLD_MS: u64 = 30_000;
20
21// ─── Streaming Result (complete, matching TypeScript) ───
22
23/// Streaming result containing accumulated content, tool calls, and metadata.
24/// Matches TypeScript's partialMessage + newMessages + usage + cost accumulation.
25#[derive(Debug, Clone)]
26pub struct StreamingResult {
27    /// Accumulated text content from all content blocks
28    pub content: String,
29    /// Accumulated tool calls (completed tool_use blocks)
30    pub tool_calls: Vec<serde_json::Value>,
31    /// Token usage information
32    pub usage: TokenUsage,
33    /// API error type if any (e.g., "max_output_tokens", "prompt_too_long")
34    pub api_error: Option<String>,
35    /// Time to first token in milliseconds
36    pub ttft_ms: Option<u64>,
37    /// The stop_reason from message_delta (e.g., "end_turn", "tool_use", "max_tokens")
38    pub stop_reason: Option<String>,
39    /// Total cost in USD for this request
40    pub cost: f64,
41    /// Whether message_start event was received
42    pub message_started: bool,
43    /// Number of content blocks that were started
44    pub content_blocks_started: u32,
45    /// Number of content blocks that were completed
46    pub content_blocks_completed: u32,
47    /// Whether any tool_use blocks were completed
48    pub any_tool_use_completed: bool,
49    /// Research data from message_start (internal only, for ant userType)
50    pub research: Option<serde_json::Value>,
51}
52
53impl Default for StreamingResult {
54    fn default() -> Self {
55        Self {
56            content: String::new(),
57            tool_calls: Vec::new(),
58            usage: TokenUsage::default(),
59            api_error: None,
60            ttft_ms: None,
61            stop_reason: None,
62            cost: 0.0,
63            message_started: false,
64            content_blocks_started: 0,
65            content_blocks_completed: 0,
66            any_tool_use_completed: false,
67            research: None,
68        }
69    }
70}
71
72// ─── Stall Tracking ───
73
74/// Tracks streaming stall statistics.
75#[derive(Debug, Clone, Default)]
76pub struct StallStats {
77    /// Number of stalls detected
78    pub stall_count: u64,
79    /// Total stall time in milliseconds
80    pub total_stall_time_ms: u64,
81    /// Individual stall durations in milliseconds
82    pub stall_durations: Vec<u64>,
83}
84
85// ─── Stream Watchdog (idle timeout) ───
86
87/// Manages the stream idle timeout watchdog.
88/// Matches TypeScript's streamIdleTimer/streamIdleWarningTimer logic.
89pub struct StreamWatchdog {
90    /// Whether the watchdog is enabled
91    pub enabled: bool,
92    /// Idle timeout in milliseconds
93    pub idle_timeout_ms: u64,
94    /// Warning threshold in milliseconds
95    pub warning_threshold_ms: u64,
96    /// Whether the stream was aborted by the watchdog
97    pub aborted: bool,
98    /// When the watchdog fired (performance.now() snapshot)
99    pub watchdog_fired_at: Option<u128>,
100}
101
102impl StreamWatchdog {
103    pub fn new(enabled: bool, idle_timeout_ms: u64) -> Self {
104        Self {
105            enabled,
106            idle_timeout_ms,
107            warning_threshold_ms: idle_timeout_ms / 2,
108            aborted: false,
109            watchdog_fired_at: None,
110        }
111    }
112
113    pub fn from_env() -> Self {
114        let enabled = std::env::var(crate::constants::env::ai_code::ENABLE_STREAM_WATCHDOG)
115            .map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes" | "on"))
116            .unwrap_or(false);
117
118        let timeout_ms = std::env::var(crate::constants::env::ai_code::STREAM_IDLE_TIMEOUT_MS)
119            .ok()
120            .and_then(|s| s.parse::<u64>().ok())
121            .unwrap_or(DEFAULT_STREAM_IDLE_TIMEOUT_MS);
122
123        Self::new(enabled, timeout_ms)
124    }
125
126    /// Check if the watchdog has aborted the stream
127    pub fn is_aborted(&self) -> bool {
128        self.aborted
129    }
130
131    /// Get when the watchdog fired (for measuring abort propagation delay)
132    pub fn watchdog_fired_at(&self) -> Option<u128> {
133        self.watchdog_fired_at
134    }
135
136    /// Mark the watchdog as having fired (called by the actual timeout logic).
137    /// Returns the abort reason message.
138    pub fn fire(&mut self) -> String {
139        self.aborted = true;
140        self.watchdog_fired_at = Some(
141            std::time::SystemTime::now()
142                .duration_since(std::time::UNIX_EPOCH)
143                .unwrap_or_default()
144                .as_millis(),
145        );
146        format!(
147            "Stream idle timeout - no chunks received for {}ms",
148            self.idle_timeout_ms
149        )
150    }
151
152    /// Log warning when stream has been idle for half the timeout
153    pub fn warning_message(&self) -> String {
154        format!(
155            "Streaming idle warning: no chunks received for {}ms",
156            self.warning_threshold_ms
157        )
158    }
159}
160
161// ─── Non-Streaming Fallback Control ───
162
163/// Determines whether non-streaming fallback should be disabled.
164/// Matches TypeScript's disableFallback logic:
165/// - AI_CODE_DISABLE_NONSTREAMING_FALLBACK env var
166/// - GrowthBook feature flag 'tengu_disable_streaming_to_non_streaming_fallback'
167pub fn is_nonstreaming_fallback_disabled() -> bool {
168    // Check env var first
169    if std::env::var(crate::constants::env::ai_code::DISABLE_NONSTREAMING_FALLBACK)
170        .map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes" | "on"))
171        .unwrap_or(false)
172    {
173        return true;
174    }
175
176    // Check GrowthBook feature flag
177    if let Ok(value) = std::env::var("AI_CODE_TENGU_DISABLE_STREAMING_FALLBACK") {
178        if matches!(value.to_lowercase().as_str(), "1" | "true" | "yes" | "on") {
179            return true;
180        }
181    }
182
183    false
184}
185
186// ─── Non-Streaming Fallback Timeout ───
187
188/// Get the timeout for non-streaming fallback in milliseconds.
189/// Matches TypeScript's getNonstreamingFallbackTimeoutMs().
190pub fn get_nonstreaming_fallback_timeout_ms() -> u64 {
191    // Check for explicit override
192    if let Ok(ms) = std::env::var(crate::constants::env::ai_code::API_TIMEOUT_MS) {
193        if let Ok(val) = ms.parse::<u64>() {
194            return val;
195        }
196    }
197
198    // Default: 120s for remote (bridge) mode, 300s for local
199    if std::env::var("AI_CODE_REMOTE").is_ok() {
200        120_000
201    } else {
202        300_000
203    }
204}
205
206// ─── Stream Resource Cleanup ───
207
208/// Manages cleanup of stream resources to prevent memory leaks.
209/// Matches TypeScript's releaseStreamResources() + cleanupStream().
210pub fn cleanup_stream(abort_handle: &Option<Arc<AtomicBool>>) {
211    if let Some(handle) = abort_handle {
212        handle.store(true, Ordering::SeqCst);
213    }
214}
215
216pub fn release_stream_resources(
217    abort_handle: &Option<Arc<AtomicBool>>,
218    _stream_response: &Option<reqwest::Response>,
219) {
220    cleanup_stream(abort_handle);
221    // reqwest::Response body will be dropped when the Option is set to None
222    // The Response object holds native TLS/socket buffers outside the heap,
223    // so we must explicitly cancel it (matching TypeScript's streamResponse.body?.cancel()).
224    if let Some(response) = _stream_response {
225        // Abort the underlying connection if possible
226        let _ = response.error_for_status_ref();
227    }
228}
229
230// ─── Stream Completion Validation ───
231
232/// Validates that a stream completed properly.
233/// Matches TypeScript's check:
234///   if (!partialMessage || (newMessages.length === 0 && !stopReason))
235///     throw new Error('Stream ended without receiving any events')
236pub fn validate_stream_completion(result: &StreamingResult) -> Result<(), AgentError> {
237    if !result.message_started {
238        return Err(AgentError::StreamEndedWithoutEvents);
239    }
240
241    // If message_start was received but no content blocks completed AND no stop_reason,
242    // the stream ended prematurely (proxy returned message_start but dropped connection)
243    if result.content_blocks_started > 0
244        && result.content_blocks_completed == 0
245        && result.stop_reason.is_none()
246    {
247        return Err(AgentError::StreamEndedWithoutEvents);
248    }
249
250    Ok(())
251}
252
253// ─── 404 Stream Creation Error Detection ───
254
255/// Check if an error is a 404 during stream creation that should trigger
256/// non-streaming fallback.
257/// Matches TypeScript's is404StreamCreationError check.
258pub fn is_404_stream_creation_error(error: &AgentError) -> bool {
259    let error_str = error.to_string();
260    error_str.contains("404")
261        && (error_str.contains("Not Found") || error_str.contains("streaming"))
262}
263
264// ─── Fallback Triggered Error ───
265
266/// Error thrown after MAX_529_RETRIES consecutive 529 errors when a fallback
267/// model is available. Matches TypeScript's FallbackTriggeredError.
268#[derive(Debug, Clone)]
269pub struct FallbackTriggeredError {
270    pub original_model: String,
271    pub fallback_model: String,
272}
273
274impl std::fmt::Display for FallbackTriggeredError {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        write!(
277            f,
278            "Model fallback triggered: {} -> {}",
279            self.original_model, self.fallback_model
280        )
281    }
282}
283
284impl std::error::Error for FallbackTriggeredError {}
285
286/// Check if an AgentError wraps a FallbackTriggeredError.
287pub fn is_fallback_triggered_error(error: &AgentError) -> bool {
288    let msg = error.to_string();
289    msg.contains("Model fallback triggered")
290}
291
292/// Extract fallback info from an AgentError, if it's a FallbackTriggeredError.
293pub fn extract_fallback_error(error: &AgentError) -> Option<(String, String)> {
294    let msg = error.to_string();
295    const PREFIX: &str = "Model fallback triggered: ";
296    if msg.contains(PREFIX) {
297        // Parse "Model fallback triggered: original -> fallback"
298        if let Some(remainder) = msg.strip_prefix(PREFIX) {
299            if let Some(arrow_pos) = remainder.find(" -> ") {
300                let original = remainder[..arrow_pos].trim().to_string();
301                let fallback = remainder[arrow_pos + 4..].trim().to_string();
302                return Some((original, fallback));
303            }
304        }
305    }
306    None
307}
308
309// ─── 529 Error Detection ───
310
311/// Maximum consecutive 529 retries before triggering model fallback.
312pub const MAX_529_RETRIES: u32 = 3;
313
314/// Check if an error is a 529 (server overload) error.
315/// Matches TypeScript's is529Error.
316pub fn is_529_error(error: &AgentError) -> bool {
317    let msg = error.to_string();
318    let lower = msg.to_lowercase();
319    lower.contains("529")
320        || lower.contains("overloaded")
321        || lower.contains(r#""type":"overloaded_error""#)
322}
323
324/// Check if an error is a stale connection (ECONNRESET/EPIPE) that should
325/// trigger HTTP client recreation. Matches TypeScript's isStaleConnectionError.
326pub fn is_stale_connection_error(error: &AgentError) -> bool {
327    let msg = error.to_string();
328    let lower = msg.to_lowercase();
329    lower.contains("econnreset") || lower.contains("epipe") || lower.contains("connection reset")
330}
331
332/// Check if an error is an authentication failure (401) that should trigger
333/// client recreation / token refresh.
334pub fn is_auth_error(error: &AgentError) -> bool {
335    match error {
336        AgentError::Auth(_) => true,
337        AgentError::Api(msg) => {
338            let s = msg.to_lowercase();
339            s.contains("401") || s.contains("unauthorized") || s.contains("api key")
340        }
341        AgentError::Http(http_err) => {
342            let status = http_err.status();
343            status == Some(reqwest::StatusCode::UNAUTHORIZED)
344        }
345        _ => false,
346    }
347}
348
349// ─── Max Tokens Context Overflow ───
350
351/// Parse a max-tokens-context-overflow API error (400).
352/// Matches TypeScript's parseMaxTokensContextOverflowError.
353/// Example: "input length and `max_tokens` exceed context limit: 188059 + 20000 > 200000"
354pub fn parse_max_tokens_context_overflow(error: &AgentError) -> Option<(u64, u64, u64)> {
355    let msg = error.to_string();
356    if !msg.contains("input length and `max_tokens` exceed context limit") {
357        return None;
358    }
359
360    // Find pattern: N + M > L
361    let regex = regex::Regex::new(r"(\d+)\s*\+\s*(\d+)\s*>\s*(\d+)").ok()?;
362    let caps = regex.captures(&msg)?;
363    let input_tokens: u64 = caps.get(1)?.as_str().parse().ok()?;
364    let max_tokens: u64 = caps.get(2)?.as_str().parse().ok()?;
365    let context_limit: u64 = caps.get(3)?.as_str().parse().ok()?;
366
367    Some((input_tokens, max_tokens, context_limit))
368}
369
370// Minimum output tokens floor when adjusting max_tokens
371pub const FLOOR_OUTPUT_TOKENS: u64 = 3000;
372
373// ─── Abort Handling ───
374
375/// Check if an error is a 429 rate limit error specifically (NOT 529).
376/// 529 is server overload, 429 is client rate limit. They have different
377/// retry semantics. This separates pure 429 from 529 which was being
378/// caught by a broad "429" or "529" check.
379/// Matches TypeScript's is429OnlyError.
380pub fn is_429_only_error(error: &AgentError) -> bool {
381    let msg = error.to_string();
382    let lower = msg.to_lowercase();
383    // Match 429 but explicitly exclude 529
384    (lower.contains("429") || lower.contains("rate_limit") || lower.contains("rate limit"))
385        && !lower.contains("529")
386}
387
388/// Check if an error is a user-initiated abort.
389/// Matches TypeScript's APIUserAbortError handling.
390pub fn is_user_abort_error(error: &AgentError) -> bool {
391    matches!(error, AgentError::UserAborted)
392}
393
394/// Check if an error is an API connection timeout.
395pub fn is_api_timeout_error(error: &AgentError) -> bool {
396    matches!(error, AgentError::ApiConnectionTimeout(_))
397}
398
399// ─── Cost Calculation ───
400
401/// Calculate cost based on token usage and model.
402/// Matches TypeScript's cost tracking in message_delta.
403pub fn calculate_streaming_cost(usage: &TokenUsage, model: &str) -> f64 {
404    use crate::services::model_cost::TokenUsage as ModelCostTokenUsage;
405
406    // Convert from types::TokenUsage to model_cost::TokenUsage
407    let model_usage = ModelCostTokenUsage {
408        input_tokens: usage.input_tokens as u32,
409        output_tokens: usage.output_tokens as u32,
410        prompt_cache_write_tokens: usage.cache_creation_input_tokens.unwrap_or(0) as u32,
411        prompt_cache_read_tokens: usage.cache_read_input_tokens.unwrap_or(0) as u32,
412    };
413
414    crate::services::model_cost::calculate_cost(model, &model_usage)
415}
416
417// ─── Streaming Tool Executor ───
418
419use futures_util::{FutureExt, StreamExt};
420use std::sync::Mutex;
421
422/// Thread-safe shared executor state.
423/// Uses a tokio async channel to route tool execution requests.
424struct SharedExecutorInner {
425    tx: tokio::sync::mpsc::Sender<(
426        String,
427        serde_json::Value,
428        String,
429        tokio::sync::mpsc::Sender<crate::types::ToolResult>,
430    )>,
431}
432
433/// A clonable, thread-safe tool executor function wrapper.
434/// Uses an async channel to dispatch tool execution.
435pub struct SharedExecutorFn {
436    inner: Arc<SharedExecutorInner>,
437}
438
439impl Clone for SharedExecutorFn {
440    fn clone(&self) -> Self {
441        Self {
442            inner: Arc::clone(&self.inner),
443        }
444    }
445}
446
447impl SharedExecutorFn {
448    /// Create a new executor and spawn the dispatcher task.
449    /// Returns the executor and the dispatcher join handle.
450    pub fn new<F, Fut>(executor: F) -> (Self, tokio::task::JoinHandle<()>)
451    where
452        F: Fn(String, serde_json::Value, String) -> Fut + Send + Sync + 'static,
453        Fut: std::future::Future<Output = crate::types::ToolResult> + Send + 'static,
454    {
455        let (tx, mut rx) = tokio::sync::mpsc::channel(256);
456        let inner = Arc::new(SharedExecutorInner { tx });
457        let handle = tokio::spawn(async move {
458            while let Some((name, args, tool_call_id, resp_tx)) = rx.recv().await {
459                let result = executor(name, args, tool_call_id).await;
460                let _ = resp_tx.send(result).await;
461            }
462        });
463        (Self { inner }, handle)
464    }
465
466    pub async fn call(
467        &self,
468        name: String,
469        args: serde_json::Value,
470        tool_call_id: String,
471    ) -> crate::types::ToolResult {
472        let (resp_tx, mut resp_rx) = tokio::sync::mpsc::channel(1);
473        self.inner
474            .tx
475            .send((name, args, tool_call_id, resp_tx))
476            .await
477            .expect("dispatcher disconnected");
478        resp_rx.recv().await.expect("dispatcher dropped response")
479    }
480}
481
482/// Status of a tracked tool in the streaming executor.
483#[derive(Debug, Clone, PartialEq)]
484pub enum ToolStatus {
485    Queued,
486    Executing,
487    Completed,
488    Yielded,
489}
490
491/// A tool being tracked by the streaming executor.
492#[derive(Debug)]
493pub struct TrackedTool {
494    /// Unique tool ID
495    pub id: String,
496    /// The tool_use block from the API
497    pub block: serde_json::Value,
498    /// Whether this tool is concurrency-safe
499    pub is_concurrency_safe: bool,
500    /// Current status
501    pub status: ToolStatus,
502    /// Pending progress messages to be yielded
503    pub pending_progress: Vec<AgentEvent>,
504    /// Whether this tool has errored
505    pub has_errored: bool,
506    /// Context modifiers collected during tool execution (for contextModifier support)
507    pub context_modifiers: Vec<fn(crate::types::ToolContext) -> crate::types::ToolContext>,
508}
509
510/// Internal shared state for the streaming executor.
511struct ExecutorState {
512    tools: Vec<TrackedTool>,
513    discarded: bool,
514    has_errored: bool,
515    errored_tool_description: String,
516    parent_abort: Arc<AtomicBool>,
517    max_concurrency: usize,
518}
519
520/// Executes tools as they stream in with concurrency control.
521/// Rust port of TypeScript's StreamingToolExecutor class.
522/// - Concurrency-safe tools can execute in parallel
523/// - Non-concurrent tools must execute exclusively
524///
525/// Uses Arc<Mutex<ExecutorState>> for interior mutability so it can be shared
526/// via Arc and called from the SSE parsing loop, while spawned tasks can
527/// also access the state for marking completions.
528pub struct StreamingToolExecutor {
529    state: Arc<Mutex<ExecutorState>>,
530}
531
532impl StreamingToolExecutor {
533    pub fn new(parent_abort: Arc<AtomicBool>) -> Self {
534        Self {
535            state: Arc::new(Mutex::new(ExecutorState {
536                tools: Vec::new(),
537                discarded: false,
538                has_errored: false,
539                errored_tool_description: String::new(),
540                parent_abort,
541                max_concurrency: 4,
542            })),
543        }
544    }
545
546    fn clone_state(&self) -> Arc<Mutex<ExecutorState>> {
547        Arc::clone(&self.state)
548    }
549
550    /// Discard all pending and in-progress tools.
551    /// Called when streaming fallback occurs.
552    pub fn discard(&self) {
553        self.state
554            .lock()
555            .expect("StreamingToolExecutor mutex poisoned")
556            .discarded = true;
557    }
558
559    /// Add a tool to the execution queue.
560    pub fn add_tool(&self, tool_use_block: serde_json::Value, is_concurrency_safe: bool) {
561        let tool_id = tool_use_block
562            .get("id")
563            .and_then(|v| v.as_str())
564            .unwrap_or("")
565            .to_string();
566
567        let mut state = self
568            .state
569            .lock()
570            .expect("StreamingToolExecutor mutex poisoned");
571        state.tools.push(TrackedTool {
572            id: tool_id,
573            block: tool_use_block,
574            is_concurrency_safe,
575            status: ToolStatus::Queued,
576            pending_progress: Vec::new(),
577            has_errored: false,
578            context_modifiers: Vec::new(),
579        });
580    }
581
582    /// Check if a tool can execute based on current concurrency state.
583    fn can_execute_tool(&self, is_concurrency_safe: bool) -> bool {
584        let state = self
585            .state
586            .lock()
587            .expect("StreamingToolExecutor mutex poisoned");
588        let executing_safe: Vec<bool> = state
589            .tools
590            .iter()
591            .filter(|t| t.status == ToolStatus::Executing)
592            .map(|t| t.is_concurrency_safe)
593            .collect();
594        drop(state);
595
596        executing_safe.is_empty() || (is_concurrency_safe && executing_safe.iter().all(|s| *s))
597    }
598
599    /// Check abort reasons for a tool.
600    fn get_abort_reason_inner(&self) -> Option<&'static str> {
601        let state = self
602            .state
603            .lock()
604            .expect("StreamingToolExecutor mutex poisoned");
605        if state.discarded {
606            return Some("streaming_fallback");
607        }
608        if state.has_errored {
609            return Some("sibling_error");
610        }
611        if state.parent_abort.load(Ordering::SeqCst) {
612            return Some("user_interrupted");
613        }
614        None
615    }
616
617    /// Get the number of currently executing tools
618    fn executing_count(&self) -> usize {
619        let state = self
620            .state
621            .lock()
622            .expect("StreamingToolExecutor mutex poisoned");
623        state
624            .tools
625            .iter()
626            .filter(|t| t.status == ToolStatus::Executing)
627            .count()
628    }
629
630    /// Check if there are any unfinished tools
631    pub fn has_unfinished_tools(&self) -> bool {
632        let state = self
633            .state
634            .lock()
635            .expect("StreamingToolExecutor mutex poisoned");
636        state.tools.iter().any(|t| t.status != ToolStatus::Yielded)
637    }
638
639    /// Get completed results that haven't been yielded.
640    /// Stops on non-concurrency-safe executing tool (yielding order).
641    pub fn get_completed_results(&self) -> Vec<(String, serde_json::Value)> {
642        let mut state = self
643            .state
644            .lock()
645            .expect("StreamingToolExecutor mutex poisoned");
646        if state.discarded {
647            return Vec::new();
648        }
649
650        let mut results = Vec::new();
651
652        for tool in &mut state.tools {
653            tool.pending_progress.clear();
654
655            if tool.status == ToolStatus::Yielded {
656                continue;
657            }
658
659            if tool.status == ToolStatus::Completed {
660                tool.status = ToolStatus::Yielded;
661                results.push((tool.id.clone(), tool.block.clone()));
662            } else if tool.status == ToolStatus::Executing && !tool.is_concurrency_safe {
663                break;
664            }
665        }
666
667        results
668    }
669
670    /// Mark a tool as having errored (cascading error for sibling tools).
671    pub fn mark_tool_errored(&self, tool_id: &str, _description: &str) {
672        let mut state = self
673            .state
674            .lock()
675            .expect("StreamingToolExecutor mutex poisoned");
676        state.has_errored = true;
677
678        if let Some(tool) = state.tools.iter_mut().find(|t| t.id == tool_id) {
679            tool.has_errored = true;
680        }
681    }
682
683    /// Get the current state summary for debugging
684    pub fn summary(&self) -> String {
685        let state = self
686            .state
687            .lock()
688            .expect("StreamingToolExecutor mutex poisoned");
689        let queued = state
690            .tools
691            .iter()
692            .filter(|t| t.status == ToolStatus::Queued)
693            .count();
694        let executing = state
695            .tools
696            .iter()
697            .filter(|t| t.status == ToolStatus::Executing)
698            .count();
699        let completed = state
700            .tools
701            .iter()
702            .filter(|t| t.status == ToolStatus::Completed)
703            .count();
704        let yielded = state
705            .tools
706            .iter()
707            .filter(|t| t.status == ToolStatus::Yielded)
708            .count();
709        let discarded = state.discarded;
710        drop(state);
711        format!(
712            "StreamingToolExecutor: queued={}, executing={}, completed={}, yielded={}, discarded={}",
713            queued, executing, completed, yielded, discarded
714        )
715    }
716
717    /// Execute queued tools with concurrency control.
718    /// Spawns each tool as a task and waits for results respecting concurrency limits.
719    /// Returns list of (tool_id, result) pairs in execution order.
720    pub async fn execute_all(
721        &self,
722        executor_fn: SharedExecutorFn,
723    ) -> Vec<(String, Result<crate::types::ToolResult, crate::AgentError>)> {
724        // ── Synchronous phase: collect can-run tools and mark them executing ──
725        let (can_run, max_concurrency) = {
726            let state = self
727                .state
728                .lock()
729                .expect("StreamingToolExecutor mutex poisoned");
730
731            let mut can_run: Vec<(String, serde_json::Value, serde_json::Value, bool)> = Vec::new();
732
733            for tool in &state.tools {
734                if tool.status != ToolStatus::Queued {
735                    continue;
736                }
737                if tool.has_errored {
738                    continue;
739                }
740
741                let block = tool.block.clone();
742                let tool_id = tool.id.clone();
743
744                let blocked = state
745                    .tools
746                    .iter()
747                    .any(|t| t.status == ToolStatus::Executing && !t.is_concurrency_safe);
748                if blocked && !tool.is_concurrency_safe {
749                    continue;
750                }
751
752                let executing_in_state = state
753                    .tools
754                    .iter()
755                    .filter(|t| t.status == ToolStatus::Executing)
756                    .count();
757                if executing_in_state >= state.max_concurrency {
758                    continue;
759                }
760
761                let name = block
762                    .get("name")
763                    .and_then(|n| n.as_str())
764                    .unwrap_or("")
765                    .to_string();
766                let args = block
767                    .get("arguments")
768                    .cloned()
769                    .unwrap_or(serde_json::Value::Null);
770                can_run.push((tool_id, block, args, tool.is_concurrency_safe));
771            }
772
773            let max_concurrency = state.max_concurrency;
774            drop(state);
775
776            // Mark can-run tools as executing
777            {
778                let mut state = self
779                    .state
780                    .lock()
781                    .expect("StreamingToolExecutor mutex poisoned");
782                for (tool_id, _, _, _) in &can_run {
783                    if let Some(tool) = state.tools.iter_mut().find(|t| t.id == *tool_id) {
784                        tool.status = ToolStatus::Executing;
785                    }
786                }
787            }
788
789            (can_run, max_concurrency)
790        };
791
792        // ── Async phase: execute tools ──
793        let mut results: Vec<(String, Result<crate::types::ToolResult, crate::AgentError>)> =
794            Vec::with_capacity(can_run.len());
795
796        let state_arc = self.clone_state();
797        let total = can_run.len();
798
799        for chunk_start in (0..total).step_by(max_concurrency) {
800            let chunk_end = (chunk_start + max_concurrency).min(total);
801            let mut handles = Vec::new();
802
803            for (tool_id, block, args, _is_safe) in &can_run[chunk_start..chunk_end] {
804                let name = block
805                    .get("name")
806                    .and_then(|n| n.as_str())
807                    .unwrap_or("")
808                    .to_string();
809                let tid = tool_id.clone();
810                let args = args.clone();
811                let exec = executor_fn.clone();
812                let state_arc = Arc::clone(&state_arc);
813
814                let handle = tokio::spawn(async move {
815                    let tool_result = exec.call(name, args, tid.clone()).await;
816
817                    // Mark as completed
818                    {
819                        let mut st = state_arc
820                            .lock()
821                            .expect("StreamingToolExecutor mutex poisoned");
822                        if let Some(tool) = st.tools.iter_mut().find(|t| t.id == tid) {
823                            tool.status = ToolStatus::Completed;
824                        }
825                    }
826
827                    let result = Ok(tool_result);
828                    if result
829                        .as_ref()
830                        .map(|r| r.is_error == Some(true))
831                        .unwrap_or(false)
832                    {
833                        state_arc
834                            .lock()
835                            .expect("StreamingToolExecutor mutex poisoned")
836                            .has_errored = true;
837                    }
838
839                    (tid, result)
840                });
841                handles.push(handle);
842            }
843
844            // Collect results for this chunk
845            for handle in handles {
846                let (tool_id, result) = handle.await.unwrap_or_else(|e| {
847                    (
848                        "unknown".to_string(),
849                        Err(crate::AgentError::Tool(format!("Task panicked: {}", e))),
850                    )
851                });
852                results.push((tool_id, result));
853            }
854        }
855
856        results
857    }
858}
859
860#[cfg(test)]
861mod tests {
862    use super::*;
863
864    #[test]
865    fn test_streaming_result_defaults() {
866        let result = StreamingResult::default();
867        assert!(!result.message_started);
868        assert_eq!(result.content_blocks_started, 0);
869        assert_eq!(result.content_blocks_completed, 0);
870        assert!(!result.any_tool_use_completed);
871        assert!(result.ttft_ms.is_none());
872        assert!(result.stop_reason.is_none());
873        assert_eq!(result.cost, 0.0);
874    }
875
876    #[test]
877    fn test_stream_watchdog_defaults() {
878        let watchdog = StreamWatchdog::new(false, DEFAULT_STREAM_IDLE_TIMEOUT_MS);
879        assert!(!watchdog.is_aborted());
880        assert!(watchdog.watchdog_fired_at().is_none());
881    }
882
883    #[test]
884    fn test_stream_watchdog_fire() {
885        let mut watchdog = StreamWatchdog::new(true, 90_000);
886        assert!(!watchdog.is_aborted());
887
888        let reason = watchdog.fire();
889        assert!(watchdog.is_aborted());
890        assert!(watchdog.watchdog_fired_at().is_some());
891        assert!(reason.contains("idle timeout"));
892    }
893
894    #[test]
895    fn test_nonstreaming_fallback_disabled_default() {
896        // By default, fallback should NOT be disabled
897        assert!(!is_nonstreaming_fallback_disabled());
898    }
899
900    #[test]
901    fn test_stream_completion_validation_started_but_not_completed() {
902        let mut result = StreamingResult::default();
903        result.message_started = true;
904        result.content_blocks_started = 1;
905        // No blocks completed, no stop_reason - should fail validation
906        assert!(validate_stream_completion(&result).is_err());
907    }
908
909    #[test]
910    fn test_stream_completion_validation_message_not_started() {
911        let result = StreamingResult::default();
912        assert!(validate_stream_completion(&result).is_err());
913    }
914
915    #[test]
916    fn test_stream_completion_validation_valid() {
917        let mut result = StreamingResult::default();
918        result.message_started = true;
919        result.content_blocks_started = 1;
920        result.content_blocks_completed = 1;
921        assert!(validate_stream_completion(&result).is_ok());
922    }
923
924    #[test]
925    fn test_stream_completion_validation_with_stop_reason() {
926        let mut result = StreamingResult::default();
927        result.message_started = true;
928        result.content_blocks_started = 1;
929        result.stop_reason = Some("end_turn".to_string());
930        assert!(validate_stream_completion(&result).is_ok());
931    }
932
933    #[test]
934    fn test_is_404_stream_creation_error() {
935        assert!(is_404_stream_creation_error(&AgentError::Api(
936            "Streaming API error 404: Not Found".to_string()
937        )));
938        assert!(is_404_stream_creation_error(&AgentError::Api(
939            "404 streaming endpoint not found".to_string()
940        )));
941        assert!(!is_404_stream_creation_error(&AgentError::Api(
942            "API error: 500".to_string()
943        )));
944    }
945
946    #[test]
947    fn test_is_user_abort_error() {
948        assert!(is_user_abort_error(&AgentError::UserAborted));
949        assert!(!is_user_abort_error(&AgentError::Api(
950            "timeout".to_string()
951        )));
952    }
953
954    #[test]
955    fn test_is_api_timeout_error() {
956        assert!(is_api_timeout_error(&AgentError::ApiConnectionTimeout(
957            "Request timed out".to_string()
958        )));
959        assert!(!is_api_timeout_error(&AgentError::Api("other".to_string())));
960    }
961
962    #[test]
963    fn test_streaming_tool_executor_add_and_summary() {
964        let abort = Arc::new(AtomicBool::new(false));
965        let executor = StreamingToolExecutor::new(abort);
966
967        executor.add_tool(
968            serde_json::json!({"id": "tool_1", "name": "Bash", "input": {"command": "ls"}}),
969            true,
970        );
971        executor.add_tool(
972            serde_json::json!({"id": "tool_2", "name": "Read", "input": {"file": "foo.txt"}}),
973            false,
974        );
975
976        let summary = executor.summary();
977        assert!(summary.contains("queued=2"));
978        assert!(executor.has_unfinished_tools());
979    }
980
981    #[test]
982    fn test_streaming_tool_executor_can_execute() {
983        let abort = Arc::new(AtomicBool::new(false));
984        let executor = StreamingToolExecutor::new(abort);
985
986        // No tools executing - should allow
987        assert!(executor.can_execute_tool(true));
988        assert!(executor.can_execute_tool(false));
989
990        // Simulate a concurrency-safe tool executing
991        executor.add_tool(serde_json::json!({"id": "tool_1", "name": "Bash"}), true);
992        {
993            let mut state = executor.state.lock().expect("mutex poisoned");
994            state.tools[0].status = ToolStatus::Executing;
995        }
996
997        // Another concurrency-safe tool can execute alongside
998        assert!(executor.can_execute_tool(true));
999        // Non-concurrency-safe tool cannot
1000        assert!(!executor.can_execute_tool(false));
1001    }
1002
1003    #[test]
1004    fn test_streaming_tool_executor_discard() {
1005        let abort = Arc::new(AtomicBool::new(false));
1006        let mut executor = StreamingToolExecutor::new(abort);
1007
1008        executor.add_tool(serde_json::json!({"id": "tool_1", "name": "Bash"}), true);
1009        executor.discard();
1010
1011        let results = executor.get_completed_results();
1012        assert!(results.is_empty());
1013    }
1014
1015    #[test]
1016    fn test_stall_stats_default() {
1017        let stats = StallStats::default();
1018        assert_eq!(stats.stall_count, 0);
1019        assert_eq!(stats.total_stall_time_ms, 0);
1020    }
1021
1022    #[test]
1023    fn test_release_stream_resources() {
1024        let abort = Arc::new(AtomicBool::new(false));
1025        release_stream_resources(&Some(abort.clone()), &None);
1026        assert!(abort.load(Ordering::SeqCst));
1027    }
1028}