use super::error::AuthorizationError;
use super::response::AuthResponse;
use crate::auth::Authenticatable;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
type AbilityCallback =
Box<dyn Fn(&dyn Authenticatable, Option<&dyn Any>) -> AuthResponse + Send + Sync>;
type BeforeCallback = Box<dyn Fn(&dyn Authenticatable, &str) -> Option<bool> + Send + Sync>;
static GATE_REGISTRY: RwLock<Option<GateRegistry>> = RwLock::new(None);
struct GateRegistry {
abilities: HashMap<String, AbilityCallback>,
before_hooks: Vec<BeforeCallback>,
policies: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
}
impl GateRegistry {
fn new() -> Self {
Self {
abilities: HashMap::new(),
before_hooks: Vec::new(),
policies: HashMap::new(),
}
}
}
pub struct Gate;
impl Gate {
pub fn init() {
let mut registry = GATE_REGISTRY.write().unwrap();
if registry.is_none() {
*registry = Some(GateRegistry::new());
}
}
pub fn define<F>(ability: &str, callback: F)
where
F: Fn(&dyn Authenticatable, Option<&dyn Any>) -> AuthResponse + Send + Sync + 'static,
{
Self::init();
let mut registry = GATE_REGISTRY.write().unwrap();
if let Some(ref mut reg) = *registry {
reg.abilities
.insert(ability.to_string(), Box::new(callback));
}
}
pub fn before<F>(callback: F)
where
F: Fn(&dyn Authenticatable, &str) -> Option<bool> + Send + Sync + 'static,
{
Self::init();
let mut registry = GATE_REGISTRY.write().unwrap();
if let Some(ref mut reg) = *registry {
reg.before_hooks.push(Box::new(callback));
}
}
pub fn allows(ability: &str, resource: Option<&dyn Any>) -> bool {
crate::auth::Auth::id().is_some() && Self::allows_for_user_id(ability, resource)
}
pub fn denies(ability: &str, resource: Option<&dyn Any>) -> bool {
!Self::allows(ability, resource)
}
pub fn authorize(ability: &str, resource: Option<&dyn Any>) -> Result<(), AuthorizationError> {
if crate::auth::Auth::id().is_none() {
return Err(AuthorizationError::new(ability).with_status(401));
}
if Self::allows_for_user_id(ability, resource) {
Ok(())
} else {
Err(AuthorizationError::new(ability))
}
}
pub fn allows_for<U: Authenticatable>(
user: &U,
ability: &str,
resource: Option<&dyn Any>,
) -> bool {
Self::inspect(user, ability, resource).allowed()
}
pub fn authorize_for<U: Authenticatable>(
user: &U,
ability: &str,
resource: Option<&dyn Any>,
) -> Result<(), AuthorizationError> {
let response = Self::inspect(user, ability, resource);
if response.allowed() {
Ok(())
} else {
let mut error = AuthorizationError::new(ability);
if let Some(msg) = response.message() {
error.message = Some(msg.to_string());
}
error.status = response.status();
Err(error)
}
}
pub fn check_for<U: Authenticatable>(
user: &U,
ability: &str,
resource: Option<&dyn Any>,
) -> AuthResponse {
Self::inspect(user, ability, resource)
}
pub fn inspect(
user: &dyn Authenticatable,
ability: &str,
resource: Option<&dyn Any>,
) -> AuthResponse {
let registry = GATE_REGISTRY.read().unwrap();
let reg = match &*registry {
Some(r) => r,
None => return AuthResponse::deny_silent(),
};
for hook in ®.before_hooks {
if let Some(result) = hook(user, ability) {
return result.into();
}
}
if let Some(callback) = reg.abilities.get(ability) {
return callback(user, resource);
}
AuthResponse::deny_silent()
}
fn allows_for_user_id(ability: &str, _resource: Option<&dyn Any>) -> bool {
let registry = GATE_REGISTRY.read().unwrap();
let reg = match &*registry {
Some(r) => r,
None => return false,
};
if !reg.abilities.contains_key(ability) && reg.before_hooks.is_empty() {
return false;
}
false
}
pub fn has_policy_for<M: 'static>() -> bool {
let registry = GATE_REGISTRY.read().unwrap();
registry
.as_ref()
.map(|r| r.policies.contains_key(&TypeId::of::<M>()))
.unwrap_or(false)
}
#[cfg(test)]
pub fn flush() {
let mut registry = GATE_REGISTRY.write().unwrap();
*registry = Some(GateRegistry::new());
}
#[cfg(test)]
pub fn test_lock() -> std::sync::MutexGuard<'static, ()> {
use std::sync::Mutex;
static TEST_LOCK: Mutex<()> = Mutex::new(());
TEST_LOCK.lock().unwrap()
}
}
impl Gate {
pub async fn user_allows(ability: &str, resource: Option<&dyn Any>) -> bool {
match Self::resolve_user_and_check(ability, resource).await {
Ok(response) => response.allowed(),
Err(_) => false,
}
}
pub async fn user_authorize(
ability: &str,
resource: Option<&dyn Any>,
) -> Result<(), AuthorizationError> {
let response = Self::resolve_user_and_check(ability, resource).await?;
if response.allowed() {
Ok(())
} else {
let mut error = AuthorizationError::new(ability);
if let Some(msg) = response.message() {
error.message = Some(msg.to_string());
}
error.status = response.status();
Err(error)
}
}
async fn resolve_user_and_check(
ability: &str,
resource: Option<&dyn Any>,
) -> Result<AuthResponse, AuthorizationError> {
let user = crate::auth::Auth::user()
.await
.map_err(|_| AuthorizationError::new(ability).with_status(401))?
.ok_or_else(|| AuthorizationError::new(ability).with_status(401))?;
Ok(Self::inspect(user.as_ref(), ability, resource))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::any::Any;
#[derive(Debug, Clone)]
struct TestUser {
id: i64,
is_admin: bool,
}
impl Authenticatable for TestUser {
fn auth_identifier(&self) -> i64 {
self.id
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[test]
fn test_define_and_check() {
let _guard = Gate::test_lock();
Gate::flush();
Gate::define("test-ability", |user, _| {
user.as_any()
.downcast_ref::<TestUser>()
.map(|u| u.is_admin.into())
.unwrap_or_else(AuthResponse::deny_silent)
});
let admin = TestUser {
id: 1,
is_admin: true,
};
let regular = TestUser {
id: 2,
is_admin: false,
};
assert!(Gate::allows_for(&admin, "test-ability", None));
assert!(!Gate::allows_for(®ular, "test-ability", None));
}
#[test]
fn test_before_hook() {
let _guard = Gate::test_lock();
Gate::flush();
Gate::before(|user, _| {
if let Some(u) = user.as_any().downcast_ref::<TestUser>() {
if u.is_admin {
return Some(true);
}
}
None
});
Gate::define("restricted", |_, _| AuthResponse::deny("Always denied"));
let admin = TestUser {
id: 1,
is_admin: true,
};
let regular = TestUser {
id: 2,
is_admin: false,
};
assert!(Gate::allows_for(&admin, "restricted", None));
assert!(!Gate::allows_for(®ular, "restricted", None));
}
#[test]
fn test_authorize_for() {
let _guard = Gate::test_lock();
Gate::flush();
Gate::define("view-posts", |_, _| AuthResponse::allow());
Gate::define("admin-only", |user, _| {
user.as_any()
.downcast_ref::<TestUser>()
.map(|u| {
if u.is_admin {
AuthResponse::allow()
} else {
AuthResponse::deny("Admin access required")
}
})
.unwrap_or_else(AuthResponse::deny_silent)
});
let admin = TestUser {
id: 1,
is_admin: true,
};
let regular = TestUser {
id: 2,
is_admin: false,
};
assert!(Gate::authorize_for(&admin, "view-posts", None).is_ok());
assert!(Gate::authorize_for(®ular, "view-posts", None).is_ok());
assert!(Gate::authorize_for(&admin, "admin-only", None).is_ok());
let err = Gate::authorize_for(®ular, "admin-only", None).unwrap_err();
assert_eq!(err.message, Some("Admin access required".to_string()));
}
#[test]
fn test_undefined_ability() {
let _guard = Gate::test_lock();
Gate::flush();
let user = TestUser {
id: 1,
is_admin: false,
};
assert!(!Gate::allows_for(&user, "undefined-ability", None));
}
}