1use std::{
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use futures_core::ready;
10use http::{Request, StatusCode, header::LOCATION, response::Response};
11use pin_project_lite::pin_project;
12use tower::{Layer, Service};
13
14use crate::HX_REQUEST;
15
16#[derive(Debug, Clone)]
22pub struct HxRequestGuardLayer<'a> {
23 redirect_to: &'a str,
24}
25
26impl<'a> HxRequestGuardLayer<'a> {
27 pub fn new(redirect_to: &'a str) -> Self {
28 Self { redirect_to }
29 }
30}
31
32impl Default for HxRequestGuardLayer<'_> {
33 fn default() -> Self {
34 Self { redirect_to: "/" }
35 }
36}
37
38impl<'a, S> Layer<S> for HxRequestGuardLayer<'a> {
39 type Service = HxRequestGuard<'a, S>;
40
41 fn layer(&self, inner: S) -> Self::Service {
42 HxRequestGuard {
43 inner,
44 hx_request: false,
45 layer: self.clone(),
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct HxRequestGuard<'a, S> {
53 inner: S,
54 hx_request: bool,
55 layer: HxRequestGuardLayer<'a>,
56}
57
58impl<'a, S, T, U> Service<Request<T>> for HxRequestGuard<'a, S>
59where
60 S: Service<Request<T>, Response = Response<U>>,
61 U: Default,
62{
63 type Response = S::Response;
64 type Error = S::Error;
65 type Future = private::ResponseFuture<'a, S::Future>;
66
67 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
68 self.inner.poll_ready(cx)
69 }
70
71 fn call(&mut self, req: Request<T>) -> Self::Future {
72 if req.headers().contains_key(HX_REQUEST) {
74 self.hx_request = true;
75 }
76
77 let response_future = self.inner.call(req);
78
79 private::ResponseFuture {
80 response_future,
81 hx_request: self.hx_request,
82 layer: self.layer.clone(),
83 }
84 }
85}
86
87mod private {
88 use super::*;
89
90 pin_project! {
91 pub struct ResponseFuture<'a, F> {
92 #[pin]
93 pub(super) response_future: F,
94 pub(super) hx_request: bool,
95 pub(super) layer: HxRequestGuardLayer<'a>,
96 }
97 }
98
99 impl<F, B, E> Future for ResponseFuture<'_, F>
100 where
101 F: Future<Output = Result<Response<B>, E>>,
102 B: Default,
103 {
104 type Output = Result<Response<B>, E>;
105
106 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
107 let this = self.project();
108 let response: Response<B> = ready!(this.response_future.poll(cx))?;
109
110 match *this.hx_request {
111 true => Poll::Ready(Ok(response)),
112 false => {
113 let res = Response::builder()
114 .status(StatusCode::SEE_OTHER)
115 .header(LOCATION, this.layer.redirect_to)
116 .body(B::default())
117 .expect("failed to build response");
118
119 Poll::Ready(Ok(res))
120 }
121 }
122 }
123 }
124}