axum_login/
middleware.rs

1use axum::http::{self, Uri};
2
3fn update_query(uri: &Uri, new_query: String) -> Result<Uri, http::Error> {
4    let query = form_urlencoded::parse(uri.query().map(|q| q.as_bytes()).unwrap_or_default());
5    let updated_query = form_urlencoded::Serializer::new(new_query)
6        .extend_pairs(query)
7        .finish();
8
9    let mut parts = uri.clone().into_parts();
10    parts.path_and_query = Some(format!("{}?{}", uri.path(), updated_query).parse()?);
11
12    Ok(Uri::from_parts(parts)?)
13}
14
15/// This is intended for internal use only and subject to change in the future
16/// without warning!
17#[doc(hidden)]
18pub fn url_with_redirect_query(
19    url: &str,
20    redirect_field: &str,
21    redirect_uri: Uri,
22) -> Result<Uri, http::Error> {
23    let uri = url.parse::<Uri>()?;
24
25    if uri.query().is_some_and(|q| q.contains(redirect_field)) {
26        return Ok(uri);
27    };
28
29    let redirect_uri_string = redirect_uri.to_string();
30    let redirect_uri_encoded = urlencoding::encode(&redirect_uri_string);
31    let redirect_query = format!("{redirect_field}={redirect_uri_encoded}");
32
33    update_query(&uri, redirect_query)
34}
35
36/// Login predicate middleware.
37///
38/// Requires that the user is authenticated.
39#[macro_export]
40macro_rules! login_required {
41    ($backend_type:ty) => {{
42        async fn is_authenticated(auth_session: $crate::AuthSession<$backend_type>) -> bool {
43            auth_session.user.is_some()
44        }
45
46        $crate::predicate_required!(
47            is_authenticated,
48            $crate::axum::http::StatusCode::UNAUTHORIZED
49        )
50    }};
51
52    ($backend_type:ty, login_url = $login_url:expr, redirect_field = $redirect_field:expr) => {{
53        async fn is_authenticated(auth_session: $crate::AuthSession<$backend_type>) -> bool {
54            auth_session.user.is_some()
55        }
56
57        $crate::predicate_required!(
58            is_authenticated,
59            login_url = $login_url,
60            redirect_field = $redirect_field
61        )
62    }};
63
64    ($backend_type:ty, login_url = $login_url:expr) => {
65        $crate::login_required!(
66            $backend_type,
67            login_url = $login_url,
68            redirect_field = "next"
69        )
70    };
71}
72
73/// Permission predicate middleware.
74///
75/// Requires that the specified permissions, either user or group or both, are
76/// all assigned to the user.
77#[macro_export]
78macro_rules! permission_required {
79    ($backend_type:ty, login_url = $login_url:expr, redirect_field = $redirect_field:expr, $($perm:expr),+ $(,)?) => {{
80        use $crate::AuthzBackend;
81
82        async fn is_authorized(auth_session: $crate::AuthSession<$backend_type>) -> bool {
83            if let Some(ref user) = auth_session.user {
84                let mut has_all_permissions = true;
85                $(
86                    has_all_permissions = has_all_permissions &&
87                        auth_session.backend.has_perm(user, $perm.into()).await.unwrap_or(false);
88                )+
89                has_all_permissions
90            } else {
91                false
92            }
93        }
94
95        $crate::predicate_required!(
96            is_authorized,
97            login_url = $login_url,
98            redirect_field = $redirect_field
99        )
100    }};
101
102    ($backend_type:ty, login_url = $login_url:expr, $($perm:expr),+ $(,)?) => {
103        $crate::permission_required!(
104            $backend_type,
105            login_url = $login_url,
106            redirect_field = "next",
107            $($perm),+
108        )
109    };
110
111    ($backend_type:ty, $($perm:expr),+ $(,)?) => {{
112        use $crate::AuthzBackend;
113
114        async fn is_authorized(auth_session: $crate::AuthSession<$backend_type>) -> bool {
115            if let Some(ref user) = auth_session.user {
116                let mut has_all_permissions = true;
117                $(
118                    has_all_permissions = has_all_permissions &&
119                        auth_session.backend.has_perm(user, $perm.into()).await.unwrap_or(false);
120                )+
121                has_all_permissions
122            } else {
123                false
124            }
125        }
126
127        $crate::predicate_required!(
128            is_authorized,
129            $crate::axum::http::StatusCode::FORBIDDEN
130        )
131    }};
132}
133
134/// Predicate middleware.
135///
136/// Can be specified with a login URL and next redirect field or an alternative
137/// which implements [`IntoResponse`](axum::response::IntoResponse).
138///
139/// When the predicate passes, the request processes normally. On failure,
140/// either a redirect to the specified login URL is issued or the alternative is
141/// used as the response.
142#[macro_export]
143macro_rules! predicate_required {
144    ($predicate:expr, $alternative:expr) => {{
145        use $crate::axum::{
146            middleware::{from_fn, Next},
147            response::IntoResponse,
148        };
149
150        from_fn(
151            |auth_session: $crate::AuthSession<_>, req, next: Next| async move {
152                if $predicate(auth_session).await {
153                    next.run(req).await
154                } else {
155                    $alternative.into_response()
156                }
157            },
158        )
159    }};
160
161    ($predicate:expr, login_url = $login_url:expr, redirect_field = $redirect_field:expr) => {{
162        use $crate::axum::{
163            extract::OriginalUri,
164            middleware::{from_fn, Next},
165            response::{IntoResponse, Redirect},
166        };
167
168        from_fn(
169            |auth_session: $crate::AuthSession<_>,
170             OriginalUri(original_uri): OriginalUri,
171             req,
172             next: Next| async move {
173                if $predicate(auth_session).await {
174                    next.run(req).await
175                } else {
176                    match $crate::url_with_redirect_query(
177                        $login_url,
178                        $redirect_field,
179                        original_uri
180                    ) {
181                        Ok(login_url) => {
182                            Redirect::temporary(&login_url.to_string()).into_response()
183                        }
184
185                        Err(err) => {
186                            $crate::tracing::error!(err = %err);
187                            $crate::axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response()
188                        }
189                    }
190                }
191            },
192        )
193    }};
194}
195
196#[cfg(test)]
197mod tests {
198    use std::collections::HashSet;
199
200    use axum::{
201        body::Body,
202        http::{header, Request, Response, StatusCode},
203        Router,
204    };
205    use tower::ServiceExt;
206    use tower_cookies::cookie;
207    use tower_sessions::SessionManagerLayer;
208    use tower_sessions_sqlx_store::{sqlx::SqlitePool, SqliteStore};
209
210    use crate::{AuthManagerLayerBuilder, AuthSession, AuthUser, AuthnBackend, AuthzBackend};
211
212    #[derive(Debug, Clone)]
213    struct User;
214
215    impl AuthUser for User {
216        type Id = i64;
217
218        fn id(&self) -> Self::Id {
219            0
220        }
221
222        fn session_auth_hash(&self) -> &[u8] {
223            &[]
224        }
225    }
226
227    #[derive(Debug, Clone)]
228    struct Credentials;
229
230    #[derive(thiserror::Error, Debug)]
231    struct Error;
232
233    impl std::fmt::Display for Error {
234        fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235            Ok(())
236        }
237    }
238
239    #[derive(Clone)]
240    struct Backend;
241
242    impl AuthnBackend for Backend {
243        type User = User;
244        type Credentials = Credentials;
245        type Error = Error;
246
247        async fn authenticate(
248            &self,
249            _: Self::Credentials,
250        ) -> Result<Option<Self::User>, Self::Error> {
251            Ok(Some(User))
252        }
253
254        async fn get_user(
255            &self,
256            _: &<<Backend as AuthnBackend>::User as AuthUser>::Id,
257        ) -> Result<Option<Self::User>, Self::Error> {
258            Ok(Some(User))
259        }
260    }
261
262    #[derive(Debug, Clone, Eq, PartialEq, Hash)]
263    pub struct Permission {
264        pub name: String,
265    }
266
267    impl From<&str> for Permission {
268        fn from(name: &str) -> Self {
269            Permission {
270                name: name.to_string(),
271            }
272        }
273    }
274
275    impl AuthzBackend for Backend {
276        type Permission = Permission;
277
278        async fn get_user_permissions(
279            &self,
280            _user: &Self::User,
281        ) -> Result<HashSet<Self::Permission>, Self::Error> {
282            let perms: HashSet<Self::Permission> =
283                HashSet::from_iter(["test.read".into(), "test.write".into()]);
284            Ok(perms)
285        }
286    }
287
288    macro_rules! auth_layer {
289        () => {{
290            let pool = SqlitePool::connect(":memory:").await.unwrap();
291            let session_store = SqliteStore::new(pool.clone());
292            session_store.migrate().await.unwrap();
293
294            let session_layer = SessionManagerLayer::new(session_store).with_secure(false);
295
296            AuthManagerLayerBuilder::new(Backend, session_layer).build()
297        }};
298    }
299
300    fn get_session_cookie(res: &Response<Body>) -> Option<String> {
301        res.headers()
302            .get(header::SET_COOKIE)
303            .and_then(|h| h.to_str().ok())
304            .and_then(|cookie_str| {
305                let cookie = cookie::Cookie::parse(cookie_str);
306                cookie.map(|c| c.to_string()).ok()
307            })
308    }
309
310    #[tokio::test]
311    async fn test_login_required() {
312        let app = Router::new()
313            .route("/", axum::routing::get(|| async {}))
314            .route_layer(login_required!(Backend))
315            .route(
316                "/login",
317                axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
318                    auth_session.login(&User).await.unwrap();
319                }),
320            )
321            .layer(auth_layer!());
322
323        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
324        let res = app.clone().oneshot(req).await.unwrap();
325        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
326
327        let req = Request::builder()
328            .uri("/login")
329            .body(Body::empty())
330            .unwrap();
331        let res = app.clone().oneshot(req).await.unwrap();
332        let session_cookie =
333            get_session_cookie(&res).expect("Response should have a valid session cookie");
334
335        let req = Request::builder()
336            .uri("/")
337            .header(header::COOKIE, session_cookie)
338            .body(Body::empty())
339            .unwrap();
340        let res = app.oneshot(req).await.unwrap();
341        assert_eq!(res.status(), StatusCode::OK);
342    }
343
344    #[tokio::test]
345    async fn test_login_required_with_login_url() {
346        let app = Router::new()
347            .route("/", axum::routing::get(|| async {}))
348            .route_layer(login_required!(Backend, login_url = "/login"))
349            .route(
350                "/login",
351                axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
352                    auth_session.login(&User).await.unwrap();
353                }),
354            )
355            .layer(auth_layer!());
356
357        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
358        let res = app.clone().oneshot(req).await.unwrap();
359
360        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
361        assert_eq!(
362            res.headers()
363                .get(header::LOCATION)
364                .and_then(|h| h.to_str().ok()),
365            Some("/login?next=%2F")
366        );
367
368        let req = Request::builder()
369            .uri("/login")
370            .body(Body::empty())
371            .unwrap();
372        let res = app.clone().oneshot(req).await.unwrap();
373        let session_cookie =
374            get_session_cookie(&res).expect("Response should have a valid session cookie");
375
376        let req = Request::builder()
377            .uri("/")
378            .header(header::COOKIE, session_cookie)
379            .body(Body::empty())
380            .unwrap();
381        let res = app.oneshot(req).await.unwrap();
382        assert_eq!(res.status(), StatusCode::OK);
383    }
384
385    #[tokio::test]
386    async fn test_login_required_with_login_url_and_redirect_field() {
387        let app = Router::new()
388            .route("/", axum::routing::get(|| async {}))
389            .route_layer(login_required!(
390                Backend,
391                login_url = "/signin",
392                redirect_field = "next_uri"
393            ))
394            .route(
395                "/signin",
396                axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
397                    auth_session.login(&User).await.unwrap();
398                }),
399            )
400            .layer(auth_layer!());
401
402        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
403        let res = app.clone().oneshot(req).await.unwrap();
404
405        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
406        assert_eq!(
407            res.headers()
408                .get(header::LOCATION)
409                .and_then(|h| h.to_str().ok()),
410            Some("/signin?next_uri=%2F")
411        );
412
413        let req = Request::builder()
414            .uri("/signin")
415            .body(Body::empty())
416            .unwrap();
417        let res = app.clone().oneshot(req).await.unwrap();
418        let session_cookie =
419            get_session_cookie(&res).expect("Response should have a valid session cookie");
420
421        let req = Request::builder()
422            .uri("/")
423            .header(header::COOKIE, session_cookie)
424            .body(Body::empty())
425            .unwrap();
426        let res = app.oneshot(req).await.unwrap();
427        assert_eq!(res.status(), StatusCode::OK);
428    }
429
430    #[tokio::test]
431    async fn test_permission_required() {
432        let app = Router::new()
433            .route("/", axum::routing::get(|| async {}))
434            .route_layer(permission_required!(Backend, "test.read"))
435            .route(
436                "/login",
437                axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
438                    auth_session.login(&User).await.unwrap();
439                }),
440            )
441            .layer(auth_layer!());
442
443        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
444        let res = app.clone().oneshot(req).await.unwrap();
445        assert_eq!(res.status(), StatusCode::FORBIDDEN);
446
447        let req = Request::builder()
448            .uri("/login")
449            .body(Body::empty())
450            .unwrap();
451        let res = app.clone().oneshot(req).await.unwrap();
452        let session_cookie =
453            get_session_cookie(&res).expect("Response should have a valid session cookie");
454
455        let req = Request::builder()
456            .uri("/")
457            .header(header::COOKIE, session_cookie)
458            .body(Body::empty())
459            .unwrap();
460        let res = app.oneshot(req).await.unwrap();
461        assert_eq!(res.status(), StatusCode::OK);
462    }
463
464    #[tokio::test]
465    async fn test_permission_required_multiple_permissions() {
466        let app = Router::new()
467            .route("/", axum::routing::get(|| async {}))
468            .route_layer(permission_required!(Backend, "test.read", "test.write"))
469            .route(
470                "/login",
471                axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
472                    auth_session.login(&User).await.unwrap();
473                }),
474            )
475            .layer(auth_layer!());
476
477        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
478        let res = app.clone().oneshot(req).await.unwrap();
479        assert_eq!(res.status(), StatusCode::FORBIDDEN);
480
481        let req = Request::builder()
482            .uri("/login")
483            .body(Body::empty())
484            .unwrap();
485        let res = app.clone().oneshot(req).await.unwrap();
486        let session_cookie =
487            get_session_cookie(&res).expect("Response should have a valid session cookie");
488
489        let req = Request::builder()
490            .uri("/")
491            .header(header::COOKIE, session_cookie)
492            .body(Body::empty())
493            .unwrap();
494        let res = app.oneshot(req).await.unwrap();
495        assert_eq!(res.status(), StatusCode::OK);
496    }
497
498    #[tokio::test]
499    async fn test_permission_required_with_login_url() {
500        let app = Router::new()
501            .route("/", axum::routing::get(|| async {}))
502            .route_layer(permission_required!(
503                Backend,
504                login_url = "/login",
505                "test.read"
506            ))
507            .route(
508                "/login",
509                axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
510                    auth_session.login(&User).await.unwrap();
511                }),
512            )
513            .layer(auth_layer!());
514
515        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
516        let res = app.clone().oneshot(req).await.unwrap();
517        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
518        assert_eq!(
519            res.headers()
520                .get(header::LOCATION)
521                .and_then(|h| h.to_str().ok()),
522            Some("/login?next=%2F")
523        );
524
525        let req = Request::builder()
526            .uri("/login")
527            .body(Body::empty())
528            .unwrap();
529        let res = app.clone().oneshot(req).await.unwrap();
530        let session_cookie =
531            get_session_cookie(&res).expect("Response should have a valid session cookie");
532
533        let req = Request::builder()
534            .uri("/")
535            .header(header::COOKIE, session_cookie)
536            .body(Body::empty())
537            .unwrap();
538        let res = app.oneshot(req).await.unwrap();
539        assert_eq!(res.status(), StatusCode::OK);
540    }
541
542    #[tokio::test]
543    async fn test_permission_required_with_login_url_and_redirect_field() {
544        let app = Router::new()
545            .route("/", axum::routing::get(|| async {}))
546            .route_layer(permission_required!(
547                Backend,
548                login_url = "/signin",
549                redirect_field = "next_uri",
550                "test.read"
551            ))
552            .route(
553                "/signin",
554                axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
555                    auth_session.login(&User).await.unwrap();
556                }),
557            )
558            .layer(auth_layer!());
559
560        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
561        let res = app.clone().oneshot(req).await.unwrap();
562        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
563        assert_eq!(
564            res.headers()
565                .get(header::LOCATION)
566                .and_then(|h| h.to_str().ok()),
567            Some("/signin?next_uri=%2F")
568        );
569
570        let req = Request::builder()
571            .uri("/signin")
572            .body(Body::empty())
573            .unwrap();
574        let res = app.clone().oneshot(req).await.unwrap();
575        let session_cookie =
576            get_session_cookie(&res).expect("Response should have a valid session cookie");
577
578        let req = Request::builder()
579            .uri("/")
580            .header(header::COOKIE, session_cookie)
581            .body(Body::empty())
582            .unwrap();
583        let res = app.oneshot(req).await.unwrap();
584        assert_eq!(res.status(), StatusCode::OK);
585    }
586
587    #[tokio::test]
588    async fn test_permission_required_missing_permissions() {
589        let app = Router::new()
590            .route("/", axum::routing::get(|| async {}))
591            .route_layer(permission_required!(
592                Backend,
593                "test.read",
594                "test.write",
595                "admin.read"
596            ))
597            .route(
598                "/login",
599                axum::routing::get(|mut auth_session: AuthSession<Backend>| async move {
600                    auth_session.login(&User).await.unwrap();
601                }),
602            )
603            .layer(auth_layer!());
604
605        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
606        let res = app.clone().oneshot(req).await.unwrap();
607        assert_eq!(res.status(), StatusCode::FORBIDDEN);
608
609        let req = Request::builder()
610            .uri("/login")
611            .body(Body::empty())
612            .unwrap();
613        let res = app.clone().oneshot(req).await.unwrap();
614        let session_cookie =
615            get_session_cookie(&res).expect("Response should have a valid session cookie");
616
617        let req = Request::builder()
618            .uri("/")
619            .header(header::COOKIE, session_cookie)
620            .body(Body::empty())
621            .unwrap();
622        let res = app.oneshot(req).await.unwrap();
623        assert_eq!(res.status(), StatusCode::FORBIDDEN);
624    }
625
626    #[tokio::test]
627    async fn test_redirect_uri_query() {
628        let app = Router::new()
629            .route("/", axum::routing::get(|| async {}))
630            .route_layer(login_required!(Backend, login_url = "/login"))
631            .layer(auth_layer!());
632
633        let req = Request::builder()
634            .uri("/?foo=bar&foo=baz")
635            .body(Body::empty())
636            .unwrap();
637        let res = app.oneshot(req).await.unwrap();
638        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
639        assert_eq!(
640            res.headers()
641                .get(header::LOCATION)
642                .and_then(|h| h.to_str().ok()),
643            Some("/login?next=%2F%3Ffoo%3Dbar%26foo%3Dbaz")
644        );
645    }
646
647    #[tokio::test]
648    async fn test_login_url_query() {
649        let app = Router::new()
650            .route("/", axum::routing::get(|| async {}))
651            .route_layer(login_required!(
652                Backend,
653                login_url = "/login?foo=bar&foo=baz"
654            ))
655            .layer(auth_layer!());
656
657        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
658        let res = app.clone().oneshot(req).await.unwrap();
659        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
660        assert_eq!(
661            res.headers()
662                .get(header::LOCATION)
663                .and_then(|h| h.to_str().ok()),
664            Some("/login?next=%2F&foo=bar&foo=baz")
665        );
666
667        let req = Request::builder()
668            .uri("/?a=b&a=c")
669            .body(Body::empty())
670            .unwrap();
671        let res = app.oneshot(req).await.unwrap();
672        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
673        assert_eq!(
674            res.headers()
675                .get(header::LOCATION)
676                .and_then(|h| h.to_str().ok()),
677            Some("/login?next=%2F%3Fa%3Db%26a%3Dc&foo=bar&foo=baz")
678        );
679    }
680
681    #[tokio::test]
682    async fn test_login_url_explicit_redirect() {
683        let app = Router::new()
684            .route("/", axum::routing::get(|| async {}))
685            .route_layer(login_required!(
686                Backend,
687                login_url = "/login?next_url=%2Fdashboard",
688                redirect_field = "next_url"
689            ))
690            .layer(auth_layer!());
691
692        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
693        let res = app.oneshot(req).await.unwrap();
694        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
695        assert_eq!(
696            res.headers()
697                .get(header::LOCATION)
698                .and_then(|h| h.to_str().ok()),
699            Some("/login?next_url=%2Fdashboard")
700        );
701
702        let app = Router::new()
703            .route("/", axum::routing::get(|| async {}))
704            .route_layer(login_required!(
705                Backend,
706                login_url = "/login?next=%2Fdashboard"
707            ))
708            .layer(auth_layer!());
709
710        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
711        let res = app.oneshot(req).await.unwrap();
712        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
713        assert_eq!(
714            res.headers()
715                .get(header::LOCATION)
716                .and_then(|h| h.to_str().ok()),
717            Some("/login?next=%2Fdashboard")
718        );
719    }
720
721    #[tokio::test]
722    async fn test_nested() {
723        let nested = Router::new()
724            .route("/foo", axum::routing::get(|| async {}))
725            .route_layer(login_required!(Backend, login_url = "/login"));
726        let app = Router::new().nest("/nested", nested).layer(auth_layer!());
727
728        let req = Request::builder()
729            .uri("/nested/foo")
730            .body(Body::empty())
731            .unwrap();
732        let res = app.oneshot(req).await.unwrap();
733        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
734        assert_eq!(
735            res.headers()
736                .get(header::LOCATION)
737                .and_then(|h| h.to_str().ok()),
738            Some("/login?next=%2Fnested%2Ffoo")
739        );
740    }
741}