actix_htmx/
middleware.rs

1use crate::{headers::ResponseHeaders, Htmx, TriggerPayload, TriggerType};
2
3use actix_web::http::header::{HeaderName, HeaderValue};
4use actix_web::{
5    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
6    Error, HttpMessage,
7};
8use futures_util::future::LocalBoxFuture;
9use indexmap::IndexMap;
10use log::warn;
11use serde_json::{Map, Value};
12use std::future::{ready, Ready};
13
14/// A middleware for Actix Web that handles htmx specific headers and triggers.
15///
16/// This module provides middleware functionality for using htmx in your Actix Web
17/// application. It processes htmx headers and manages various types of triggers that
18/// can be used for client-side interactions.
19///
20/// [`HtmxMiddleware`] injects an Htmx struct into any route that it wraps. This
21/// Htmx struct provides helper properties and methods that allow for your application
22/// to easily interact with htmx.
23///
24/// # Example
25///
26/// ```no_run
27/// use actix_web::{web, App, HttpServer, Responder, HttpResponse};
28/// use actix_htmx::{Htmx, HtmxMiddleware};
29///
30/// #[actix_web::main]
31/// async fn main() -> std::io::Result<()> {
32///     HttpServer::new(|| {
33///         App::new()
34///            .wrap(HtmxMiddleware)
35///             .route("/", web::get().to(index))
36///     })
37///     .bind("127.0.0.1:8080")?
38///     .run()
39///     .await
40/// }
41///
42/// async fn index(htmx: Htmx) -> impl Responder {
43///     if !htmx.is_htmx {
44///         HttpResponse::Ok().body(r##"
45///             <!DOCTYPE html>
46///             <html>
47///                 <head>
48///                     <title>htmx example</title>
49///                     <script src="https://unpkg.com/htmx.org@2.0.6"></script>
50///                 </head>
51///                 <body>
52///                     <div id="content">
53///                         This was not an htmx request! <a href="/" hx-get="/" hx-target="#content">Make it htmx!</a>
54///                     </div>
55///                 </body>
56///             </html>
57///         "##)
58///     } else {
59///         HttpResponse::Ok().body(r##"
60///         <div id="content">
61///             This was an htmx request! <a href="/">Let's go back to plain old HTML</a>
62///         <div>
63///         "##)
64///     }
65/// }
66/// ```
67///
68/// The middleware automatically processes the following htmx headers:
69/// - `HX-Trigger`: For standard htmx triggers
70/// - `HX-Trigger-After-Settle`: For triggers that fire after the settling phase
71/// - `HX-Trigger-After-Swap`: For triggers that fire after content swap
72///
73pub struct HtmxMiddleware;
74
75impl<S, B> Transform<S, ServiceRequest> for HtmxMiddleware
76where
77    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
78    S::Future: 'static,
79    B: 'static,
80{
81    type Response = ServiceResponse<B>;
82    type Error = Error;
83    type Transform = InnerHtmxMiddleware<S>;
84    type InitError = ();
85    type Future = Ready<Result<Self::Transform, Self::InitError>>;
86
87    fn new_transform(&self, service: S) -> Self::Future {
88        ready(Ok(InnerHtmxMiddleware { service }))
89    }
90}
91
92#[doc(hidden)]
93#[non_exhaustive]
94pub struct InnerHtmxMiddleware<S> {
95    service: S,
96}
97
98impl<S, B> Service<ServiceRequest> for InnerHtmxMiddleware<S>
99where
100    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
101    S::Future: 'static,
102    B: 'static,
103{
104    type Response = ServiceResponse<B>;
105    type Error = Error;
106    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
107
108    forward_ready!(service);
109
110    fn call(&self, req: ServiceRequest) -> Self::Future {
111        let htmx = Htmx::new(&req);
112
113        req.extensions_mut().insert(htmx);
114
115        let fut = self.service.call(req);
116
117        Box::pin(async move {
118            let res: ServiceResponse<B> = fut.await?;
119
120            let (req, mut res) = res.into_parts();
121
122            let trigger_json =
123                |trigger_map: &IndexMap<String, Option<TriggerPayload>>| -> Option<String> {
124                    let mut object = Map::new();
125                    for (key, value) in trigger_map.iter() {
126                        let json_value = match value {
127                            Some(payload) => payload.as_json_value(),
128                            None => Value::Null,
129                        };
130                        object.insert(key.clone(), json_value);
131                    }
132                    serde_json::to_string(&Value::Object(object)).ok()
133                };
134
135            let simple_header = |trigger_map: &IndexMap<String, Option<TriggerPayload>>| -> String {
136                trigger_map.keys().cloned().collect::<Vec<_>>().join(",")
137            };
138
139            let mut process_trigger_header =
140                |header_name: HeaderName,
141                 trigger_map: IndexMap<String, Option<TriggerPayload>>,
142                 simple: bool| {
143                    if trigger_map.is_empty() {
144                        return;
145                    }
146
147                    let triggers = if simple {
148                        simple_header(&trigger_map)
149                    } else if let Some(json) = trigger_json(&trigger_map) {
150                        json
151                    } else {
152                        warn!("Failed to serialize HX-Trigger header");
153                        return;
154                    };
155
156                    if let Ok(value) = HeaderValue::from_str(&triggers) {
157                        res.headers_mut().insert(header_name, value);
158                    } else {
159                        warn!("Failed to parse {} header value: {}", header_name, triggers)
160                    }
161                };
162
163            if let Some(htmx_response) = req.extensions().get::<Htmx>() {
164                process_trigger_header(
165                    HeaderName::from_static(ResponseHeaders::HX_TRIGGER),
166                    htmx_response.get_triggers(TriggerType::Standard),
167                    htmx_response.is_simple_trigger(TriggerType::Standard),
168                );
169                process_trigger_header(
170                    HeaderName::from_static(ResponseHeaders::HX_TRIGGER_AFTER_SETTLE),
171                    htmx_response.get_triggers(TriggerType::AfterSettle),
172                    htmx_response.is_simple_trigger(TriggerType::AfterSettle),
173                );
174                process_trigger_header(
175                    HeaderName::from_static(ResponseHeaders::HX_TRIGGER_AFTER_SWAP),
176                    htmx_response.get_triggers(TriggerType::AfterSwap),
177                    htmx_response.is_simple_trigger(TriggerType::AfterSwap),
178                );
179
180                let response_headers = htmx_response.get_response_headers();
181                response_headers
182                    .iter()
183                    .for_each(|(key, value)| match key.parse() {
184                        Ok(key) => {
185                            if let Ok(value) = HeaderValue::from_str(value) {
186                                res.headers_mut().insert(key, value);
187                            } else {
188                                warn!("Failed to parse {} header value: {}", key, value)
189                            }
190                        }
191                        _ => {
192                            warn!("Failed to parse header name: {}", key)
193                        }
194                    });
195            }
196
197            Ok(ServiceResponse::new(req, res))
198        })
199    }
200}