use super::support::*;
use crate::LlmTerminalReason;
pub trait ProviderState: Send + Sync + std::fmt::Debug {
fn kind(&self) -> &'static str;
fn options(&self) -> ProviderOptions;
fn set_options(&mut self, options: ProviderOptions);
fn serialize_config(&self) -> serde_json::Value;
fn clone_boxed(&self) -> Box<dyn ProviderState>;
}
#[async_trait]
pub trait ProviderTransport: Send + Sync + std::fmt::Debug {
async fn complete(&mut self, request: LlmRequest) -> Result<LlmResponse, LlmTransportError>;
fn requires_streaming(&self) -> bool {
false
}
fn clone_boxed(&self) -> Box<dyn ProviderTransport>;
}
pub trait ProviderModelPolicy: Send + Sync + std::fmt::Debug {
fn default_model(&self) -> &str;
fn supported_variants(&self, model: &str) -> &'static [&'static str];
fn default_model_variant(&self, model: &str) -> Option<&'static str>;
fn request_variant_config(&self, model: &str, variant: &str) -> Option<VariantRequestConfig>;
fn default_agent_model(&self, tier: &str) -> Option<AgentModelSelection>;
fn resolve_model(&self, model: &str) -> String {
model.to_string()
}
fn context_lookup_model(&self, model: &str) -> String {
model.to_string()
}
fn input_usage_excludes_cached_tokens(&self) -> bool {
false
}
}
pub trait ProviderFailureClassifier: Send + Sync + std::fmt::Debug {
fn classify(&self, failure: ProviderFailure) -> ProviderFailure;
}
#[derive(Clone, Debug, Default)]
pub struct DefaultProviderFailureClassifier;
impl ProviderFailureClassifier for DefaultProviderFailureClassifier {
fn classify(&self, mut failure: ProviderFailure) -> ProviderFailure {
if let Some(status) = failure.status.or_else(|| {
failure
.code
.as_deref()
.and_then(|code| code.parse::<u16>().ok())
}) {
failure.status = Some(status);
if failure.kind == ProviderFailureKind::Unknown {
failure.kind = ProviderFailureKind::Http;
}
failure.retryable = matches!(status, 408 | 409 | 425 | 429 | 500 | 502 | 503 | 504);
if matches!(status, 401 | 403) {
failure.kind = ProviderFailureKind::Auth;
} else if matches!(status, 400 | 413 | 422) {
failure.kind = ProviderFailureKind::Validation;
}
} else if matches!(
failure.kind,
ProviderFailureKind::Transport | ProviderFailureKind::Timeout
) {
failure.retryable = true;
}
let haystack = format!(
"{}\n{}\n{}",
failure.code.as_deref().unwrap_or_default(),
failure.message,
failure.raw.as_deref().unwrap_or_default()
)
.to_ascii_lowercase();
if is_context_overflow_text(&haystack) {
failure.kind = ProviderFailureKind::Validation;
failure.retryable = false;
failure.terminal_reason = LlmTerminalReason::ContextOverflow;
}
if haystack.contains("insufficient_quota")
|| haystack.contains("usage_limit_reached")
|| haystack.contains("usage_not_included")
|| haystack.contains("quota")
{
failure.kind = ProviderFailureKind::Quota;
failure.retryable = false;
}
if haystack.contains("content_filter")
|| haystack.contains("prohibited_content")
|| haystack.contains("safety")
|| haystack.contains("sensitive")
{
failure.terminal_reason = LlmTerminalReason::ContentFilter;
}
if haystack.contains("model_not_found")
|| haystack.contains("unsupported model")
|| haystack.contains("does not exist")
{
failure.kind = ProviderFailureKind::Unsupported;
failure.retryable = false;
}
failure
}
}
pub fn is_context_overflow_text(haystack: &str) -> bool {
let lower = haystack.to_ascii_lowercase();
if lower.contains("rate limit")
|| lower.contains("rate_limit")
|| lower.contains("ratelimit")
|| lower.contains("throttle")
|| lower.contains("throttling")
|| lower.contains("too many requests")
|| lower.contains("tokens per minute")
|| lower.contains("tpm")
|| lower.contains("quota")
{
return false;
}
lower.contains("context_length_exceeded")
|| lower.contains("context_length")
|| lower.contains("context length")
|| lower.contains("maximum context")
|| lower.contains("max context")
|| lower.contains("context window")
|| lower.contains("context window exceeds limit")
|| lower.contains("exceeds the context window")
|| lower.contains("prompt is too long")
|| lower.contains("prompt too long")
|| lower.contains("request_too_large")
|| lower.contains("input token count") && lower.contains("exceeds the maximum")
|| lower.contains("maximum prompt length is")
|| lower.contains("reduce the length of the messages")
|| lower.contains("maximum context length is")
|| lower.contains("model's context length")
|| lower.contains("models context length")
|| lower.contains("exceeds the available context size")
|| lower.contains("greater than the context length")
|| lower.contains("exceeded model token limit")
|| lower.contains("too large for model with")
|| lower.contains("model_context_window_exceeded")
|| lower.contains("too many tokens")
|| lower.contains("exceeds the maximum number of tokens")
|| lower.contains("exceeds maximum number of tokens")
|| lower.contains("request too large")
|| lower.contains("input is too long")
|| lower.contains("token limit exceeded")
|| lower.contains("reduce the length of the messages")
|| lower.contains("reduce the length of your prompt")
}