1use std::{
5 sync::Arc,
6 task::{Context, Poll},
7};
8
9use axum_core::{
10 extract::Request,
11 response::{IntoResponse, Response},
12};
13use futures::future::{BoxFuture, join_all};
14use http::{
15 Extensions,
16 header::{HeaderValue, VARY},
17};
18use tokio::sync::oneshot::{self, Receiver, Sender};
19use tower::{Layer, Service};
20
21use crate::{
22 HxError,
23 headers::{HX_REQUEST_STR, HX_TARGET_STR, HX_TRIGGER_NAME_STR, HX_TRIGGER_STR},
24};
25#[cfg(doc)]
26use crate::{HxRequest, HxTarget, HxTrigger, HxTriggerName};
27
28const MIDDLEWARE_DOUBLE_USE: &str =
29 "Configuration error: `axum_httpx::vary_middleware` is used twice";
30
31#[derive(Clone)]
36pub struct AutoVaryLayer;
37
38#[derive(Clone)]
40pub struct AutoVaryMiddleware<S> {
41 inner: S,
42}
43
44pub(crate) trait Notifier {
45 fn sender(&mut self) -> Option<Sender<()>>;
46
47 fn notify(&mut self) {
48 if let Some(sender) = self.sender() {
49 sender.send(()).ok();
50 }
51 }
52
53 fn insert(extensions: &mut Extensions) -> Receiver<()>;
54}
55
56macro_rules! define_notifiers {
57 ($($name:ident),*) => {
58 $(
59 #[derive(Clone)]
60 pub(crate) struct $name(Option<Arc<Sender<()>>>);
61
62 impl Notifier for $name {
63 fn sender(&mut self) -> Option<Sender<()>> {
64 self.0.take().and_then(Arc::into_inner)
65 }
66
67 fn insert(extensions: &mut Extensions) -> Receiver<()> {
68 let (tx, rx) = oneshot::channel();
69 if extensions.insert(Self(Some(Arc::new(tx)))).is_some() {
70 panic!("{}", MIDDLEWARE_DOUBLE_USE);
71 }
72 rx
73 }
74 }
75 )*
76 }
77}
78
79define_notifiers!(
80 HxRequestExtracted,
81 HxTargetExtracted,
82 HxTriggerExtracted,
83 HxTriggerNameExtracted
84);
85
86impl<S> Layer<S> for AutoVaryLayer {
87 type Service = AutoVaryMiddleware<S>;
88
89 fn layer(&self, inner: S) -> Self::Service {
90 AutoVaryMiddleware { inner }
91 }
92}
93
94impl<S> Service<Request> for AutoVaryMiddleware<S>
95where
96 S: Service<Request, Response = Response> + Send + 'static,
97 S::Future: Send + 'static,
98{
99 type Response = S::Response;
100 type Error = S::Error;
101 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
102
103 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
104 self.inner.poll_ready(cx)
105 }
106
107 fn call(&mut self, mut request: Request) -> Self::Future {
108 let exts = request.extensions_mut();
109 let rx_header = [
110 (HxRequestExtracted::insert(exts), HX_REQUEST_STR),
111 (HxTargetExtracted::insert(exts), HX_TARGET_STR),
112 (HxTriggerExtracted::insert(exts), HX_TRIGGER_STR),
113 (HxTriggerNameExtracted::insert(exts), HX_TRIGGER_NAME_STR),
114 ];
115 let future = self.inner.call(request);
116 Box::pin(async move {
117 let mut response: Response = future.await?;
118 let used_headers: Vec<_> = join_all(
119 rx_header
120 .into_iter()
121 .map(|(rx, header)| async move { rx.await.ok().map(|_| header) }),
122 )
123 .await
124 .into_iter()
125 .flatten()
126 .collect();
127
128 if used_headers.is_empty() {
129 return Ok(response);
130 }
131
132 let value = match HeaderValue::from_str(&used_headers.join(", ")) {
133 Ok(x) => x,
134 Err(e) => return Ok(HxError::from(e).into_response()),
135 };
136
137 if let Err(e) = response.headers_mut().try_append(VARY, value) {
138 return Ok(HxError::from(e).into_response());
139 }
140
141 Ok(response)
142 })
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use axum::{Router, routing::get};
149
150 use super::*;
151 use crate::{HxRequest, HxTarget, HxTrigger, HxTriggerName};
152
153 fn vary_headers(resp: &axum_test::TestResponse) -> Vec<HeaderValue> {
154 resp.iter_headers_by_name("vary").cloned().collect()
155 }
156
157 fn server() -> axum_test::TestServer {
158 let app = Router::new()
159 .route("/no-extractors", get(|| async { () }))
160 .route("/hx-request", get(|_: HxRequest| async { () }))
161 .route("/hx-target", get(|_: HxTarget| async { () }))
162 .route("/hx-trigger", get(|_: HxTrigger| async { () }))
163 .route("/hx-trigger-name", get(|_: HxTriggerName| async { () }))
164 .route(
165 "/repeated-extractor",
166 get(|_: HxRequest, _: HxRequest| async { () }),
167 )
168 .route(
169 "/multiple-extractors",
170 get(|_: HxRequest, _: HxTarget, _: HxTrigger, _: HxTriggerName| async { () }),
171 )
172 .layer(AutoVaryLayer);
173 axum_test::TestServer::new(app).unwrap()
174 }
175
176 #[tokio::test]
177 async fn no_extractors() {
178 assert!(vary_headers(&server().get("/no-extractors").await).is_empty());
179 }
180
181 #[tokio::test]
182 async fn single_hx_request() {
183 assert_eq!(
184 vary_headers(&server().get("/hx-request").await),
185 ["hx-request"]
186 );
187 }
188
189 #[tokio::test]
190 async fn single_hx_target() {
191 assert_eq!(
192 vary_headers(&server().get("/hx-target").await),
193 ["hx-target"]
194 );
195 }
196
197 #[tokio::test]
198 async fn single_hx_trigger() {
199 assert_eq!(
200 vary_headers(&server().get("/hx-trigger").await),
201 ["hx-trigger"]
202 );
203 }
204
205 #[tokio::test]
206 async fn single_hx_trigger_name() {
207 assert_eq!(
208 vary_headers(&server().get("/hx-trigger-name").await),
209 ["hx-trigger-name"]
210 );
211 }
212
213 #[tokio::test]
214 async fn repeated_extractor() {
215 assert_eq!(
216 vary_headers(&server().get("/repeated-extractor").await),
217 ["hx-request"]
218 );
219 }
220
221 #[tokio::test]
223 async fn multiple_extractors() {
224 assert_eq!(
225 vary_headers(&server().get("/multiple-extractors").await),
226 ["hx-request, hx-target, hx-trigger, hx-trigger-name"],
227 );
228 }
229}