pub mod anthropic;
pub mod anyscale;
#[cfg(feature = "aws-bedrock")]
pub mod aws_bedrock;
pub mod custom;
pub mod deepinfra;
pub mod fireworks;
pub mod groq;
pub mod mistral;
pub mod ollama;
pub mod openai;
pub mod together;
use std::{fmt::Debug, time::Duration};
use error_stack::Report;
use reqwest::StatusCode;
use thiserror::Error;
use crate::format::{ChatRequest, StreamingResponseSender};
#[derive(Debug)]
pub struct SendRequestOptions {
pub timeout: Duration,
pub override_url: Option<String>,
pub api_key: Option<String>,
pub body: ChatRequest,
}
#[async_trait::async_trait]
pub trait ChatModelProvider: Debug + Send + Sync {
fn name(&self) -> &str;
fn label(&self) -> &str;
async fn send_request(
&self,
options: SendRequestOptions,
chunk_tx: StreamingResponseSender,
) -> Result<(), Report<ProviderError>>;
fn is_default_for_model(&self, model: &str) -> bool;
}
#[derive(Debug, Error)]
#[error("{kind}")]
pub struct ProviderError {
pub kind: ProviderErrorKind,
pub status_code: Option<reqwest::StatusCode>,
pub body: Option<serde_json::Value>,
pub latency: std::time::Duration,
}
impl ProviderError {
pub fn from_kind(kind: ProviderErrorKind) -> Self {
Self {
kind,
status_code: None,
body: None,
latency: std::time::Duration::ZERO,
}
}
pub fn transforming_request() -> Self {
Self::from_kind(ProviderErrorKind::TransformingRequest)
}
}
#[cfg(feature = "filigree")]
impl filigree::errors::HttpError for ProviderError {
type Detail = serde_json::Value;
fn status_code(&self) -> StatusCode {
let Some(status_code) = self.status_code else {
return StatusCode::INTERNAL_SERVER_ERROR;
};
if status_code.is_success() {
self.kind.status_code()
} else {
status_code
}
}
fn error_kind(&self) -> &'static str {
self.kind.as_str()
}
fn error_detail(&self) -> Self::Detail {
self.body.clone().unwrap_or(serde_json::Value::Null)
}
}
#[derive(Debug, Error)]
pub enum ProviderErrorKind {
#[error("Model provider returned an error")]
Generic,
#[error("Model provider encountered a server error")]
Server,
#[error("Failed while trying to send request")]
Sending,
#[error("Failed while parsing response")]
ParsingResponse,
#[error("Error transforming a model request")]
TransformingRequest,
#[error("Error transforming a model response")]
TransformingResponse,
#[error("Provider closed connection prematurely")]
ProviderClosedConnection,
#[error("Model provider rate limited this request")]
RateLimit {
retry_after: Option<std::time::Duration>,
},
#[error("Timed out waiting for model provider's response")]
Timeout,
#[error("Model provider encountered an unrecoverable error")]
Permanent,
#[error("Model provider rejected the request format")]
BadInput,
#[error("Model provider authorization error")]
AuthRejected,
#[error("No API key provided")]
AuthMissing,
#[error("Out of credits with this provider")]
OutOfCredits,
}
impl ProviderErrorKind {
pub fn from_status_code(code: reqwest::StatusCode) -> Option<Self> {
if code.is_success() {
return None;
}
let code = match code {
StatusCode::TOO_MANY_REQUESTS => Self::RateLimit { retry_after: None },
StatusCode::PAYMENT_REQUIRED => Self::OutOfCredits,
StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED => Self::AuthRejected,
StatusCode::BAD_REQUEST
| StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE
| StatusCode::UNPROCESSABLE_ENTITY
| StatusCode::UNSUPPORTED_MEDIA_TYPE
| StatusCode::PAYLOAD_TOO_LARGE
| StatusCode::NOT_FOUND
| StatusCode::METHOD_NOT_ALLOWED
| StatusCode::NOT_ACCEPTABLE => Self::BadInput,
c if c.is_server_error() => Self::Server,
c if c.is_client_error() => Self::Permanent,
_ => Self::Generic,
};
Some(code)
}
pub fn status_code(&self) -> StatusCode {
match self {
ProviderErrorKind::Generic => StatusCode::INTERNAL_SERVER_ERROR,
ProviderErrorKind::Server => StatusCode::SERVICE_UNAVAILABLE,
ProviderErrorKind::Sending => StatusCode::BAD_GATEWAY,
ProviderErrorKind::ParsingResponse => StatusCode::BAD_GATEWAY,
ProviderErrorKind::ProviderClosedConnection => StatusCode::BAD_GATEWAY,
ProviderErrorKind::RateLimit { .. } => StatusCode::TOO_MANY_REQUESTS,
ProviderErrorKind::Timeout => StatusCode::GATEWAY_TIMEOUT,
ProviderErrorKind::Permanent => StatusCode::INTERNAL_SERVER_ERROR,
ProviderErrorKind::BadInput => StatusCode::UNPROCESSABLE_ENTITY,
ProviderErrorKind::AuthRejected => StatusCode::UNAUTHORIZED,
ProviderErrorKind::AuthMissing => StatusCode::UNAUTHORIZED,
ProviderErrorKind::OutOfCredits => StatusCode::PAYMENT_REQUIRED,
ProviderErrorKind::TransformingRequest => StatusCode::BAD_REQUEST,
ProviderErrorKind::TransformingResponse => StatusCode::INTERNAL_SERVER_ERROR,
}
}
pub fn retryable(&self) -> bool {
matches!(
self,
Self::Server
| Self::ParsingResponse
| Self::TransformingResponse
| Self::Sending
| Self::RateLimit { .. }
| Self::Generic
)
}
pub fn as_str(&self) -> &'static str {
match self {
ProviderErrorKind::Generic => "generic",
ProviderErrorKind::Server => "provider_server_error",
ProviderErrorKind::ProviderClosedConnection => "provider_connection_closed",
ProviderErrorKind::Sending => "provider_connection_error",
ProviderErrorKind::ParsingResponse => "parsing_provider_response",
ProviderErrorKind::RateLimit { .. } => "rate_limit",
ProviderErrorKind::Timeout => "timeout",
ProviderErrorKind::Permanent => "unrecoverable_server_error",
ProviderErrorKind::BadInput => "provider_rejected_input",
ProviderErrorKind::AuthRejected => "provider_rejected_token",
ProviderErrorKind::AuthMissing => "auth_missing",
ProviderErrorKind::OutOfCredits => "out_of_credits",
ProviderErrorKind::TransformingRequest => "transforming_request",
ProviderErrorKind::TransformingResponse => "transforming_response",
}
}
}