1use crate::{headers::ResponseHeaders, Htmx, TriggerPayload, TriggerType};
2
3use actix_web::http::header::{HeaderName, HeaderValue};
4use actix_web::{
5 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
6 Error, HttpMessage,
7};
8use futures_util::future::LocalBoxFuture;
9use indexmap::IndexMap;
10use log::warn;
11use serde_json::{Map, Value};
12use std::future::{ready, Ready};
13
14pub struct HtmxMiddleware;
74
75impl<S, B> Transform<S, ServiceRequest> for HtmxMiddleware
76where
77 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
78 S::Future: 'static,
79 B: 'static,
80{
81 type Response = ServiceResponse<B>;
82 type Error = Error;
83 type Transform = InnerHtmxMiddleware<S>;
84 type InitError = ();
85 type Future = Ready<Result<Self::Transform, Self::InitError>>;
86
87 fn new_transform(&self, service: S) -> Self::Future {
88 ready(Ok(InnerHtmxMiddleware { service }))
89 }
90}
91
92#[doc(hidden)]
93#[non_exhaustive]
94pub struct InnerHtmxMiddleware<S> {
95 service: S,
96}
97
98impl<S, B> Service<ServiceRequest> for InnerHtmxMiddleware<S>
99where
100 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
101 S::Future: 'static,
102 B: 'static,
103{
104 type Response = ServiceResponse<B>;
105 type Error = Error;
106 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
107
108 forward_ready!(service);
109
110 fn call(&self, req: ServiceRequest) -> Self::Future {
111 let htmx = Htmx::new(&req);
112
113 req.extensions_mut().insert(htmx);
114
115 let fut = self.service.call(req);
116
117 Box::pin(async move {
118 let res: ServiceResponse<B> = fut.await?;
119
120 let (req, mut res) = res.into_parts();
121
122 let trigger_json =
123 |trigger_map: &IndexMap<String, Option<TriggerPayload>>| -> Option<String> {
124 let mut object = Map::new();
125 for (key, value) in trigger_map.iter() {
126 let json_value = match value {
127 Some(payload) => payload.as_json_value(),
128 None => Value::Null,
129 };
130 object.insert(key.clone(), json_value);
131 }
132 serde_json::to_string(&Value::Object(object)).ok()
133 };
134
135 let simple_header = |trigger_map: &IndexMap<String, Option<TriggerPayload>>| -> String {
136 trigger_map.keys().cloned().collect::<Vec<_>>().join(",")
137 };
138
139 let mut process_trigger_header =
140 |header_name: HeaderName,
141 trigger_map: IndexMap<String, Option<TriggerPayload>>,
142 simple: bool| {
143 if trigger_map.is_empty() {
144 return;
145 }
146
147 let triggers = if simple {
148 simple_header(&trigger_map)
149 } else if let Some(json) = trigger_json(&trigger_map) {
150 json
151 } else {
152 warn!("Failed to serialize HX-Trigger header");
153 return;
154 };
155
156 if let Ok(value) = HeaderValue::from_str(&triggers) {
157 res.headers_mut().insert(header_name, value);
158 } else {
159 warn!("Failed to parse {} header value: {}", header_name, triggers)
160 }
161 };
162
163 if let Some(htmx_response) = req.extensions().get::<Htmx>() {
164 process_trigger_header(
165 HeaderName::from_static(ResponseHeaders::HX_TRIGGER),
166 htmx_response.get_triggers(TriggerType::Standard),
167 htmx_response.is_simple_trigger(TriggerType::Standard),
168 );
169 process_trigger_header(
170 HeaderName::from_static(ResponseHeaders::HX_TRIGGER_AFTER_SETTLE),
171 htmx_response.get_triggers(TriggerType::AfterSettle),
172 htmx_response.is_simple_trigger(TriggerType::AfterSettle),
173 );
174 process_trigger_header(
175 HeaderName::from_static(ResponseHeaders::HX_TRIGGER_AFTER_SWAP),
176 htmx_response.get_triggers(TriggerType::AfterSwap),
177 htmx_response.is_simple_trigger(TriggerType::AfterSwap),
178 );
179
180 let response_headers = htmx_response.get_response_headers();
181 response_headers
182 .iter()
183 .for_each(|(key, value)| match key.parse() {
184 Ok(key) => {
185 if let Ok(value) = HeaderValue::from_str(value) {
186 res.headers_mut().insert(key, value);
187 } else {
188 warn!("Failed to parse {} header value: {}", key, value)
189 }
190 }
191 _ => {
192 warn!("Failed to parse header name: {}", key)
193 }
194 });
195 }
196
197 Ok(ServiceResponse::new(req, res))
198 })
199 }
200}