use rustapi_core::{ApiError, FromRequestParts, Request};
#[derive(Debug, Clone)]
pub struct RoleGuard {
pub role: String,
}
impl FromRequestParts for RoleGuard {
fn from_request_parts(req: &Request) -> rustapi_core::Result<Self> {
let extensions = req.extensions();
#[cfg(feature = "jwt")]
{
use crate::jwt::{AuthUser, ValidatedClaims};
if let Some(validated) = extensions.get::<ValidatedClaims<serde_json::Value>>() {
if let Some(role) = validated.0.get("role").and_then(|r| r.as_str()) {
return Ok(Self {
role: role.to_string(),
});
}
}
if let Some(user) = extensions.get::<AuthUser<serde_json::Value>>() {
if let Some(role) = user.0.get("role").and_then(|r| r.as_str()) {
return Ok(Self {
role: role.to_string(),
});
}
}
}
#[cfg(not(feature = "jwt"))]
{
let _ = extensions;
}
Err(ApiError::forbidden(
"Authentication required: missing or invalid role",
))
}
}
impl RoleGuard {
pub fn has_role(&self, role: &str) -> bool {
self.role == role
}
pub fn require_role(&self, role: &str) -> Result<(), ApiError> {
if self.has_role(role) {
Ok(())
} else {
Err(ApiError::forbidden(format!("Required role: {}", role)))
}
}
}
#[derive(Debug, Clone)]
pub struct PermissionGuard {
pub permissions: Vec<String>,
}
impl FromRequestParts for PermissionGuard {
fn from_request_parts(req: &Request) -> rustapi_core::Result<Self> {
let extensions = req.extensions();
#[cfg(feature = "jwt")]
{
use crate::jwt::{AuthUser, ValidatedClaims};
if let Some(validated) = extensions.get::<ValidatedClaims<serde_json::Value>>() {
if let Some(permissions_value) = validated.0.get("permissions") {
if let Some(permissions_array) = permissions_value.as_array() {
let permissions: Vec<String> = permissions_array
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
if !permissions.is_empty() {
return Ok(Self { permissions });
}
}
}
}
if let Some(user) = extensions.get::<AuthUser<serde_json::Value>>() {
if let Some(permissions_value) = user.0.get("permissions") {
if let Some(permissions_array) = permissions_value.as_array() {
let permissions: Vec<String> = permissions_array
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
if !permissions.is_empty() {
return Ok(Self { permissions });
}
}
}
}
}
#[cfg(not(feature = "jwt"))]
{
let _ = extensions;
}
Err(ApiError::forbidden(
"Authentication required: missing or invalid permissions",
))
}
}
impl PermissionGuard {
pub fn has_permission(&self, permission: &str) -> bool {
self.permissions.iter().any(|p| p == permission)
}
pub fn require_permission(&self, permission: &str) -> Result<(), ApiError> {
if self.has_permission(permission) {
Ok(())
} else {
Err(ApiError::forbidden(format!(
"Required permission: {}",
permission
)))
}
}
pub fn has_any_permission(&self, permissions: &[&str]) -> bool {
self.permissions
.iter()
.any(|p| permissions.contains(&p.as_str()))
}
pub fn has_all_permissions(&self, permissions: &[&str]) -> bool {
permissions
.iter()
.all(|required| self.has_permission(required))
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[tokio::test]
async fn role_guard_without_auth_fails() {
let req = Request::from_http_request(
http::Request::builder()
.method("GET")
.uri("/")
.body(())
.unwrap(),
Bytes::new(),
);
let result = RoleGuard::from_request_parts(&req);
assert!(result.is_err());
}
#[tokio::test]
async fn permission_guard_without_auth_fails() {
let req = Request::from_http_request(
http::Request::builder()
.method("GET")
.uri("/")
.body(())
.unwrap(),
Bytes::new(),
);
let result = PermissionGuard::from_request_parts(&req);
assert!(result.is_err());
}
#[test]
fn role_guard_has_role_works() {
let guard = RoleGuard {
role: "admin".to_string(),
};
assert!(guard.has_role("admin"));
assert!(!guard.has_role("user"));
}
#[test]
fn permission_guard_has_permission_works() {
let guard = PermissionGuard {
permissions: vec!["users.read".to_string(), "users.write".to_string()],
};
assert!(guard.has_permission("users.read"));
assert!(guard.has_permission("users.write"));
assert!(!guard.has_permission("users.delete"));
}
#[test]
fn permission_guard_has_all_permissions_works() {
let guard = PermissionGuard {
permissions: vec!["users.read".to_string(), "users.write".to_string()],
};
assert!(guard.has_all_permissions(&["users.read", "users.write"]));
assert!(!guard.has_all_permissions(&["users.read", "users.delete"]));
}
}