use std::future::{ready, Ready};
use actix_web::{
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
Error,
};
use futures_util::future::LocalBoxFuture;
use crate::storage::RbacStorage;
use crate::RbacService;
pub struct RequirePermission<S: RbacStorage> {
permission: String,
rbac: std::sync::Arc<RbacService<S>>,
}
impl<S: RbacStorage> RequirePermission<S> {
pub fn new(permission: impl Into<String>, rbac: std::sync::Arc<RbacService<S>>) -> Self {
Self {
permission: permission.into(),
rbac,
}
}
}
impl<S, B, R> Transform<R, ServiceRequest> for RequirePermission<S>
where
S: RbacStorage + 'static,
R: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
R::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Transform = RequirePermissionMiddleware<S, R>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: R) -> Self::Future {
ready(Ok(RequirePermissionMiddleware {
service,
permission: self.permission.clone(),
rbac: self.rbac.clone(),
}))
}
}
pub struct RequirePermissionMiddleware<S: RbacStorage, R> {
service: R,
permission: String,
rbac: std::sync::Arc<RbacService<S>>,
}
impl<S, R, B> Service<ServiceRequest> for RequirePermissionMiddleware<S, R>
where
S: RbacStorage + 'static,
R: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
R::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let permission = self.permission.clone();
let rbac = self.rbac.clone();
let user_id = match get_user_id_from_request(&req) {
Some(id) => id,
None => {
return Box::pin(async move {
Err(actix_web::error::ErrorUnauthorized("Authentication required"))
});
}
};
let service = self.service.clone();
Box::pin(async move {
match rbac.subject_has_permission(&user_id, &permission).await {
Ok(true) => {
service.call(req).await
}
Ok(false) => {
Err(actix_web::error::ErrorForbidden(
format!("Missing required permission: {}", permission)
))
}
Err(e) => {
Err(actix_web::error::ErrorInternalServerError(
format!("Error checking permission: {}", e)
))
}
}
})
}
}
pub struct RequireRole<S: RbacStorage> {
role: String,
rbac: std::sync::Arc<RbacService<S>>,
}
impl<S: RbacStorage> RequireRole<S> {
pub fn new(role: impl Into<String>, rbac: std::sync::Arc<RbacService<S>>) -> Self {
Self {
role: role.into(),
rbac,
}
}
}
impl<S, B, R> Transform<R, ServiceRequest> for RequireRole<S>
where
S: RbacStorage + 'static,
R: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
R::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Transform = RequireRoleMiddleware<S, R>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: R) -> Self::Future {
ready(Ok(RequireRoleMiddleware {
service,
role: self.role.clone(),
rbac: self.rbac.clone(),
}))
}
}
pub struct RequireRoleMiddleware<S: RbacStorage, R> {
service: R,
role: String,
rbac: std::sync::Arc<RbacService<S>>,
}
impl<S, R, B> Service<ServiceRequest> for RequireRoleMiddleware<S, R>
where
S: RbacStorage + 'static,
R: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
R::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let role = self.role.clone();
let rbac = self.rbac.clone();
let user_id = match get_user_id_from_request(&req) {
Some(id) => id,
None => {
return Box::pin(async move {
Err(actix_web::error::ErrorUnauthorized("Authentication required"))
});
}
};
let service = self.service.clone();
Box::pin(async move {
let roles = match rbac.get_roles_for_subject(&user_id).await {
Ok(roles) => roles,
Err(e) => {
return Err(actix_web::error::ErrorInternalServerError(
format!("Error checking roles: {}", e)
));
}
};
if roles.iter().any(|r| r.name == role) {
service.call(req).await
} else {
Err(actix_web::error::ErrorForbidden(
format!("Missing required role: {}", role)
))
}
})
}
}
fn get_user_id_from_request(req: &ServiceRequest) -> Option<String> {
req.headers()
.get("Authorization")
.and_then(|header| header.to_str().ok())
.and_then(|auth| {
if auth.starts_with("Bearer ") {
Some("user123".to_string())
} else {
None
}
})
}