1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use http::{Request, Response};
8use tower::{Layer, Service};
9
10use super::context::TemplateContext;
11use crate::flash::state::FlashState;
12use crate::i18n::Translator;
13
14#[derive(Clone, Default)]
55pub struct TemplateContextLayer;
56
57impl TemplateContextLayer {
58 pub fn new() -> Self {
60 Self
61 }
62}
63
64impl<S> Layer<S> for TemplateContextLayer {
65 type Service = TemplateContextMiddleware<S>;
66
67 fn layer(&self, inner: S) -> Self::Service {
68 TemplateContextMiddleware { inner }
69 }
70}
71
72#[derive(Clone)]
79pub struct TemplateContextMiddleware<S> {
80 inner: S,
81}
82
83impl<S, ReqBody> Service<Request<ReqBody>> for TemplateContextMiddleware<S>
84where
85 S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
86 S::Future: Send + 'static,
87 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
88 ReqBody: Send + 'static,
89{
90 type Response = Response<Body>;
91 type Error = S::Error;
92 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
93
94 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
95 self.inner.poll_ready(cx)
96 }
97
98 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
99 let mut inner = self.inner.clone();
100 std::mem::swap(&mut self.inner, &mut inner);
101
102 Box::pin(async move {
103 let mut ctx = TemplateContext::default();
105
106 ctx.set(
108 "current_url",
109 minijinja::Value::from(request.uri().to_string()),
110 );
111
112 let is_htmx = request
114 .headers()
115 .get("hx-request")
116 .and_then(|v| v.to_str().ok())
117 .is_some_and(|v| v == "true");
118 ctx.set("is_htmx", minijinja::Value::from(is_htmx));
119
120 if let Some(req_id) = request
122 .headers()
123 .get("x-request-id")
124 .and_then(|v| v.to_str().ok())
125 {
126 ctx.set("request_id", minijinja::Value::from(req_id.to_string()));
127 }
128
129 {
131 let (mut parts, body) = request.into_parts();
132
133 if let Some(translator) = parts.extensions.get::<Translator>() {
139 ctx.set(
140 "locale",
141 minijinja::Value::from(translator.locale().to_string()),
142 );
143 }
144
145 if let Some(csrf) = parts.extensions.get::<crate::middleware::CsrfToken>() {
147 ctx.set("csrf_token", minijinja::Value::from(csrf.0.clone()));
148 }
149
150 if let Some(flash_state) = parts.extensions.get::<Arc<FlashState>>() {
152 let state = flash_state.clone();
153 ctx.set(
154 "flash_messages",
155 minijinja::Value::from_function(
156 move |_args: &[minijinja::Value]| -> Result<minijinja::Value, minijinja::Error> {
157 state.mark_read();
158 let entries = state.incoming_as_template_value();
159 Ok(minijinja::Value::from_serialize(&entries))
160 },
161 ),
162 );
163 }
164
165 if let Some(tier_info) = parts.extensions.get::<crate::tier::TierInfo>() {
167 ctx.set("tier_name", minijinja::Value::from(tier_info.name.clone()));
168
169 let ti = Arc::new(tier_info.clone());
170
171 let t = ti.clone();
172 ctx.set(
173 "tier_has",
174 minijinja::Value::from_function(move |name: &str| -> bool {
175 t.has_feature(name)
176 }),
177 );
178
179 let t = ti.clone();
180 ctx.set(
181 "tier_enabled",
182 minijinja::Value::from_function(move |name: &str| -> bool {
183 t.is_enabled(name)
184 }),
185 );
186
187 ctx.set(
188 "tier_limit",
189 minijinja::Value::from_function(move |name: &str| -> Option<u64> {
190 ti.limit(name)
191 }),
192 );
193 }
194
195 parts.extensions.insert(ctx);
197
198 request = Request::from_parts(parts, body);
199 }
200
201 inner.call(request).await
202 })
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use axum::{Router, body::Body, routing::get};
210 use http::{Request, StatusCode};
211 use tower::ServiceExt;
212
213 use crate::i18n::{I18n, I18nConfig};
214 use crate::template::{Engine, TemplateConfig, TemplateContext};
215
216 fn test_engine() -> (tempfile::TempDir, Engine) {
218 let dir = tempfile::tempdir().unwrap();
219 let tpl_dir = dir.path().join("templates");
220 let static_dir = dir.path().join("static");
221 std::fs::create_dir_all(&tpl_dir).unwrap();
222 std::fs::create_dir_all(&static_dir).unwrap();
223
224 let config = TemplateConfig {
225 templates_path: tpl_dir.to_str().unwrap().into(),
226 static_path: static_dir.to_str().unwrap().into(),
227 ..TemplateConfig::default()
228 };
229
230 let engine = Engine::builder().config(config).build().unwrap();
231 (dir, engine)
232 }
233
234 fn test_i18n(dir: &std::path::Path) -> I18n {
235 let locales_dir = dir.join("locales");
236 std::fs::create_dir_all(locales_dir.join("en")).unwrap();
237 std::fs::write(
238 locales_dir.join("en").join("common.yaml"),
239 "greeting: Hello",
240 )
241 .unwrap();
242 std::fs::create_dir_all(locales_dir.join("uk")).unwrap();
243 std::fs::write(
244 locales_dir.join("uk").join("common.yaml"),
245 "greeting: Привіт",
246 )
247 .unwrap();
248
249 let config = I18nConfig {
250 locales_path: locales_dir.to_str().unwrap().into(),
251 default_locale: "en".into(),
252 ..I18nConfig::default()
253 };
254 I18n::new(&config).unwrap()
255 }
256
257 async fn extract_url(req: Request<Body>) -> (StatusCode, String) {
259 let ctx = req.extensions().get::<TemplateContext>().unwrap();
260 let url = ctx
261 .get("current_url")
262 .map(|v| v.to_string())
263 .unwrap_or_default();
264 (StatusCode::OK, url)
265 }
266
267 async fn extract_is_htmx(req: Request<Body>) -> (StatusCode, String) {
268 let ctx = req.extensions().get::<TemplateContext>().unwrap();
269 let is_htmx = ctx
270 .get("is_htmx")
271 .map(|v| v.to_string())
272 .unwrap_or_default();
273 (StatusCode::OK, is_htmx)
274 }
275
276 async fn extract_locale(req: Request<Body>) -> (StatusCode, String) {
277 let ctx = req.extensions().get::<TemplateContext>().unwrap();
278 let locale = ctx.get("locale").map(|v| v.to_string()).unwrap_or_default();
279 (StatusCode::OK, locale)
280 }
281
282 async fn extract_request_id(req: Request<Body>) -> (StatusCode, String) {
283 let ctx = req.extensions().get::<TemplateContext>().unwrap();
284 let request_id = ctx
285 .get("request_id")
286 .map(|v| v.to_string())
287 .unwrap_or_default();
288 (StatusCode::OK, request_id)
289 }
290
291 #[tokio::test]
292 async fn injects_current_url_value() {
293 let (_dir, _engine) = test_engine();
294 let app = Router::new()
295 .route("/test", get(extract_url))
296 .layer(TemplateContextLayer::new());
297
298 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
299 let resp = app.oneshot(req).await.unwrap();
300 assert_eq!(resp.status(), StatusCode::OK);
301 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
302 .await
303 .unwrap();
304 assert_eq!(body, "/test");
305 }
306
307 #[tokio::test]
308 async fn injects_is_htmx_false() {
309 let (_dir, _engine) = test_engine();
310 let app = Router::new()
311 .route("/test", get(extract_is_htmx))
312 .layer(TemplateContextLayer::new());
313
314 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
315 let resp = app.oneshot(req).await.unwrap();
316 assert_eq!(resp.status(), StatusCode::OK);
317 }
318
319 #[tokio::test]
320 async fn injects_is_htmx_true() {
321 let (_dir, _engine) = test_engine();
322 let app = Router::new()
323 .route("/test", get(extract_is_htmx))
324 .layer(TemplateContextLayer::new());
325
326 let req = Request::builder()
327 .uri("/test")
328 .header("hx-request", "true")
329 .body(Body::empty())
330 .unwrap();
331 let resp = app.oneshot(req).await.unwrap();
332 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
333 .await
334 .unwrap();
335 assert_eq!(body, "true");
336 }
337
338 #[tokio::test]
339 async fn injects_locale_default() {
340 let (_dir, _engine) = test_engine();
341 let i18n = test_i18n(_dir.path());
342 let app = Router::new()
345 .route("/test", get(extract_locale))
346 .layer(TemplateContextLayer::new())
347 .layer(i18n.layer());
348
349 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
350 let resp = app.oneshot(req).await.unwrap();
351 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
352 .await
353 .unwrap();
354 assert_eq!(body, "en");
355 }
356
357 #[tokio::test]
358 async fn injects_locale_from_query() {
359 let (_dir, _engine) = test_engine();
360 let i18n = test_i18n(_dir.path());
361 let app = Router::new()
362 .route("/test", get(extract_locale))
363 .layer(TemplateContextLayer::new())
364 .layer(i18n.layer());
365
366 let req = Request::builder()
367 .uri("/test?lang=uk")
368 .body(Body::empty())
369 .unwrap();
370 let resp = app.oneshot(req).await.unwrap();
371 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
372 .await
373 .unwrap();
374 assert_eq!(body, "uk");
375 }
376
377 #[tokio::test]
378 async fn no_i18n_layer_omits_locale() {
379 let (_dir, _engine) = test_engine();
380 let app = Router::new()
382 .route("/test", get(extract_locale))
383 .layer(TemplateContextLayer::new());
384
385 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
386 let resp = app.oneshot(req).await.unwrap();
387 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
388 .await
389 .unwrap();
390 assert_eq!(body, "");
392 }
393
394 #[tokio::test]
395 async fn injects_request_id() {
396 let (_dir, _engine) = test_engine();
397 let app = Router::new()
398 .route("/test", get(extract_request_id))
399 .layer(TemplateContextLayer::new());
400
401 let req = Request::builder()
402 .uri("/test")
403 .header("x-request-id", "abc123")
404 .body(Body::empty())
405 .unwrap();
406 let resp = app.oneshot(req).await.unwrap();
407 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
408 .await
409 .unwrap();
410 assert_eq!(body, "abc123");
411 }
412
413 #[tokio::test]
414 async fn injects_flash_messages_function() {
415 use crate::flash::state::{FlashEntry, FlashState};
416
417 let (_dir, engine) = test_engine();
418 let tpl_dir = _dir.path().join("templates");
419 std::fs::write(
420 tpl_dir.join("flash_test.html"),
421 "{% for msg in flash_messages() %}{% for level, text in msg|items %}{{ level }}:{{ text }};{% endfor %}{% endfor %}",
422 ).unwrap();
423
424 let entries = vec![
425 FlashEntry {
426 level: "error".into(),
427 message: "bad".into(),
428 },
429 FlashEntry {
430 level: "info".into(),
431 message: "ok".into(),
432 },
433 ];
434 let flash_state = Arc::new(FlashState::new(entries));
435
436 let mut ctx = TemplateContext::default();
438
439 let state = flash_state.clone();
441 ctx.set(
442 "flash_messages",
443 minijinja::Value::from_function(
444 move |_args: &[minijinja::Value]| -> Result<minijinja::Value, minijinja::Error> {
445 state.mark_read();
446 let entries = state.incoming_as_template_value();
447 Ok(minijinja::Value::from_serialize(&entries))
448 },
449 ),
450 );
451
452 let merged = ctx.merge(minijinja::context! {});
453 let result = engine.render("flash_test.html", merged).unwrap();
454
455 assert!(result.contains("error:bad;"));
456 assert!(result.contains("info:ok;"));
457 assert!(flash_state.was_read());
458 }
459
460 mod tier_tests {
461 use super::*;
462 use std::collections::HashMap;
463
464 use crate::tier::{FeatureAccess, TierInfo};
465
466 fn test_tier() -> TierInfo {
467 TierInfo {
468 name: "pro".into(),
469 features: HashMap::from([
470 ("sso".into(), FeatureAccess::Toggle(true)),
471 ("custom_domain".into(), FeatureAccess::Toggle(false)),
472 ("api_calls".into(), FeatureAccess::Limit(100_000)),
473 ]),
474 }
475 }
476
477 async fn extract_tier_name(req: Request<Body>) -> (StatusCode, String) {
478 let ctx = req.extensions().get::<TemplateContext>().unwrap();
479 let name = ctx
480 .get("tier_name")
481 .map(|v| v.to_string())
482 .unwrap_or_default();
483 (StatusCode::OK, name)
484 }
485
486 #[tokio::test]
487 async fn injects_tier_name() {
488 let (_dir, _engine) = test_engine();
489 let app = Router::new()
490 .route("/test", get(extract_tier_name))
491 .layer(TemplateContextLayer::new());
492
493 let mut req = Request::builder().uri("/test").body(Body::empty()).unwrap();
494 req.extensions_mut().insert(test_tier());
495 let resp = app.oneshot(req).await.unwrap();
496 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
497 .await
498 .unwrap();
499 assert_eq!(body, "pro");
500 }
501
502 #[tokio::test]
503 async fn tier_has_function_works() {
504 let (_dir, engine) = test_engine();
505 let tpl_dir = _dir.path().join("templates");
506 std::fs::write(
507 tpl_dir.join("tier_has_test.html"),
508 "{% if tier_has('sso') %}yes{% else %}no{% endif %}",
509 )
510 .unwrap();
511
512 let mut ctx = TemplateContext::default();
513 let tier = test_tier();
514 ctx.set("tier_name", minijinja::Value::from(tier.name.clone()));
515
516 let ti = tier.clone();
517 ctx.set(
518 "tier_has",
519 minijinja::Value::from_function(move |name: &str| -> bool { ti.has_feature(name) }),
520 );
521
522 let merged = ctx.merge(minijinja::context! {});
523 let result = engine.render("tier_has_test.html", merged).unwrap();
524 assert_eq!(result, "yes");
525 }
526
527 #[tokio::test]
528 async fn tier_has_returns_false_for_disabled() {
529 let (_dir, engine) = test_engine();
530 let tpl_dir = _dir.path().join("templates");
531 std::fs::write(
532 tpl_dir.join("tier_disabled_test.html"),
533 "{% if tier_has('custom_domain') %}yes{% else %}no{% endif %}",
534 )
535 .unwrap();
536
537 let mut ctx = TemplateContext::default();
538 let tier = test_tier();
539
540 let ti = tier.clone();
541 ctx.set(
542 "tier_has",
543 minijinja::Value::from_function(move |name: &str| -> bool { ti.has_feature(name) }),
544 );
545
546 let merged = ctx.merge(minijinja::context! {});
547 let result = engine.render("tier_disabled_test.html", merged).unwrap();
548 assert_eq!(result, "no");
549 }
550
551 #[tokio::test]
552 async fn tier_limit_function_works() {
553 let (_dir, engine) = test_engine();
554 let tpl_dir = _dir.path().join("templates");
555 std::fs::write(
556 tpl_dir.join("tier_limit_test.html"),
557 "{{ tier_limit('api_calls') }}",
558 )
559 .unwrap();
560
561 let mut ctx = TemplateContext::default();
562 let tier = test_tier();
563
564 let ti = tier.clone();
565 ctx.set(
566 "tier_limit",
567 minijinja::Value::from_function(move |name: &str| -> Option<u64> {
568 ti.limit(name)
569 }),
570 );
571
572 let merged = ctx.merge(minijinja::context! {});
573 let result = engine.render("tier_limit_test.html", merged).unwrap();
574 assert_eq!(result, "100000");
575 }
576
577 #[tokio::test]
578 async fn no_tier_info_no_injection() {
579 let (_dir, _engine) = test_engine();
580 let app = Router::new()
581 .route("/test", get(extract_tier_name))
582 .layer(TemplateContextLayer::new());
583
584 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
585 let resp = app.oneshot(req).await.unwrap();
586 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
587 .await
588 .unwrap();
589 assert_eq!(body, "");
591 }
592 }
593}