use crate::{
body::{self, Bytes, HttpBody},
response::{IntoResponse, Response},
BoxError,
};
use http::Request;
use pin_project_lite::pin_project;
use std::{
any::type_name,
convert::Infallible,
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::{util::BoxCloneService, ServiceBuilder};
use tower_http::ServiceBuilderExt;
use tower_layer::Layer;
use tower_service::Service;
pub fn from_fn<F>(f: F) -> FromFnLayer<F> {
FromFnLayer { f }
}
#[derive(Clone, Copy)]
pub struct FromFnLayer<F> {
f: F,
}
impl<S, F> Layer<S> for FromFnLayer<F>
where
F: Clone,
{
type Service = FromFn<F, S>;
fn layer(&self, inner: S) -> Self::Service {
FromFn {
f: self.f.clone(),
inner,
}
}
}
impl<F> fmt::Debug for FromFnLayer<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromFnLayer")
.field("f", &format_args!("{}", type_name::<F>()))
.finish()
}
}
#[derive(Clone, Copy)]
pub struct FromFn<F, S> {
f: F,
inner: S,
}
impl<F, Fut, Out, S, ReqBody, ResBody> Service<Request<ReqBody>> for FromFn<F, S>
where
F: FnMut(Request<ReqBody>, Next<ReqBody>) -> Fut,
Fut: Future<Output = Out>,
Out: IntoResponse,
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
type Response = Response;
type Error = Infallible;
type Future = ResponseFuture<Fut>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let not_ready_inner = self.inner.clone();
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
let inner = ServiceBuilder::new()
.boxed_clone()
.map_response_body(body::boxed)
.service(ready_inner);
let next = Next { inner };
ResponseFuture {
inner: (self.f)(req, next),
}
}
}
impl<F, S> fmt::Debug for FromFn<F, S>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromFnLayer")
.field("f", &format_args!("{}", type_name::<F>()))
.field("inner", &self.inner)
.finish()
}
}
pub struct Next<ReqBody> {
inner: BoxCloneService<Request<ReqBody>, Response, Infallible>,
}
impl<ReqBody> Next<ReqBody> {
pub async fn run(mut self, req: Request<ReqBody>) -> Response {
match self.inner.call(req).await {
Ok(res) => res,
Err(err) => match err {},
}
}
}
impl<ReqBody> fmt::Debug for Next<ReqBody> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromFnLayer")
.field("inner", &self.inner)
.finish()
}
}
pin_project! {
pub struct ResponseFuture<F> {
#[pin]
inner: F,
}
}
impl<F, Out> Future for ResponseFuture<F>
where
F: Future<Output = Out>,
Out: IntoResponse,
{
type Output = Result<Response, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project()
.inner
.poll(cx)
.map(IntoResponse::into_response)
.map(Ok)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{body::Empty, routing::get, Router};
use http::{HeaderMap, StatusCode};
use tower::ServiceExt;
#[tokio::test]
async fn basic() {
async fn insert_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
req.headers_mut()
.insert("x-axum-test", "ok".parse().unwrap());
next.run(req).await
}
async fn handle(headers: HeaderMap) -> String {
(&headers["x-axum-test"]).to_str().unwrap().to_owned()
}
let app = Router::new()
.route("/", get(handle))
.layer(from_fn(insert_header));
let res = app
.oneshot(
Request::builder()
.uri("/")
.body(body::boxed(Empty::new()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = hyper::body::to_bytes(res).await.unwrap();
assert_eq!(&body[..], b"ok");
}
}