axum_security/rbac/
mod.rs1use std::{convert::Infallible, fmt::Debug, marker::PhantomData};
2
3use axum::{
4 extract::{FromRequestParts, Request, State},
5 http::{StatusCode, request::Parts},
6 middleware::Next,
7 response::{IntoResponse, Response},
8 routing::MethodRouter,
9};
10pub use axum_security_macros::{requires, requires_any};
11
12#[cfg(feature = "cookie")]
13use crate::cookie::CookieSession;
14
15pub fn __requires<T: RBAC>(resource: RolesExtractor<T>, roles: &[T]) -> Option<Response> {
16 if resource.roles.iter().all(|r| roles.contains(r)) {
17 None
18 } else {
19 Some(StatusCode::UNAUTHORIZED.into_response())
20 }
21}
22
23pub fn __requires_any<T: RBAC>(resource: RolesExtractor<T>, roles: &[T]) -> Option<Response> {
24 if resource.roles.iter().any(|r| roles.contains(r)) {
25 None
26 } else {
27 Some(StatusCode::UNAUTHORIZED.into_response())
28 }
29}
30
31pub struct RolesExtractor<T: RBAC> {
32 roles: Vec<T>,
33 _p: PhantomData<T>,
34}
35
36impl<S: Send + Sync, R: RBAC> FromRequestParts<S> for RolesExtractor<R>
37where
38 R::Resource: Clone + Send + Sync + 'static,
39 R: Copy,
40{
41 type Rejection = StatusCode;
42
43 async fn from_request_parts(
44 parts: &mut axum::http::request::Parts,
45 _state: &S,
46 ) -> Result<Self, Self::Rejection> {
47 let Some(roles) = extract_roles(parts) else {
48 return Err(StatusCode::UNAUTHORIZED);
49 };
50
51 Ok(RolesExtractor {
52 roles,
53 _p: PhantomData,
54 })
55 }
56}
57
58fn extract_roles<R>(parts: &mut Parts) -> Option<Vec<R>>
59where
60 R: RBAC,
61 R::Resource: Clone,
62 R: Copy,
63{
64 #[cfg(feature = "jwt")]
65 if let Some(resource) = parts.extensions.remove::<crate::jwt::Jwt<R::Resource>>() {
66 let roles: Vec<R> = R::extract_roles(&resource.0).into_iter().copied().collect();
67 parts.extensions.insert(resource);
68 return Some(roles);
69 }
70
71 #[cfg(feature = "cookie")]
72 if let Some(resource) = parts.extensions.remove::<CookieSession<R::Resource>>() {
73 let roles: Vec<R> = R::extract_roles(&resource.state)
74 .into_iter()
75 .copied()
76 .collect();
77 parts.extensions.insert(resource);
78 return Some(roles);
79 }
80
81 None
82}
83
84pub trait RBAC: Send + Sync + 'static + Clone + Eq + Copy + Debug {
85 type Resource: Send + Sync + 'static;
86
87 fn extract_roles(resource: &Self::Resource) -> impl IntoIterator<Item = &Self>;
88}
89
90pub trait RBACExt {
91 fn requires<T: RBAC>(self, rol: T) -> Self;
92
93 fn requires_all<T: RBAC>(self, rol: impl Into<Vec<T>>) -> Self;
94
95 fn requires_any<T: RBAC>(self, rol: impl Into<Vec<T>>) -> Self;
96}
97
98#[derive(Clone)]
99enum AuthType<T: RBAC> {
100 RequiresAll(Vec<T>),
101 RequiresAny(Vec<T>),
102}
103
104impl<S: Clone + 'static> RBACExt for MethodRouter<S, Infallible> {
105 fn requires<T: RBAC>(self, rol: T) -> Self {
106 let auth_type = AuthType::RequiresAll(vec![rol]);
107 let middleware = axum::middleware::from_fn_with_state(auth_type, rbac_layer::<T>);
108
109 self.layer::<_, Infallible>(middleware)
110 }
111
112 fn requires_all<T: RBAC>(self, rol: impl Into<Vec<T>>) -> Self {
113 let auth_type = AuthType::RequiresAll(rol.into());
114 let middleware = axum::middleware::from_fn_with_state(auth_type, rbac_layer::<T>);
115
116 self.layer::<_, Infallible>(middleware)
117 }
118
119 fn requires_any<T: RBAC>(self, rol: impl Into<Vec<T>>) -> Self {
120 let auth_type = AuthType::RequiresAny(rol.into());
121 let middleware = axum::middleware::from_fn_with_state(auth_type, rbac_layer::<T>);
122
123 self.layer::<_, Infallible>(middleware)
124 }
125}
126
127fn extract_resource<R: RBAC>(req: &mut Request) -> Result<R::Resource, Response> {
128 #[cfg(feature = "jwt")]
129 if let Some(user) = req
130 .extensions_mut()
131 .remove::<crate::jwt::Jwt<R::Resource>>()
132 {
133 return Ok(user.0);
134 }
135
136 #[cfg(feature = "cookie")]
137 if let Some(user) = req
138 .extensions_mut()
139 .remove::<crate::cookie::CookieSession<R::Resource>>()
140 {
141 return Ok(user.state);
142 }
143
144 Err(StatusCode::UNAUTHORIZED.into_response())
145}
146
147async fn rbac_layer<R: RBAC>(
148 State(auth_type): State<AuthType<R>>,
149 mut req: Request,
150 next: Next,
151) -> Response {
152 let resource = match extract_resource::<R>(&mut req) {
153 Ok(r) => r,
154 Err(e) => return e,
155 };
156
157 match auth_type {
158 AuthType::RequiresAll(roles) => {
159 let mut extracted_roles = R::extract_roles(&resource).into_iter();
160
161 if extracted_roles.any(|r| !roles.contains(r)) {
162 return StatusCode::UNAUTHORIZED.into_response();
163 }
164 }
165 AuthType::RequiresAny(roles) => {
166 let mut extracted_roles = R::extract_roles(&resource).into_iter();
167
168 if extracted_roles.all(|r| !roles.contains(r)) {
169 return StatusCode::UNAUTHORIZED.into_response();
170 }
171 }
172 }
173
174 next.run(req).await
175}