Skip to main content

aster/providers/
errors.rs

1use reqwest::StatusCode;
2use std::time::Duration;
3use thiserror::Error;
4
5#[derive(Error, Debug, PartialEq)]
6pub enum ProviderError {
7    #[error("Authentication error: {0}")]
8    Authentication(String),
9
10    #[error("Context length exceeded: {0}")]
11    ContextLengthExceeded(String),
12
13    #[error("Rate limit exceeded: {details}")]
14    RateLimitExceeded {
15        details: String,
16        retry_delay: Option<Duration>,
17    },
18
19    #[error("Server error: {0}")]
20    ServerError(String),
21
22    #[error("Request failed: {0}")]
23    RequestFailed(String),
24
25    #[error("Execution error: {0}")]
26    ExecutionError(String),
27
28    #[error("Usage data error: {0}")]
29    UsageError(String),
30
31    #[error("Unsupported operation: {0}")]
32    NotImplemented(String),
33}
34
35impl ProviderError {
36    pub fn telemetry_type(&self) -> &'static str {
37        match self {
38            ProviderError::Authentication(_) => "auth",
39            ProviderError::ContextLengthExceeded(_) => "context_length",
40            ProviderError::RateLimitExceeded { .. } => "rate_limit",
41            ProviderError::ServerError(_) => "server",
42            ProviderError::RequestFailed(_) => "request",
43            ProviderError::ExecutionError(_) => "execution",
44            ProviderError::UsageError(_) => "usage",
45            ProviderError::NotImplemented(_) => "not_implemented",
46        }
47    }
48}
49
50impl From<anyhow::Error> for ProviderError {
51    fn from(error: anyhow::Error) -> Self {
52        if let Some(reqwest_err) = error.downcast_ref::<reqwest::Error>() {
53            let mut details = vec![];
54
55            if let Some(status) = reqwest_err.status() {
56                details.push(format!("status: {}", status));
57            }
58            if reqwest_err.is_timeout() {
59                details.push("timeout".to_string());
60            }
61            if reqwest_err.is_connect() {
62                if let Some(url) = reqwest_err.url() {
63                    if let Some(host) = url.host_str() {
64                        let port_info = url.port().map(|p| format!(":{}", p)).unwrap_or_default();
65
66                        details.push(format!("failed to connect to {}{}", host, port_info));
67
68                        if url.port().is_some() {
69                            details.push("check that the port is correct".to_string());
70                        }
71                    }
72                } else {
73                    details.push("connection failed".to_string());
74                }
75            }
76            let msg = if details.is_empty() {
77                reqwest_err.to_string()
78            } else {
79                format!("{} ({})", reqwest_err, details.join(", "))
80            };
81            return ProviderError::RequestFailed(msg);
82        }
83        ProviderError::ExecutionError(error.to_string())
84    }
85}
86
87impl From<reqwest::Error> for ProviderError {
88    fn from(error: reqwest::Error) -> Self {
89        ProviderError::RequestFailed(error.to_string())
90    }
91}
92
93#[derive(Debug)]
94pub enum GoogleErrorCode {
95    BadRequest = 400,
96    Unauthorized = 401,
97    Forbidden = 403,
98    NotFound = 404,
99    TooManyRequests = 429,
100    InternalServerError = 500,
101    ServiceUnavailable = 503,
102}
103
104impl GoogleErrorCode {
105    pub fn to_status_code(&self) -> StatusCode {
106        match self {
107            Self::BadRequest => StatusCode::BAD_REQUEST,
108            Self::Unauthorized => StatusCode::UNAUTHORIZED,
109            Self::Forbidden => StatusCode::FORBIDDEN,
110            Self::NotFound => StatusCode::NOT_FOUND,
111            Self::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
112            Self::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR,
113            Self::ServiceUnavailable => StatusCode::SERVICE_UNAVAILABLE,
114        }
115    }
116
117    pub fn from_code(code: u64) -> Option<Self> {
118        match code {
119            400 => Some(Self::BadRequest),
120            401 => Some(Self::Unauthorized),
121            403 => Some(Self::Forbidden),
122            404 => Some(Self::NotFound),
123            429 => Some(Self::TooManyRequests),
124            500 => Some(Self::InternalServerError),
125            503 => Some(Self::ServiceUnavailable),
126            _ => Some(Self::InternalServerError),
127        }
128    }
129}