1use std::ops::Deref;
2
3use axum::extract::{FromRequestParts, OptionalFromRequestParts};
4use http::request::Parts;
5
6use crate::Error;
7
8#[derive(Clone, Debug, PartialEq, Eq)]
15pub struct Role(pub(crate) String);
16
17impl Role {
18 pub fn as_str(&self) -> &str {
20 &self.0
21 }
22}
23
24impl Deref for Role {
25 type Target = str;
26 fn deref(&self) -> &str {
27 &self.0
28 }
29}
30
31impl<S: Send + Sync> FromRequestParts<S> for Role {
32 type Rejection = Error;
33
34 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
35 parts
36 .extensions
37 .get::<Role>()
38 .cloned()
39 .ok_or_else(|| Error::internal("RBAC middleware not applied"))
40 }
41}
42
43impl<S: Send + Sync> OptionalFromRequestParts<S> for Role {
44 type Rejection = Error;
45
46 async fn from_request_parts(
47 parts: &mut Parts,
48 _state: &S,
49 ) -> Result<Option<Self>, Self::Rejection> {
50 Ok(parts.extensions.get::<Role>().cloned())
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57
58 #[test]
59 fn role_as_str() {
60 let role = Role("admin".into());
61 assert_eq!(role.as_str(), "admin");
62 }
63
64 #[test]
65 fn role_deref() {
66 let role = Role("editor".into());
67 let s: &str = &role;
68 assert_eq!(s, "editor");
69 }
70
71 #[test]
72 fn role_clone() {
73 let role = Role("admin".into());
74 let cloned = role.clone();
75 assert_eq!(role, cloned);
76 }
77
78 #[test]
79 fn role_debug() {
80 let role = Role("admin".into());
81 assert_eq!(format!("{role:?}"), r#"Role("admin")"#);
82 }
83
84 #[tokio::test]
85 async fn extract_from_extensions() {
86 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
87 parts.extensions.insert(Role("admin".into()));
88
89 let result = <Role as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
90 assert!(result.is_ok());
91 assert_eq!(result.unwrap().as_str(), "admin");
92 }
93
94 #[tokio::test]
95 async fn extract_missing_returns_500() {
96 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
97
98 let result = <Role as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
99 assert!(result.is_err());
100 let err = result.unwrap_err();
101 assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
102 }
103
104 #[tokio::test]
105 async fn option_role_none_when_missing() {
106 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
107
108 let result =
109 <Role as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
110 assert!(result.is_ok());
111 assert!(result.unwrap().is_none());
112 }
113
114 #[tokio::test]
115 async fn option_role_some_when_present() {
116 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
117 parts.extensions.insert(Role("viewer".into()));
118
119 let result =
120 <Role as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
121 assert!(result.is_ok());
122 assert!(result.unwrap().is_some());
123 }
124}