use reqwest::Client as AsyncClient;
use reqwest::StatusCode;
use reqwest::blocking::Client;
use reqwest::header::CONTENT_TYPE;
use roas::loader::{AsyncResourceFetcher, FetchFuture, LoaderError, ResourceFetcher};
#[cfg(feature = "yaml")]
use serde::de::Error as _;
use serde_json::Value;
use std::time::Duration;
use url::Url;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Clone, Debug)]
pub struct Fetcher<C> {
client: C,
}
pub type HttpFetcher = Fetcher<Client>;
pub type AsyncHttpFetcher = Fetcher<AsyncClient>;
impl Fetcher<Client> {
pub fn new() -> Self {
Self::try_new().expect("default reqwest::blocking::Client must build")
}
pub fn try_new() -> Result<Self, reqwest::Error> {
Ok(Self::with_client(
Client::builder().timeout(DEFAULT_TIMEOUT).build()?,
))
}
pub fn with_client(client: Client) -> Self {
Self { client }
}
}
impl Default for Fetcher<Client> {
fn default() -> Self {
Self::new()
}
}
impl Fetcher<AsyncClient> {
pub fn new() -> Self {
Self::try_new().expect("default reqwest::Client must build")
}
pub fn try_new() -> Result<Self, reqwest::Error> {
Ok(Self::with_client(
AsyncClient::builder().timeout(DEFAULT_TIMEOUT).build()?,
))
}
pub fn with_client(client: AsyncClient) -> Self {
Self { client }
}
}
impl Default for Fetcher<AsyncClient> {
fn default() -> Self {
Self::new()
}
}
impl ResourceFetcher for Fetcher<Client> {
fn fetch(&mut self, uri: &Url) -> Result<Value, LoaderError> {
check_scheme(uri)?;
let response = self.client.get(uri.as_str()).send().map_err(|source| {
fetch_error(uri.as_str().to_string(), HttpFetchError::Request { source })
})?;
let status = response.status();
if !status.is_success() {
return Err(fetch_error(
uri.as_str().to_string(),
HttpFetchError::Status { status },
));
}
let content_type = response
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let bytes = response.bytes().map_err(|source| {
fetch_error(uri.as_str().to_string(), HttpFetchError::Body { source })
})?;
parse_body(uri, content_type.as_deref(), &bytes)
}
}
impl AsyncResourceFetcher for Fetcher<AsyncClient> {
fn fetch<'a>(&'a mut self, uri: &'a Url) -> FetchFuture<'a> {
let client = self.client.clone();
Box::pin(async move {
check_scheme(uri)?;
let response = client.get(uri.as_str()).send().await.map_err(|source| {
fetch_error(uri.as_str().to_string(), HttpFetchError::Request { source })
})?;
let status = response.status();
if !status.is_success() {
return Err(fetch_error(
uri.as_str().to_string(),
HttpFetchError::Status { status },
));
}
let content_type = response
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let bytes = response.bytes().await.map_err(|source| {
fetch_error(uri.as_str().to_string(), HttpFetchError::Body { source })
})?;
parse_body(uri, content_type.as_deref(), &bytes)
})
}
}
fn check_scheme(uri: &Url) -> Result<(), LoaderError> {
match uri.scheme() {
"http" | "https" => Ok(()),
_ => Err(LoaderError::UnsupportedFetcherUri(uri.as_str().to_string())),
}
}
fn parse_body(uri: &Url, content_type: Option<&str>, bytes: &[u8]) -> Result<Value, LoaderError> {
if is_yaml(content_type, uri) {
parse_yaml(uri, bytes)
} else {
serde_json::from_slice(bytes).map_err(|source| LoaderError::Parse {
uri: uri.as_str().to_string(),
source,
})
}
}
#[allow(unused_variables)]
fn is_yaml(content_type: Option<&str>, uri: &Url) -> bool {
#[cfg(feature = "yaml")]
{
if let Some(ct) = content_type {
let mime = ct
.split(';')
.next()
.unwrap_or("")
.trim()
.to_ascii_lowercase();
if mime.contains("yaml") {
return true;
}
if !mime.is_empty() && mime != "application/octet-stream" {
return false;
}
}
let path = uri.path().to_ascii_lowercase();
path.ends_with(".yaml") || path.ends_with(".yml")
}
#[cfg(not(feature = "yaml"))]
{
false
}
}
#[cfg(feature = "yaml")]
fn parse_yaml(uri: &Url, bytes: &[u8]) -> Result<Value, LoaderError> {
serde_yaml_ng::from_slice(bytes).map_err(|yaml_err| LoaderError::Parse {
uri: uri.as_str().to_string(),
source: serde_json::Error::custom(yaml_err.to_string()),
})
}
#[cfg(not(feature = "yaml"))]
#[allow(dead_code)]
fn parse_yaml(_uri: &Url, _bytes: &[u8]) -> Result<Value, LoaderError> {
unreachable!("parse_yaml is only reached when the `yaml` feature is enabled")
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum HttpFetchError {
#[error("HTTP request failed")]
Request {
#[source]
source: reqwest::Error,
},
#[error("non-success HTTP response: {status}")]
Status { status: StatusCode },
#[error("failed to read response body")]
Body {
#[source]
source: reqwest::Error,
},
}
fn fetch_error(uri: String, source: HttpFetchError) -> LoaderError {
LoaderError::Fetch {
uri,
source: Box::new(source),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn http_fetcher_default_constructs() {
let _ = HttpFetcher::default();
let _ = HttpFetcher::new();
let _ = AsyncHttpFetcher::default();
let _ = AsyncHttpFetcher::new();
}
#[test]
fn http_fetcher_try_new_succeeds_for_default_config() {
HttpFetcher::try_new().expect("blocking client must build");
AsyncHttpFetcher::try_new().expect("async client must build");
}
#[test]
fn http_fetcher_is_clone_and_shares_pool() {
let fetcher = HttpFetcher::new();
let _second = fetcher.clone();
let async_fetcher = AsyncHttpFetcher::new();
let _async_second = async_fetcher.clone();
}
#[test]
fn fetch_error_helper_boxes_into_loader_error_fetch() {
let inner = HttpFetchError::Status {
status: StatusCode::NOT_FOUND,
};
let err = fetch_error("https://example.test/x.json".into(), inner);
match err {
LoaderError::Fetch { uri, source } => {
assert_eq!(uri, "https://example.test/x.json");
let downcast = source
.downcast_ref::<HttpFetchError>()
.expect("source must downcast to HttpFetchError");
assert!(matches!(
downcast,
HttpFetchError::Status {
status: StatusCode::NOT_FOUND
}
));
}
other => panic!("expected LoaderError::Fetch, got {other:?}"),
}
}
#[test]
fn check_scheme_accepts_http_and_https() {
check_scheme(&Url::parse("http://example.test/x.json").unwrap()).unwrap();
check_scheme(&Url::parse("https://example.test/x.json").unwrap()).unwrap();
}
#[test]
fn check_scheme_rejects_file_uri_with_unsupported_fetcher_uri() {
let err = check_scheme(&Url::parse("file:///tmp/x.json").unwrap())
.expect_err("file:// must be rejected");
assert!(matches!(err, LoaderError::UnsupportedFetcherUri(s) if s.starts_with("file://")));
}
}