rust-rbac 0.1.0

A flexible Role-Based Access Control (RBAC) system for Rust applications
Documentation
use std::future::{ready, Ready};

use actix_web::{
    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
    Error,
};
use futures_util::future::LocalBoxFuture;

use crate::storage::RbacStorage;
use crate::RbacService;

/// Middleware for requiring a specific permission
pub struct RequirePermission<S: RbacStorage> {
    permission: String,
    rbac: std::sync::Arc<RbacService<S>>,
}

impl<S: RbacStorage> RequirePermission<S> {
    /// Create a new RequirePermission middleware
    pub fn new(permission: impl Into<String>, rbac: std::sync::Arc<RbacService<S>>) -> Self {
        Self {
            permission: permission.into(),
            rbac,
        }
    }
}

impl<S, B, R> Transform<R, ServiceRequest> for RequirePermission<S>
where
    S: RbacStorage + 'static,
    R: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
    R::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Transform = RequirePermissionMiddleware<S, R>;
    type InitError = ();
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: R) -> Self::Future {
        ready(Ok(RequirePermissionMiddleware {
            service,
            permission: self.permission.clone(),
            rbac: self.rbac.clone(),
        }))
    }
}

pub struct RequirePermissionMiddleware<S: RbacStorage, R> {
    service: R,
    permission: String,
    rbac: std::sync::Arc<RbacService<S>>,
}

impl<S, R, B> Service<ServiceRequest> for RequirePermissionMiddleware<S, R>
where
    S: RbacStorage + 'static,
    R: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
    R::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    forward_ready!(service);

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let permission = self.permission.clone();
        let rbac = self.rbac.clone();
        
        // Get the user ID from the request (e.g., from a token)
        let user_id = match get_user_id_from_request(&req) {
            Some(id) => id,
            None => {
                return Box::pin(async move {
                    Err(actix_web::error::ErrorUnauthorized("Authentication required"))
                });
            }
        };
        
        let service = self.service.clone();
        
        Box::pin(async move {
            // Check if the user has the required permission
            match rbac.subject_has_permission(&user_id, &permission).await {
                Ok(true) => {
                    // User has permission, continue with the request
                    service.call(req).await
                }
                Ok(false) => {
                    // User doesn't have permission
                    Err(actix_web::error::ErrorForbidden(
                        format!("Missing required permission: {}", permission)
                    ))
                }
                Err(e) => {
                    // Error checking permission
                    Err(actix_web::error::ErrorInternalServerError(
                        format!("Error checking permission: {}", e)
                    ))
                }
            }
        })
    }
}

/// Middleware for requiring a specific role
pub struct RequireRole<S: RbacStorage> {
    role: String,
    rbac: std::sync::Arc<RbacService<S>>,
}

impl<S: RbacStorage> RequireRole<S> {
    /// Create a new RequireRole middleware
    pub fn new(role: impl Into<String>, rbac: std::sync::Arc<RbacService<S>>) -> Self {
        Self {
            role: role.into(),
            rbac,
        }
    }
}

impl<S, B, R> Transform<R, ServiceRequest> for RequireRole<S>
where
    S: RbacStorage + 'static,
    R: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
    R::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Transform = RequireRoleMiddleware<S, R>;
    type InitError = ();
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: R) -> Self::Future {
        ready(Ok(RequireRoleMiddleware {
            service,
            role: self.role.clone(),
            rbac: self.rbac.clone(),
        }))
    }
}

pub struct RequireRoleMiddleware<S: RbacStorage, R> {
    service: R,
    role: String,
    rbac: std::sync::Arc<RbacService<S>>,
}

impl<S, R, B> Service<ServiceRequest> for RequireRoleMiddleware<S, R>
where
    S: RbacStorage + 'static,
    R: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
    R::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    forward_ready!(service);

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let role = self.role.clone();
        let rbac = self.rbac.clone();
        
        // Get the user ID from the request
        let user_id = match get_user_id_from_request(&req) {
            Some(id) => id,
            None => {
                return Box::pin(async move {
                    Err(actix_web::error::ErrorUnauthorized("Authentication required"))
                });
            }
        };
        
        let service = self.service.clone();
        
        Box::pin(async move {
            // Check if the user has the required role
            let roles = match rbac.get_roles_for_subject(&user_id).await {
                Ok(roles) => roles,
                Err(e) => {
                    return Err(actix_web::error::ErrorInternalServerError(
                        format!("Error checking roles: {}", e)
                    ));
                }
            };
            
            if roles.iter().any(|r| r.name == role) {
                // User has the role, continue with the request
                service.call(req).await
            } else {
                // User doesn't have the role
                Err(actix_web::error::ErrorForbidden(
                    format!("Missing required role: {}", role)
                ))
            }
        })
    }
}

// Helper function to get the user ID from the request
// This is just a placeholder - you would need to implement this based on your authentication system
fn get_user_id_from_request(req: &ServiceRequest) -> Option<String> {
    // Example: Get user ID from a token in the Authorization header
    req.headers()
        .get("Authorization")
        .and_then(|header| header.to_str().ok())
        .and_then(|auth| {
            if auth.starts_with("Bearer ") {
                // Parse the token and extract the user ID
                // This is just a placeholder - you would need to implement this based on your token format
                Some("user123".to_string())
            } else {
                None
            }
        })
}