use std::task::{Context, Poll};
use bytes::Bytes;
use futures_core::future::BoxFuture;
use futures_util::TryFutureExt;
use http_body_util::BodyExt;
use http_body_util::combinators::BoxBody;
use tower::Service;
use tower_sessions::{MemoryStore, SessionManagerLayer};
use crate::error::ErrorRepr;
use crate::project::MiddlewareContext;
use crate::request::Request;
use crate::response::Response;
use crate::{Body, Error};
#[derive(Debug, Copy, Clone)]
pub struct IntoCotResponseLayer;
impl IntoCotResponseLayer {
#[must_use]
pub fn new() -> Self {
Self {}
}
}
impl Default for IntoCotResponseLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> tower::Layer<S> for IntoCotResponseLayer {
type Service = IntoCotResponse<S>;
fn layer(&self, inner: S) -> Self::Service {
IntoCotResponse { inner }
}
}
#[derive(Debug, Clone)]
pub struct IntoCotResponse<S> {
inner: S,
}
impl<S, ResBody, E> Service<Request> for IntoCotResponse<S>
where
S: Service<Request, Response = http::Response<ResBody>>,
ResBody: http_body::Body<Data = Bytes, Error = E> + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = futures_util::future::MapOk<S::Future, fn(http::Response<ResBody>) -> Response>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
#[inline]
fn call(&mut self, request: Request) -> Self::Future {
self.inner.call(request).map_ok(map_response)
}
}
fn map_response<ResBody, E>(response: http::response::Response<ResBody>) -> Response
where
ResBody: http_body::Body<Data = Bytes, Error = E> + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
{
response.map(|body| Body::wrapper(BoxBody::new(body.map_err(map_err))))
}
#[derive(Debug, Copy, Clone)]
pub struct IntoCotErrorLayer;
impl IntoCotErrorLayer {
#[must_use]
pub fn new() -> Self {
Self {}
}
}
impl Default for IntoCotErrorLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> tower::Layer<S> for IntoCotErrorLayer {
type Service = IntoCotError<S>;
fn layer(&self, inner: S) -> Self::Service {
IntoCotError { inner }
}
}
#[derive(Debug, Clone)]
pub struct IntoCotError<S> {
inner: S,
}
impl<S> Service<Request> for IntoCotError<S>
where
S: Service<Request>,
<S as Service<Request>>::Error: std::error::Error + Send + Sync + 'static,
{
type Response = S::Response;
type Error = Error;
type Future = futures_util::future::MapErr<S::Future, fn(S::Error) -> Error>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(map_err)
}
#[inline]
fn call(&mut self, request: Request) -> Self::Future {
self.inner.call(request).map_err(map_err)
}
}
fn map_err<E>(error: E) -> Error
where
E: std::error::Error + Send + Sync + 'static,
{
Error::new(ErrorRepr::MiddlewareWrapped {
source: Box::new(error),
})
}
#[derive(Debug, Clone)]
pub struct SessionMiddleware {
inner: SessionManagerLayer<MemoryStore>,
}
impl SessionMiddleware {
#[must_use]
pub fn new() -> Self {
let store = MemoryStore::default();
let layer = SessionManagerLayer::new(store);
Self { inner: layer }
}
#[must_use]
pub fn from_context(context: &MiddlewareContext) -> Self {
Self::new().secure(context.config().middlewares.session.secure)
}
#[must_use]
pub fn secure(self, secure: bool) -> Self {
Self {
inner: self.inner.with_secure(secure),
}
}
}
impl Default for SessionMiddleware {
fn default() -> Self {
Self::new()
}
}
impl<S> tower::Layer<S> for SessionMiddleware {
type Service = <SessionManagerLayer<MemoryStore> as tower::Layer<
<SessionWrapperLayer as tower::Layer<S>>::Service,
>>::Service;
fn layer(&self, inner: S) -> Self::Service {
let session_wrapper_layer = SessionWrapperLayer::new();
let layers = (&self.inner, session_wrapper_layer);
layers.layer(inner)
}
}
#[derive(Debug, Copy, Clone)]
pub struct SessionWrapperLayer;
impl SessionWrapperLayer {
#[must_use]
pub fn new() -> Self {
Self {}
}
}
impl Default for SessionWrapperLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> tower::Layer<S> for SessionWrapperLayer {
type Service = SessionWrapper<S>;
fn layer(&self, inner: S) -> Self::Service {
SessionWrapper { inner }
}
}
#[derive(Debug, Clone)]
pub struct SessionWrapper<S> {
inner: S,
}
impl<ReqBody, ResBody, S> Service<http::Request<ReqBody>> for SessionWrapper<S>
where
S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone + Send + 'static,
S::Future: Send,
ReqBody: Send + 'static,
ResBody: Default + Send,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
let session = req
.extensions_mut()
.remove::<tower_sessions::Session>()
.expect("session extension must be present");
let session_wrapped = crate::session::Session::new(session);
req.extensions_mut().insert(session_wrapped);
self.inner.call(req)
}
}
#[derive(Debug, Copy, Clone)]
pub struct AuthMiddleware;
impl AuthMiddleware {
#[must_use]
pub fn new() -> Self {
Self {}
}
}
impl Default for AuthMiddleware {
fn default() -> Self {
Self::new()
}
}
impl<S> tower::Layer<S> for AuthMiddleware {
type Service = AuthService<S>;
fn layer(&self, inner: S) -> Self::Service {
AuthService::new(inner)
}
}
#[derive(Debug, Clone)]
pub struct AuthService<S> {
inner: S,
}
impl<S> AuthService<S> {
fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S> Service<Request> for AuthService<S>
where
S: Service<Request, Response = Response, Error = Error> + Clone + Send + 'static,
S::Future: Send,
{
type Response = S::Response;
type Error = Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request) -> Self::Future {
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
Box::pin(async move {
let auth = crate::auth::Auth::from_request(&mut req).await?;
req.extensions_mut().insert(auth);
inner.call(req).await
})
}
}
#[cfg(feature = "live-reload")]
type LiveReloadLayerType = tower::util::Either<
(
IntoCotErrorLayer,
IntoCotResponseLayer,
tower_livereload::LiveReloadLayer,
),
tower::layer::util::Identity,
>;
#[cfg(feature = "live-reload")]
#[derive(Debug, Clone)]
pub struct LiveReloadMiddleware(LiveReloadLayerType);
#[cfg(feature = "live-reload")]
impl LiveReloadMiddleware {
#[must_use]
pub fn new() -> Self {
Self::with_enabled(true)
}
#[must_use]
pub fn from_context(context: &MiddlewareContext) -> Self {
Self::with_enabled(context.config().middlewares.live_reload.enabled)
}
fn with_enabled(enabled: bool) -> Self {
let option_layer = enabled.then(|| {
(
IntoCotErrorLayer::new(),
IntoCotResponseLayer::new(),
tower_livereload::LiveReloadLayer::new(),
)
});
Self(tower::util::option_layer(option_layer))
}
}
#[cfg(feature = "live-reload")]
impl Default for LiveReloadMiddleware {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "live-reload")]
impl<S> tower::Layer<S> for LiveReloadMiddleware {
type Service = <LiveReloadLayerType as tower::Layer<S>>::Service;
fn layer(&self, inner: S) -> Self::Service {
self.0.layer(inner)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use http::Request;
use tower::{Layer, ServiceExt};
use super::*;
use crate::auth::Auth;
use crate::session::Session;
use crate::test::TestRequestBuilder;
#[tokio::test]
async fn session_middleware_adds_session() {
let svc = tower::service_fn(|req: Request<Body>| async move {
assert!(req.extensions().get::<Session>().is_some());
Ok::<_, Error>(Response::new(Body::empty()))
});
let mut svc = SessionMiddleware::new().layer(svc);
let request = TestRequestBuilder::get("/").build();
svc.ready().await.unwrap().call(request).await.unwrap();
}
#[tokio::test]
async fn session_middleware_adds_cookie() {
let svc = tower::service_fn(|req: Request<Body>| async move {
let session = req.extensions().get::<Session>().unwrap();
session.insert("test", "test").await.unwrap();
Ok::<_, Error>(Response::new(Body::empty()))
});
let mut svc = SessionMiddleware::new().layer(svc);
let request = TestRequestBuilder::get("/").build();
let response = svc.ready().await.unwrap().call(request).await.unwrap();
assert!(response.headers().contains_key("set-cookie"));
let cookie_value = response
.headers()
.get("set-cookie")
.unwrap()
.to_str()
.unwrap();
assert!(cookie_value.contains("id="));
assert!(cookie_value.contains("HttpOnly;"));
assert!(cookie_value.contains("SameSite=Strict;"));
assert!(cookie_value.contains("Secure;"));
assert!(cookie_value.contains("Path=/"));
}
#[tokio::test]
async fn session_middleware_adds_cookie_not_secure() {
let svc = tower::service_fn(|req: Request<Body>| async move {
let session = req.extensions().get::<Session>().unwrap();
session.insert("test", "test").await.unwrap();
Ok::<_, Error>(Response::new(Body::empty()))
});
let mut svc = SessionMiddleware::new().secure(false).layer(svc);
let request = TestRequestBuilder::get("/").build();
let response = svc.ready().await.unwrap().call(request).await.unwrap();
let cookie_value = response
.headers()
.get("set-cookie")
.unwrap()
.to_str()
.unwrap();
assert!(!cookie_value.contains("Secure;"));
}
#[tokio::test]
async fn auth_middleware_adds_auth() {
let svc = tower::service_fn(|req: Request<Body>| async move {
let auth = req
.extensions()
.get::<Auth>()
.expect("Auth should be present");
assert!(!auth.user().is_authenticated());
Ok::<_, Error>(Response::new(Body::empty()))
});
let mut svc = AuthMiddleware::new().layer(svc);
let request = TestRequestBuilder::get("/").with_session().build();
svc.ready().await.unwrap().call(request).await.unwrap();
}
#[tokio::test]
#[should_panic(
expected = "Session extension missing. Did you forget to add the SessionMiddleware?"
)]
async fn auth_middleware_requires_session() {
let svc = tower::service_fn(|_req: Request<Body>| async move {
Ok::<_, Error>(Response::new(Body::empty()))
});
let mut svc = AuthMiddleware::new().layer(svc);
let request = TestRequestBuilder::get("/").build();
let _result = svc.ready().await.unwrap().call(request).await;
}
#[tokio::test]
async fn auth_service_cloning() {
let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let counter_clone = counter.clone();
let svc = tower::service_fn(move |req: Request<Body>| {
let counter = counter_clone.clone();
async move {
counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
assert!(req.extensions().get::<Auth>().is_some());
Ok::<_, Error>(Response::new(Body::empty()))
}
});
let mut svc = AuthMiddleware::new().layer(svc);
let svc = svc.ready().await.unwrap();
let request1 = TestRequestBuilder::get("/").with_session().build();
let request2 = TestRequestBuilder::get("/").with_session().build();
let (res1, res2) = tokio::join!(svc.clone().call(request1), svc.call(request2));
assert!(res1.is_ok());
assert!(res2.is_ok());
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 2);
}
}