Skip to main content

modo/auth/role/
extractor.rs

1use std::ops::Deref;
2
3use axum::extract::{FromRequestParts, OptionalFromRequestParts};
4use http::request::Parts;
5
6use crate::Error;
7
8/// Axum extractor that surfaces the resolved role to handlers.
9///
10/// Pulls the role previously stored in request extensions by the role
11/// [`middleware`](super::middleware()). Extracting as `Role` returns `500` if
12/// the middleware is not applied — this is a developer misconfiguration, not a
13/// user-facing error.
14///
15/// Use `Option<Role>` on routes that serve both authenticated and anonymous
16/// callers; `None` is returned when the middleware is absent or the extractor
17/// returned no role.
18///
19/// `Role` is a transparent newtype over `String` and also re-exported from
20/// [`modo::prelude`](crate::prelude) and [`modo::extractors`](crate::extractors).
21#[derive(Clone, Debug, PartialEq, Eq)]
22pub struct Role(pub(crate) String);
23
24impl Role {
25    /// Borrows the role as a string slice.
26    pub fn as_str(&self) -> &str {
27        &self.0
28    }
29}
30
31impl Deref for Role {
32    type Target = str;
33    fn deref(&self) -> &str {
34        &self.0
35    }
36}
37
38impl<S: Send + Sync> FromRequestParts<S> for Role {
39    type Rejection = Error;
40
41    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
42        parts
43            .extensions
44            .get::<Role>()
45            .cloned()
46            .ok_or_else(|| Error::internal("role middleware not applied"))
47    }
48}
49
50impl<S: Send + Sync> OptionalFromRequestParts<S> for Role {
51    type Rejection = Error;
52
53    async fn from_request_parts(
54        parts: &mut Parts,
55        _state: &S,
56    ) -> Result<Option<Self>, Self::Rejection> {
57        Ok(parts.extensions.get::<Role>().cloned())
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    #[test]
66    fn role_as_str() {
67        let role = Role("admin".into());
68        assert_eq!(role.as_str(), "admin");
69    }
70
71    #[test]
72    fn role_deref() {
73        let role = Role("editor".into());
74        let s: &str = &role;
75        assert_eq!(s, "editor");
76    }
77
78    #[test]
79    fn role_clone() {
80        let role = Role("admin".into());
81        let cloned = role.clone();
82        assert_eq!(role, cloned);
83    }
84
85    #[test]
86    fn role_debug() {
87        let role = Role("admin".into());
88        assert_eq!(format!("{role:?}"), r#"Role("admin")"#);
89    }
90
91    #[tokio::test]
92    async fn extract_from_extensions() {
93        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
94        parts.extensions.insert(Role("admin".into()));
95
96        let result = <Role as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
97        assert!(result.is_ok());
98        assert_eq!(result.unwrap().as_str(), "admin");
99    }
100
101    #[tokio::test]
102    async fn extract_missing_returns_500() {
103        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
104
105        let result = <Role as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
106        assert!(result.is_err());
107        let err = result.unwrap_err();
108        assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
109    }
110
111    #[tokio::test]
112    async fn option_role_none_when_missing() {
113        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
114
115        let result =
116            <Role as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
117        assert!(result.is_ok());
118        assert!(result.unwrap().is_none());
119    }
120
121    #[tokio::test]
122    async fn option_role_some_when_present() {
123        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
124        parts.extensions.insert(Role("viewer".into()));
125
126        let result =
127            <Role as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
128        assert!(result.is_ok());
129        assert!(result.unwrap().is_some());
130    }
131}