use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use async_trait::async_trait;
use bytes::Bytes;
use http::Response as HttpResponse;
use http_body::Body;
use tower::{Service, ServiceExt};
use crate::core::next::Next;
use crate::core::req_body::ReqBody;
use crate::core::res_body::ResBody;
use crate::error::BoxedError;
use crate::{Handler, MiddleWareHandler, Request, Response, SilentError, StatusCode};
#[derive(Clone)]
struct SilentExtras {
state: crate::State,
path_params: std::collections::HashMap<String, crate::core::path_param::PathParam>,
}
#[doc(hidden)]
pub struct TowerLayerAdapter<L> {
layer: L,
}
impl<L> TowerLayerAdapter<L> {
#[doc(hidden)]
pub fn new(layer: L) -> Self {
Self { layer }
}
}
#[derive(Clone)]
#[doc(hidden)]
pub struct NextServicePublic {
pub(crate) next: Next,
}
impl Service<http::Request<ReqBody>> for NextServicePublic {
type Response = HttpResponse<ResBody>;
type Error = BoxedError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
let next = self.next.clone();
Box::pin(async move {
let silent_req = from_http_request(req);
let silent_res = next
.call(silent_req)
.await
.map_err(|e| -> BoxedError { Box::new(e) })?;
Ok(into_http_response(silent_res))
})
}
}
#[async_trait]
impl<L> MiddleWareHandler for TowerLayerAdapter<L>
where
L: tower::Layer<NextServicePublic> + Clone + Send + Sync + 'static,
L::Service: Service<http::Request<ReqBody>> + Clone + Send + 'static,
<L::Service as Service<http::Request<ReqBody>>>::Response: IntoSilentResponse + Send,
<L::Service as Service<http::Request<ReqBody>>>::Error: Into<BoxedError> + Send,
<L::Service as Service<http::Request<ReqBody>>>::Future: Send,
{
async fn handle(&self, req: Request, next: &Next) -> crate::Result<Response> {
let next_svc = NextServicePublic { next: next.clone() };
let svc = self.layer.clone().layer(next_svc);
let http_req = into_http_request(req);
let tower_res = svc.oneshot(http_req).await.map_err(|e| {
let err: BoxedError = e.into();
SilentError::business_error(StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
})?;
Ok(tower_res.into_silent_response())
}
}
fn into_http_request(req: Request) -> http::Request<ReqBody> {
let state = req.state();
let path_params = req.path_params().clone();
let mut http_req = req.into_http();
http_req
.extensions_mut()
.insert(SilentExtras { state, path_params });
http_req
}
fn from_http_request(mut req: http::Request<ReqBody>) -> Request {
let extras = req.extensions_mut().remove::<SilentExtras>();
let (parts, body) = req.into_parts();
let mut silent_req = Request::from_parts(parts, body);
if let Some(extras) = extras {
*silent_req.state_mut() = extras.state;
for (key, value) in extras.path_params {
silent_req.set_path_params(key, value);
}
}
silent_req
}
fn into_http_response(mut res: Response) -> HttpResponse<ResBody> {
let body = res.take_body();
let mut builder = HttpResponse::builder()
.status(res.status)
.version(res.version);
if let Some(headers) = builder.headers_mut() {
*headers = std::mem::take(&mut res.headers);
}
builder.body(body).unwrap()
}
#[doc(hidden)]
pub trait IntoSilentResponse {
fn into_silent_response(self) -> Response;
}
impl<B> IntoSilentResponse for HttpResponse<B>
where
B: Body<Data = Bytes> + Send + 'static,
B::Error: Into<BoxedError> + 'static,
{
fn into_silent_response(self) -> Response {
use http_body_util::BodyExt;
let (parts, body) = self.into_parts();
let mapped = body.map_err(|e| -> BoxedError { e.into() });
let res_body = ResBody::Boxed(Box::pin(mapped));
let mut res = Response::empty();
res.set_status(parts.status);
res.version = parts.version;
res.headers = parts.headers;
res.set_body(res_body);
res
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Handler;
use crate::route::Route;
#[derive(Clone)]
struct AddHeaderLayer {
name: &'static str,
value: &'static str,
}
impl<S: Clone> tower::Layer<S> for AddHeaderLayer {
type Service = AddHeaderService<S>;
fn layer(&self, inner: S) -> Self::Service {
AddHeaderService {
inner,
name: self.name,
value: self.value,
}
}
}
#[derive(Clone)]
struct AddHeaderService<S> {
inner: S,
name: &'static str,
value: &'static str,
}
impl<S> Service<http::Request<ReqBody>> for AddHeaderService<S>
where
S: Service<http::Request<ReqBody>, Response = HttpResponse<ResBody>, Error = BoxedError>
+ Clone
+ Send
+ 'static,
S::Future: Send,
{
type Response = HttpResponse<ResBody>;
type Error = BoxedError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
let mut inner = self.inner.clone();
let name = self.name;
let value = self.value;
Box::pin(async move {
let mut res = inner.call(req).await?;
res.headers_mut()
.insert(name, http::HeaderValue::from_static(value));
Ok(res)
})
}
}
#[tokio::test]
async fn test_tower_layer_adds_header() {
let layer = AddHeaderLayer {
name: "x-tower-test",
value: "hello",
};
let route = Route::new("")
.hook_layer(layer)
.get(|_req: Request| async { Ok("ok") });
let route = Route::new_root().append(route);
let req = Request::empty();
let res = route.call(req).await.unwrap();
assert_eq!(
res.headers().get("x-tower-test").unwrap().to_str().unwrap(),
"hello"
);
}
#[tokio::test]
async fn test_tower_layer_preserves_state() {
let layer = AddHeaderLayer {
name: "x-check",
value: "passed",
};
let route =
Route::new("")
.with_state(42i32)
.hook_layer(layer)
.get(|req: Request| async move {
let num = req.get_state::<i32>()?;
Ok(format!("state={}", num))
});
let route = Route::new_root().append(route);
let req = Request::empty();
let res = route.call(req).await.unwrap();
assert!(res.headers().get("x-check").is_some());
}
#[tokio::test]
async fn test_tower_layer_chain() {
let layer1 = AddHeaderLayer {
name: "x-first",
value: "1",
};
let layer2 = AddHeaderLayer {
name: "x-second",
value: "2",
};
let route = Route::new("")
.hook_layer(layer1)
.hook_layer(layer2)
.get(|_req: Request| async { Ok("chained") });
let route = Route::new_root().append(route);
let req = Request::empty();
let res = route.call(req).await.unwrap();
assert_eq!(res.headers().get("x-first").unwrap().to_str().unwrap(), "1");
assert_eq!(
res.headers().get("x-second").unwrap().to_str().unwrap(),
"2"
);
}
}