pub mod claude;
pub mod cooldown;
pub mod error_classifier;
pub mod fallback;
pub mod gemini;
pub mod openai;
pub mod plugin;
mod registry;
pub mod retry;
pub mod rotation;
pub mod structured;
mod types;
pub const RUNTIME_SUPPORTED_PROVIDERS: &[&str] = &[
"anthropic",
"openai",
"openrouter",
"groq",
"zhipu",
"vllm",
"gemini",
"ollama",
"nvidia",
];
use crate::error::ProviderError;
pub use claude::ClaudeProvider;
pub use cooldown::{CooldownTracker, FailoverReason};
pub use error_classifier::classify_error_message;
pub use fallback::FallbackProvider;
pub use gemini::GeminiProvider;
pub use openai::OpenAIProvider;
pub use plugin::ProviderPlugin;
pub use registry::{
configured_provider_names, configured_unsupported_provider_names, provider_config_by_name,
resolve_runtime_provider, resolve_runtime_providers, ProviderSpec, RuntimeProviderSelection,
PROVIDER_REGISTRY,
};
pub use retry::RetryProvider;
pub use rotation::{RotationProvider, RotationStrategy};
pub use structured::{validate_json_response, OutputFormat};
pub use types::{
ChatOptions, LLMProvider, LLMResponse, LLMToolCall, StreamEvent, ToolDefinition, Usage,
};
pub fn parse_provider_error(status: u16, body: &str) -> ProviderError {
match status {
401 => ProviderError::Auth(body.to_string()),
402 => ProviderError::Billing(body.to_string()),
404 => ProviderError::ModelNotFound(body.to_string()),
429 => ProviderError::RateLimit(body.to_string()),
400 => {
let classified = error_classifier::classify_error_message(body);
if matches!(classified, ProviderError::Format(_)) {
classified
} else {
ProviderError::InvalidRequest(body.to_string())
}
}
500..=599 => {
let classified = error_classifier::classify_error_message(body);
if matches!(classified, ProviderError::Overloaded(_)) {
classified
} else {
ProviderError::ServerError(body.to_string())
}
}
_ => {
let msg = format!("HTTP {}: {}", status, body);
error_classifier::classify_error_message(&msg)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_provider_error_401() {
let err = parse_provider_error(401, "invalid api key");
assert!(matches!(err, ProviderError::Auth(_)));
assert_eq!(err.status_code(), Some(401));
}
#[test]
fn test_parse_provider_error_402() {
let err = parse_provider_error(402, "payment required");
assert!(matches!(err, ProviderError::Billing(_)));
assert_eq!(err.status_code(), Some(402));
}
#[test]
fn test_parse_provider_error_404() {
let err = parse_provider_error(404, "model not found");
assert!(matches!(err, ProviderError::ModelNotFound(_)));
assert_eq!(err.status_code(), Some(404));
}
#[test]
fn test_parse_provider_error_429() {
let err = parse_provider_error(429, "rate limited");
assert!(matches!(err, ProviderError::RateLimit(_)));
assert!(err.is_retryable());
}
#[test]
fn test_parse_provider_error_400() {
let err = parse_provider_error(400, "bad json");
assert!(matches!(err, ProviderError::InvalidRequest(_)));
assert!(!err.is_retryable());
}
#[test]
fn test_parse_provider_error_500() {
let err = parse_provider_error(500, "internal server error");
assert!(matches!(err, ProviderError::ServerError(_)));
assert!(err.is_retryable());
}
#[test]
fn test_parse_provider_error_502() {
let err = parse_provider_error(502, "bad gateway");
assert!(matches!(err, ProviderError::ServerError(_)));
assert!(err.is_retryable());
}
#[test]
fn test_parse_provider_error_503() {
let err = parse_provider_error(503, "service unavailable");
assert!(matches!(err, ProviderError::ServerError(_)));
}
#[test]
fn test_parse_provider_error_504() {
let err = parse_provider_error(504, "gateway timeout");
assert!(matches!(err, ProviderError::ServerError(_)));
}
#[test]
fn test_parse_provider_error_unknown() {
let err = parse_provider_error(418, "i'm a teapot");
assert!(matches!(err, ProviderError::Unknown(_)));
assert!(err.to_string().contains("HTTP 418"));
}
}