Skip to main content

modo/template/
middleware.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use http::{Request, Response};
8use tower::{Layer, Service};
9
10use super::context::TemplateContext;
11use crate::flash::state::FlashState;
12use crate::i18n::Translator;
13
14// --- Layer ---
15
16/// Tower middleware layer that populates [`TemplateContext`] for every request.
17///
18/// Install this layer on your router **before** any handler that uses
19/// [`Renderer`](super::Renderer). The layer injects the following keys into
20/// the request's [`TemplateContext`]:
21///
22/// | Key               | Source                                                              |
23/// |-------------------|---------------------------------------------------------------------|
24/// | `current_url`     | `request.uri().to_string()`                                         |
25/// | `is_htmx`         | `HX-Request: true` header                                           |
26/// | `request_id`      | `X-Request-Id` header (if present)                                  |
27/// | `locale`          | [`Translator`](crate::i18n::Translator) in extensions (present only when [`I18nLayer`](crate::i18n::I18nLayer) is installed upstream) |
28/// | `csrf_token`      | [`CsrfToken`](crate::middleware::CsrfToken) extension (if present)  |
29/// | `flash_messages`  | Callable returning flash entries; `FlashState` extension must be set by [`FlashLayer`](crate::flash::FlashLayer) |
30/// | `tier_name`       | `TierInfo::name` (when `TierInfo` extension is present)             |
31/// | `tier_has`        | Template function `tier_has(name) -> bool` (when `TierInfo` is present) |
32/// | `tier_enabled`    | Template function `tier_enabled(name) -> bool` (when `TierInfo` is present) |
33/// | `tier_limit`      | Template function `tier_limit(name) -> Option<u64>` (when `TierInfo` is present) |
34///
35/// This layer reads the current request's
36/// [`Translator`](crate::i18n::Translator) (installed by
37/// [`I18nLayer`](crate::i18n::I18nLayer)) and exposes `locale` to templates.
38/// If no `I18nLayer` is upstream, the `locale` variable is simply absent from
39/// the template context.
40///
41/// This layer is also re-exported as
42/// [`modo::middlewares::TemplateContext`](crate::middlewares::TemplateContext)
43/// for convenience at wiring sites.
44///
45/// # Example
46///
47/// ```rust,no_run
48/// use modo::template::TemplateContextLayer;
49///
50/// let router: axum::Router = axum::Router::new()
51///     // ... routes ...
52///     .layer(TemplateContextLayer::new());
53/// ```
54#[derive(Clone, Default)]
55pub struct TemplateContextLayer;
56
57impl TemplateContextLayer {
58    /// Creates a new [`TemplateContextLayer`].
59    pub fn new() -> Self {
60        Self
61    }
62}
63
64impl<S> Layer<S> for TemplateContextLayer {
65    type Service = TemplateContextMiddleware<S>;
66
67    fn layer(&self, inner: S) -> Self::Service {
68        TemplateContextMiddleware { inner }
69    }
70}
71
72// --- Service ---
73
74/// Tower [`Service`] produced by [`TemplateContextLayer`].
75///
76/// Populates a [`TemplateContext`] with per-request data and inserts it into
77/// request extensions before delegating to the inner service.
78#[derive(Clone)]
79pub struct TemplateContextMiddleware<S> {
80    inner: S,
81}
82
83impl<S, ReqBody> Service<Request<ReqBody>> for TemplateContextMiddleware<S>
84where
85    S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
86    S::Future: Send + 'static,
87    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
88    ReqBody: Send + 'static,
89{
90    type Response = Response<Body>;
91    type Error = S::Error;
92    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
93
94    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
95        self.inner.poll_ready(cx)
96    }
97
98    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
99        let mut inner = self.inner.clone();
100        std::mem::swap(&mut self.inner, &mut inner);
101
102        Box::pin(async move {
103            // Build TemplateContext with request-scoped data
104            let mut ctx = TemplateContext::default();
105
106            // current_url
107            ctx.set(
108                "current_url",
109                minijinja::Value::from(request.uri().to_string()),
110            );
111
112            // is_htmx
113            let is_htmx = request
114                .headers()
115                .get("hx-request")
116                .and_then(|v| v.to_str().ok())
117                .is_some_and(|v| v == "true");
118            ctx.set("is_htmx", minijinja::Value::from(is_htmx));
119
120            // request_id (if present)
121            if let Some(req_id) = request
122                .headers()
123                .get("x-request-id")
124                .and_then(|v| v.to_str().ok())
125            {
126                ctx.set("request_id", minijinja::Value::from(req_id.to_string()));
127            }
128
129            // Read request extensions for locale, csrf, flash, tier.
130            {
131                let (mut parts, body) = request.into_parts();
132
133                // locale comes from the Translator installed upstream by I18nLayer.
134                // If no I18nLayer is installed, `locale` is simply absent from
135                // the template context — the template engine's MiniJinja state
136                // then sees `locale` as undefined, which is the correct signal
137                // that `t()` cannot be used.
138                if let Some(translator) = parts.extensions.get::<Translator>() {
139                    ctx.set(
140                        "locale",
141                        minijinja::Value::from(translator.locale().to_string()),
142                    );
143                }
144
145                // csrf_token (if present in extensions)
146                if let Some(csrf) = parts.extensions.get::<crate::middleware::CsrfToken>() {
147                    ctx.set("csrf_token", minijinja::Value::from(csrf.0.clone()));
148                }
149
150                // flash_messages() template function
151                if let Some(flash_state) = parts.extensions.get::<Arc<FlashState>>() {
152                    let state = flash_state.clone();
153                    ctx.set(
154                        "flash_messages",
155                        minijinja::Value::from_function(
156                            move |_args: &[minijinja::Value]| -> Result<minijinja::Value, minijinja::Error> {
157                                state.mark_read();
158                                let entries = state.incoming_as_template_value();
159                                Ok(minijinja::Value::from_serialize(&entries))
160                            },
161                        ),
162                    );
163                }
164
165                // tier info (if tier feature enabled and TierInfo in extensions)
166                if let Some(tier_info) = parts.extensions.get::<crate::tier::TierInfo>() {
167                    ctx.set("tier_name", minijinja::Value::from(tier_info.name.clone()));
168
169                    let ti = Arc::new(tier_info.clone());
170
171                    let t = ti.clone();
172                    ctx.set(
173                        "tier_has",
174                        minijinja::Value::from_function(move |name: &str| -> bool {
175                            t.has_feature(name)
176                        }),
177                    );
178
179                    let t = ti.clone();
180                    ctx.set(
181                        "tier_enabled",
182                        minijinja::Value::from_function(move |name: &str| -> bool {
183                            t.is_enabled(name)
184                        }),
185                    );
186
187                    ctx.set(
188                        "tier_limit",
189                        minijinja::Value::from_function(move |name: &str| -> Option<u64> {
190                            ti.limit(name)
191                        }),
192                    );
193                }
194
195                // Insert TemplateContext into extensions
196                parts.extensions.insert(ctx);
197
198                request = Request::from_parts(parts, body);
199            }
200
201            inner.call(request).await
202        })
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use axum::{Router, body::Body, routing::get};
210    use http::{Request, StatusCode};
211    use tower::ServiceExt;
212
213    use crate::i18n::{I18n, I18nConfig};
214    use crate::template::{Engine, TemplateConfig, TemplateContext};
215
216    // Return TempDir alongside Engine so files persist for the test's lifetime
217    fn test_engine() -> (tempfile::TempDir, Engine) {
218        let dir = tempfile::tempdir().unwrap();
219        let tpl_dir = dir.path().join("templates");
220        let static_dir = dir.path().join("static");
221        std::fs::create_dir_all(&tpl_dir).unwrap();
222        std::fs::create_dir_all(&static_dir).unwrap();
223
224        let config = TemplateConfig {
225            templates_path: tpl_dir.to_str().unwrap().into(),
226            static_path: static_dir.to_str().unwrap().into(),
227            ..TemplateConfig::default()
228        };
229
230        let engine = Engine::builder().config(config).build().unwrap();
231        (dir, engine)
232    }
233
234    fn test_i18n(dir: &std::path::Path) -> I18n {
235        let locales_dir = dir.join("locales");
236        std::fs::create_dir_all(locales_dir.join("en")).unwrap();
237        std::fs::write(
238            locales_dir.join("en").join("common.yaml"),
239            "greeting: Hello",
240        )
241        .unwrap();
242        std::fs::create_dir_all(locales_dir.join("uk")).unwrap();
243        std::fs::write(
244            locales_dir.join("uk").join("common.yaml"),
245            "greeting: Привіт",
246        )
247        .unwrap();
248
249        let config = I18nConfig {
250            locales_path: locales_dir.to_str().unwrap().into(),
251            default_locale: "en".into(),
252            ..I18nConfig::default()
253        };
254        I18n::new(&config).unwrap()
255    }
256
257    // Handlers must be module-level async fn per CLAUDE.md gotcha
258    async fn extract_url(req: Request<Body>) -> (StatusCode, String) {
259        let ctx = req.extensions().get::<TemplateContext>().unwrap();
260        let url = ctx
261            .get("current_url")
262            .map(|v| v.to_string())
263            .unwrap_or_default();
264        (StatusCode::OK, url)
265    }
266
267    async fn extract_is_htmx(req: Request<Body>) -> (StatusCode, String) {
268        let ctx = req.extensions().get::<TemplateContext>().unwrap();
269        let is_htmx = ctx
270            .get("is_htmx")
271            .map(|v| v.to_string())
272            .unwrap_or_default();
273        (StatusCode::OK, is_htmx)
274    }
275
276    async fn extract_locale(req: Request<Body>) -> (StatusCode, String) {
277        let ctx = req.extensions().get::<TemplateContext>().unwrap();
278        let locale = ctx.get("locale").map(|v| v.to_string()).unwrap_or_default();
279        (StatusCode::OK, locale)
280    }
281
282    async fn extract_request_id(req: Request<Body>) -> (StatusCode, String) {
283        let ctx = req.extensions().get::<TemplateContext>().unwrap();
284        let request_id = ctx
285            .get("request_id")
286            .map(|v| v.to_string())
287            .unwrap_or_default();
288        (StatusCode::OK, request_id)
289    }
290
291    #[tokio::test]
292    async fn injects_current_url_value() {
293        let (_dir, _engine) = test_engine();
294        let app = Router::new()
295            .route("/test", get(extract_url))
296            .layer(TemplateContextLayer::new());
297
298        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
299        let resp = app.oneshot(req).await.unwrap();
300        assert_eq!(resp.status(), StatusCode::OK);
301        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
302            .await
303            .unwrap();
304        assert_eq!(body, "/test");
305    }
306
307    #[tokio::test]
308    async fn injects_is_htmx_false() {
309        let (_dir, _engine) = test_engine();
310        let app = Router::new()
311            .route("/test", get(extract_is_htmx))
312            .layer(TemplateContextLayer::new());
313
314        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
315        let resp = app.oneshot(req).await.unwrap();
316        assert_eq!(resp.status(), StatusCode::OK);
317    }
318
319    #[tokio::test]
320    async fn injects_is_htmx_true() {
321        let (_dir, _engine) = test_engine();
322        let app = Router::new()
323            .route("/test", get(extract_is_htmx))
324            .layer(TemplateContextLayer::new());
325
326        let req = Request::builder()
327            .uri("/test")
328            .header("hx-request", "true")
329            .body(Body::empty())
330            .unwrap();
331        let resp = app.oneshot(req).await.unwrap();
332        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
333            .await
334            .unwrap();
335        assert_eq!(body, "true");
336    }
337
338    #[tokio::test]
339    async fn injects_locale_default() {
340        let (_dir, _engine) = test_engine();
341        let i18n = test_i18n(_dir.path());
342        // I18nLayer must run before TemplateContextLayer sees the request, so
343        // install I18nLayer as an outer layer (outer layers run first in axum).
344        let app = Router::new()
345            .route("/test", get(extract_locale))
346            .layer(TemplateContextLayer::new())
347            .layer(i18n.layer());
348
349        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
350        let resp = app.oneshot(req).await.unwrap();
351        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
352            .await
353            .unwrap();
354        assert_eq!(body, "en");
355    }
356
357    #[tokio::test]
358    async fn injects_locale_from_query() {
359        let (_dir, _engine) = test_engine();
360        let i18n = test_i18n(_dir.path());
361        let app = Router::new()
362            .route("/test", get(extract_locale))
363            .layer(TemplateContextLayer::new())
364            .layer(i18n.layer());
365
366        let req = Request::builder()
367            .uri("/test?lang=uk")
368            .body(Body::empty())
369            .unwrap();
370        let resp = app.oneshot(req).await.unwrap();
371        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
372            .await
373            .unwrap();
374        assert_eq!(body, "uk");
375    }
376
377    #[tokio::test]
378    async fn no_i18n_layer_omits_locale() {
379        let (_dir, _engine) = test_engine();
380        // No I18nLayer installed.
381        let app = Router::new()
382            .route("/test", get(extract_locale))
383            .layer(TemplateContextLayer::new());
384
385        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
386        let resp = app.oneshot(req).await.unwrap();
387        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
388            .await
389            .unwrap();
390        // locale not set — returns empty string
391        assert_eq!(body, "");
392    }
393
394    #[tokio::test]
395    async fn injects_request_id() {
396        let (_dir, _engine) = test_engine();
397        let app = Router::new()
398            .route("/test", get(extract_request_id))
399            .layer(TemplateContextLayer::new());
400
401        let req = Request::builder()
402            .uri("/test")
403            .header("x-request-id", "abc123")
404            .body(Body::empty())
405            .unwrap();
406        let resp = app.oneshot(req).await.unwrap();
407        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
408            .await
409            .unwrap();
410        assert_eq!(body, "abc123");
411    }
412
413    #[tokio::test]
414    async fn injects_flash_messages_function() {
415        use crate::flash::state::{FlashEntry, FlashState};
416
417        let (_dir, engine) = test_engine();
418        let tpl_dir = _dir.path().join("templates");
419        std::fs::write(
420            tpl_dir.join("flash_test.html"),
421            "{% for msg in flash_messages() %}{% for level, text in msg|items %}{{ level }}:{{ text }};{% endfor %}{% endfor %}",
422        ).unwrap();
423
424        let entries = vec![
425            FlashEntry {
426                level: "error".into(),
427                message: "bad".into(),
428            },
429            FlashEntry {
430                level: "info".into(),
431                message: "ok".into(),
432            },
433        ];
434        let flash_state = Arc::new(FlashState::new(entries));
435
436        // Use the engine directly to render, simulating what Renderer does
437        let mut ctx = TemplateContext::default();
438
439        // Register flash_messages function (same logic as middleware)
440        let state = flash_state.clone();
441        ctx.set(
442            "flash_messages",
443            minijinja::Value::from_function(
444                move |_args: &[minijinja::Value]| -> Result<minijinja::Value, minijinja::Error> {
445                    state.mark_read();
446                    let entries = state.incoming_as_template_value();
447                    Ok(minijinja::Value::from_serialize(&entries))
448                },
449            ),
450        );
451
452        let merged = ctx.merge(minijinja::context! {});
453        let result = engine.render("flash_test.html", merged).unwrap();
454
455        assert!(result.contains("error:bad;"));
456        assert!(result.contains("info:ok;"));
457        assert!(flash_state.was_read());
458    }
459
460    mod tier_tests {
461        use super::*;
462        use std::collections::HashMap;
463
464        use crate::tier::{FeatureAccess, TierInfo};
465
466        fn test_tier() -> TierInfo {
467            TierInfo {
468                name: "pro".into(),
469                features: HashMap::from([
470                    ("sso".into(), FeatureAccess::Toggle(true)),
471                    ("custom_domain".into(), FeatureAccess::Toggle(false)),
472                    ("api_calls".into(), FeatureAccess::Limit(100_000)),
473                ]),
474            }
475        }
476
477        async fn extract_tier_name(req: Request<Body>) -> (StatusCode, String) {
478            let ctx = req.extensions().get::<TemplateContext>().unwrap();
479            let name = ctx
480                .get("tier_name")
481                .map(|v| v.to_string())
482                .unwrap_or_default();
483            (StatusCode::OK, name)
484        }
485
486        #[tokio::test]
487        async fn injects_tier_name() {
488            let (_dir, _engine) = test_engine();
489            let app = Router::new()
490                .route("/test", get(extract_tier_name))
491                .layer(TemplateContextLayer::new());
492
493            let mut req = Request::builder().uri("/test").body(Body::empty()).unwrap();
494            req.extensions_mut().insert(test_tier());
495            let resp = app.oneshot(req).await.unwrap();
496            let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
497                .await
498                .unwrap();
499            assert_eq!(body, "pro");
500        }
501
502        #[tokio::test]
503        async fn tier_has_function_works() {
504            let (_dir, engine) = test_engine();
505            let tpl_dir = _dir.path().join("templates");
506            std::fs::write(
507                tpl_dir.join("tier_has_test.html"),
508                "{% if tier_has('sso') %}yes{% else %}no{% endif %}",
509            )
510            .unwrap();
511
512            let mut ctx = TemplateContext::default();
513            let tier = test_tier();
514            ctx.set("tier_name", minijinja::Value::from(tier.name.clone()));
515
516            let ti = tier.clone();
517            ctx.set(
518                "tier_has",
519                minijinja::Value::from_function(move |name: &str| -> bool { ti.has_feature(name) }),
520            );
521
522            let merged = ctx.merge(minijinja::context! {});
523            let result = engine.render("tier_has_test.html", merged).unwrap();
524            assert_eq!(result, "yes");
525        }
526
527        #[tokio::test]
528        async fn tier_has_returns_false_for_disabled() {
529            let (_dir, engine) = test_engine();
530            let tpl_dir = _dir.path().join("templates");
531            std::fs::write(
532                tpl_dir.join("tier_disabled_test.html"),
533                "{% if tier_has('custom_domain') %}yes{% else %}no{% endif %}",
534            )
535            .unwrap();
536
537            let mut ctx = TemplateContext::default();
538            let tier = test_tier();
539
540            let ti = tier.clone();
541            ctx.set(
542                "tier_has",
543                minijinja::Value::from_function(move |name: &str| -> bool { ti.has_feature(name) }),
544            );
545
546            let merged = ctx.merge(minijinja::context! {});
547            let result = engine.render("tier_disabled_test.html", merged).unwrap();
548            assert_eq!(result, "no");
549        }
550
551        #[tokio::test]
552        async fn tier_limit_function_works() {
553            let (_dir, engine) = test_engine();
554            let tpl_dir = _dir.path().join("templates");
555            std::fs::write(
556                tpl_dir.join("tier_limit_test.html"),
557                "{{ tier_limit('api_calls') }}",
558            )
559            .unwrap();
560
561            let mut ctx = TemplateContext::default();
562            let tier = test_tier();
563
564            let ti = tier.clone();
565            ctx.set(
566                "tier_limit",
567                minijinja::Value::from_function(move |name: &str| -> Option<u64> {
568                    ti.limit(name)
569                }),
570            );
571
572            let merged = ctx.merge(minijinja::context! {});
573            let result = engine.render("tier_limit_test.html", merged).unwrap();
574            assert_eq!(result, "100000");
575        }
576
577        #[tokio::test]
578        async fn no_tier_info_no_injection() {
579            let (_dir, _engine) = test_engine();
580            let app = Router::new()
581                .route("/test", get(extract_tier_name))
582                .layer(TemplateContextLayer::new());
583
584            let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
585            let resp = app.oneshot(req).await.unwrap();
586            let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
587                .await
588                .unwrap();
589            // tier_name not set — returns empty string
590            assert_eq!(body, "");
591        }
592    }
593}