use std::time::Duration;
use reqwest::header::HeaderMap;
use reqwest::StatusCode;
use serde_json::Value;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Api(Box<ApiError>),
#[error("request timed out after {}ms", .timeout.as_millis())]
Timeout {
timeout: Duration,
},
#[error("validation error{}: {message}", field_display(.field.as_deref()))]
Validation {
field: Option<String>,
message: String,
},
#[error("network error: {message}")]
Network {
message: String,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
},
#[error("serialization error: {0}")]
Serde(#[from] serde_json::Error),
#[error("url error: {0}")]
Url(#[from] url::ParseError),
}
impl Error {
pub fn validation(field: impl Into<String>, message: impl Into<String>) -> Self {
Self::Validation {
field: Some(field.into()),
message: message.into(),
}
}
pub fn validation_msg(message: impl Into<String>) -> Self {
Self::Validation {
field: None,
message: message.into(),
}
}
pub fn status(&self) -> Option<StatusCode> {
match self {
Error::Api(e) => Some(e.status),
_ => None,
}
}
pub fn as_api_error(&self) -> Option<&ApiError> {
match self {
Error::Api(e) => Some(e.as_ref()),
_ => None,
}
}
}
impl From<ApiError> for Error {
fn from(e: ApiError) -> Self {
Self::Api(Box::new(e))
}
}
fn field_display(field: Option<&str>) -> String {
match field {
Some(f) => format!(" ({f})"),
None => String::new(),
}
}
#[derive(Debug, thiserror::Error)]
#[error("{status}: {message}")]
pub struct ApiError {
pub kind: ApiErrorKind,
pub status: StatusCode,
pub message: String,
pub error_code: Option<String>,
pub body: Option<Value>,
pub headers: HeaderMap,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApiErrorKind {
BadRequest,
Authentication,
Forbidden,
NotFound,
ConstraintViolation,
RateLimit,
InternalServer,
Other,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ConstraintViolationDetails {
pub term_id: Option<String>,
pub feature: Option<String>,
pub constraint: Option<String>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct RateLimitDetails {
pub retry_after: Option<u64>,
pub limit: Option<u64>,
pub remaining: Option<u64>,
}
impl ApiError {
pub(crate) fn from_response(
status: StatusCode,
body: Option<Value>,
headers: HeaderMap,
) -> Self {
let (message, error_code) = extract_error_fields(body.as_ref(), status);
let kind = match status.as_u16() {
400 => ApiErrorKind::BadRequest,
401 => ApiErrorKind::Authentication,
403 => ApiErrorKind::Forbidden,
404 => ApiErrorKind::NotFound,
409 => ApiErrorKind::ConstraintViolation,
429 => ApiErrorKind::RateLimit,
s if s >= 500 => ApiErrorKind::InternalServer,
_ => ApiErrorKind::Other,
};
Self {
kind,
status,
message,
error_code,
body,
headers,
}
}
pub fn constraint_violation(&self) -> Option<ConstraintViolationDetails> {
if self.kind != ApiErrorKind::ConstraintViolation {
return None;
}
let details = self.body.as_ref()?.get("details")?.as_object()?;
Some(ConstraintViolationDetails {
term_id: details
.get("term_id")
.and_then(|v| v.as_str())
.map(str::to_owned),
feature: details
.get("feature")
.and_then(|v| v.as_str())
.map(str::to_owned),
constraint: details
.get("constraint")
.and_then(|v| v.as_str())
.map(str::to_owned),
})
}
pub fn rate_limit(&self) -> Option<RateLimitDetails> {
if self.kind != ApiErrorKind::RateLimit {
return None;
}
Some(RateLimitDetails {
retry_after: parse_numeric_header(&self.headers, "retry-after"),
limit: parse_numeric_header(&self.headers, "x-ratelimit-limit"),
remaining: parse_numeric_header(&self.headers, "x-ratelimit-remaining"),
})
}
}
fn extract_error_fields(body: Option<&Value>, status: StatusCode) -> (String, Option<String>) {
let fallback = format!("HTTP {}", status.as_u16());
let Some(Value::Object(map)) = body else {
return (fallback, None);
};
let message = map
.get("message")
.and_then(|v| v.as_str())
.map(str::to_owned)
.unwrap_or(fallback);
let error_code = map.get("error").and_then(|v| v.as_str()).map(str::to_owned);
(message, error_code)
}
pub(crate) fn parse_numeric_header(headers: &HeaderMap, name: &str) -> Option<u64> {
headers.get(name)?.to_str().ok()?.parse::<u64>().ok()
}