chronicle_proxy/providers/
mod.rs1pub mod anthropic;
2pub mod anyscale;
3#[cfg(feature = "aws-bedrock")]
4pub mod aws_bedrock;
5pub mod custom;
6pub mod deepinfra;
7pub mod fireworks;
8pub mod groq;
9pub mod mistral;
10pub mod ollama;
11pub mod openai;
12pub mod together;
13
14use std::{fmt::Debug, time::Duration};
15
16use error_stack::Report;
17use reqwest::StatusCode;
18use thiserror::Error;
19
20use crate::format::{ChatRequest, StreamingResponseSender};
21
22#[derive(Debug)]
23pub struct SendRequestOptions {
24 pub timeout: Duration,
25 pub override_url: Option<String>,
26 pub api_key: Option<String>,
27 pub body: ChatRequest,
28}
29
30#[async_trait::async_trait]
31pub trait ChatModelProvider: Debug + Send + Sync {
32 fn name(&self) -> &str;
34
35 fn label(&self) -> &str;
37
38 async fn send_request(
42 &self,
43 options: SendRequestOptions,
44 chunk_tx: StreamingResponseSender,
45 ) -> Result<(), Report<ProviderError>>;
46
47 fn is_default_for_model(&self, model: &str) -> bool;
48}
49
50#[derive(Debug, Error)]
51#[error("{kind}")]
52pub struct ProviderError {
53 pub kind: ProviderErrorKind,
55 pub status_code: Option<reqwest::StatusCode>,
57 pub body: Option<serde_json::Value>,
59 pub latency: std::time::Duration,
61}
62
63impl ProviderError {
64 pub fn from_kind(kind: ProviderErrorKind) -> Self {
66 Self {
67 kind,
68 status_code: None,
69 body: None,
70 latency: std::time::Duration::ZERO,
71 }
72 }
73
74 pub fn transforming_request() -> Self {
77 Self::from_kind(ProviderErrorKind::TransformingRequest)
78 }
79}
80
81#[cfg(feature = "filigree")]
82impl filigree::errors::HttpError for ProviderError {
83 type Detail = serde_json::Value;
84
85 fn status_code(&self) -> StatusCode {
86 let Some(status_code) = self.status_code else {
87 return StatusCode::INTERNAL_SERVER_ERROR;
88 };
89
90 if status_code.is_success() {
91 self.kind.status_code()
92 } else {
93 status_code
94 }
95 }
96
97 fn error_kind(&self) -> &'static str {
98 self.kind.as_str()
99 }
100
101 fn error_detail(&self) -> Self::Detail {
102 self.body.clone().unwrap_or(serde_json::Value::Null)
103 }
104}
105
106#[derive(Debug, Error)]
107pub enum ProviderErrorKind {
108 #[error("Model provider returned an error")]
110 Generic,
111 #[error("Model provider encountered a server error")]
113 Server,
114 #[error("Failed while trying to send request")]
115 Sending,
116 #[error("Failed while parsing response")]
117 ParsingResponse,
118 #[error("Error transforming a model request")]
119 TransformingRequest,
120 #[error("Error transforming a model response")]
121 TransformingResponse,
122 #[error("Provider closed connection prematurely")]
123 ProviderClosedConnection,
124 #[error("Model provider rate limited this request")]
126 RateLimit {
127 retry_after: Option<std::time::Duration>,
129 },
130
131 #[error("Timed out waiting for model provider's response")]
133 Timeout,
134
135 #[error("Model provider encountered an unrecoverable error")]
137 Permanent,
138 #[error("Model provider rejected the request format")]
140 BadInput,
141 #[error("Model provider authorization error")]
143 AuthRejected,
144 #[error("No API key provided")]
146 AuthMissing,
147 #[error("Out of credits with this provider")]
149 OutOfCredits,
150}
151
152impl ProviderErrorKind {
153 pub fn from_status_code(code: reqwest::StatusCode) -> Option<Self> {
156 if code.is_success() {
157 return None;
158 }
159
160 let code = match code {
161 StatusCode::TOO_MANY_REQUESTS => Self::RateLimit { retry_after: None },
164 StatusCode::PAYMENT_REQUIRED => Self::OutOfCredits,
166 StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED => Self::AuthRejected,
167 StatusCode::BAD_REQUEST
168 | StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE
169 | StatusCode::UNPROCESSABLE_ENTITY
170 | StatusCode::UNSUPPORTED_MEDIA_TYPE
171 | StatusCode::PAYLOAD_TOO_LARGE
172 | StatusCode::NOT_FOUND
173 | StatusCode::METHOD_NOT_ALLOWED
174 | StatusCode::NOT_ACCEPTABLE => Self::BadInput,
175 c if c.is_server_error() => Self::Server,
176 c if c.is_client_error() => Self::Permanent,
178 _ => Self::Generic,
179 };
180
181 Some(code)
182 }
183
184 pub fn status_code(&self) -> StatusCode {
185 match self {
186 ProviderErrorKind::Generic => StatusCode::INTERNAL_SERVER_ERROR,
187 ProviderErrorKind::Server => StatusCode::SERVICE_UNAVAILABLE,
188 ProviderErrorKind::Sending => StatusCode::BAD_GATEWAY,
189 ProviderErrorKind::ParsingResponse => StatusCode::BAD_GATEWAY,
190 ProviderErrorKind::ProviderClosedConnection => StatusCode::BAD_GATEWAY,
191 ProviderErrorKind::RateLimit { .. } => StatusCode::TOO_MANY_REQUESTS,
192 ProviderErrorKind::Timeout => StatusCode::GATEWAY_TIMEOUT,
193 ProviderErrorKind::Permanent => StatusCode::INTERNAL_SERVER_ERROR,
194 ProviderErrorKind::BadInput => StatusCode::UNPROCESSABLE_ENTITY,
195 ProviderErrorKind::AuthRejected => StatusCode::UNAUTHORIZED,
196 ProviderErrorKind::AuthMissing => StatusCode::UNAUTHORIZED,
197 ProviderErrorKind::OutOfCredits => StatusCode::PAYMENT_REQUIRED,
198 ProviderErrorKind::TransformingRequest => StatusCode::BAD_REQUEST,
199 ProviderErrorKind::TransformingResponse => StatusCode::INTERNAL_SERVER_ERROR,
200 }
201 }
202
203 pub fn retryable(&self) -> bool {
205 matches!(
206 self,
207 Self::Server
208 | Self::ParsingResponse
209 | Self::TransformingResponse
210 | Self::Sending
211 | Self::RateLimit { .. }
212 | Self::Generic
213 )
214 }
215
216 pub fn as_str(&self) -> &'static str {
217 match self {
218 ProviderErrorKind::Generic => "generic",
219 ProviderErrorKind::Server => "provider_server_error",
220 ProviderErrorKind::ProviderClosedConnection => "provider_connection_closed",
221 ProviderErrorKind::Sending => "provider_connection_error",
222 ProviderErrorKind::ParsingResponse => "parsing_provider_response",
223 ProviderErrorKind::RateLimit { .. } => "rate_limit",
224 ProviderErrorKind::Timeout => "timeout",
225 ProviderErrorKind::Permanent => "unrecoverable_server_error",
226 ProviderErrorKind::BadInput => "provider_rejected_input",
227 ProviderErrorKind::AuthRejected => "provider_rejected_token",
228 ProviderErrorKind::AuthMissing => "auth_missing",
229 ProviderErrorKind::OutOfCredits => "out_of_credits",
230 ProviderErrorKind::TransformingRequest => "transforming_request",
231 ProviderErrorKind::TransformingResponse => "transforming_response",
232 }
233 }
234}