#![doc = include_str!("../docs/handlers_intro.md")]
#![doc = include_str!("../docs/debugging_handler_type_errors.md")]
#[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{
body::Body,
extract::{FromRequest, FromRequestParts},
response::{IntoResponse, Response},
routing::IntoMakeService,
};
use http::Request;
use std::{convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin};
use tower::ServiceExt;
use tower_layer::Layer;
use tower_service::Service;
pub mod future;
mod service;
pub use self::service::HandlerService;
#[doc = include_str!("../docs/debugging_handler_type_errors.md")]
#[cfg_attr(
nightly_error_messages,
rustc_on_unimplemented(
note = "Consider using `#[axum::debug_handler]` to improve the error message"
)
)]
pub trait Handler<T, S, B = Body>: Clone + Send + Sized + 'static {
type Future: Future<Output = Response> + Send + 'static;
fn call(self, req: Request<B>, state: S) -> Self::Future;
fn layer<L, NewReqBody>(self, layer: L) -> Layered<L, Self, T, S, B, NewReqBody>
where
L: Layer<HandlerService<Self, T, S, B>> + Clone,
L::Service: Service<Request<NewReqBody>>,
{
Layered {
layer,
handler: self,
_marker: PhantomData,
}
}
fn with_state(self, state: S) -> HandlerService<Self, T, S, B> {
HandlerService::new(self, state)
}
}
impl<F, Fut, Res, S, B> Handler<((),), S, B> for F
where
F: FnOnce() -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
Res: IntoResponse,
B: Send + 'static,
{
type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, _req: Request<B>, _state: S) -> Self::Future {
Box::pin(async move { self().await.into_response() })
}
}
macro_rules! impl_handler {
(
[$($ty:ident),*], $last:ident
) => {
#[allow(non_snake_case, unused_mut)]
impl<F, Fut, S, B, Res, M, $($ty,)* $last> Handler<(M, $($ty,)* $last,), S, B> for F
where
F: FnOnce($($ty,)* $last,) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
B: Send + 'static,
S: Send + Sync + 'static,
Res: IntoResponse,
$( $ty: FromRequestParts<S> + Send, )*
$last: FromRequest<S, B, M> + Send,
{
type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, req: Request<B>, state: S) -> Self::Future {
Box::pin(async move {
let (mut parts, body) = req.into_parts();
let state = &state;
$(
let $ty = match $ty::from_request_parts(&mut parts, state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
)*
let req = Request::from_parts(parts, body);
let $last = match $last::from_request(req, state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
let res = self($($ty,)* $last,).await;
res.into_response()
})
}
}
};
}
all_the_tuples!(impl_handler);
pub struct Layered<L, H, T, S, B, B2> {
layer: L,
handler: H,
_marker: PhantomData<fn() -> (T, S, B, B2)>,
}
impl<L, H, T, S, B, B2> fmt::Debug for Layered<L, H, T, S, B, B2>
where
L: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Layered")
.field("layer", &self.layer)
.finish()
}
}
impl<L, H, T, S, B, B2> Clone for Layered<L, H, T, S, B, B2>
where
L: Clone,
H: Clone,
{
fn clone(&self) -> Self {
Self {
layer: self.layer.clone(),
handler: self.handler.clone(),
_marker: PhantomData,
}
}
}
impl<H, S, T, L, B, B2> Handler<T, S, B2> for Layered<L, H, T, S, B, B2>
where
L: Layer<HandlerService<H, T, S, B>> + Clone + Send + 'static,
H: Handler<T, S, B>,
L::Service: Service<Request<B2>, Error = Infallible> + Clone + Send + 'static,
<L::Service as Service<Request<B2>>>::Response: IntoResponse,
<L::Service as Service<Request<B2>>>::Future: Send,
T: 'static,
S: 'static,
B: Send + 'static,
B2: Send + 'static,
{
type Future = future::LayeredFuture<B2, L::Service>;
fn call(self, req: Request<B2>, state: S) -> Self::Future {
use futures_util::future::{FutureExt, Map};
let svc = self.handler.with_state(state);
let svc = self.layer.layer(svc);
let future: Map<
_,
fn(
Result<
<L::Service as Service<Request<B2>>>::Response,
<L::Service as Service<Request<B2>>>::Error,
>,
) -> _,
> = svc.oneshot(req).map(|result| match result {
Ok(res) => res.into_response(),
Err(err) => match err {},
});
future::LayeredFuture::new(future)
}
}
pub trait HandlerWithoutStateExt<T, B>: Handler<T, (), B> {
fn into_service(self) -> HandlerService<Self, T, (), B>;
fn into_make_service(self) -> IntoMakeService<HandlerService<Self, T, (), B>>;
#[cfg(feature = "tokio")]
fn into_make_service_with_connect_info<C>(
self,
) -> IntoMakeServiceWithConnectInfo<HandlerService<Self, T, (), B>, C>;
}
impl<H, T, B> HandlerWithoutStateExt<T, B> for H
where
H: Handler<T, (), B>,
{
fn into_service(self) -> HandlerService<Self, T, (), B> {
self.with_state(())
}
fn into_make_service(self) -> IntoMakeService<HandlerService<Self, T, (), B>> {
self.into_service().into_make_service()
}
#[cfg(feature = "tokio")]
fn into_make_service_with_connect_info<C>(
self,
) -> IntoMakeServiceWithConnectInfo<HandlerService<Self, T, (), B>, C> {
self.into_service().into_make_service_with_connect_info()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{body, extract::State, test_helpers::*};
use http::StatusCode;
use std::time::Duration;
use tower_http::{
compression::CompressionLayer, limit::RequestBodyLimitLayer,
map_request_body::MapRequestBodyLayer, map_response_body::MapResponseBodyLayer,
timeout::TimeoutLayer,
};
#[crate::test]
async fn handler_into_service() {
async fn handle(body: String) -> impl IntoResponse {
format!("you said: {body}")
}
let client = TestClient::new(handle.into_service());
let res = client.post("/").body("hi there!").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "you said: hi there!");
}
#[crate::test]
async fn with_layer_that_changes_request_body_and_state() {
async fn handle(State(state): State<&'static str>) -> &'static str {
state
}
let svc = handle
.layer((
RequestBodyLimitLayer::new(1024),
TimeoutLayer::new(Duration::from_secs(10)),
MapResponseBodyLayer::new(body::boxed),
CompressionLayer::new(),
))
.layer(MapRequestBodyLayer::new(body::boxed))
.with_state("foo");
let client = TestClient::new(svc);
let res = client.get("/").send().await;
assert_eq!(res.text().await, "foo");
}
}