use crate::body::{Body, Bytes, HttpBody};
use crate::response::{IntoResponse, Response};
use crate::BoxError;
use axum_core::extract::{FromRequest, FromRequestParts};
use futures_util::future::BoxFuture;
use http::Request;
use std::{
any::type_name,
convert::Infallible,
fmt,
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;
pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T> {
map_request_with_state((), f)
}
pub fn map_request_with_state<F, S, T>(state: S, f: F) -> MapRequestLayer<F, S, T> {
MapRequestLayer {
f,
state,
_extractor: PhantomData,
}
}
#[must_use]
pub struct MapRequestLayer<F, S, T> {
f: F,
state: S,
_extractor: PhantomData<fn() -> T>,
}
impl<F, S, T> Clone for MapRequestLayer<F, S, T>
where
F: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
state: self.state.clone(),
_extractor: self._extractor,
}
}
}
impl<S, I, F, T> Layer<I> for MapRequestLayer<F, S, T>
where
F: Clone,
S: Clone,
{
type Service = MapRequest<F, S, I, T>;
fn layer(&self, inner: I) -> Self::Service {
MapRequest {
f: self.f.clone(),
state: self.state.clone(),
inner,
_extractor: PhantomData,
}
}
}
impl<F, S, T> fmt::Debug for MapRequestLayer<F, S, T>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MapRequestLayer")
.field("f", &format_args!("{}", type_name::<F>()))
.field("state", &self.state)
.finish()
}
}
pub struct MapRequest<F, S, I, T> {
f: F,
inner: I,
state: S,
_extractor: PhantomData<fn() -> T>,
}
impl<F, S, I, T> Clone for MapRequest<F, S, I, T>
where
F: Clone,
I: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
inner: self.inner.clone(),
state: self.state.clone(),
_extractor: self._extractor,
}
}
}
macro_rules! impl_service {
(
[$($ty:ident),*], $last:ident
) => {
#[allow(non_snake_case, unused_mut)]
impl<F, Fut, S, I, B, $($ty,)* $last> Service<Request<B>> for MapRequest<F, S, I, ($($ty,)* $last,)>
where
F: FnMut($($ty,)* $last) -> Fut + Clone + Send + 'static,
$( $ty: FromRequestParts<S> + Send, )*
$last: FromRequest<S> + Send,
Fut: Future + Send + 'static,
Fut::Output: IntoMapRequestResult<B> + Send + 'static,
I: Service<Request<B>, Error = Infallible>
+ Clone
+ Send
+ 'static,
I::Response: IntoResponse,
I::Future: Send + 'static,
B: HttpBody<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>,
S: Clone + Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
type Future = ResponseFuture;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let req = req.map(Body::new);
let not_ready_inner = self.inner.clone();
let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
let mut f = self.f.clone();
let state = self.state.clone();
let (mut parts, body) = req.into_parts();
let future = Box::pin(async move {
$(
let $ty = match $ty::from_request_parts(&mut parts, &state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
)*
let req = Request::from_parts(parts, body);
let $last = match $last::from_request(req, &state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
match f($($ty,)* $last).await.into_map_request_result() {
Ok(req) => {
ready_inner.call(req).await.into_response()
}
Err(res) => {
res
}
}
});
ResponseFuture {
inner: future
}
}
}
};
}
all_the_tuples!(impl_service);
impl<F, S, I, T> fmt::Debug for MapRequest<F, S, I, T>
where
S: fmt::Debug,
I: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MapRequest")
.field("f", &format_args!("{}", type_name::<F>()))
.field("inner", &self.inner)
.field("state", &self.state)
.finish()
}
}
pub struct ResponseFuture {
inner: BoxFuture<'static, Response>,
}
impl Future for ResponseFuture {
type Output = Result<Response, Infallible>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner.as_mut().poll(cx).map(Ok)
}
}
impl fmt::Debug for ResponseFuture {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ResponseFuture").finish()
}
}
mod private {
use crate::{http::Request, response::IntoResponse};
pub trait Sealed<B> {}
impl<B, E> Sealed<B> for Result<Request<B>, E> where E: IntoResponse {}
impl<B> Sealed<B> for Request<B> {}
}
pub trait IntoMapRequestResult<B>: private::Sealed<B> {
#[allow(clippy::result_large_err)]
fn into_map_request_result(self) -> Result<Request<B>, Response>;
}
impl<B, E> IntoMapRequestResult<B> for Result<Request<B>, E>
where
E: IntoResponse,
{
fn into_map_request_result(self) -> Result<Request<B>, Response> {
self.map_err(IntoResponse::into_response)
}
}
impl<B> IntoMapRequestResult<B> for Request<B> {
fn into_map_request_result(self) -> Result<Request<B>, Response> {
Ok(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{routing::get, test_helpers::TestClient, Router};
use http::{HeaderMap, StatusCode};
#[crate::test]
async fn works() {
async fn add_header<B>(mut req: Request<B>) -> Request<B> {
req.headers_mut().insert("x-foo", "foo".parse().unwrap());
req
}
async fn handler(headers: HeaderMap) -> Response {
headers["x-foo"]
.to_str()
.unwrap()
.to_owned()
.into_response()
}
let app = Router::new()
.route("/", get(handler))
.layer(map_request(add_header));
let client = TestClient::new(app);
let res = client.get("/").await;
assert_eq!(res.text().await, "foo");
}
#[crate::test]
async fn works_for_short_circutting() {
async fn add_header<B>(_req: Request<B>) -> Result<Request<B>, (StatusCode, &'static str)> {
Err((StatusCode::INTERNAL_SERVER_ERROR, "something went wrong"))
}
async fn handler(_headers: HeaderMap) -> Response {
unreachable!()
}
let app = Router::new()
.route("/", get(handler))
.layer(map_request(add_header));
let client = TestClient::new(app);
let res = client.get("/").await;
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(res.text().await, "something went wrong");
}
}