axum_htmx/
guard.rs

1//! Request guard for protecting a router against non-htmx requests.
2
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use futures_core::ready;
10use http::{Request, StatusCode, header::LOCATION, response::Response};
11use pin_project_lite::pin_project;
12use tower::{Layer, Service};
13
14use crate::HX_REQUEST;
15
16/// Checks if the request contains the `HX-Request` header, redirecting to the
17/// given location if not.
18///
19/// This can be useful for preventing users from accidently ending up on a route
20/// which would otherwise return only partial HTML data.
21#[derive(Debug, Clone)]
22pub struct HxRequestGuardLayer<'a> {
23    redirect_to: &'a str,
24}
25
26impl<'a> HxRequestGuardLayer<'a> {
27    pub fn new(redirect_to: &'a str) -> Self {
28        Self { redirect_to }
29    }
30}
31
32impl Default for HxRequestGuardLayer<'_> {
33    fn default() -> Self {
34        Self { redirect_to: "/" }
35    }
36}
37
38impl<'a, S> Layer<S> for HxRequestGuardLayer<'a> {
39    type Service = HxRequestGuard<'a, S>;
40
41    fn layer(&self, inner: S) -> Self::Service {
42        HxRequestGuard {
43            inner,
44            hx_request: false,
45            layer: self.clone(),
46        }
47    }
48}
49
50/// Tower service that implements redirecting to non-partial routes.
51#[derive(Debug, Clone)]
52pub struct HxRequestGuard<'a, S> {
53    inner: S,
54    hx_request: bool,
55    layer: HxRequestGuardLayer<'a>,
56}
57
58impl<'a, S, T, U> Service<Request<T>> for HxRequestGuard<'a, S>
59where
60    S: Service<Request<T>, Response = Response<U>>,
61    U: Default,
62{
63    type Response = S::Response;
64    type Error = S::Error;
65    type Future = private::ResponseFuture<'a, S::Future>;
66
67    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
68        self.inner.poll_ready(cx)
69    }
70
71    fn call(&mut self, req: Request<T>) -> Self::Future {
72        // This will always contain a "true" value.
73        if req.headers().contains_key(HX_REQUEST) {
74            self.hx_request = true;
75        }
76
77        let response_future = self.inner.call(req);
78
79        private::ResponseFuture {
80            response_future,
81            hx_request: self.hx_request,
82            layer: self.layer.clone(),
83        }
84    }
85}
86
87mod private {
88    use super::*;
89
90    pin_project! {
91        pub struct ResponseFuture<'a, F> {
92            #[pin]
93            pub(super) response_future: F,
94            pub(super) hx_request: bool,
95            pub(super) layer: HxRequestGuardLayer<'a>,
96        }
97    }
98
99    impl<F, B, E> Future for ResponseFuture<'_, F>
100    where
101        F: Future<Output = Result<Response<B>, E>>,
102        B: Default,
103    {
104        type Output = Result<Response<B>, E>;
105
106        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
107            let this = self.project();
108            let response: Response<B> = ready!(this.response_future.poll(cx))?;
109
110            match *this.hx_request {
111                true => Poll::Ready(Ok(response)),
112                false => {
113                    let res = Response::builder()
114                        .status(StatusCode::SEE_OTHER)
115                        .header(LOCATION, this.layer.redirect_to)
116                        .body(B::default())
117                        .expect("failed to build response");
118
119                    Poll::Ready(Ok(res))
120                }
121            }
122        }
123    }
124}