actix_htmx/
middleware.rs

1use crate::{headers::ResponseHeaders, Htmx, TriggerType};
2
3use actix_web::http::header::{HeaderName, HeaderValue};
4use actix_web::{
5    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
6    Error, HttpMessage,
7};
8use futures_util::future::LocalBoxFuture;
9use indexmap::IndexMap;
10use log::warn;
11use std::future::{ready, Ready};
12
13/// A middleware for Actix Web that handles htmx specific headers and triggers.
14///
15/// This module provides middleware functionality for using htmx in your Actix Web
16/// application. It processes htmx headers and manages various types of triggers that
17/// can be used for client-side interactions.
18///
19/// [`HtmxMiddleware`] injects an Htmx struct into any route that it wraps. This
20/// Htmx struct provides helper properties and methods that allow for your application
21/// to easily interact with htmx.
22///
23/// # Example
24///
25/// ```no_run
26/// use actix_web::{web, App, HttpServer, Responder, HttpResponse};
27/// use actix_htmx::{Htmx, HtmxMiddleware};
28///
29/// #[actix_web::main]
30/// async fn main() -> std::io::Result<()> {
31///     HttpServer::new(|| {
32///         App::new()
33///            .wrap(HtmxMiddleware)
34///             .route("/", web::get().to(index))
35///     })
36///     .bind("127.0.0.1:8080")?
37///     .run()
38///     .await
39/// }
40///
41/// async fn index(htmx: Htmx) -> impl Responder {
42///     if !htmx.is_htmx {
43///         HttpResponse::Ok().body(r##"
44///             <!DOCTYPE html>
45///             <html>
46///                 <head>
47///                     <title>htmx example</title>
48///                     <script src="https://unpkg.com/htmx.org@2.0.6"></script>
49///                 </head>
50///                 <body>
51///                     <div id="content">
52///                         This was not an htmx request! <a href="/" hx-get="/" hx-target="#content">Make it htmx!</a>
53///                     </div>
54///                 </body>
55///             </html>
56///         "##)
57///     } else {
58///         HttpResponse::Ok().body(r##"
59///         <div id="content">
60///             This was an htmx request! <a href="/">Let's go back to plain old HTML</a>
61///         <div>
62///         "##)
63///     }
64/// }
65/// ```
66///
67/// The middleware automatically processes the following htmx headers:
68/// - `HX-Trigger`: For standard htmx triggers
69/// - `HX-Trigger-After-Settle`: For triggers that fire after the settling phase
70/// - `HX-Trigger-After-Swap`: For triggers that fire after content swap
71///
72pub struct HtmxMiddleware;
73
74impl<S, B> Transform<S, ServiceRequest> for HtmxMiddleware
75where
76    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
77    S::Future: 'static,
78    B: 'static,
79{
80    type Response = ServiceResponse<B>;
81    type Error = Error;
82    type Transform = InnerHtmxMiddleware<S>;
83    type InitError = ();
84    type Future = Ready<Result<Self::Transform, Self::InitError>>;
85
86    fn new_transform(&self, service: S) -> Self::Future {
87        ready(Ok(InnerHtmxMiddleware { service }))
88    }
89}
90
91#[doc(hidden)]
92#[non_exhaustive]
93pub struct InnerHtmxMiddleware<S> {
94    service: S,
95}
96
97impl<S, B> Service<ServiceRequest> for InnerHtmxMiddleware<S>
98where
99    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
100    S::Future: 'static,
101    B: 'static,
102{
103    type Response = ServiceResponse<B>;
104    type Error = Error;
105    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
106
107    forward_ready!(service);
108
109    fn call(&self, req: ServiceRequest) -> Self::Future {
110        let htmx = Htmx::new(&req);
111
112        req.extensions_mut().insert(htmx);
113
114        let fut = self.service.call(req);
115
116        Box::pin(async move {
117            let res: ServiceResponse<B> = fut.await?;
118
119            let (req, mut res) = res.into_parts();
120
121            let trigger_json = |trigger_map: IndexMap<String, Option<String>>| -> String {
122                let mut triggers = String::new();
123                triggers.push('{');
124                trigger_map.iter().for_each(|(key, value)| {
125                    if let Some(value) = value {
126                        if value.trim().starts_with('{') {
127                            triggers.push_str(&format!("\"{}\": {},", key, value));
128                        } else {
129                            triggers.push_str(&format!("\"{}\": \"{}\",", key, value));
130                        }
131                    } else {
132                        triggers.push_str(&format!("\"{}\": null,", key));
133                    }
134                });
135                triggers.pop();
136                triggers.push('}');
137                triggers
138            };
139
140            let simple_header = |trigger_map: IndexMap<String, Option<String>>| -> String {
141                let mut triggers = trigger_map
142                    .iter()
143                    .map(|(key, _)| key.to_string() + ",")
144                    .collect::<String>();
145                triggers.pop();
146                triggers
147            };
148
149            let mut process_trigger_header =
150                |header_name: HeaderName,
151                 trigger_map: IndexMap<String, Option<String>>,
152                 simple: bool| {
153                    if trigger_map.is_empty() {
154                        return;
155                    }
156
157                    let triggers = if simple {
158                        simple_header(trigger_map)
159                    } else {
160                        trigger_json(trigger_map)
161                    };
162
163                    if let Ok(value) = HeaderValue::from_str(&triggers) {
164                        res.headers_mut().insert(header_name, value);
165                    } else {
166                        warn!("Failed to parse {} header value: {}", header_name, triggers)
167                    }
168                };
169
170            if let Some(htmx_response) = req.extensions().get::<Htmx>() {
171                process_trigger_header(
172                    HeaderName::from_static(ResponseHeaders::HX_TRIGGER),
173                    htmx_response.get_triggers(TriggerType::Standard),
174                    htmx_response.is_simple_trigger(TriggerType::Standard),
175                );
176                process_trigger_header(
177                    HeaderName::from_static(ResponseHeaders::HX_TRIGGER_AFTER_SETTLE),
178                    htmx_response.get_triggers(TriggerType::AfterSettle),
179                    htmx_response.is_simple_trigger(TriggerType::AfterSettle),
180                );
181                process_trigger_header(
182                    HeaderName::from_static(ResponseHeaders::HX_TRIGGER_AFTER_SWAP),
183                    htmx_response.get_triggers(TriggerType::AfterSwap),
184                    htmx_response.is_simple_trigger(TriggerType::AfterSwap),
185                );
186
187                let response_headers = htmx_response.get_response_headers();
188                response_headers
189                    .iter()
190                    .for_each(|(key, value)| match key.parse() {
191                        Ok(key) => {
192                            if let Ok(value) = HeaderValue::from_str(value) {
193                                res.headers_mut().insert(key, value);
194                            } else {
195                                warn!("Failed to parse {} header value: {}", key, value)
196                            }
197                        }
198                        _ => {
199                            warn!("Failed to parse header name: {}", key)
200                        }
201                    });
202            }
203
204            Ok(ServiceResponse::new(req, res))
205        })
206    }
207}