use crate::types::*;
use async_trait::async_trait;
use tokio::sync::mpsc;
use super::model::ModelConfig;
#[derive(Debug, Clone)]
pub enum StreamEvent {
Start,
TextDelta { content_index: usize, delta: String },
ThinkingDelta { content_index: usize, delta: String },
ToolCallStart {
content_index: usize,
id: String,
name: String,
},
ToolCallDelta { content_index: usize, delta: String },
ToolCallEnd { content_index: usize },
Done { message: Message },
Error { message: Message },
}
#[derive(Debug, Clone)]
pub struct StreamConfig {
pub model: String,
pub system_prompt: String,
pub messages: Vec<Message>,
pub tools: Vec<ToolDefinition>,
pub thinking_level: ThinkingLevel,
pub api_key: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub model_config: Option<ModelConfig>,
pub cache_config: CacheConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
use serde::{Deserialize, Serialize};
#[async_trait]
pub trait StreamProvider: Send + Sync {
async fn stream(
&self,
config: StreamConfig,
tx: mpsc::UnboundedSender<StreamEvent>,
cancel: tokio_util::sync::CancellationToken,
) -> Result<Message, ProviderError>;
}
#[derive(Debug, thiserror::Error)]
pub enum ProviderError {
#[error("API error: {0}")]
Api(String),
#[error("Network error: {0}")]
Network(String),
#[error("Auth error: {0}")]
Auth(String),
#[error("Rate limited, retry after {retry_after_ms:?}ms")]
RateLimited { retry_after_ms: Option<u64> },
#[error("Context overflow: {message}")]
ContextOverflow { message: String },
#[error("Cancelled")]
Cancelled,
#[error("{0}")]
Other(String),
}
impl ProviderError {
pub fn classify(status: u16, message: &str) -> Self {
if is_context_overflow(status, message) {
Self::ContextOverflow {
message: message.to_string(),
}
} else if status == 429 {
Self::RateLimited {
retry_after_ms: None,
}
} else if status == 401 || status == 403 {
Self::Auth(message.to_string())
} else {
Self::Api(message.to_string())
}
}
pub fn is_context_overflow(&self) -> bool {
matches!(self, Self::ContextOverflow { .. })
}
}
pub async fn classify_eventsource_error(error: reqwest_eventsource::Error) -> ProviderError {
match error {
reqwest_eventsource::Error::InvalidStatusCode(status, response) => {
let status_code = status.as_u16();
let body = response.text().await.unwrap_or_default();
ProviderError::classify(
status_code,
&format!(
"HTTP {} {}: {}",
status_code,
status.canonical_reason().unwrap_or(""),
body
),
)
}
reqwest_eventsource::Error::Transport(e) => ProviderError::Network(format!("{:?}", e)),
other => ProviderError::Other(other.to_string()),
}
}
pub fn classify_sse_error_event(message: &str) -> ProviderError {
if is_context_overflow_message(message) {
ProviderError::ContextOverflow {
message: message.to_string(),
}
} else {
ProviderError::Api(message.to_string())
}
}
const OVERFLOW_PHRASES: &[&str] = &[
"prompt is too long", "input is too long", "exceeds the context window", "exceeds the maximum", "maximum prompt length", "reduce the length of the messages", "maximum context length", "exceeds the limit of", "exceeds the available context size", "greater than the context length", "context window exceeds limit", "exceeded model token limit", "context length exceeded", "context_length_exceeded", "too many tokens", "token limit exceeded", ];
pub(crate) fn is_context_overflow_message(message: &str) -> bool {
let lower = message.to_lowercase();
OVERFLOW_PHRASES.iter().any(|phrase| lower.contains(phrase))
}
fn is_context_overflow(status: u16, message: &str) -> bool {
if (status == 400 || status == 413) && message.trim().is_empty() {
return true;
}
is_context_overflow_message(message)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classify_anthropic_overflow() {
let err =
ProviderError::classify(400, "prompt is too long: 213462 tokens > 200000 maximum");
assert!(err.is_context_overflow());
}
#[test]
fn classify_openai_overflow() {
let err =
ProviderError::classify(400, "Your input exceeds the context window of this model");
assert!(err.is_context_overflow());
}
#[test]
fn classify_google_overflow() {
let err = ProviderError::classify(
400,
"The input token count (1196265) exceeds the maximum number of tokens allowed",
);
assert!(err.is_context_overflow());
}
#[test]
fn classify_bedrock_overflow() {
let err = ProviderError::classify(400, "input is too long for requested model");
assert!(err.is_context_overflow());
}
#[test]
fn classify_xai_overflow() {
let err = ProviderError::classify(
400,
"This model's maximum prompt length is 131072 but request contains 537812 tokens",
);
assert!(err.is_context_overflow());
}
#[test]
fn classify_groq_overflow() {
let err = ProviderError::classify(
400,
"Please reduce the length of the messages or completion",
);
assert!(err.is_context_overflow());
}
#[test]
fn classify_empty_body_overflow() {
let err = ProviderError::classify(413, "");
assert!(err.is_context_overflow());
let err = ProviderError::classify(400, " ");
assert!(err.is_context_overflow());
}
#[test]
fn classify_rate_limit() {
let err = ProviderError::classify(429, "rate limit exceeded");
assert!(matches!(err, ProviderError::RateLimited { .. }));
}
#[test]
fn classify_auth_error() {
let err = ProviderError::classify(401, "invalid api key");
assert!(matches!(err, ProviderError::Auth(_)));
let err = ProviderError::classify(403, "forbidden");
assert!(matches!(err, ProviderError::Auth(_)));
}
#[test]
fn classify_regular_api_error() {
let err = ProviderError::classify(400, "invalid request format");
assert!(matches!(err, ProviderError::Api(_)));
assert!(!err.is_context_overflow());
}
#[test]
fn overflow_message_case_insensitive() {
assert!(is_context_overflow_message("PROMPT IS TOO LONG"));
assert!(is_context_overflow_message("Too Many Tokens in request"));
}
#[test]
fn non_overflow_messages() {
assert!(!is_context_overflow_message("invalid api key"));
assert!(!is_context_overflow_message("internal server error"));
assert!(!is_context_overflow_message(""));
}
}