axum_htmx/
auto_vary.rs

1//! A middleware to automatically add a `Vary` header when needed to address
2//! [htmx caching issue](https://htmx.org/docs/#caching)
3
4use std::{
5    sync::Arc,
6    task::{Context, Poll},
7};
8
9use axum_core::{
10    extract::Request,
11    response::{IntoResponse, Response},
12};
13use futures::future::{BoxFuture, join_all};
14use http::{
15    Extensions,
16    header::{HeaderValue, VARY},
17};
18use tokio::sync::oneshot::{self, Receiver, Sender};
19use tower::{Layer, Service};
20
21use crate::{
22    HxError,
23    headers::{HX_REQUEST_STR, HX_TARGET_STR, HX_TRIGGER_NAME_STR, HX_TRIGGER_STR},
24};
25#[cfg(doc)]
26use crate::{HxRequest, HxTarget, HxTrigger, HxTriggerName};
27
28const MIDDLEWARE_DOUBLE_USE: &str =
29    "Configuration error: `axum_httpx::vary_middleware` is used twice";
30
31/// Addresses [htmx caching issues](https://htmx.org/docs/#caching)
32/// by automatically adding a corresponding `Vary` header when
33/// [`HxRequest`], [`HxTarget`], [`HxTrigger`], [`HxTriggerName`]
34/// or their combination is used.
35#[derive(Clone)]
36pub struct AutoVaryLayer;
37
38/// Tower service for [`AutoVaryLayer`]
39#[derive(Clone)]
40pub struct AutoVaryMiddleware<S> {
41    inner: S,
42}
43
44pub(crate) trait Notifier {
45    fn sender(&mut self) -> Option<Sender<()>>;
46
47    fn notify(&mut self) {
48        if let Some(sender) = self.sender() {
49            sender.send(()).ok();
50        }
51    }
52
53    fn insert(extensions: &mut Extensions) -> Receiver<()>;
54}
55
56macro_rules! define_notifiers {
57    ($($name:ident),*) => {
58        $(
59            #[derive(Clone)]
60            pub(crate) struct $name(Option<Arc<Sender<()>>>);
61
62            impl Notifier for $name {
63                fn sender(&mut self) -> Option<Sender<()>> {
64                    self.0.take().and_then(Arc::into_inner)
65                }
66
67                fn insert(extensions: &mut Extensions) -> Receiver<()> {
68                    let (tx, rx) = oneshot::channel();
69                    if extensions.insert(Self(Some(Arc::new(tx)))).is_some() {
70                        panic!("{}", MIDDLEWARE_DOUBLE_USE);
71                    }
72                    rx
73                }
74            }
75        )*
76    }
77}
78
79define_notifiers!(
80    HxRequestExtracted,
81    HxTargetExtracted,
82    HxTriggerExtracted,
83    HxTriggerNameExtracted
84);
85
86impl<S> Layer<S> for AutoVaryLayer {
87    type Service = AutoVaryMiddleware<S>;
88
89    fn layer(&self, inner: S) -> Self::Service {
90        AutoVaryMiddleware { inner }
91    }
92}
93
94impl<S> Service<Request> for AutoVaryMiddleware<S>
95where
96    S: Service<Request, Response = Response> + Send + 'static,
97    S::Future: Send + 'static,
98{
99    type Response = S::Response;
100    type Error = S::Error;
101    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
102
103    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
104        self.inner.poll_ready(cx)
105    }
106
107    fn call(&mut self, mut request: Request) -> Self::Future {
108        let exts = request.extensions_mut();
109        let rx_header = [
110            (HxRequestExtracted::insert(exts), HX_REQUEST_STR),
111            (HxTargetExtracted::insert(exts), HX_TARGET_STR),
112            (HxTriggerExtracted::insert(exts), HX_TRIGGER_STR),
113            (HxTriggerNameExtracted::insert(exts), HX_TRIGGER_NAME_STR),
114        ];
115        let future = self.inner.call(request);
116        Box::pin(async move {
117            let mut response: Response = future.await?;
118            let used_headers: Vec<_> = join_all(
119                rx_header
120                    .into_iter()
121                    .map(|(rx, header)| async move { rx.await.ok().map(|_| header) }),
122            )
123            .await
124            .into_iter()
125            .flatten()
126            .collect();
127
128            if used_headers.is_empty() {
129                return Ok(response);
130            }
131
132            let value = match HeaderValue::from_str(&used_headers.join(", ")) {
133                Ok(x) => x,
134                Err(e) => return Ok(HxError::from(e).into_response()),
135            };
136
137            if let Err(e) = response.headers_mut().try_append(VARY, value) {
138                return Ok(HxError::from(e).into_response());
139            }
140
141            Ok(response)
142        })
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use axum::{Router, routing::get};
149
150    use super::*;
151    use crate::{HxRequest, HxTarget, HxTrigger, HxTriggerName};
152
153    fn vary_headers(resp: &axum_test::TestResponse) -> Vec<HeaderValue> {
154        resp.iter_headers_by_name("vary").cloned().collect()
155    }
156
157    fn server() -> axum_test::TestServer {
158        let app = Router::new()
159            .route("/no-extractors", get(|| async { () }))
160            .route("/hx-request", get(|_: HxRequest| async { () }))
161            .route("/hx-target", get(|_: HxTarget| async { () }))
162            .route("/hx-trigger", get(|_: HxTrigger| async { () }))
163            .route("/hx-trigger-name", get(|_: HxTriggerName| async { () }))
164            .route(
165                "/repeated-extractor",
166                get(|_: HxRequest, _: HxRequest| async { () }),
167            )
168            .route(
169                "/multiple-extractors",
170                get(|_: HxRequest, _: HxTarget, _: HxTrigger, _: HxTriggerName| async { () }),
171            )
172            .layer(AutoVaryLayer);
173        axum_test::TestServer::new(app).unwrap()
174    }
175
176    #[tokio::test]
177    async fn no_extractors() {
178        assert!(vary_headers(&server().get("/no-extractors").await).is_empty());
179    }
180
181    #[tokio::test]
182    async fn single_hx_request() {
183        assert_eq!(
184            vary_headers(&server().get("/hx-request").await),
185            ["hx-request"]
186        );
187    }
188
189    #[tokio::test]
190    async fn single_hx_target() {
191        assert_eq!(
192            vary_headers(&server().get("/hx-target").await),
193            ["hx-target"]
194        );
195    }
196
197    #[tokio::test]
198    async fn single_hx_trigger() {
199        assert_eq!(
200            vary_headers(&server().get("/hx-trigger").await),
201            ["hx-trigger"]
202        );
203    }
204
205    #[tokio::test]
206    async fn single_hx_trigger_name() {
207        assert_eq!(
208            vary_headers(&server().get("/hx-trigger-name").await),
209            ["hx-trigger-name"]
210        );
211    }
212
213    #[tokio::test]
214    async fn repeated_extractor() {
215        assert_eq!(
216            vary_headers(&server().get("/repeated-extractor").await),
217            ["hx-request"]
218        );
219    }
220
221    // Extractors can be used multiple times e.g. in middlewares
222    #[tokio::test]
223    async fn multiple_extractors() {
224        assert_eq!(
225            vary_headers(&server().get("/multiple-extractors").await),
226            ["hx-request, hx-target, hx-trigger, hx-trigger-name"],
227        );
228    }
229}