use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use crate::{ApiClient, Request, Server, body::Body};
#[derive(Debug, Clone)]
pub struct ReqwestService {
client: reqwest::Client,
}
impl ReqwestService {
pub fn new(client: reqwest::Client) -> Self {
Self { client }
}
pub fn client(&self) -> &reqwest::Client {
&self.client
}
}
impl Default for ReqwestService {
fn default() -> Self {
Self::new(reqwest::Client::new())
}
}
impl From<reqwest::Client> for ReqwestService {
fn from(client: reqwest::Client) -> Self {
Self::new(client)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ReqwestError {
#[error("invalid request: {0}")]
InvalidRequest(#[source] crate::BoxError),
#[error("reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
}
impl tower::Service<Request> for ReqwestService {
type Response = ::http::Response<reqwest::Body>;
type Error = ReqwestError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request) -> Self::Future {
let client = self.client.clone();
Box::pin(async move {
let reqwest_req = build_reqwest_request(req)?;
let resp = client.execute(reqwest_req).await?;
Ok(into_http_response(resp))
})
}
}
fn build_reqwest_request(req: Request) -> Result<reqwest::Request, ReqwestError> {
let (parts, body) = req.into_parts();
let http_req = ::http::Request::from_parts(parts, reqwest::Body::wrap(body));
reqwest::Request::try_from(http_req).map_err(|e| ReqwestError::InvalidRequest(Box::new(e)))
}
fn into_http_response(resp: reqwest::Response) -> ::http::Response<reqwest::Body> {
let status = resp.status();
let version = resp.version();
let mut builder = ::http::Response::builder().status(status).version(version);
if let Some(headers) = builder.headers_mut() {
*headers = resp.headers().clone();
}
if let Some(extensions) = builder.extensions_mut() {
*extensions = resp.extensions().clone();
}
builder
.body(reqwest::Body::from(resp))
.expect("status/version/headers copied from a valid reqwest response")
}
const _: fn() = || {
fn assert_send_sync<T: Send + Sync + 'static>() {}
assert_send_sync::<Body>();
};
impl ApiClient<ReqwestService> {
pub fn new_reqwest<Srv: Server>(client: reqwest::Client, server: Srv) -> Self {
ApiClient::new(ReqwestService::new(client), server)
}
}