#![allow(dead_code)]
use crate::error::AgentError;
use crate::types::TokenUsage;
use crate::types::*;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
pub const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 90_000;
pub const DEFAULT_STREAM_IDLE_WARNING_MS: u64 = 45_000;
pub const STALL_THRESHOLD_MS: u64 = 30_000;
#[derive(Debug, Clone)]
pub struct StreamingResult {
pub content: String,
pub tool_calls: Vec<serde_json::Value>,
pub usage: TokenUsage,
pub api_error: Option<String>,
pub ttft_ms: Option<u64>,
pub stop_reason: Option<String>,
pub cost: f64,
pub message_started: bool,
pub content_blocks_started: u32,
pub content_blocks_completed: u32,
pub any_tool_use_completed: bool,
pub research: Option<serde_json::Value>,
}
impl Default for StreamingResult {
fn default() -> Self {
Self {
content: String::new(),
tool_calls: Vec::new(),
usage: TokenUsage::default(),
api_error: None,
ttft_ms: None,
stop_reason: None,
cost: 0.0,
message_started: false,
content_blocks_started: 0,
content_blocks_completed: 0,
any_tool_use_completed: false,
research: None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct StallStats {
pub stall_count: u64,
pub total_stall_time_ms: u64,
pub stall_durations: Vec<u64>,
}
pub struct StreamWatchdog {
pub enabled: bool,
pub idle_timeout_ms: u64,
pub warning_threshold_ms: u64,
pub aborted: bool,
pub watchdog_fired_at: Option<u128>,
}
impl StreamWatchdog {
pub fn new(enabled: bool, idle_timeout_ms: u64) -> Self {
Self {
enabled,
idle_timeout_ms,
warning_threshold_ms: idle_timeout_ms / 2,
aborted: false,
watchdog_fired_at: None,
}
}
pub fn from_env() -> Self {
let enabled = std::env::var(crate::constants::env::ai_code::ENABLE_STREAM_WATCHDOG)
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes" | "on"))
.unwrap_or(false);
let timeout_ms = std::env::var(crate::constants::env::ai_code::STREAM_IDLE_TIMEOUT_MS)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(DEFAULT_STREAM_IDLE_TIMEOUT_MS);
Self::new(enabled, timeout_ms)
}
pub fn is_aborted(&self) -> bool {
self.aborted
}
pub fn watchdog_fired_at(&self) -> Option<u128> {
self.watchdog_fired_at
}
pub fn fire(&mut self) -> String {
self.aborted = true;
self.watchdog_fired_at = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis(),
);
format!(
"Stream idle timeout - no chunks received for {}ms",
self.idle_timeout_ms
)
}
pub fn warning_message(&self) -> String {
format!(
"Streaming idle warning: no chunks received for {}ms",
self.warning_threshold_ms
)
}
}
pub fn is_nonstreaming_fallback_disabled() -> bool {
if std::env::var(crate::constants::env::ai_code::DISABLE_NONSTREAMING_FALLBACK)
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes" | "on"))
.unwrap_or(false)
{
return true;
}
if let Ok(value) = std::env::var("AI_CODE_TENGU_DISABLE_STREAMING_FALLBACK") {
if matches!(value.to_lowercase().as_str(), "1" | "true" | "yes" | "on") {
return true;
}
}
false
}
pub fn get_nonstreaming_fallback_timeout_ms() -> u64 {
if let Ok(ms) = std::env::var(crate::constants::env::ai_code::API_TIMEOUT_MS) {
if let Ok(val) = ms.parse::<u64>() {
return val;
}
}
if std::env::var("AI_CODE_REMOTE").is_ok() {
120_000
} else {
300_000
}
}
pub fn cleanup_stream(abort_handle: &Option<Arc<AtomicBool>>) {
if let Some(handle) = abort_handle {
handle.store(true, Ordering::SeqCst);
}
}
pub fn release_stream_resources(
abort_handle: &Option<Arc<AtomicBool>>,
_stream_response: &Option<reqwest::Response>,
) {
cleanup_stream(abort_handle);
if let Some(response) = _stream_response {
let _ = response.error_for_status_ref();
}
}
pub fn validate_stream_completion(result: &StreamingResult) -> Result<(), AgentError> {
if !result.message_started {
return Err(AgentError::StreamEndedWithoutEvents);
}
if result.content_blocks_started > 0
&& result.content_blocks_completed == 0
&& result.stop_reason.is_none()
{
return Err(AgentError::StreamEndedWithoutEvents);
}
Ok(())
}
pub fn is_404_stream_creation_error(error: &AgentError) -> bool {
let error_str = error.to_string();
error_str.contains("404")
&& (error_str.contains("Not Found") || error_str.contains("streaming"))
}
#[derive(Debug, Clone)]
pub struct FallbackTriggeredError {
pub original_model: String,
pub fallback_model: String,
}
impl std::fmt::Display for FallbackTriggeredError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Model fallback triggered: {} -> {}",
self.original_model, self.fallback_model
)
}
}
impl std::error::Error for FallbackTriggeredError {}
pub fn is_fallback_triggered_error(error: &AgentError) -> bool {
let msg = error.to_string();
msg.contains("Model fallback triggered")
}
pub fn extract_fallback_error(error: &AgentError) -> Option<(String, String)> {
let msg = error.to_string();
const PREFIX: &str = "Model fallback triggered: ";
if msg.contains(PREFIX) {
if let Some(remainder) = msg.strip_prefix(PREFIX) {
if let Some(arrow_pos) = remainder.find(" -> ") {
let original = remainder[..arrow_pos].trim().to_string();
let fallback = remainder[arrow_pos + 4..].trim().to_string();
return Some((original, fallback));
}
}
}
None
}
pub const MAX_529_RETRIES: u32 = 3;
pub fn is_529_error(error: &AgentError) -> bool {
let msg = error.to_string();
let lower = msg.to_lowercase();
lower.contains("529")
|| lower.contains("overloaded")
|| lower.contains(r#""type":"overloaded_error""#)
}
pub fn is_stale_connection_error(error: &AgentError) -> bool {
let msg = error.to_string();
let lower = msg.to_lowercase();
lower.contains("econnreset") || lower.contains("epipe") || lower.contains("connection reset")
}
pub fn is_auth_error(error: &AgentError) -> bool {
match error {
AgentError::Auth(_) => true,
AgentError::Api(msg) => {
let s = msg.to_lowercase();
s.contains("401") || s.contains("unauthorized") || s.contains("api key")
}
AgentError::Http(http_err) => {
let status = http_err.status();
status == Some(reqwest::StatusCode::UNAUTHORIZED)
}
_ => false,
}
}
pub fn parse_max_tokens_context_overflow(error: &AgentError) -> Option<(u64, u64, u64)> {
let msg = error.to_string();
if !msg.contains("input length and `max_tokens` exceed context limit") {
return None;
}
let regex = regex::Regex::new(r"(\d+)\s*\+\s*(\d+)\s*>\s*(\d+)").ok()?;
let caps = regex.captures(&msg)?;
let input_tokens: u64 = caps.get(1)?.as_str().parse().ok()?;
let max_tokens: u64 = caps.get(2)?.as_str().parse().ok()?;
let context_limit: u64 = caps.get(3)?.as_str().parse().ok()?;
Some((input_tokens, max_tokens, context_limit))
}
pub const FLOOR_OUTPUT_TOKENS: u64 = 3000;
pub fn is_429_only_error(error: &AgentError) -> bool {
let msg = error.to_string();
let lower = msg.to_lowercase();
(lower.contains("429") || lower.contains("rate_limit") || lower.contains("rate limit"))
&& !lower.contains("529")
}
pub fn is_user_abort_error(error: &AgentError) -> bool {
matches!(error, AgentError::UserAborted)
}
pub fn is_api_timeout_error(error: &AgentError) -> bool {
matches!(error, AgentError::ApiConnectionTimeout(_))
}
pub fn calculate_streaming_cost(usage: &TokenUsage, model: &str) -> f64 {
use crate::services::model_cost::TokenUsage as ModelCostTokenUsage;
let model_usage = ModelCostTokenUsage {
input_tokens: usage.input_tokens as u32,
output_tokens: usage.output_tokens as u32,
prompt_cache_write_tokens: usage.cache_creation_input_tokens.unwrap_or(0) as u32,
prompt_cache_read_tokens: usage.cache_read_input_tokens.unwrap_or(0) as u32,
};
crate::services::model_cost::calculate_cost(model, &model_usage)
}
use futures_util::{FutureExt, StreamExt};
use std::sync::Mutex;
struct SharedExecutorInner {
tx: tokio::sync::mpsc::Sender<(
String,
serde_json::Value,
String,
tokio::sync::mpsc::Sender<crate::types::ToolResult>,
)>,
}
pub struct SharedExecutorFn {
inner: Arc<SharedExecutorInner>,
}
impl Clone for SharedExecutorFn {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl SharedExecutorFn {
pub fn new<F, Fut>(executor: F) -> (Self, tokio::task::JoinHandle<()>)
where
F: Fn(String, serde_json::Value, String) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = crate::types::ToolResult> + Send + 'static,
{
let (tx, mut rx) = tokio::sync::mpsc::channel(256);
let inner = Arc::new(SharedExecutorInner { tx });
let handle = tokio::spawn(async move {
while let Some((name, args, tool_call_id, resp_tx)) = rx.recv().await {
let result = executor(name, args, tool_call_id).await;
let _ = resp_tx.send(result).await;
}
});
(Self { inner }, handle)
}
pub async fn call(
&self,
name: String,
args: serde_json::Value,
tool_call_id: String,
) -> crate::types::ToolResult {
let (resp_tx, mut resp_rx) = tokio::sync::mpsc::channel(1);
self.inner
.tx
.send((name, args, tool_call_id, resp_tx))
.await
.expect("dispatcher disconnected");
resp_rx.recv().await.expect("dispatcher dropped response")
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ToolStatus {
Queued,
Executing,
Completed,
Yielded,
}
#[derive(Debug)]
pub struct TrackedTool {
pub id: String,
pub block: serde_json::Value,
pub is_concurrency_safe: bool,
pub status: ToolStatus,
pub pending_progress: Vec<AgentEvent>,
pub has_errored: bool,
pub context_modifiers: Vec<fn(crate::types::ToolContext) -> crate::types::ToolContext>,
}
struct ExecutorState {
tools: Vec<TrackedTool>,
discarded: bool,
has_errored: bool,
errored_tool_description: String,
parent_abort: Arc<AtomicBool>,
max_concurrency: usize,
}
pub struct StreamingToolExecutor {
state: Arc<Mutex<ExecutorState>>,
}
impl StreamingToolExecutor {
pub fn new(parent_abort: Arc<AtomicBool>) -> Self {
Self {
state: Arc::new(Mutex::new(ExecutorState {
tools: Vec::new(),
discarded: false,
has_errored: false,
errored_tool_description: String::new(),
parent_abort,
max_concurrency: 4,
})),
}
}
fn clone_state(&self) -> Arc<Mutex<ExecutorState>> {
Arc::clone(&self.state)
}
pub fn discard(&self) {
self.state
.lock()
.expect("StreamingToolExecutor mutex poisoned")
.discarded = true;
}
pub fn add_tool(&self, tool_use_block: serde_json::Value, is_concurrency_safe: bool) {
let tool_id = tool_use_block
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let mut state = self
.state
.lock()
.expect("StreamingToolExecutor mutex poisoned");
state.tools.push(TrackedTool {
id: tool_id,
block: tool_use_block,
is_concurrency_safe,
status: ToolStatus::Queued,
pending_progress: Vec::new(),
has_errored: false,
context_modifiers: Vec::new(),
});
}
fn can_execute_tool(&self, is_concurrency_safe: bool) -> bool {
let state = self
.state
.lock()
.expect("StreamingToolExecutor mutex poisoned");
let executing_safe: Vec<bool> = state
.tools
.iter()
.filter(|t| t.status == ToolStatus::Executing)
.map(|t| t.is_concurrency_safe)
.collect();
drop(state);
executing_safe.is_empty() || (is_concurrency_safe && executing_safe.iter().all(|s| *s))
}
fn get_abort_reason_inner(&self) -> Option<&'static str> {
let state = self
.state
.lock()
.expect("StreamingToolExecutor mutex poisoned");
if state.discarded {
return Some("streaming_fallback");
}
if state.has_errored {
return Some("sibling_error");
}
if state.parent_abort.load(Ordering::SeqCst) {
return Some("user_interrupted");
}
None
}
fn executing_count(&self) -> usize {
let state = self
.state
.lock()
.expect("StreamingToolExecutor mutex poisoned");
state
.tools
.iter()
.filter(|t| t.status == ToolStatus::Executing)
.count()
}
pub fn has_unfinished_tools(&self) -> bool {
let state = self
.state
.lock()
.expect("StreamingToolExecutor mutex poisoned");
state.tools.iter().any(|t| t.status != ToolStatus::Yielded)
}
pub fn get_completed_results(&self) -> Vec<(String, serde_json::Value)> {
let mut state = self
.state
.lock()
.expect("StreamingToolExecutor mutex poisoned");
if state.discarded {
return Vec::new();
}
let mut results = Vec::new();
for tool in &mut state.tools {
tool.pending_progress.clear();
if tool.status == ToolStatus::Yielded {
continue;
}
if tool.status == ToolStatus::Completed {
tool.status = ToolStatus::Yielded;
results.push((tool.id.clone(), tool.block.clone()));
} else if tool.status == ToolStatus::Executing && !tool.is_concurrency_safe {
break;
}
}
results
}
pub fn mark_tool_errored(&self, tool_id: &str, _description: &str) {
let mut state = self
.state
.lock()
.expect("StreamingToolExecutor mutex poisoned");
state.has_errored = true;
if let Some(tool) = state.tools.iter_mut().find(|t| t.id == tool_id) {
tool.has_errored = true;
}
}
pub fn summary(&self) -> String {
let state = self
.state
.lock()
.expect("StreamingToolExecutor mutex poisoned");
let queued = state
.tools
.iter()
.filter(|t| t.status == ToolStatus::Queued)
.count();
let executing = state
.tools
.iter()
.filter(|t| t.status == ToolStatus::Executing)
.count();
let completed = state
.tools
.iter()
.filter(|t| t.status == ToolStatus::Completed)
.count();
let yielded = state
.tools
.iter()
.filter(|t| t.status == ToolStatus::Yielded)
.count();
let discarded = state.discarded;
drop(state);
format!(
"StreamingToolExecutor: queued={}, executing={}, completed={}, yielded={}, discarded={}",
queued, executing, completed, yielded, discarded
)
}
pub async fn execute_all(
&self,
executor_fn: SharedExecutorFn,
) -> Vec<(String, Result<crate::types::ToolResult, crate::AgentError>)> {
let (can_run, max_concurrency) = {
let state = self
.state
.lock()
.expect("StreamingToolExecutor mutex poisoned");
let mut can_run: Vec<(String, serde_json::Value, serde_json::Value, bool)> = Vec::new();
for tool in &state.tools {
if tool.status != ToolStatus::Queued {
continue;
}
if tool.has_errored {
continue;
}
let block = tool.block.clone();
let tool_id = tool.id.clone();
let blocked = state
.tools
.iter()
.any(|t| t.status == ToolStatus::Executing && !t.is_concurrency_safe);
if blocked && !tool.is_concurrency_safe {
continue;
}
let executing_in_state = state
.tools
.iter()
.filter(|t| t.status == ToolStatus::Executing)
.count();
if executing_in_state >= state.max_concurrency {
continue;
}
let name = block
.get("name")
.and_then(|n| n.as_str())
.unwrap_or("")
.to_string();
let args = block
.get("arguments")
.cloned()
.unwrap_or(serde_json::Value::Null);
can_run.push((tool_id, block, args, tool.is_concurrency_safe));
}
let max_concurrency = state.max_concurrency;
drop(state);
{
let mut state = self
.state
.lock()
.expect("StreamingToolExecutor mutex poisoned");
for (tool_id, _, _, _) in &can_run {
if let Some(tool) = state.tools.iter_mut().find(|t| t.id == *tool_id) {
tool.status = ToolStatus::Executing;
}
}
}
(can_run, max_concurrency)
};
let mut results: Vec<(String, Result<crate::types::ToolResult, crate::AgentError>)> =
Vec::with_capacity(can_run.len());
let state_arc = self.clone_state();
let total = can_run.len();
for chunk_start in (0..total).step_by(max_concurrency) {
let chunk_end = (chunk_start + max_concurrency).min(total);
let mut handles = Vec::new();
for (tool_id, block, args, _is_safe) in &can_run[chunk_start..chunk_end] {
let name = block
.get("name")
.and_then(|n| n.as_str())
.unwrap_or("")
.to_string();
let tid = tool_id.clone();
let args = args.clone();
let exec = executor_fn.clone();
let state_arc = Arc::clone(&state_arc);
let handle = tokio::spawn(async move {
let tool_result = exec.call(name, args, tid.clone()).await;
{
let mut st = state_arc
.lock()
.expect("StreamingToolExecutor mutex poisoned");
if let Some(tool) = st.tools.iter_mut().find(|t| t.id == tid) {
tool.status = ToolStatus::Completed;
}
}
let result = Ok(tool_result);
if result
.as_ref()
.map(|r| r.is_error == Some(true))
.unwrap_or(false)
{
state_arc
.lock()
.expect("StreamingToolExecutor mutex poisoned")
.has_errored = true;
}
(tid, result)
});
handles.push(handle);
}
for handle in handles {
let (tool_id, result) = handle.await.unwrap_or_else(|e| {
(
"unknown".to_string(),
Err(crate::AgentError::Tool(format!("Task panicked: {}", e))),
)
});
results.push((tool_id, result));
}
}
results
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_streaming_result_defaults() {
let result = StreamingResult::default();
assert!(!result.message_started);
assert_eq!(result.content_blocks_started, 0);
assert_eq!(result.content_blocks_completed, 0);
assert!(!result.any_tool_use_completed);
assert!(result.ttft_ms.is_none());
assert!(result.stop_reason.is_none());
assert_eq!(result.cost, 0.0);
}
#[test]
fn test_stream_watchdog_defaults() {
let watchdog = StreamWatchdog::new(false, DEFAULT_STREAM_IDLE_TIMEOUT_MS);
assert!(!watchdog.is_aborted());
assert!(watchdog.watchdog_fired_at().is_none());
}
#[test]
fn test_stream_watchdog_fire() {
let mut watchdog = StreamWatchdog::new(true, 90_000);
assert!(!watchdog.is_aborted());
let reason = watchdog.fire();
assert!(watchdog.is_aborted());
assert!(watchdog.watchdog_fired_at().is_some());
assert!(reason.contains("idle timeout"));
}
#[test]
fn test_nonstreaming_fallback_disabled_default() {
assert!(!is_nonstreaming_fallback_disabled());
}
#[test]
fn test_stream_completion_validation_started_but_not_completed() {
let mut result = StreamingResult::default();
result.message_started = true;
result.content_blocks_started = 1;
assert!(validate_stream_completion(&result).is_err());
}
#[test]
fn test_stream_completion_validation_message_not_started() {
let result = StreamingResult::default();
assert!(validate_stream_completion(&result).is_err());
}
#[test]
fn test_stream_completion_validation_valid() {
let mut result = StreamingResult::default();
result.message_started = true;
result.content_blocks_started = 1;
result.content_blocks_completed = 1;
assert!(validate_stream_completion(&result).is_ok());
}
#[test]
fn test_stream_completion_validation_with_stop_reason() {
let mut result = StreamingResult::default();
result.message_started = true;
result.content_blocks_started = 1;
result.stop_reason = Some("end_turn".to_string());
assert!(validate_stream_completion(&result).is_ok());
}
#[test]
fn test_is_404_stream_creation_error() {
assert!(is_404_stream_creation_error(&AgentError::Api(
"Streaming API error 404: Not Found".to_string()
)));
assert!(is_404_stream_creation_error(&AgentError::Api(
"404 streaming endpoint not found".to_string()
)));
assert!(!is_404_stream_creation_error(&AgentError::Api(
"API error: 500".to_string()
)));
}
#[test]
fn test_is_user_abort_error() {
assert!(is_user_abort_error(&AgentError::UserAborted));
assert!(!is_user_abort_error(&AgentError::Api(
"timeout".to_string()
)));
}
#[test]
fn test_is_api_timeout_error() {
assert!(is_api_timeout_error(&AgentError::ApiConnectionTimeout(
"Request timed out".to_string()
)));
assert!(!is_api_timeout_error(&AgentError::Api("other".to_string())));
}
#[test]
fn test_streaming_tool_executor_add_and_summary() {
let abort = Arc::new(AtomicBool::new(false));
let executor = StreamingToolExecutor::new(abort);
executor.add_tool(
serde_json::json!({"id": "tool_1", "name": "Bash", "input": {"command": "ls"}}),
true,
);
executor.add_tool(
serde_json::json!({"id": "tool_2", "name": "Read", "input": {"file": "foo.txt"}}),
false,
);
let summary = executor.summary();
assert!(summary.contains("queued=2"));
assert!(executor.has_unfinished_tools());
}
#[test]
fn test_streaming_tool_executor_can_execute() {
let abort = Arc::new(AtomicBool::new(false));
let executor = StreamingToolExecutor::new(abort);
assert!(executor.can_execute_tool(true));
assert!(executor.can_execute_tool(false));
executor.add_tool(serde_json::json!({"id": "tool_1", "name": "Bash"}), true);
{
let mut state = executor.state.lock().expect("mutex poisoned");
state.tools[0].status = ToolStatus::Executing;
}
assert!(executor.can_execute_tool(true));
assert!(!executor.can_execute_tool(false));
}
#[test]
fn test_streaming_tool_executor_discard() {
let abort = Arc::new(AtomicBool::new(false));
let mut executor = StreamingToolExecutor::new(abort);
executor.add_tool(serde_json::json!({"id": "tool_1", "name": "Bash"}), true);
executor.discard();
let results = executor.get_completed_results();
assert!(results.is_empty());
}
#[test]
fn test_stall_stats_default() {
let stats = StallStats::default();
assert_eq!(stats.stall_count, 0);
assert_eq!(stats.total_stall_time_ms, 0);
}
#[test]
fn test_release_stream_resources() {
let abort = Arc::new(AtomicBool::new(false));
release_stream_resources(&Some(abort.clone()), &None);
assert!(abort.load(Ordering::SeqCst));
}
}