1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use crate::{headers::ResponseHeaders, HtmxDetails, TriggerType};

use actix_web::http::header::{HeaderName, HeaderValue};
use actix_web::{
    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
    Error, HttpMessage,
};
use futures_util::future::LocalBoxFuture;
use indexmap::IndexMap;
use log::warn;
use std::future::{ready, Ready};

pub struct HtmxMiddleware;

impl<S, B> Transform<S, ServiceRequest> for HtmxMiddleware
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Transform = InnerHtmxMiddleware<S>;
    type InitError = ();
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(InnerHtmxMiddleware { service }))
    }
}

#[doc(hidden)]
#[non_exhaustive]
pub struct InnerHtmxMiddleware<S> {
    service: S,
}

impl<S, B> Service<ServiceRequest> for InnerHtmxMiddleware<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    forward_ready!(service);

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let htmx_details = HtmxDetails::new(&req);

        req.extensions_mut().insert(htmx_details);

        let fut = self.service.call(req);

        Box::pin(async move {
            let res: ServiceResponse<B> = fut.await?;

            let (req, mut res) = res.into_parts();

            let trigger_json = |trigger_map: IndexMap<String, String>| -> String {
                let mut triggers = String::new();
                triggers.push('{');
                trigger_map.iter().for_each(|(key, value)| {
                    if value.trim().starts_with('{') {
                        triggers.push_str(&format!("\"{}\": {},", key, value));
                    } else {
                        triggers.push_str(&format!("\"{}\": \"{}\",", key, value));
                    }
                });
                triggers.pop();
                triggers.push('}');
                triggers
            };

            let mut process_trigger_header =
                |header_name: HeaderName, trigger_map: IndexMap<String, String>| {
                    if trigger_map.is_empty() {
                        return;
                    }
                    let triggers = trigger_json(trigger_map);
                    if let Ok(value) = HeaderValue::from_str(&triggers) {
                        res.headers_mut().insert(header_name, value);
                    } else {
                        warn!("Failed to parse {} header value: {}", header_name, triggers)
                    }
                };

            if let Some(htmx_response_details) = req.extensions().get::<HtmxDetails>() {
                process_trigger_header(
                    HeaderName::from_static(ResponseHeaders::HX_TRIGGER),
                    htmx_response_details.get_triggers(TriggerType::Standard),
                );
                process_trigger_header(
                    HeaderName::from_static(ResponseHeaders::HX_TRIGGER_AFTER_SETTLE),
                    htmx_response_details.get_triggers(TriggerType::AfterSettle),
                );
                process_trigger_header(
                    HeaderName::from_static(ResponseHeaders::HX_TRIGGER_AFTER_SWAP),
                    htmx_response_details.get_triggers(TriggerType::AfterSwap),
                );

                let response_headers = htmx_response_details.get_response_headers();
                response_headers
                    .iter()
                    .for_each(|(key, value)| match key.parse() {
                        Ok(key) => {
                            if let Ok(value) = HeaderValue::from_str(value) {
                                res.headers_mut().insert(key, value);
                            } else {
                                warn!("Failed to parse {} header value: {}", key, value)
                            }
                        }
                        _ => {
                            warn!("Failed to parse header name: {}", key)
                        }
                    });
            }

            Ok(ServiceResponse::new(req, res))
        })
    }
}