use std::string::FromUtf8Error;
use thiserror::Error;
use reqwest::StatusCode;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum TransportError {
#[error("invalid configuration: {0}")]
InvalidConfig(&'static str),
#[error("failed to build http client: {0}")]
BuildClient(#[source] reqwest::Error),
#[error("request failed: {0}")]
Transport(#[source] reqwest::Error),
#[error("api returned {status}")]
HttpStatus { status: StatusCode, body: String },
#[error("failed to deserialize response body: {source}")]
Deserialize {
#[source]
source: serde_json::Error,
body: String,
},
#[error("failed to decode streamed response as UTF-8: {0}")]
Utf8(#[source] FromUtf8Error),
#[error("response body exceeded {limit}-byte limit")]
BodyTooLarge { limit: usize },
#[error("invalid response: {0}")]
InvalidResponse(String),
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum ProviderError {
#[error("transport error: {0}")]
Transport(#[from] TransportError),
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("serialization failed: {0}")]
Serialize(#[from] serde_json::Error),
#[error("failed to deserialize response body: {source}")]
Deserialize {
#[source]
source: serde_json::Error,
body: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use std::error::Error as StdError;
#[test]
fn transport_error_reachable_via_source_chain() {
let te = TransportError::HttpStatus {
status: StatusCode::TOO_MANY_REQUESTS,
body: "rate limited".into(),
};
let pe: ProviderError = te.into();
let mut cur: Option<&(dyn StdError + 'static)> = Some(&pe);
let mut found = None;
while let Some(e) = cur {
if let Some(t) = e.downcast_ref::<TransportError>() {
found = Some(t);
break;
}
cur = e.source();
}
let found = found.expect("TransportError must be reachable via the source chain");
assert!(
matches!(
found,
TransportError::HttpStatus { status, .. } if *status == StatusCode::TOO_MANY_REQUESTS
),
"expected HttpStatus 429, got {found:?}"
);
}
}