Skip to main content

modo/template/
middleware.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use http::{Request, Response};
8use tower::{Layer, Service};
9
10use super::context::TemplateContext;
11use super::engine::Engine;
12use super::locale;
13use crate::flash::state::FlashState;
14
15// --- Layer ---
16
17/// Tower middleware layer that populates [`TemplateContext`] for every request.
18///
19/// Install this layer on your router **before** any handler that uses
20/// [`Renderer`](super::Renderer). The layer injects the following keys into
21/// the request's [`TemplateContext`]:
22///
23/// | Key               | Source                                                              |
24/// |-------------------|---------------------------------------------------------------------|
25/// | `current_url`     | `request.uri().to_string()`                                         |
26/// | `is_htmx`         | `HX-Request: true` header                                           |
27/// | `request_id`      | `X-Request-Id` header (if present)                                  |
28/// | `locale`          | Locale resolver chain (falls back to [`TemplateConfig::default_locale`](super::TemplateConfig::default_locale)) |
29/// | `csrf_token`      | [`CsrfToken`](crate::middleware::CsrfToken) extension (if present)  |
30/// | `flash_messages`  | Callable returning flash entries; `FlashState` extension must be set by [`FlashLayer`](crate::flash::FlashLayer) |
31/// | `tier_name`       | `TierInfo::name` (when `TierInfo` extension is present)             |
32/// | `tier_has`        | Template function `tier_has(name) -> bool` (when `TierInfo` is present) |
33/// | `tier_enabled`    | Template function `tier_enabled(name) -> bool` (when `TierInfo` is present) |
34/// | `tier_limit`      | Template function `tier_limit(name) -> Option<u64>` (when `TierInfo` is present) |
35///
36/// This layer is also re-exported as
37/// [`modo::middlewares::TemplateContext`](crate::middlewares::TemplateContext)
38/// for convenience at wiring sites.
39///
40/// # Example
41///
42/// ```rust,no_run
43/// use modo::template::{Engine, TemplateContextLayer};
44///
45/// # fn example(engine: Engine) {
46/// let router: axum::Router = axum::Router::new()
47///     // ... routes ...
48///     .layer(TemplateContextLayer::new(engine));
49/// # }
50/// ```
51#[derive(Clone)]
52pub struct TemplateContextLayer {
53    engine: Engine,
54}
55
56impl TemplateContextLayer {
57    /// Creates a new layer backed by the given [`Engine`].
58    pub fn new(engine: Engine) -> Self {
59        Self { engine }
60    }
61}
62
63impl<S> Layer<S> for TemplateContextLayer {
64    type Service = TemplateContextMiddleware<S>;
65
66    fn layer(&self, inner: S) -> Self::Service {
67        TemplateContextMiddleware {
68            inner,
69            engine: self.engine.clone(),
70        }
71    }
72}
73
74// --- Service ---
75
76/// Tower [`Service`] produced by [`TemplateContextLayer`].
77///
78/// Populates a [`TemplateContext`] with per-request data and inserts it into
79/// request extensions before delegating to the inner service.
80#[derive(Clone)]
81pub struct TemplateContextMiddleware<S> {
82    inner: S,
83    engine: Engine,
84}
85
86impl<S, ReqBody> Service<Request<ReqBody>> for TemplateContextMiddleware<S>
87where
88    S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
89    S::Future: Send + 'static,
90    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
91    ReqBody: Send + 'static,
92{
93    type Response = Response<Body>;
94    type Error = S::Error;
95    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
96
97    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98        self.inner.poll_ready(cx)
99    }
100
101    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
102        let engine = self.engine.clone();
103        let mut inner = self.inner.clone();
104        std::mem::swap(&mut self.inner, &mut inner);
105
106        Box::pin(async move {
107            // Build TemplateContext with request-scoped data
108            let mut ctx = TemplateContext::default();
109
110            // current_url
111            ctx.set(
112                "current_url",
113                minijinja::Value::from(request.uri().to_string()),
114            );
115
116            // is_htmx
117            let is_htmx = request
118                .headers()
119                .get("hx-request")
120                .and_then(|v| v.to_str().ok())
121                .is_some_and(|v| v == "true");
122            ctx.set("is_htmx", minijinja::Value::from(is_htmx));
123
124            // request_id (if present)
125            if let Some(req_id) = request
126                .headers()
127                .get("x-request-id")
128                .and_then(|v| v.to_str().ok())
129            {
130                ctx.set("request_id", minijinja::Value::from(req_id.to_string()));
131            }
132
133            // locale resolution
134            {
135                // We need to extract Parts temporarily for locale resolution
136                // Since we can't split the request here, read the values we need from headers
137                let (mut parts, body) = request.into_parts();
138
139                let resolved_locale = locale::resolve_locale(engine.locale_chain(), &parts);
140                let locale_value =
141                    resolved_locale.unwrap_or_else(|| engine.default_locale().to_string());
142                ctx.set("locale", minijinja::Value::from(locale_value));
143
144                // csrf_token (if present in extensions)
145                if let Some(csrf) = parts.extensions.get::<crate::middleware::CsrfToken>() {
146                    ctx.set("csrf_token", minijinja::Value::from(csrf.0.clone()));
147                }
148
149                // flash_messages() template function
150                if let Some(flash_state) = parts.extensions.get::<Arc<FlashState>>() {
151                    let state = flash_state.clone();
152                    ctx.set(
153                        "flash_messages",
154                        minijinja::Value::from_function(
155                            move |_args: &[minijinja::Value]| -> Result<minijinja::Value, minijinja::Error> {
156                                state.mark_read();
157                                let entries = state.incoming_as_template_value();
158                                Ok(minijinja::Value::from_serialize(&entries))
159                            },
160                        ),
161                    );
162                }
163
164                // tier info (if tier feature enabled and TierInfo in extensions)
165                if let Some(tier_info) = parts.extensions.get::<crate::tier::TierInfo>() {
166                    ctx.set("tier_name", minijinja::Value::from(tier_info.name.clone()));
167
168                    let ti = Arc::new(tier_info.clone());
169
170                    let t = ti.clone();
171                    ctx.set(
172                        "tier_has",
173                        minijinja::Value::from_function(move |name: &str| -> bool {
174                            t.has_feature(name)
175                        }),
176                    );
177
178                    let t = ti.clone();
179                    ctx.set(
180                        "tier_enabled",
181                        minijinja::Value::from_function(move |name: &str| -> bool {
182                            t.is_enabled(name)
183                        }),
184                    );
185
186                    ctx.set(
187                        "tier_limit",
188                        minijinja::Value::from_function(move |name: &str| -> Option<u64> {
189                            ti.limit(name)
190                        }),
191                    );
192                }
193
194                // Insert TemplateContext into extensions
195                parts.extensions.insert(ctx);
196
197                request = Request::from_parts(parts, body);
198            }
199
200            inner.call(request).await
201        })
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use axum::{Router, routing::get};
209    use http::{Request, StatusCode};
210    use tower::ServiceExt;
211
212    use crate::template::{TemplateConfig, TemplateContext};
213
214    // Return TempDir alongside Engine so files persist for the test's lifetime
215    fn test_engine() -> (tempfile::TempDir, Engine) {
216        let dir = tempfile::tempdir().unwrap();
217        let tpl_dir = dir.path().join("templates");
218        let locales_dir = dir.path().join("locales/en");
219        let static_dir = dir.path().join("static");
220        std::fs::create_dir_all(&tpl_dir).unwrap();
221        std::fs::create_dir_all(&locales_dir).unwrap();
222        std::fs::create_dir_all(&static_dir).unwrap();
223        std::fs::write(locales_dir.join("common.yaml"), "greeting: Hello").unwrap();
224
225        let uk_locales_dir = dir.path().join("locales/uk");
226        std::fs::create_dir_all(&uk_locales_dir).unwrap();
227        std::fs::write(uk_locales_dir.join("common.yaml"), "greeting: Привіт").unwrap();
228
229        let config = TemplateConfig {
230            templates_path: tpl_dir.to_str().unwrap().into(),
231            locales_path: dir.path().join("locales").to_str().unwrap().into(),
232            static_path: static_dir.to_str().unwrap().into(),
233            ..TemplateConfig::default()
234        };
235
236        let engine = Engine::builder().config(config).build().unwrap();
237        (dir, engine)
238    }
239
240    // Handlers must be module-level async fn per CLAUDE.md gotcha
241    async fn extract_url(req: Request<Body>) -> (StatusCode, String) {
242        let ctx = req.extensions().get::<TemplateContext>().unwrap();
243        let url = ctx
244            .get("current_url")
245            .map(|v| v.to_string())
246            .unwrap_or_default();
247        (StatusCode::OK, url)
248    }
249
250    async fn extract_is_htmx(req: Request<Body>) -> (StatusCode, String) {
251        let ctx = req.extensions().get::<TemplateContext>().unwrap();
252        let is_htmx = ctx
253            .get("is_htmx")
254            .map(|v| v.to_string())
255            .unwrap_or_default();
256        (StatusCode::OK, is_htmx)
257    }
258
259    async fn extract_locale(req: Request<Body>) -> (StatusCode, String) {
260        let ctx = req.extensions().get::<TemplateContext>().unwrap();
261        let locale = ctx.get("locale").map(|v| v.to_string()).unwrap_or_default();
262        (StatusCode::OK, locale)
263    }
264
265    async fn extract_request_id(req: Request<Body>) -> (StatusCode, String) {
266        let ctx = req.extensions().get::<TemplateContext>().unwrap();
267        let request_id = ctx
268            .get("request_id")
269            .map(|v| v.to_string())
270            .unwrap_or_default();
271        (StatusCode::OK, request_id)
272    }
273
274    #[tokio::test]
275    async fn injects_current_url_value() {
276        let (_dir, engine) = test_engine();
277        let app = Router::new()
278            .route("/test", get(extract_url))
279            .layer(TemplateContextLayer::new(engine));
280
281        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
282        let resp = app.oneshot(req).await.unwrap();
283        assert_eq!(resp.status(), StatusCode::OK);
284        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
285            .await
286            .unwrap();
287        assert_eq!(body, "/test");
288    }
289
290    #[tokio::test]
291    async fn injects_is_htmx_false() {
292        let (_dir, engine) = test_engine();
293        let app = Router::new()
294            .route("/test", get(extract_is_htmx))
295            .layer(TemplateContextLayer::new(engine));
296
297        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
298        let resp = app.oneshot(req).await.unwrap();
299        assert_eq!(resp.status(), StatusCode::OK);
300    }
301
302    #[tokio::test]
303    async fn injects_is_htmx_true() {
304        let (_dir, engine) = test_engine();
305        let app = Router::new()
306            .route("/test", get(extract_is_htmx))
307            .layer(TemplateContextLayer::new(engine));
308
309        let req = Request::builder()
310            .uri("/test")
311            .header("hx-request", "true")
312            .body(Body::empty())
313            .unwrap();
314        let resp = app.oneshot(req).await.unwrap();
315        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
316            .await
317            .unwrap();
318        assert_eq!(body, "true");
319    }
320
321    #[tokio::test]
322    async fn injects_locale_default() {
323        let (_dir, engine) = test_engine();
324        let app = Router::new()
325            .route("/test", get(extract_locale))
326            .layer(TemplateContextLayer::new(engine));
327
328        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
329        let resp = app.oneshot(req).await.unwrap();
330        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
331            .await
332            .unwrap();
333        assert_eq!(body, "en");
334    }
335
336    #[tokio::test]
337    async fn injects_locale_from_query() {
338        let (_dir, engine) = test_engine();
339        let app = Router::new()
340            .route("/test", get(extract_locale))
341            .layer(TemplateContextLayer::new(engine));
342
343        let req = Request::builder()
344            .uri("/test?lang=uk")
345            .body(Body::empty())
346            .unwrap();
347        let resp = app.oneshot(req).await.unwrap();
348        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
349            .await
350            .unwrap();
351        assert_eq!(body, "uk");
352    }
353
354    #[tokio::test]
355    async fn injects_request_id() {
356        let (_dir, engine) = test_engine();
357        let app = Router::new()
358            .route("/test", get(extract_request_id))
359            .layer(TemplateContextLayer::new(engine));
360
361        let req = Request::builder()
362            .uri("/test")
363            .header("x-request-id", "abc123")
364            .body(Body::empty())
365            .unwrap();
366        let resp = app.oneshot(req).await.unwrap();
367        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
368            .await
369            .unwrap();
370        assert_eq!(body, "abc123");
371    }
372
373    #[tokio::test]
374    async fn injects_flash_messages_function() {
375        use crate::flash::state::{FlashEntry, FlashState};
376
377        let (_dir, engine) = test_engine();
378        let tpl_dir = _dir.path().join("templates");
379        std::fs::write(
380            tpl_dir.join("flash_test.html"),
381            "{% for msg in flash_messages() %}{% for level, text in msg|items %}{{ level }}:{{ text }};{% endfor %}{% endfor %}",
382        ).unwrap();
383
384        let entries = vec![
385            FlashEntry {
386                level: "error".into(),
387                message: "bad".into(),
388            },
389            FlashEntry {
390                level: "info".into(),
391                message: "ok".into(),
392            },
393        ];
394        let flash_state = Arc::new(FlashState::new(entries));
395
396        // Use the engine directly to render, simulating what Renderer does
397        let mut ctx = TemplateContext::default();
398
399        // Register flash_messages function (same logic as middleware)
400        let state = flash_state.clone();
401        ctx.set(
402            "flash_messages",
403            minijinja::Value::from_function(
404                move |_args: &[minijinja::Value]| -> Result<minijinja::Value, minijinja::Error> {
405                    state.mark_read();
406                    let entries = state.incoming_as_template_value();
407                    Ok(minijinja::Value::from_serialize(&entries))
408                },
409            ),
410        );
411
412        let merged = ctx.merge(minijinja::context! {});
413        let result = engine.render("flash_test.html", merged).unwrap();
414
415        assert!(result.contains("error:bad;"));
416        assert!(result.contains("info:ok;"));
417        assert!(flash_state.was_read());
418    }
419
420    mod tier_tests {
421        use super::*;
422        use std::collections::HashMap;
423
424        use crate::tier::{FeatureAccess, TierInfo};
425
426        fn test_tier() -> TierInfo {
427            TierInfo {
428                name: "pro".into(),
429                features: HashMap::from([
430                    ("sso".into(), FeatureAccess::Toggle(true)),
431                    ("custom_domain".into(), FeatureAccess::Toggle(false)),
432                    ("api_calls".into(), FeatureAccess::Limit(100_000)),
433                ]),
434            }
435        }
436
437        async fn extract_tier_name(req: Request<Body>) -> (StatusCode, String) {
438            let ctx = req.extensions().get::<TemplateContext>().unwrap();
439            let name = ctx
440                .get("tier_name")
441                .map(|v| v.to_string())
442                .unwrap_or_default();
443            (StatusCode::OK, name)
444        }
445
446        #[tokio::test]
447        async fn injects_tier_name() {
448            let (_dir, engine) = test_engine();
449            let app = Router::new()
450                .route("/test", get(extract_tier_name))
451                .layer(TemplateContextLayer::new(engine));
452
453            let mut req = Request::builder().uri("/test").body(Body::empty()).unwrap();
454            req.extensions_mut().insert(test_tier());
455            let resp = app.oneshot(req).await.unwrap();
456            let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
457                .await
458                .unwrap();
459            assert_eq!(body, "pro");
460        }
461
462        #[tokio::test]
463        async fn tier_has_function_works() {
464            let (_dir, engine) = test_engine();
465            let tpl_dir = _dir.path().join("templates");
466            std::fs::write(
467                tpl_dir.join("tier_has_test.html"),
468                "{% if tier_has('sso') %}yes{% else %}no{% endif %}",
469            )
470            .unwrap();
471
472            let mut ctx = TemplateContext::default();
473            let tier = test_tier();
474            ctx.set("tier_name", minijinja::Value::from(tier.name.clone()));
475
476            let ti = tier.clone();
477            ctx.set(
478                "tier_has",
479                minijinja::Value::from_function(move |name: &str| -> bool { ti.has_feature(name) }),
480            );
481
482            let merged = ctx.merge(minijinja::context! {});
483            let result = engine.render("tier_has_test.html", merged).unwrap();
484            assert_eq!(result, "yes");
485        }
486
487        #[tokio::test]
488        async fn tier_has_returns_false_for_disabled() {
489            let (_dir, engine) = test_engine();
490            let tpl_dir = _dir.path().join("templates");
491            std::fs::write(
492                tpl_dir.join("tier_disabled_test.html"),
493                "{% if tier_has('custom_domain') %}yes{% else %}no{% endif %}",
494            )
495            .unwrap();
496
497            let mut ctx = TemplateContext::default();
498            let tier = test_tier();
499
500            let ti = tier.clone();
501            ctx.set(
502                "tier_has",
503                minijinja::Value::from_function(move |name: &str| -> bool { ti.has_feature(name) }),
504            );
505
506            let merged = ctx.merge(minijinja::context! {});
507            let result = engine.render("tier_disabled_test.html", merged).unwrap();
508            assert_eq!(result, "no");
509        }
510
511        #[tokio::test]
512        async fn tier_limit_function_works() {
513            let (_dir, engine) = test_engine();
514            let tpl_dir = _dir.path().join("templates");
515            std::fs::write(
516                tpl_dir.join("tier_limit_test.html"),
517                "{{ tier_limit('api_calls') }}",
518            )
519            .unwrap();
520
521            let mut ctx = TemplateContext::default();
522            let tier = test_tier();
523
524            let ti = tier.clone();
525            ctx.set(
526                "tier_limit",
527                minijinja::Value::from_function(move |name: &str| -> Option<u64> {
528                    ti.limit(name)
529                }),
530            );
531
532            let merged = ctx.merge(minijinja::context! {});
533            let result = engine.render("tier_limit_test.html", merged).unwrap();
534            assert_eq!(result, "100000");
535        }
536
537        #[tokio::test]
538        async fn no_tier_info_no_injection() {
539            let (_dir, engine) = test_engine();
540            let app = Router::new()
541                .route("/test", get(extract_tier_name))
542                .layer(TemplateContextLayer::new(engine));
543
544            let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
545            let resp = app.oneshot(req).await.unwrap();
546            let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
547                .await
548                .unwrap();
549            // tier_name not set — returns empty string
550            assert_eq!(body, "");
551        }
552    }
553}