use crate::error::FetchError;
use async_trait::async_trait;
use bytes::Bytes;
use futures::stream::{Stream, StreamExt};
use std::net::SocketAddr;
use std::pin::Pin;
use std::time::Duration;
use url::Url;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransportMethod {
Get,
Head,
}
pub struct TransportRequest {
pub method: TransportMethod,
pub url: Url,
pub headers: Vec<(String, String)>,
pub timeout: Option<Duration>,
pub pinned_addrs: Vec<SocketAddr>,
pub respect_proxy_env: bool,
}
pub type BodyStream = Pin<Box<dyn Stream<Item = Result<Bytes, TransportError>> + Send>>;
pub struct TransportResponse {
pub status: u16,
pub url: Url,
pub headers: Vec<(String, String)>,
pub body: BodyStream,
}
#[derive(Debug, thiserror::Error)]
pub enum TransportError {
#[error("transport connect error")]
Connect,
#[error("transport timeout")]
Timeout,
#[error("transport request error: {0}")]
Request(String),
#[error("transport error: {0}")]
Other(String),
#[error("transport reqwest error")]
Reqwest(#[source] reqwest::Error),
}
impl From<TransportError> for FetchError {
fn from(err: TransportError) -> Self {
match err {
TransportError::Reqwest(e) => FetchError::from_reqwest(e),
TransportError::Connect => {
FetchError::RequestError("failed to connect to server".to_string())
}
TransportError::Timeout => FetchError::FirstByteTimeout,
TransportError::Request(msg) => FetchError::RequestError(msg),
TransportError::Other(msg) => FetchError::RequestError(msg),
}
}
}
#[async_trait]
pub trait HttpTransport: Send + Sync {
async fn execute(&self, req: TransportRequest) -> Result<TransportResponse, TransportError>;
}
#[derive(Debug, Default, Clone)]
pub struct ReqwestTransport;
impl ReqwestTransport {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl HttpTransport for ReqwestTransport {
async fn execute(&self, req: TransportRequest) -> Result<TransportResponse, TransportError> {
let mut builder = reqwest::Client::builder().redirect(reqwest::redirect::Policy::none());
if let Some(timeout) = req.timeout {
builder = builder.connect_timeout(timeout).timeout(timeout);
}
if !req.respect_proxy_env {
builder = builder.no_proxy();
}
if !req.pinned_addrs.is_empty() {
if let Some(host) = req.url.host_str() {
builder = builder.resolve_to_addrs(host, &req.pinned_addrs);
}
}
let client = builder
.build()
.map_err(|e| TransportError::Other(e.to_string()))?;
let method = match req.method {
TransportMethod::Get => reqwest::Method::GET,
TransportMethod::Head => reqwest::Method::HEAD,
};
let mut request = client.request(method, req.url.clone());
for (name, value) in &req.headers {
request = request.header(name, value);
}
let response = request.send().await.map_err(TransportError::Reqwest)?;
let status = response.status().as_u16();
let final_url = response.url().clone();
let headers = response
.headers()
.iter()
.filter_map(|(name, value)| {
value
.to_str()
.ok()
.map(|v| (name.as_str().to_string(), v.to_string()))
})
.collect();
let body: BodyStream = Box::pin(
response
.bytes_stream()
.map(|chunk| chunk.map_err(TransportError::Reqwest)),
);
Ok(TransportResponse {
status,
url: final_url,
headers,
body,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_error_maps_to_fetch_error() {
assert!(matches!(
FetchError::from(TransportError::Timeout),
FetchError::FirstByteTimeout
));
assert!(matches!(
FetchError::from(TransportError::Connect),
FetchError::RequestError(_)
));
assert!(matches!(
FetchError::from(TransportError::Request("x".into())),
FetchError::RequestError(_)
));
}
}