chronicle_proxy/providers/
mod.rs

1pub mod anthropic;
2pub mod anyscale;
3#[cfg(feature = "aws-bedrock")]
4pub mod aws_bedrock;
5pub mod custom;
6pub mod deepinfra;
7pub mod fireworks;
8pub mod groq;
9pub mod mistral;
10pub mod ollama;
11pub mod openai;
12pub mod together;
13
14use std::{fmt::Debug, time::Duration};
15
16use error_stack::Report;
17use reqwest::StatusCode;
18use thiserror::Error;
19
20use crate::format::{ChatRequest, StreamingResponseSender};
21
22#[derive(Debug)]
23pub struct SendRequestOptions {
24    pub timeout: Duration,
25    pub override_url: Option<String>,
26    pub api_key: Option<String>,
27    pub body: ChatRequest,
28}
29
30#[async_trait::async_trait]
31pub trait ChatModelProvider: Debug + Send + Sync {
32    /// Internal name for the provider
33    fn name(&self) -> &str;
34
35    /// A readable name for the provider
36    fn label(&self) -> &str;
37
38    /// Send a request and return the response. If there's any chance of retryable failures for
39    /// this provider (e.g. almost every provider), then this function should handle retrying with
40    /// the behavior specified in `options.retry`. The `request_with_retry` function can assist with that.
41    async fn send_request(
42        &self,
43        options: SendRequestOptions,
44        chunk_tx: StreamingResponseSender,
45    ) -> Result<(), Report<ProviderError>>;
46
47    fn is_default_for_model(&self, model: &str) -> bool;
48}
49
50#[derive(Debug, Error)]
51#[error("{kind}")]
52pub struct ProviderError {
53    /// What type of error this is
54    pub kind: ProviderErrorKind,
55    /// The HTTP status code, if there was one.
56    pub status_code: Option<reqwest::StatusCode>,
57    /// The returned body, if there was one
58    pub body: Option<serde_json::Value>,
59    /// How much time it took before we received the error
60    pub latency: std::time::Duration,
61}
62
63impl ProviderError {
64    /// A simple constructor for a [ProviderError] that only needs a kind
65    pub fn from_kind(kind: ProviderErrorKind) -> Self {
66        Self {
67            kind,
68            status_code: None,
69            body: None,
70            latency: std::time::Duration::ZERO,
71        }
72    }
73
74    /// A helper for creating a `ProviderError` with the TransformingRequest error kind. This is by
75    /// far the most common case in the codebase.
76    pub fn transforming_request() -> Self {
77        Self::from_kind(ProviderErrorKind::TransformingRequest)
78    }
79}
80
81#[cfg(feature = "filigree")]
82impl filigree::errors::HttpError for ProviderError {
83    type Detail = serde_json::Value;
84
85    fn status_code(&self) -> StatusCode {
86        let Some(status_code) = self.status_code else {
87            return StatusCode::INTERNAL_SERVER_ERROR;
88        };
89
90        if status_code.is_success() {
91            self.kind.status_code()
92        } else {
93            status_code
94        }
95    }
96
97    fn error_kind(&self) -> &'static str {
98        self.kind.as_str()
99    }
100
101    fn error_detail(&self) -> Self::Detail {
102        self.body.clone().unwrap_or(serde_json::Value::Null)
103    }
104}
105
106#[derive(Debug, Error)]
107pub enum ProviderErrorKind {
108    /// A generic error not otherwise specified. These will be retried
109    #[error("Model provider returned an error")]
110    Generic,
111    /// a 5xx HTTP status code or similar error
112    #[error("Model provider encountered a server error")]
113    Server,
114    #[error("Failed while trying to send request")]
115    Sending,
116    #[error("Failed while parsing response")]
117    ParsingResponse,
118    #[error("Error transforming a model request")]
119    TransformingRequest,
120    #[error("Error transforming a model response")]
121    TransformingResponse,
122    #[error("Provider closed connection prematurely")]
123    ProviderClosedConnection,
124    /// The provider returned a rate limit error.
125    #[error("Model provider rate limited this request")]
126    RateLimit {
127        /// How soon we can retry, if the response specified
128        retry_after: Option<std::time::Duration>,
129    },
130
131    /// The request took longer than the conifgured timeout
132    #[error("Timed out waiting for model provider's response")]
133    Timeout,
134
135    /// Some non-retryable error not covered below
136    #[error("Model provider encountered an unrecoverable error")]
137    Permanent,
138    /// The model provider didn't like our input
139    #[error("Model provider rejected the request format")]
140    BadInput,
141    /// The API token was rejected or not allowed to perform the requested operation
142    #[error("Model provider authorization error")]
143    AuthRejected,
144    /// The API token was rejected or not allowed to perform the requested operation
145    #[error("No API key provided")]
146    AuthMissing,
147    /// The provider needs more money.
148    #[error("Out of credits with this provider")]
149    OutOfCredits,
150}
151
152impl ProviderErrorKind {
153    /// Convert an HTTP status code into a `ProviderError`. Returns `None` if the request
154    /// succeeded.
155    pub fn from_status_code(code: reqwest::StatusCode) -> Option<Self> {
156        if code.is_success() {
157            return None;
158        }
159
160        let code = match code {
161            // We don't have the information on how long to wait here, but it can be extracted
162            // later by the provider if it is present.
163            StatusCode::TOO_MANY_REQUESTS => Self::RateLimit { retry_after: None },
164            // Not all providers will return a 402, but if we do see one then it's `OutOfCredits`.
165            StatusCode::PAYMENT_REQUIRED => Self::OutOfCredits,
166            StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED => Self::AuthRejected,
167            StatusCode::BAD_REQUEST
168            | StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE
169            | StatusCode::UNPROCESSABLE_ENTITY
170            | StatusCode::UNSUPPORTED_MEDIA_TYPE
171            | StatusCode::PAYLOAD_TOO_LARGE
172            | StatusCode::NOT_FOUND
173            | StatusCode::METHOD_NOT_ALLOWED
174            | StatusCode::NOT_ACCEPTABLE => Self::BadInput,
175            c if c.is_server_error() => Self::Server,
176            // Some other client error but these tend to indicate that a retry won't work.
177            c if c.is_client_error() => Self::Permanent,
178            _ => Self::Generic,
179        };
180
181        Some(code)
182    }
183
184    pub fn status_code(&self) -> StatusCode {
185        match self {
186            ProviderErrorKind::Generic => StatusCode::INTERNAL_SERVER_ERROR,
187            ProviderErrorKind::Server => StatusCode::SERVICE_UNAVAILABLE,
188            ProviderErrorKind::Sending => StatusCode::BAD_GATEWAY,
189            ProviderErrorKind::ParsingResponse => StatusCode::BAD_GATEWAY,
190            ProviderErrorKind::ProviderClosedConnection => StatusCode::BAD_GATEWAY,
191            ProviderErrorKind::RateLimit { .. } => StatusCode::TOO_MANY_REQUESTS,
192            ProviderErrorKind::Timeout => StatusCode::GATEWAY_TIMEOUT,
193            ProviderErrorKind::Permanent => StatusCode::INTERNAL_SERVER_ERROR,
194            ProviderErrorKind::BadInput => StatusCode::UNPROCESSABLE_ENTITY,
195            ProviderErrorKind::AuthRejected => StatusCode::UNAUTHORIZED,
196            ProviderErrorKind::AuthMissing => StatusCode::UNAUTHORIZED,
197            ProviderErrorKind::OutOfCredits => StatusCode::PAYMENT_REQUIRED,
198            ProviderErrorKind::TransformingRequest => StatusCode::BAD_REQUEST,
199            ProviderErrorKind::TransformingResponse => StatusCode::INTERNAL_SERVER_ERROR,
200        }
201    }
202
203    /// If the request is retryable after a short delay.
204    pub fn retryable(&self) -> bool {
205        matches!(
206            self,
207            Self::Server
208                | Self::ParsingResponse
209                | Self::TransformingResponse
210                | Self::Sending
211                | Self::RateLimit { .. }
212                | Self::Generic
213        )
214    }
215
216    pub fn as_str(&self) -> &'static str {
217        match self {
218            ProviderErrorKind::Generic => "generic",
219            ProviderErrorKind::Server => "provider_server_error",
220            ProviderErrorKind::ProviderClosedConnection => "provider_connection_closed",
221            ProviderErrorKind::Sending => "provider_connection_error",
222            ProviderErrorKind::ParsingResponse => "parsing_provider_response",
223            ProviderErrorKind::RateLimit { .. } => "rate_limit",
224            ProviderErrorKind::Timeout => "timeout",
225            ProviderErrorKind::Permanent => "unrecoverable_server_error",
226            ProviderErrorKind::BadInput => "provider_rejected_input",
227            ProviderErrorKind::AuthRejected => "provider_rejected_token",
228            ProviderErrorKind::AuthMissing => "auth_missing",
229            ProviderErrorKind::OutOfCredits => "out_of_credits",
230            ProviderErrorKind::TransformingRequest => "transforming_request",
231            ProviderErrorKind::TransformingResponse => "transforming_response",
232        }
233    }
234}