1use axum::http::{self, Uri};
2
3fn update_query(uri: &Uri, new_query: String) -> Result<Uri, http::Error> {
4 let query = form_urlencoded::parse(uri.query().map(|q| q.as_bytes()).unwrap_or_default());
5 let updated_query = form_urlencoded::Serializer::new(new_query)
6 .extend_pairs(query)
7 .finish();
8
9 let mut parts = uri.clone().into_parts();
10 parts.path_and_query = Some(format!("{}?{}", uri.path(), updated_query).parse()?);
11
12 Ok(Uri::from_parts(parts)?)
13}
14
15#[doc(hidden)]
18pub fn url_with_redirect_query(
19 url: &str,
20 redirect_field: &str,
21 redirect_uri: Uri,
22) -> Result<Uri, http::Error> {
23 let uri = url.parse::<Uri>()?;
24
25 if uri.query().is_some_and(|q| q.contains(redirect_field)) {
26 return Ok(uri);
27 };
28
29 let redirect_uri_string = redirect_uri.to_string();
30 let redirect_uri_encoded = urlencoding::encode(&redirect_uri_string);
31 let redirect_query = format!("{redirect_field}={redirect_uri_encoded}");
32
33 update_query(&uri, redirect_query)
34}
35
36#[macro_export]
40macro_rules! login_required {
41 ($backend_type:ty) => {{
42 async fn is_authenticated(auth_session: $crate::AuthSession<$backend_type>) -> bool {
43 auth_session.user.is_some()
44 }
45
46 $crate::predicate_required!(
47 is_authenticated,
48 $crate::axum::http::StatusCode::UNAUTHORIZED
49 )
50 }};
51
52 ($backend_type:ty, login_url = $login_url:expr, redirect_field = $redirect_field:expr) => {{
53 async fn is_authenticated(auth_session: $crate::AuthSession<$backend_type>) -> bool {
54 auth_session.user.is_some()
55 }
56
57 $crate::predicate_required!(
58 is_authenticated,
59 login_url = $login_url,
60 redirect_field = $redirect_field
61 )
62 }};
63
64 ($backend_type:ty, login_url = $login_url:expr) => {
65 $crate::login_required!(
66 $backend_type,
67 login_url = $login_url,
68 redirect_field = "next"
69 )
70 };
71}
72
73#[macro_export]
78macro_rules! permission_required {
79 ($backend_type:ty, login_url = $login_url:expr, redirect_field = $redirect_field:expr, $($perm:expr),+ $(,)?) => {{
80 use $crate::AuthzBackend;
81
82 async fn is_authorized(auth_session: $crate::AuthSession<$backend_type>) -> bool {
83 if let Some(ref user) = auth_session.user {
84 let mut has_all_permissions = true;
85 $(
86 has_all_permissions = has_all_permissions &&
87 auth_session.backend.has_perm(user, $perm.into()).await.unwrap_or(false);
88 )+
89 has_all_permissions
90 } else {
91 false
92 }
93 }
94
95 $crate::predicate_required!(
96 is_authorized,
97 login_url = $login_url,
98 redirect_field = $redirect_field
99 )
100 }};
101
102 ($backend_type:ty, login_url = $login_url:expr, $($perm:expr),+ $(,)?) => {
103 $crate::permission_required!(
104 $backend_type,
105 login_url = $login_url,
106 redirect_field = "next",
107 $($perm),+
108 )
109 };
110
111 ($backend_type:ty, $($perm:expr),+ $(,)?) => {{
112 use $crate::AuthzBackend;
113
114 async fn is_authorized(auth_session: $crate::AuthSession<$backend_type>) -> bool {
115 if let Some(ref user) = auth_session.user {
116 let mut has_all_permissions = true;
117 $(
118 has_all_permissions = has_all_permissions &&
119 auth_session.backend.has_perm(user, $perm.into()).await.unwrap_or(false);
120 )+
121 has_all_permissions
122 } else {
123 false
124 }
125 }
126
127 $crate::predicate_required!(
128 is_authorized,
129 $crate::axum::http::StatusCode::FORBIDDEN
130 )
131 }};
132}
133
134#[macro_export]
143macro_rules! predicate_required {
144 ($predicate:expr, $alternative:expr) => {{
145 use $crate::axum::{
146 middleware::{from_fn, Next},
147 response::IntoResponse,
148 };
149
150 from_fn(
151 |auth_session: $crate::AuthSession<_>, req, next: Next| async move {
152 if $predicate(auth_session).await {
153 next.run(req).await
154 } else {
155 $alternative.into_response()
156 }
157 },
158 )
159 }};
160
161 ($predicate:expr, login_url = $login_url:expr, redirect_field = $redirect_field:expr) => {{
162 use $crate::axum::{
163 extract::OriginalUri,
164 middleware::{from_fn, Next},
165 response::{IntoResponse, Redirect},
166 };
167
168 from_fn(
169 |auth_session: $crate::AuthSession<_>,
170 OriginalUri(original_uri): OriginalUri,
171 req,
172 next: Next| async move {
173 if $predicate(auth_session).await {
174 next.run(req).await
175 } else {
176 match $crate::url_with_redirect_query(
177 $login_url,
178 $redirect_field,
179 original_uri
180 ) {
181 Ok(login_url) => {
182 Redirect::temporary(&login_url.to_string()).into_response()
183 }
184
185 Err(err) => {
186 $crate::tracing::error!(err = %err);
187 $crate::axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response()
188 }
189 }
190 }
191 },
192 )
193 }};
194}
195
196#[cfg(test)]
197mod tests {
198 use std::collections::HashSet;
199
200 use axum::{
201 body::Body,
202 http::{header, Request, Response, StatusCode},
203 Router,
204 };
205 use tower::ServiceExt;
206 use tower_cookies::cookie;
207 use tower_sessions::SessionManagerLayer;
208 use tower_sessions_sqlx_store::{sqlx::SqlitePool, SqliteStore};
209
210 use crate::{AuthManagerLayerBuilder, AuthSession, AuthUser, AuthnBackend, AuthzBackend};
211
212 #[derive(Debug, Clone)]
213 struct User;
214
215 impl AuthUser for User {
216 type Id = i64;
217
218 fn id(&self) -> Self::Id {
219 0
220 }
221
222 fn session_auth_hash(&self) -> &[u8] {
223 &[]
224 }
225 }
226
227 #[derive(Debug, Clone)]
228 struct Credentials;
229
230 #[derive(thiserror::Error, Debug)]
231 struct Error;
232
233 impl std::fmt::Display for Error {
234 fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 Ok(())
236 }
237 }
238
239 #[derive(Clone)]
240 struct Backend;
241
242 impl AuthnBackend for Backend {
243 type User = User;
244 type Credentials = Credentials;
245 type Error = Error;
246
247 async fn authenticate(
248 &self,
249 _: Self::Credentials,
250 ) -> Result<Option<Self::User>, Self::Error> {
251 Ok(Some(User))
252 }
253
254 async fn get_user(
255 &self,
256 _: &<<Backend as AuthnBackend>::User as AuthUser>::Id,
257 ) -> Result<Option<Self::User>, Self::Error> {
258 Ok(Some(User))
259 }
260 }
261
262 #[derive(Debug, Clone, Eq, PartialEq, Hash)]
263 pub struct Permission {
264 pub name: String,
265 }
266
267 impl From<&str> for Permission {
268 fn from(name: &str) -> Self {
269 Permission {
270 name: name.to_string(),
271 }
272 }
273 }
274
275 impl AuthzBackend for Backend {
276 type Permission = Permission;
277
278 async fn get_user_permissions(
279 &self,
280 _user: &Self::User,
281 ) -> Result<HashSet<Self::Permission>, Self::Error> {
282 let perms: HashSet<Self::Permission> =
283 HashSet::from_iter(["test.read".into(), "test.write".into()]);
284 Ok(perms)
285 }
286 }
287
288 macro_rules! auth_layer {
289 () => {{
290 let pool = SqlitePool::connect(":memory:").await.unwrap();
291 let session_store = SqliteStore::new(pool.clone());
292 session_store.migrate().await.unwrap();
293
294 let session_layer = SessionManagerLayer::new(session_store).with_secure(false);
295
296 AuthManagerLayerBuilder::new(Backend, session_layer).build()
297 }};
298 }
299
300 fn get_session_cookie(res: &Response<Body>) -> Option<String> {
301 res.headers()
302 .get(header::SET_COOKIE)
303 .and_then(|h| h.to_str().ok())
304 .and_then(|cookie_str| {
305 let cookie = cookie::Cookie::parse(cookie_str);
306 cookie.map(|c| c.to_string()).ok()
307 })
308 }
309
310 #[tokio::test]
311 async fn test_login_required() {
312 let app = Router::new()
313 .route("/", axum::routing::get(|| async {}))
314 .route_layer(login_required!(Backend))
315 .route(
316 "/login",
317 axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
318 auth_session.login(&User).await.unwrap();
319 }),
320 )
321 .layer(auth_layer!());
322
323 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
324 let res = app.clone().oneshot(req).await.unwrap();
325 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
326
327 let req = Request::builder()
328 .uri("/login")
329 .body(Body::empty())
330 .unwrap();
331 let res = app.clone().oneshot(req).await.unwrap();
332 let session_cookie =
333 get_session_cookie(&res).expect("Response should have a valid session cookie");
334
335 let req = Request::builder()
336 .uri("/")
337 .header(header::COOKIE, session_cookie)
338 .body(Body::empty())
339 .unwrap();
340 let res = app.oneshot(req).await.unwrap();
341 assert_eq!(res.status(), StatusCode::OK);
342 }
343
344 #[tokio::test]
345 async fn test_login_required_with_login_url() {
346 let app = Router::new()
347 .route("/", axum::routing::get(|| async {}))
348 .route_layer(login_required!(Backend, login_url = "/login"))
349 .route(
350 "/login",
351 axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
352 auth_session.login(&User).await.unwrap();
353 }),
354 )
355 .layer(auth_layer!());
356
357 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
358 let res = app.clone().oneshot(req).await.unwrap();
359
360 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
361 assert_eq!(
362 res.headers()
363 .get(header::LOCATION)
364 .and_then(|h| h.to_str().ok()),
365 Some("/login?next=%2F")
366 );
367
368 let req = Request::builder()
369 .uri("/login")
370 .body(Body::empty())
371 .unwrap();
372 let res = app.clone().oneshot(req).await.unwrap();
373 let session_cookie =
374 get_session_cookie(&res).expect("Response should have a valid session cookie");
375
376 let req = Request::builder()
377 .uri("/")
378 .header(header::COOKIE, session_cookie)
379 .body(Body::empty())
380 .unwrap();
381 let res = app.oneshot(req).await.unwrap();
382 assert_eq!(res.status(), StatusCode::OK);
383 }
384
385 #[tokio::test]
386 async fn test_login_required_with_login_url_and_redirect_field() {
387 let app = Router::new()
388 .route("/", axum::routing::get(|| async {}))
389 .route_layer(login_required!(
390 Backend,
391 login_url = "/signin",
392 redirect_field = "next_uri"
393 ))
394 .route(
395 "/signin",
396 axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
397 auth_session.login(&User).await.unwrap();
398 }),
399 )
400 .layer(auth_layer!());
401
402 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
403 let res = app.clone().oneshot(req).await.unwrap();
404
405 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
406 assert_eq!(
407 res.headers()
408 .get(header::LOCATION)
409 .and_then(|h| h.to_str().ok()),
410 Some("/signin?next_uri=%2F")
411 );
412
413 let req = Request::builder()
414 .uri("/signin")
415 .body(Body::empty())
416 .unwrap();
417 let res = app.clone().oneshot(req).await.unwrap();
418 let session_cookie =
419 get_session_cookie(&res).expect("Response should have a valid session cookie");
420
421 let req = Request::builder()
422 .uri("/")
423 .header(header::COOKIE, session_cookie)
424 .body(Body::empty())
425 .unwrap();
426 let res = app.oneshot(req).await.unwrap();
427 assert_eq!(res.status(), StatusCode::OK);
428 }
429
430 #[tokio::test]
431 async fn test_permission_required() {
432 let app = Router::new()
433 .route("/", axum::routing::get(|| async {}))
434 .route_layer(permission_required!(Backend, "test.read"))
435 .route(
436 "/login",
437 axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
438 auth_session.login(&User).await.unwrap();
439 }),
440 )
441 .layer(auth_layer!());
442
443 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
444 let res = app.clone().oneshot(req).await.unwrap();
445 assert_eq!(res.status(), StatusCode::FORBIDDEN);
446
447 let req = Request::builder()
448 .uri("/login")
449 .body(Body::empty())
450 .unwrap();
451 let res = app.clone().oneshot(req).await.unwrap();
452 let session_cookie =
453 get_session_cookie(&res).expect("Response should have a valid session cookie");
454
455 let req = Request::builder()
456 .uri("/")
457 .header(header::COOKIE, session_cookie)
458 .body(Body::empty())
459 .unwrap();
460 let res = app.oneshot(req).await.unwrap();
461 assert_eq!(res.status(), StatusCode::OK);
462 }
463
464 #[tokio::test]
465 async fn test_permission_required_multiple_permissions() {
466 let app = Router::new()
467 .route("/", axum::routing::get(|| async {}))
468 .route_layer(permission_required!(Backend, "test.read", "test.write"))
469 .route(
470 "/login",
471 axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
472 auth_session.login(&User).await.unwrap();
473 }),
474 )
475 .layer(auth_layer!());
476
477 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
478 let res = app.clone().oneshot(req).await.unwrap();
479 assert_eq!(res.status(), StatusCode::FORBIDDEN);
480
481 let req = Request::builder()
482 .uri("/login")
483 .body(Body::empty())
484 .unwrap();
485 let res = app.clone().oneshot(req).await.unwrap();
486 let session_cookie =
487 get_session_cookie(&res).expect("Response should have a valid session cookie");
488
489 let req = Request::builder()
490 .uri("/")
491 .header(header::COOKIE, session_cookie)
492 .body(Body::empty())
493 .unwrap();
494 let res = app.oneshot(req).await.unwrap();
495 assert_eq!(res.status(), StatusCode::OK);
496 }
497
498 #[tokio::test]
499 async fn test_permission_required_with_login_url() {
500 let app = Router::new()
501 .route("/", axum::routing::get(|| async {}))
502 .route_layer(permission_required!(
503 Backend,
504 login_url = "/login",
505 "test.read"
506 ))
507 .route(
508 "/login",
509 axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
510 auth_session.login(&User).await.unwrap();
511 }),
512 )
513 .layer(auth_layer!());
514
515 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
516 let res = app.clone().oneshot(req).await.unwrap();
517 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
518 assert_eq!(
519 res.headers()
520 .get(header::LOCATION)
521 .and_then(|h| h.to_str().ok()),
522 Some("/login?next=%2F")
523 );
524
525 let req = Request::builder()
526 .uri("/login")
527 .body(Body::empty())
528 .unwrap();
529 let res = app.clone().oneshot(req).await.unwrap();
530 let session_cookie =
531 get_session_cookie(&res).expect("Response should have a valid session cookie");
532
533 let req = Request::builder()
534 .uri("/")
535 .header(header::COOKIE, session_cookie)
536 .body(Body::empty())
537 .unwrap();
538 let res = app.oneshot(req).await.unwrap();
539 assert_eq!(res.status(), StatusCode::OK);
540 }
541
542 #[tokio::test]
543 async fn test_permission_required_with_login_url_and_redirect_field() {
544 let app = Router::new()
545 .route("/", axum::routing::get(|| async {}))
546 .route_layer(permission_required!(
547 Backend,
548 login_url = "/signin",
549 redirect_field = "next_uri",
550 "test.read"
551 ))
552 .route(
553 "/signin",
554 axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
555 auth_session.login(&User).await.unwrap();
556 }),
557 )
558 .layer(auth_layer!());
559
560 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
561 let res = app.clone().oneshot(req).await.unwrap();
562 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
563 assert_eq!(
564 res.headers()
565 .get(header::LOCATION)
566 .and_then(|h| h.to_str().ok()),
567 Some("/signin?next_uri=%2F")
568 );
569
570 let req = Request::builder()
571 .uri("/signin")
572 .body(Body::empty())
573 .unwrap();
574 let res = app.clone().oneshot(req).await.unwrap();
575 let session_cookie =
576 get_session_cookie(&res).expect("Response should have a valid session cookie");
577
578 let req = Request::builder()
579 .uri("/")
580 .header(header::COOKIE, session_cookie)
581 .body(Body::empty())
582 .unwrap();
583 let res = app.oneshot(req).await.unwrap();
584 assert_eq!(res.status(), StatusCode::OK);
585 }
586
587 #[tokio::test]
588 async fn test_permission_required_missing_permissions() {
589 let app = Router::new()
590 .route("/", axum::routing::get(|| async {}))
591 .route_layer(permission_required!(
592 Backend,
593 "test.read",
594 "test.write",
595 "admin.read"
596 ))
597 .route(
598 "/login",
599 axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
600 auth_session.login(&User).await.unwrap();
601 }),
602 )
603 .layer(auth_layer!());
604
605 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
606 let res = app.clone().oneshot(req).await.unwrap();
607 assert_eq!(res.status(), StatusCode::FORBIDDEN);
608
609 let req = Request::builder()
610 .uri("/login")
611 .body(Body::empty())
612 .unwrap();
613 let res = app.clone().oneshot(req).await.unwrap();
614 let session_cookie =
615 get_session_cookie(&res).expect("Response should have a valid session cookie");
616
617 let req = Request::builder()
618 .uri("/")
619 .header(header::COOKIE, session_cookie)
620 .body(Body::empty())
621 .unwrap();
622 let res = app.oneshot(req).await.unwrap();
623 assert_eq!(res.status(), StatusCode::FORBIDDEN);
624 }
625
626 #[tokio::test]
627 async fn test_redirect_uri_query() {
628 let app = Router::new()
629 .route("/", axum::routing::get(|| async {}))
630 .route_layer(login_required!(Backend, login_url = "/login"))
631 .layer(auth_layer!());
632
633 let req = Request::builder()
634 .uri("/?foo=bar&foo=baz")
635 .body(Body::empty())
636 .unwrap();
637 let res = app.oneshot(req).await.unwrap();
638 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
639 assert_eq!(
640 res.headers()
641 .get(header::LOCATION)
642 .and_then(|h| h.to_str().ok()),
643 Some("/login?next=%2F%3Ffoo%3Dbar%26foo%3Dbaz")
644 );
645 }
646
647 #[tokio::test]
648 async fn test_login_url_query() {
649 let app = Router::new()
650 .route("/", axum::routing::get(|| async {}))
651 .route_layer(login_required!(
652 Backend,
653 login_url = "/login?foo=bar&foo=baz"
654 ))
655 .layer(auth_layer!());
656
657 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
658 let res = app.clone().oneshot(req).await.unwrap();
659 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
660 assert_eq!(
661 res.headers()
662 .get(header::LOCATION)
663 .and_then(|h| h.to_str().ok()),
664 Some("/login?next=%2F&foo=bar&foo=baz")
665 );
666
667 let req = Request::builder()
668 .uri("/?a=b&a=c")
669 .body(Body::empty())
670 .unwrap();
671 let res = app.oneshot(req).await.unwrap();
672 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
673 assert_eq!(
674 res.headers()
675 .get(header::LOCATION)
676 .and_then(|h| h.to_str().ok()),
677 Some("/login?next=%2F%3Fa%3Db%26a%3Dc&foo=bar&foo=baz")
678 );
679 }
680
681 #[tokio::test]
682 async fn test_login_url_explicit_redirect() {
683 let app = Router::new()
684 .route("/", axum::routing::get(|| async {}))
685 .route_layer(login_required!(
686 Backend,
687 login_url = "/login?next_url=%2Fdashboard",
688 redirect_field = "next_url"
689 ))
690 .layer(auth_layer!());
691
692 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
693 let res = app.oneshot(req).await.unwrap();
694 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
695 assert_eq!(
696 res.headers()
697 .get(header::LOCATION)
698 .and_then(|h| h.to_str().ok()),
699 Some("/login?next_url=%2Fdashboard")
700 );
701
702 let app = Router::new()
703 .route("/", axum::routing::get(|| async {}))
704 .route_layer(login_required!(
705 Backend,
706 login_url = "/login?next=%2Fdashboard"
707 ))
708 .layer(auth_layer!());
709
710 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
711 let res = app.oneshot(req).await.unwrap();
712 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
713 assert_eq!(
714 res.headers()
715 .get(header::LOCATION)
716 .and_then(|h| h.to_str().ok()),
717 Some("/login?next=%2Fdashboard")
718 );
719 }
720
721 #[tokio::test]
722 async fn test_nested() {
723 let nested = Router::new()
724 .route("/foo", axum::routing::get(|| async {}))
725 .route_layer(login_required!(Backend, login_url = "/login"));
726 let app = Router::new().nest("/nested", nested).layer(auth_layer!());
727
728 let req = Request::builder()
729 .uri("/nested/foo")
730 .body(Body::empty())
731 .unwrap();
732 let res = app.oneshot(req).await.unwrap();
733 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
734 assert_eq!(
735 res.headers()
736 .get(header::LOCATION)
737 .and_then(|h| h.to_str().ok()),
738 Some("/login?next=%2Fnested%2Ffoo")
739 );
740 }
741}