Skip to main content

axum_security/rbac/
mod.rs

1use std::{convert::Infallible, fmt::Debug, marker::PhantomData};
2
3use axum::{
4    extract::{FromRequestParts, Request, State},
5    http::{StatusCode, request::Parts},
6    middleware::Next,
7    response::{IntoResponse, Response},
8    routing::MethodRouter,
9};
10pub use axum_security_macros::{requires, requires_any};
11
12#[cfg(feature = "cookie")]
13use crate::cookie::CookieSession;
14
15pub fn __requires<T: RBAC>(resource: RolesExtractor<T>, roles: &[T]) -> Option<Response> {
16    if resource.roles.iter().all(|r| roles.contains(r)) {
17        None
18    } else {
19        Some(StatusCode::UNAUTHORIZED.into_response())
20    }
21}
22
23pub fn __requires_any<T: RBAC>(resource: RolesExtractor<T>, roles: &[T]) -> Option<Response> {
24    if resource.roles.iter().any(|r| roles.contains(r)) {
25        None
26    } else {
27        Some(StatusCode::UNAUTHORIZED.into_response())
28    }
29}
30
31pub struct RolesExtractor<T: RBAC> {
32    roles: Vec<T>,
33    _p: PhantomData<T>,
34}
35
36impl<S: Send + Sync, R: RBAC> FromRequestParts<S> for RolesExtractor<R>
37where
38    R::Resource: Clone + Send + Sync + 'static,
39    R: Copy,
40{
41    type Rejection = StatusCode;
42
43    async fn from_request_parts(
44        parts: &mut axum::http::request::Parts,
45        _state: &S,
46    ) -> Result<Self, Self::Rejection> {
47        let Some(roles) = extract_roles(parts) else {
48            return Err(StatusCode::UNAUTHORIZED);
49        };
50
51        Ok(RolesExtractor {
52            roles,
53            _p: PhantomData,
54        })
55    }
56}
57
58fn extract_roles<R>(parts: &mut Parts) -> Option<Vec<R>>
59where
60    R: RBAC,
61    R::Resource: Clone,
62    R: Copy,
63{
64    #[cfg(feature = "jwt")]
65    if let Some(resource) = parts.extensions.remove::<crate::jwt::Jwt<R::Resource>>() {
66        let roles: Vec<R> = R::extract_roles(&resource.0).into_iter().copied().collect();
67        parts.extensions.insert(resource);
68        return Some(roles);
69    }
70
71    #[cfg(feature = "cookie")]
72    if let Some(resource) = parts.extensions.remove::<CookieSession<R::Resource>>() {
73        let roles: Vec<R> = R::extract_roles(&resource.state)
74            .into_iter()
75            .copied()
76            .collect();
77        parts.extensions.insert(resource);
78        return Some(roles);
79    }
80
81    None
82}
83
84pub trait RBAC: Send + Sync + 'static + Clone + Eq + Copy + Debug {
85    type Resource: Send + Sync + 'static;
86
87    fn extract_roles(resource: &Self::Resource) -> impl IntoIterator<Item = &Self>;
88}
89
90pub trait RBACExt {
91    fn requires<T: RBAC>(self, rol: T) -> Self;
92
93    fn requires_all<T: RBAC>(self, rol: impl Into<Vec<T>>) -> Self;
94
95    fn requires_any<T: RBAC>(self, rol: impl Into<Vec<T>>) -> Self;
96}
97
98#[derive(Clone)]
99enum AuthType<T: RBAC> {
100    RequiresAll(Vec<T>),
101    RequiresAny(Vec<T>),
102}
103
104impl<S: Clone + 'static> RBACExt for MethodRouter<S, Infallible> {
105    fn requires<T: RBAC>(self, rol: T) -> Self {
106        let auth_type = AuthType::RequiresAll(vec![rol]);
107        let middleware = axum::middleware::from_fn_with_state(auth_type, rbac_layer::<T>);
108
109        self.layer::<_, Infallible>(middleware)
110    }
111
112    fn requires_all<T: RBAC>(self, rol: impl Into<Vec<T>>) -> Self {
113        let auth_type = AuthType::RequiresAll(rol.into());
114        let middleware = axum::middleware::from_fn_with_state(auth_type, rbac_layer::<T>);
115
116        self.layer::<_, Infallible>(middleware)
117    }
118
119    fn requires_any<T: RBAC>(self, rol: impl Into<Vec<T>>) -> Self {
120        let auth_type = AuthType::RequiresAny(rol.into());
121        let middleware = axum::middleware::from_fn_with_state(auth_type, rbac_layer::<T>);
122
123        self.layer::<_, Infallible>(middleware)
124    }
125}
126
127fn extract_resource<R: RBAC>(req: &mut Request) -> Result<R::Resource, Response> {
128    #[cfg(feature = "jwt")]
129    if let Some(user) = req
130        .extensions_mut()
131        .remove::<crate::jwt::Jwt<R::Resource>>()
132    {
133        return Ok(user.0);
134    }
135
136    #[cfg(feature = "cookie")]
137    if let Some(user) = req
138        .extensions_mut()
139        .remove::<crate::cookie::CookieSession<R::Resource>>()
140    {
141        return Ok(user.state);
142    }
143
144    Err(StatusCode::UNAUTHORIZED.into_response())
145}
146
147async fn rbac_layer<R: RBAC>(
148    State(auth_type): State<AuthType<R>>,
149    mut req: Request,
150    next: Next,
151) -> Response {
152    let resource = match extract_resource::<R>(&mut req) {
153        Ok(r) => r,
154        Err(e) => return e,
155    };
156
157    match auth_type {
158        AuthType::RequiresAll(roles) => {
159            let mut extracted_roles = R::extract_roles(&resource).into_iter();
160
161            if extracted_roles.any(|r| !roles.contains(r)) {
162                return StatusCode::UNAUTHORIZED.into_response();
163            }
164        }
165        AuthType::RequiresAny(roles) => {
166            let mut extracted_roles = R::extract_roles(&resource).into_iter();
167
168            if extracted_roles.all(|r| !roles.contains(r)) {
169                return StatusCode::UNAUTHORIZED.into_response();
170            }
171        }
172    }
173
174    next.run(req).await
175}