1use crate::{headers::ResponseHeaders, Htmx, TriggerType};
2
3use actix_web::http::header::{HeaderName, HeaderValue};
4use actix_web::{
5 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
6 Error, HttpMessage,
7};
8use futures_util::future::LocalBoxFuture;
9use indexmap::IndexMap;
10use log::warn;
11use std::future::{ready, Ready};
12
13pub struct HtmxMiddleware;
73
74impl<S, B> Transform<S, ServiceRequest> for HtmxMiddleware
75where
76 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
77 S::Future: 'static,
78 B: 'static,
79{
80 type Response = ServiceResponse<B>;
81 type Error = Error;
82 type Transform = InnerHtmxMiddleware<S>;
83 type InitError = ();
84 type Future = Ready<Result<Self::Transform, Self::InitError>>;
85
86 fn new_transform(&self, service: S) -> Self::Future {
87 ready(Ok(InnerHtmxMiddleware { service }))
88 }
89}
90
91#[doc(hidden)]
92#[non_exhaustive]
93pub struct InnerHtmxMiddleware<S> {
94 service: S,
95}
96
97impl<S, B> Service<ServiceRequest> for InnerHtmxMiddleware<S>
98where
99 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
100 S::Future: 'static,
101 B: 'static,
102{
103 type Response = ServiceResponse<B>;
104 type Error = Error;
105 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
106
107 forward_ready!(service);
108
109 fn call(&self, req: ServiceRequest) -> Self::Future {
110 let htmx = Htmx::new(&req);
111
112 req.extensions_mut().insert(htmx);
113
114 let fut = self.service.call(req);
115
116 Box::pin(async move {
117 let res: ServiceResponse<B> = fut.await?;
118
119 let (req, mut res) = res.into_parts();
120
121 let trigger_json = |trigger_map: IndexMap<String, Option<String>>| -> String {
122 let mut triggers = String::new();
123 triggers.push('{');
124 trigger_map.iter().for_each(|(key, value)| {
125 if let Some(value) = value {
126 if value.trim().starts_with('{') {
127 triggers.push_str(&format!("\"{}\": {},", key, value));
128 } else {
129 triggers.push_str(&format!("\"{}\": \"{}\",", key, value));
130 }
131 } else {
132 triggers.push_str(&format!("\"{}\": null,", key));
133 }
134 });
135 triggers.pop();
136 triggers.push('}');
137 triggers
138 };
139
140 let simple_header = |trigger_map: IndexMap<String, Option<String>>| -> String {
141 let mut triggers = trigger_map
142 .iter()
143 .map(|(key, _)| key.to_string() + ",")
144 .collect::<String>();
145 triggers.pop();
146 triggers
147 };
148
149 let mut process_trigger_header =
150 |header_name: HeaderName,
151 trigger_map: IndexMap<String, Option<String>>,
152 simple: bool| {
153 if trigger_map.is_empty() {
154 return;
155 }
156
157 let triggers = if simple {
158 simple_header(trigger_map)
159 } else {
160 trigger_json(trigger_map)
161 };
162
163 if let Ok(value) = HeaderValue::from_str(&triggers) {
164 res.headers_mut().insert(header_name, value);
165 } else {
166 warn!("Failed to parse {} header value: {}", header_name, triggers)
167 }
168 };
169
170 if let Some(htmx_response) = req.extensions().get::<Htmx>() {
171 process_trigger_header(
172 HeaderName::from_static(ResponseHeaders::HX_TRIGGER),
173 htmx_response.get_triggers(TriggerType::Standard),
174 htmx_response.is_simple_trigger(TriggerType::Standard),
175 );
176 process_trigger_header(
177 HeaderName::from_static(ResponseHeaders::HX_TRIGGER_AFTER_SETTLE),
178 htmx_response.get_triggers(TriggerType::AfterSettle),
179 htmx_response.is_simple_trigger(TriggerType::AfterSettle),
180 );
181 process_trigger_header(
182 HeaderName::from_static(ResponseHeaders::HX_TRIGGER_AFTER_SWAP),
183 htmx_response.get_triggers(TriggerType::AfterSwap),
184 htmx_response.is_simple_trigger(TriggerType::AfterSwap),
185 );
186
187 let response_headers = htmx_response.get_response_headers();
188 response_headers
189 .iter()
190 .for_each(|(key, value)| match key.parse() {
191 Ok(key) => {
192 if let Ok(value) = HeaderValue::from_str(value) {
193 res.headers_mut().insert(key, value);
194 } else {
195 warn!("Failed to parse {} header value: {}", key, value)
196 }
197 }
198 _ => {
199 warn!("Failed to parse header name: {}", key)
200 }
201 });
202 }
203
204 Ok(ServiceResponse::new(req, res))
205 })
206 }
207}