use std::task::{Context as TaskContext, Poll};
use bytes::Bytes;
use http::{Request, Response};
use tower::{Layer, Service};
use crate::BoxError;
use crate::BoxFuture;
#[derive(Clone)]
pub struct ReqwestService {
client: reqwest::Client,
}
impl ReqwestService {
pub fn new(client: reqwest::Client) -> Self {
Self { client }
}
}
impl std::fmt::Debug for ReqwestService {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReqwestService").finish()
}
}
impl Service<Request<Bytes>> for ReqwestService {
type Response = Response<Bytes>;
type Error = BoxError;
type Future = BoxFuture<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Bytes>) -> Self::Future {
let client = self.client.clone();
Box::pin(async move {
let reqwest_req = request_to_reqwest(&client, req)?;
let resp = client.execute(reqwest_req).await?;
let http_resp = response_from_reqwest(resp).await?;
Ok(http_resp)
})
}
}
#[derive(Clone, Debug)]
pub struct ReqwestLayer {
client: reqwest::Client,
}
impl ReqwestLayer {
pub fn new(client: reqwest::Client) -> Self {
Self { client }
}
}
impl<S> Layer<S> for ReqwestLayer {
type Service = ReqwestService;
fn layer(&self, _inner: S) -> Self::Service {
ReqwestService::new(self.client.clone())
}
}
fn request_to_reqwest(
client: &reqwest::Client,
req: Request<Bytes>,
) -> Result<reqwest::Request, BoxError> {
let (parts, body) = req.into_parts();
let url = reqwest::Url::parse(&parts.uri.to_string())?;
let mut reqwest_req = client.request(parts.method, url);
for (name, value) in &parts.headers {
reqwest_req = reqwest_req.header(name.as_str(), value.as_bytes());
}
reqwest_req = reqwest_req.body(reqwest::Body::from(body));
Ok(reqwest_req.build()?)
}
async fn response_from_reqwest(resp: reqwest::Response) -> Result<Response<Bytes>, BoxError> {
let mut builder = Response::builder().status(resp.status());
for (name, value) in resp.headers() {
builder = builder.header(name.as_str(), value.as_bytes());
}
let body = resp.bytes().await?;
Ok(builder.body(body)?)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{routing::get, Router};
use bytes::Bytes;
use http::{Request, StatusCode};
use std::time::Duration;
use tower::ServiceExt;
async fn spawn_echo_server() -> std::net::SocketAddr {
let app = Router::new()
.route(
"/",
get(|| async { "hello" }),
)
.route(
"/headers",
get(|headers: axum::http::HeaderMap| async move {
let value = headers
.get("x-test")
.and_then(|v| v.to_str().ok())
.unwrap_or("missing");
value.to_string()
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
tokio::time::sleep(Duration::from_millis(50)).await;
addr
}
#[tokio::test]
async fn test_reqwest_service_hits_local_server() {
let addr = spawn_echo_server().await;
let client = reqwest::Client::new();
let mut service = ReqwestService::new(client);
let req = Request::get(format!("http://{addr}/"))
.body(Bytes::new())
.unwrap();
let resp = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().as_ref(), b"hello");
}
#[tokio::test]
async fn test_reqwest_service_preserves_headers() {
let addr = spawn_echo_server().await;
let client = reqwest::Client::new();
let mut service = ReqwestService::new(client);
let req = Request::get(format!("http://{addr}/headers"))
.header("x-test", "value")
.body(Bytes::new())
.unwrap();
let resp = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.into_body().as_ref(), b"value");
}
#[tokio::test]
async fn test_reqwest_layer() {
let addr = spawn_echo_server().await;
let client = reqwest::Client::new();
let layer = ReqwestLayer::new(client);
let mut service = layer.layer(tower::service_fn(|_req: Request<Bytes>| async {
Ok::<_, BoxError>(Response::new(Bytes::new()))
}));
let req = Request::get(format!("http://{addr}/"))
.body(Bytes::new())
.unwrap();
let resp = service.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
}