use crate::{AuthError, AuthUser, Result};
use armature_core::HttpRequest;
use async_trait::async_trait;
#[async_trait]
pub trait Guard: Send + Sync {
async fn can_activate(&self, request: &HttpRequest) -> Result<bool>;
}
#[derive(Clone)]
pub struct AuthGuard;
impl AuthGuard {
pub fn new() -> Self {
Self
}
pub fn extract_user<T: AuthUser>(&self, _request: &HttpRequest) -> Result<T> {
Err(AuthError::Unauthorized)
}
}
impl Default for AuthGuard {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Guard for AuthGuard {
async fn can_activate(&self, request: &HttpRequest) -> Result<bool> {
let auth_header = request
.headers
.get("authorization")
.ok_or(AuthError::Unauthorized)?;
if !auth_header.starts_with("Bearer ") {
return Err(AuthError::InvalidToken(
"Invalid authorization format".to_string(),
));
}
Ok(true)
}
}
#[derive(Clone)]
pub struct RoleGuard {
required_roles: Vec<String>,
require_all: bool,
}
impl RoleGuard {
pub fn any(roles: Vec<String>) -> Self {
Self {
required_roles: roles,
require_all: false,
}
}
pub fn all(roles: Vec<String>) -> Self {
Self {
required_roles: roles,
require_all: true,
}
}
pub fn check_roles<T: AuthUser>(&self, user: &T) -> bool {
let role_refs: Vec<&str> = self.required_roles.iter().map(|s| s.as_str()).collect();
if self.require_all {
user.has_all_roles(&role_refs)
} else {
user.has_any_role(&role_refs)
}
}
}
#[async_trait]
impl Guard for RoleGuard {
async fn can_activate(&self, request: &HttpRequest) -> Result<bool> {
let auth_guard = AuthGuard::new();
auth_guard.can_activate(request).await?;
Ok(true)
}
}
#[derive(Clone)]
pub struct PermissionGuard {
required_permissions: Vec<String>,
require_all: bool,
}
impl PermissionGuard {
pub fn any(permissions: Vec<String>) -> Self {
Self {
required_permissions: permissions,
require_all: false,
}
}
pub fn all(permissions: Vec<String>) -> Self {
Self {
required_permissions: permissions,
require_all: true,
}
}
pub fn check_permissions<T: AuthUser>(&self, user: &T) -> bool {
if self.require_all {
self.required_permissions
.iter()
.all(|perm| user.has_permission(perm))
} else {
self.required_permissions
.iter()
.any(|perm| user.has_permission(perm))
}
}
}
#[async_trait]
impl Guard for PermissionGuard {
async fn can_activate(&self, request: &HttpRequest) -> Result<bool> {
let auth_guard = AuthGuard::new();
auth_guard.can_activate(request).await?;
Ok(true)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::UserContext;
#[test]
fn test_role_guard() {
let user = UserContext::new("user123".to_string())
.with_roles(vec!["admin".to_string(), "user".to_string()]);
let guard = RoleGuard::any(vec!["admin".to_string()]);
assert!(guard.check_roles(&user));
let guard = RoleGuard::any(vec!["guest".to_string()]);
assert!(!guard.check_roles(&user));
let guard = RoleGuard::all(vec!["admin".to_string(), "user".to_string()]);
assert!(guard.check_roles(&user));
let guard = RoleGuard::all(vec!["admin".to_string(), "guest".to_string()]);
assert!(!guard.check_roles(&user));
}
#[test]
fn test_permission_guard() {
let user = UserContext::new("user123".to_string())
.with_permissions(vec!["read".to_string(), "write".to_string()]);
let guard = PermissionGuard::any(vec!["read".to_string()]);
assert!(guard.check_permissions(&user));
let guard = PermissionGuard::any(vec!["delete".to_string()]);
assert!(!guard.check_permissions(&user));
let guard = PermissionGuard::all(vec!["read".to_string(), "write".to_string()]);
assert!(guard.check_permissions(&user));
let guard = PermissionGuard::all(vec!["read".to_string(), "delete".to_string()]);
assert!(!guard.check_permissions(&user));
}
}