1use serde::Deserialize;
2
3pub type OpenAIResult<T> = std::result::Result<T, OpenAIError>;
4
5#[derive(Debug, thiserror::Error)]
7pub enum OpenAIError {
8 #[error("http error: {0}")]
10 Reqwest(reqwest::Error),
11 #[error("failed to deserialize api response {0} with error: {1}")]
13 Serde(String, serde_json::Error),
14 #[error("missing auth token")]
16 MissingAuthToken,
17 #[error("OpenAI API error: {0}")]
19 API(OpenAIAPIError),
20 #[error("OpenAI refused to generate response: {0}")]
22 Refusal(String),
23}
24
25impl From<reqwest::Error> for OpenAIError {
26 fn from(err: reqwest::Error) -> Self {
27 Self::Reqwest(err)
28 }
29}
30
31#[derive(Debug, Deserialize, Clone, thiserror::Error)]
32#[serde(tag = "type", rename_all = "snake_case")]
33pub enum OpenAIAPIError {
34 #[error("model context length exceeded: {0}")]
35 ContextLengthExceeded(OpenAIAPIErrorData),
36 #[error("cloudflare service unavailable: {0}")]
37 CfServiceUnavailable(OpenAIAPIErrorData),
38 #[error("transient server error: {0}")]
39 ServerError(OpenAIAPIErrorData),
40 #[error("cloudflare bad gateway: {0}")]
41 CfBadGateway(OpenAIAPIErrorData),
42 #[error("quota exceeded: {0}")]
43 QuotaExceeded(OpenAIAPIErrorData),
44 #[error("internal error: {0}")]
45 InternalError(OpenAIAPIErrorData),
46 #[error("invalid request error: {0}")]
47 InvalidRequestError(OpenAIAPIErrorData),
48}
49
50#[derive(Debug, Deserialize, Clone)]
52pub struct OpenAIAPIErrorData {
53 pub message: String,
55 pub param: Option<String>,
57 pub code: Option<String>,
59}
60
61impl std::fmt::Display for OpenAIAPIErrorData {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "message: {}", self.message)?;
64 if let Some(param) = &self.param {
65 write!(f, ", param: {}", param)?;
66 }
67 if let Some(code) = &self.code {
68 write!(f, ", code: {}", code)?;
69 }
70 Ok(())
71 }
72}