1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use http::{Request, Response};
8use tower::{Layer, Service};
9
10use super::context::TemplateContext;
11use super::engine::Engine;
12use super::locale;
13use crate::flash::state::FlashState;
14
15#[derive(Clone)]
52pub struct TemplateContextLayer {
53 engine: Engine,
54}
55
56impl TemplateContextLayer {
57 pub fn new(engine: Engine) -> Self {
59 Self { engine }
60 }
61}
62
63impl<S> Layer<S> for TemplateContextLayer {
64 type Service = TemplateContextMiddleware<S>;
65
66 fn layer(&self, inner: S) -> Self::Service {
67 TemplateContextMiddleware {
68 inner,
69 engine: self.engine.clone(),
70 }
71 }
72}
73
74#[derive(Clone)]
81pub struct TemplateContextMiddleware<S> {
82 inner: S,
83 engine: Engine,
84}
85
86impl<S, ReqBody> Service<Request<ReqBody>> for TemplateContextMiddleware<S>
87where
88 S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
89 S::Future: Send + 'static,
90 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
91 ReqBody: Send + 'static,
92{
93 type Response = Response<Body>;
94 type Error = S::Error;
95 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
96
97 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98 self.inner.poll_ready(cx)
99 }
100
101 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
102 let engine = self.engine.clone();
103 let mut inner = self.inner.clone();
104 std::mem::swap(&mut self.inner, &mut inner);
105
106 Box::pin(async move {
107 let mut ctx = TemplateContext::default();
109
110 ctx.set(
112 "current_url",
113 minijinja::Value::from(request.uri().to_string()),
114 );
115
116 let is_htmx = request
118 .headers()
119 .get("hx-request")
120 .and_then(|v| v.to_str().ok())
121 .is_some_and(|v| v == "true");
122 ctx.set("is_htmx", minijinja::Value::from(is_htmx));
123
124 if let Some(req_id) = request
126 .headers()
127 .get("x-request-id")
128 .and_then(|v| v.to_str().ok())
129 {
130 ctx.set("request_id", minijinja::Value::from(req_id.to_string()));
131 }
132
133 {
135 let (mut parts, body) = request.into_parts();
138
139 let resolved_locale = locale::resolve_locale(engine.locale_chain(), &parts);
140 let locale_value =
141 resolved_locale.unwrap_or_else(|| engine.default_locale().to_string());
142 ctx.set("locale", minijinja::Value::from(locale_value));
143
144 if let Some(csrf) = parts.extensions.get::<crate::middleware::CsrfToken>() {
146 ctx.set("csrf_token", minijinja::Value::from(csrf.0.clone()));
147 }
148
149 if let Some(flash_state) = parts.extensions.get::<Arc<FlashState>>() {
151 let state = flash_state.clone();
152 ctx.set(
153 "flash_messages",
154 minijinja::Value::from_function(
155 move |_args: &[minijinja::Value]| -> Result<minijinja::Value, minijinja::Error> {
156 state.mark_read();
157 let entries = state.incoming_as_template_value();
158 Ok(minijinja::Value::from_serialize(&entries))
159 },
160 ),
161 );
162 }
163
164 if let Some(tier_info) = parts.extensions.get::<crate::tier::TierInfo>() {
166 ctx.set("tier_name", minijinja::Value::from(tier_info.name.clone()));
167
168 let ti = Arc::new(tier_info.clone());
169
170 let t = ti.clone();
171 ctx.set(
172 "tier_has",
173 minijinja::Value::from_function(move |name: &str| -> bool {
174 t.has_feature(name)
175 }),
176 );
177
178 let t = ti.clone();
179 ctx.set(
180 "tier_enabled",
181 minijinja::Value::from_function(move |name: &str| -> bool {
182 t.is_enabled(name)
183 }),
184 );
185
186 ctx.set(
187 "tier_limit",
188 minijinja::Value::from_function(move |name: &str| -> Option<u64> {
189 ti.limit(name)
190 }),
191 );
192 }
193
194 parts.extensions.insert(ctx);
196
197 request = Request::from_parts(parts, body);
198 }
199
200 inner.call(request).await
201 })
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use axum::{Router, routing::get};
209 use http::{Request, StatusCode};
210 use tower::ServiceExt;
211
212 use crate::template::{TemplateConfig, TemplateContext};
213
214 fn test_engine() -> (tempfile::TempDir, Engine) {
216 let dir = tempfile::tempdir().unwrap();
217 let tpl_dir = dir.path().join("templates");
218 let locales_dir = dir.path().join("locales/en");
219 let static_dir = dir.path().join("static");
220 std::fs::create_dir_all(&tpl_dir).unwrap();
221 std::fs::create_dir_all(&locales_dir).unwrap();
222 std::fs::create_dir_all(&static_dir).unwrap();
223 std::fs::write(locales_dir.join("common.yaml"), "greeting: Hello").unwrap();
224
225 let uk_locales_dir = dir.path().join("locales/uk");
226 std::fs::create_dir_all(&uk_locales_dir).unwrap();
227 std::fs::write(uk_locales_dir.join("common.yaml"), "greeting: Привіт").unwrap();
228
229 let config = TemplateConfig {
230 templates_path: tpl_dir.to_str().unwrap().into(),
231 locales_path: dir.path().join("locales").to_str().unwrap().into(),
232 static_path: static_dir.to_str().unwrap().into(),
233 ..TemplateConfig::default()
234 };
235
236 let engine = Engine::builder().config(config).build().unwrap();
237 (dir, engine)
238 }
239
240 async fn extract_url(req: Request<Body>) -> (StatusCode, String) {
242 let ctx = req.extensions().get::<TemplateContext>().unwrap();
243 let url = ctx
244 .get("current_url")
245 .map(|v| v.to_string())
246 .unwrap_or_default();
247 (StatusCode::OK, url)
248 }
249
250 async fn extract_is_htmx(req: Request<Body>) -> (StatusCode, String) {
251 let ctx = req.extensions().get::<TemplateContext>().unwrap();
252 let is_htmx = ctx
253 .get("is_htmx")
254 .map(|v| v.to_string())
255 .unwrap_or_default();
256 (StatusCode::OK, is_htmx)
257 }
258
259 async fn extract_locale(req: Request<Body>) -> (StatusCode, String) {
260 let ctx = req.extensions().get::<TemplateContext>().unwrap();
261 let locale = ctx.get("locale").map(|v| v.to_string()).unwrap_or_default();
262 (StatusCode::OK, locale)
263 }
264
265 async fn extract_request_id(req: Request<Body>) -> (StatusCode, String) {
266 let ctx = req.extensions().get::<TemplateContext>().unwrap();
267 let request_id = ctx
268 .get("request_id")
269 .map(|v| v.to_string())
270 .unwrap_or_default();
271 (StatusCode::OK, request_id)
272 }
273
274 #[tokio::test]
275 async fn injects_current_url_value() {
276 let (_dir, engine) = test_engine();
277 let app = Router::new()
278 .route("/test", get(extract_url))
279 .layer(TemplateContextLayer::new(engine));
280
281 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
282 let resp = app.oneshot(req).await.unwrap();
283 assert_eq!(resp.status(), StatusCode::OK);
284 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
285 .await
286 .unwrap();
287 assert_eq!(body, "/test");
288 }
289
290 #[tokio::test]
291 async fn injects_is_htmx_false() {
292 let (_dir, engine) = test_engine();
293 let app = Router::new()
294 .route("/test", get(extract_is_htmx))
295 .layer(TemplateContextLayer::new(engine));
296
297 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
298 let resp = app.oneshot(req).await.unwrap();
299 assert_eq!(resp.status(), StatusCode::OK);
300 }
301
302 #[tokio::test]
303 async fn injects_is_htmx_true() {
304 let (_dir, engine) = test_engine();
305 let app = Router::new()
306 .route("/test", get(extract_is_htmx))
307 .layer(TemplateContextLayer::new(engine));
308
309 let req = Request::builder()
310 .uri("/test")
311 .header("hx-request", "true")
312 .body(Body::empty())
313 .unwrap();
314 let resp = app.oneshot(req).await.unwrap();
315 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
316 .await
317 .unwrap();
318 assert_eq!(body, "true");
319 }
320
321 #[tokio::test]
322 async fn injects_locale_default() {
323 let (_dir, engine) = test_engine();
324 let app = Router::new()
325 .route("/test", get(extract_locale))
326 .layer(TemplateContextLayer::new(engine));
327
328 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
329 let resp = app.oneshot(req).await.unwrap();
330 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
331 .await
332 .unwrap();
333 assert_eq!(body, "en");
334 }
335
336 #[tokio::test]
337 async fn injects_locale_from_query() {
338 let (_dir, engine) = test_engine();
339 let app = Router::new()
340 .route("/test", get(extract_locale))
341 .layer(TemplateContextLayer::new(engine));
342
343 let req = Request::builder()
344 .uri("/test?lang=uk")
345 .body(Body::empty())
346 .unwrap();
347 let resp = app.oneshot(req).await.unwrap();
348 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
349 .await
350 .unwrap();
351 assert_eq!(body, "uk");
352 }
353
354 #[tokio::test]
355 async fn injects_request_id() {
356 let (_dir, engine) = test_engine();
357 let app = Router::new()
358 .route("/test", get(extract_request_id))
359 .layer(TemplateContextLayer::new(engine));
360
361 let req = Request::builder()
362 .uri("/test")
363 .header("x-request-id", "abc123")
364 .body(Body::empty())
365 .unwrap();
366 let resp = app.oneshot(req).await.unwrap();
367 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
368 .await
369 .unwrap();
370 assert_eq!(body, "abc123");
371 }
372
373 #[tokio::test]
374 async fn injects_flash_messages_function() {
375 use crate::flash::state::{FlashEntry, FlashState};
376
377 let (_dir, engine) = test_engine();
378 let tpl_dir = _dir.path().join("templates");
379 std::fs::write(
380 tpl_dir.join("flash_test.html"),
381 "{% for msg in flash_messages() %}{% for level, text in msg|items %}{{ level }}:{{ text }};{% endfor %}{% endfor %}",
382 ).unwrap();
383
384 let entries = vec![
385 FlashEntry {
386 level: "error".into(),
387 message: "bad".into(),
388 },
389 FlashEntry {
390 level: "info".into(),
391 message: "ok".into(),
392 },
393 ];
394 let flash_state = Arc::new(FlashState::new(entries));
395
396 let mut ctx = TemplateContext::default();
398
399 let state = flash_state.clone();
401 ctx.set(
402 "flash_messages",
403 minijinja::Value::from_function(
404 move |_args: &[minijinja::Value]| -> Result<minijinja::Value, minijinja::Error> {
405 state.mark_read();
406 let entries = state.incoming_as_template_value();
407 Ok(minijinja::Value::from_serialize(&entries))
408 },
409 ),
410 );
411
412 let merged = ctx.merge(minijinja::context! {});
413 let result = engine.render("flash_test.html", merged).unwrap();
414
415 assert!(result.contains("error:bad;"));
416 assert!(result.contains("info:ok;"));
417 assert!(flash_state.was_read());
418 }
419
420 mod tier_tests {
421 use super::*;
422 use std::collections::HashMap;
423
424 use crate::tier::{FeatureAccess, TierInfo};
425
426 fn test_tier() -> TierInfo {
427 TierInfo {
428 name: "pro".into(),
429 features: HashMap::from([
430 ("sso".into(), FeatureAccess::Toggle(true)),
431 ("custom_domain".into(), FeatureAccess::Toggle(false)),
432 ("api_calls".into(), FeatureAccess::Limit(100_000)),
433 ]),
434 }
435 }
436
437 async fn extract_tier_name(req: Request<Body>) -> (StatusCode, String) {
438 let ctx = req.extensions().get::<TemplateContext>().unwrap();
439 let name = ctx
440 .get("tier_name")
441 .map(|v| v.to_string())
442 .unwrap_or_default();
443 (StatusCode::OK, name)
444 }
445
446 #[tokio::test]
447 async fn injects_tier_name() {
448 let (_dir, engine) = test_engine();
449 let app = Router::new()
450 .route("/test", get(extract_tier_name))
451 .layer(TemplateContextLayer::new(engine));
452
453 let mut req = Request::builder().uri("/test").body(Body::empty()).unwrap();
454 req.extensions_mut().insert(test_tier());
455 let resp = app.oneshot(req).await.unwrap();
456 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
457 .await
458 .unwrap();
459 assert_eq!(body, "pro");
460 }
461
462 #[tokio::test]
463 async fn tier_has_function_works() {
464 let (_dir, engine) = test_engine();
465 let tpl_dir = _dir.path().join("templates");
466 std::fs::write(
467 tpl_dir.join("tier_has_test.html"),
468 "{% if tier_has('sso') %}yes{% else %}no{% endif %}",
469 )
470 .unwrap();
471
472 let mut ctx = TemplateContext::default();
473 let tier = test_tier();
474 ctx.set("tier_name", minijinja::Value::from(tier.name.clone()));
475
476 let ti = tier.clone();
477 ctx.set(
478 "tier_has",
479 minijinja::Value::from_function(move |name: &str| -> bool { ti.has_feature(name) }),
480 );
481
482 let merged = ctx.merge(minijinja::context! {});
483 let result = engine.render("tier_has_test.html", merged).unwrap();
484 assert_eq!(result, "yes");
485 }
486
487 #[tokio::test]
488 async fn tier_has_returns_false_for_disabled() {
489 let (_dir, engine) = test_engine();
490 let tpl_dir = _dir.path().join("templates");
491 std::fs::write(
492 tpl_dir.join("tier_disabled_test.html"),
493 "{% if tier_has('custom_domain') %}yes{% else %}no{% endif %}",
494 )
495 .unwrap();
496
497 let mut ctx = TemplateContext::default();
498 let tier = test_tier();
499
500 let ti = tier.clone();
501 ctx.set(
502 "tier_has",
503 minijinja::Value::from_function(move |name: &str| -> bool { ti.has_feature(name) }),
504 );
505
506 let merged = ctx.merge(minijinja::context! {});
507 let result = engine.render("tier_disabled_test.html", merged).unwrap();
508 assert_eq!(result, "no");
509 }
510
511 #[tokio::test]
512 async fn tier_limit_function_works() {
513 let (_dir, engine) = test_engine();
514 let tpl_dir = _dir.path().join("templates");
515 std::fs::write(
516 tpl_dir.join("tier_limit_test.html"),
517 "{{ tier_limit('api_calls') }}",
518 )
519 .unwrap();
520
521 let mut ctx = TemplateContext::default();
522 let tier = test_tier();
523
524 let ti = tier.clone();
525 ctx.set(
526 "tier_limit",
527 minijinja::Value::from_function(move |name: &str| -> Option<u64> {
528 ti.limit(name)
529 }),
530 );
531
532 let merged = ctx.merge(minijinja::context! {});
533 let result = engine.render("tier_limit_test.html", merged).unwrap();
534 assert_eq!(result, "100000");
535 }
536
537 #[tokio::test]
538 async fn no_tier_info_no_injection() {
539 let (_dir, engine) = test_engine();
540 let app = Router::new()
541 .route("/test", get(extract_tier_name))
542 .layer(TemplateContextLayer::new(engine));
543
544 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
545 let resp = app.oneshot(req).await.unwrap();
546 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
547 .await
548 .unwrap();
549 assert_eq!(body, "");
551 }
552 }
553}