rust_rbac/middleware/
actix.rs

1use std::future::{ready, Ready};
2
3use actix_web::{
4    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
5    Error,
6};
7use futures_util::future::LocalBoxFuture;
8
9use crate::storage::RbacStorage;
10use crate::RbacService;
11
12/// Middleware for requiring a specific permission
13pub struct RequirePermission<S: RbacStorage> {
14    permission: String,
15    rbac: std::sync::Arc<RbacService<S>>,
16}
17
18impl<S: RbacStorage> RequirePermission<S> {
19    /// Create a new RequirePermission middleware
20    pub fn new(permission: impl Into<String>, rbac: std::sync::Arc<RbacService<S>>) -> Self {
21        Self {
22            permission: permission.into(),
23            rbac,
24        }
25    }
26}
27
28impl<S, B, R> Transform<R, ServiceRequest> for RequirePermission<S>
29where
30    S: RbacStorage + 'static,
31    R: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
32    R::Future: 'static,
33    B: 'static,
34{
35    type Response = ServiceResponse<B>;
36    type Error = Error;
37    type Transform = RequirePermissionMiddleware<S, R>;
38    type InitError = ();
39    type Future = Ready<Result<Self::Transform, Self::InitError>>;
40
41    fn new_transform(&self, service: R) -> Self::Future {
42        ready(Ok(RequirePermissionMiddleware {
43            service,
44            permission: self.permission.clone(),
45            rbac: self.rbac.clone(),
46        }))
47    }
48}
49
50pub struct RequirePermissionMiddleware<S: RbacStorage, R> {
51    service: R,
52    permission: String,
53    rbac: std::sync::Arc<RbacService<S>>,
54}
55
56impl<S, R, B> Service<ServiceRequest> for RequirePermissionMiddleware<S, R>
57where
58    S: RbacStorage + 'static,
59    R: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
60    R::Future: 'static,
61    B: 'static,
62{
63    type Response = ServiceResponse<B>;
64    type Error = Error;
65    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
66
67    forward_ready!(service);
68
69    fn call(&self, req: ServiceRequest) -> Self::Future {
70        let permission = self.permission.clone();
71        let rbac = self.rbac.clone();
72        
73        // Get the user ID from the request (e.g., from a token)
74        let user_id = match get_user_id_from_request(&req) {
75            Some(id) => id,
76            None => {
77                return Box::pin(async move {
78                    Err(actix_web::error::ErrorUnauthorized("Authentication required"))
79                });
80            }
81        };
82        
83        let service = self.service.clone();
84        
85        Box::pin(async move {
86            // Check if the user has the required permission
87            match rbac.subject_has_permission(&user_id, &permission).await {
88                Ok(true) => {
89                    // User has permission, continue with the request
90                    service.call(req).await
91                }
92                Ok(false) => {
93                    // User doesn't have permission
94                    Err(actix_web::error::ErrorForbidden(
95                        format!("Missing required permission: {}", permission)
96                    ))
97                }
98                Err(e) => {
99                    // Error checking permission
100                    Err(actix_web::error::ErrorInternalServerError(
101                        format!("Error checking permission: {}", e)
102                    ))
103                }
104            }
105        })
106    }
107}
108
109/// Middleware for requiring a specific role
110pub struct RequireRole<S: RbacStorage> {
111    role: String,
112    rbac: std::sync::Arc<RbacService<S>>,
113}
114
115impl<S: RbacStorage> RequireRole<S> {
116    /// Create a new RequireRole middleware
117    pub fn new(role: impl Into<String>, rbac: std::sync::Arc<RbacService<S>>) -> Self {
118        Self {
119            role: role.into(),
120            rbac,
121        }
122    }
123}
124
125impl<S, B, R> Transform<R, ServiceRequest> for RequireRole<S>
126where
127    S: RbacStorage + 'static,
128    R: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
129    R::Future: 'static,
130    B: 'static,
131{
132    type Response = ServiceResponse<B>;
133    type Error = Error;
134    type Transform = RequireRoleMiddleware<S, R>;
135    type InitError = ();
136    type Future = Ready<Result<Self::Transform, Self::InitError>>;
137
138    fn new_transform(&self, service: R) -> Self::Future {
139        ready(Ok(RequireRoleMiddleware {
140            service,
141            role: self.role.clone(),
142            rbac: self.rbac.clone(),
143        }))
144    }
145}
146
147pub struct RequireRoleMiddleware<S: RbacStorage, R> {
148    service: R,
149    role: String,
150    rbac: std::sync::Arc<RbacService<S>>,
151}
152
153impl<S, R, B> Service<ServiceRequest> for RequireRoleMiddleware<S, R>
154where
155    S: RbacStorage + 'static,
156    R: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
157    R::Future: 'static,
158    B: 'static,
159{
160    type Response = ServiceResponse<B>;
161    type Error = Error;
162    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
163
164    forward_ready!(service);
165
166    fn call(&self, req: ServiceRequest) -> Self::Future {
167        let role = self.role.clone();
168        let rbac = self.rbac.clone();
169        
170        // Get the user ID from the request
171        let user_id = match get_user_id_from_request(&req) {
172            Some(id) => id,
173            None => {
174                return Box::pin(async move {
175                    Err(actix_web::error::ErrorUnauthorized("Authentication required"))
176                });
177            }
178        };
179        
180        let service = self.service.clone();
181        
182        Box::pin(async move {
183            // Check if the user has the required role
184            let roles = match rbac.get_roles_for_subject(&user_id).await {
185                Ok(roles) => roles,
186                Err(e) => {
187                    return Err(actix_web::error::ErrorInternalServerError(
188                        format!("Error checking roles: {}", e)
189                    ));
190                }
191            };
192            
193            if roles.iter().any(|r| r.name == role) {
194                // User has the role, continue with the request
195                service.call(req).await
196            } else {
197                // User doesn't have the role
198                Err(actix_web::error::ErrorForbidden(
199                    format!("Missing required role: {}", role)
200                ))
201            }
202        })
203    }
204}
205
206// Helper function to get the user ID from the request
207// This is just a placeholder - you would need to implement this based on your authentication system
208fn get_user_id_from_request(req: &ServiceRequest) -> Option<String> {
209    // Example: Get user ID from a token in the Authorization header
210    req.headers()
211        .get("Authorization")
212        .and_then(|header| header.to_str().ok())
213        .and_then(|auth| {
214            if auth.starts_with("Bearer ") {
215                // Parse the token and extract the user ID
216                // This is just a placeholder - you would need to implement this based on your token format
217                Some("user123".to_string())
218            } else {
219                None
220            }
221        })
222}