use async_trait::async_trait;
use parking_lot::RwLock;
use serde_json::Value;
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::{Arc, LazyLock};
use crate::context::Context;
tokio::task_local! {
pub(crate) static HANDLER_ERROR: RefCell<Option<String>>;
}
thread_local! {
static HANDLER_ERROR_SYNC: RefCell<(u32, Option<String>)> = const { RefCell::new((0, None)) };
}
pub fn report_handler_error(message: impl Into<String>) {
let msg = message.into();
if HANDLER_ERROR
.try_with(|cell| {
*cell.borrow_mut() = Some(msg.clone());
})
.is_ok()
{
return;
}
HANDLER_ERROR_SYNC.with(|cell| {
let mut state = cell.borrow_mut();
if state.0 > 0 {
state.1 = Some(msg);
}
});
}
pub(crate) fn current_handler_error() -> Option<String> {
if let Ok(v) = HANDLER_ERROR.try_with(|cell| cell.borrow().clone()) {
if v.is_some() {
return v;
}
}
HANDLER_ERROR_SYNC.with(|cell| cell.borrow().1.clone())
}
pub async fn with_handler_error_capture<F, T>(fut: F) -> (T, Option<String>)
where
F: std::future::Future<Output = T>,
{
let cell = RefCell::new(None);
let result = HANDLER_ERROR.scope(cell, async move {
let value = fut.await;
let captured = HANDLER_ERROR.with(|c| c.borrow().clone());
(value, captured)
});
result.await
}
pub fn with_handler_error_capture_sync<F, T>(f: F) -> (T, Option<String>)
where
F: FnOnce() -> T,
{
let previous = HANDLER_ERROR_SYNC.with(|cell| {
let mut state = cell.borrow_mut();
state.0 += 1;
state.1.take()
});
let value = f();
let captured = HANDLER_ERROR_SYNC.with(|cell| {
let mut state = cell.borrow_mut();
let captured = state.1.take();
state.0 -= 1;
state.1 = previous;
captured
});
(value, captured)
}
pub(crate) fn panic_message(payload: &(dyn std::any::Any + Send)) -> String {
if let Some(s) = payload.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"<non-string panic payload>".to_string()
}
}
#[async_trait]
pub trait ACLConditionHandler: Send + Sync {
async fn evaluate(&self, value: &Value, ctx: &Context<Value>) -> bool;
}
pub static CONDITION_HANDLERS: LazyLock<RwLock<HashMap<String, Arc<dyn ACLConditionHandler>>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
pub static ASYNC_CONDITION_HANDLERS: LazyLock<
RwLock<HashMap<String, Arc<dyn ACLConditionHandler>>>,
> = LazyLock::new(|| RwLock::new(HashMap::new()));
pub fn register_condition(key: impl Into<String>, handler: Arc<dyn ACLConditionHandler>) {
let mut map = CONDITION_HANDLERS.write();
map.insert(key.into(), handler);
}
pub fn register_async_condition(key: impl Into<String>, handler: Arc<dyn ACLConditionHandler>) {
let mut map = ASYNC_CONDITION_HANDLERS.write();
map.insert(key.into(), handler);
}
pub async fn evaluate_conditions_async<S: ::std::hash::BuildHasher>(
conditions: &HashMap<String, Value, S>,
ctx: &Context<Value>,
) -> bool {
let mut to_evaluate: Vec<(String, Arc<dyn ACLConditionHandler>, Value)> =
Vec::with_capacity(conditions.len());
{
let async_handlers = ASYNC_CONDITION_HANDLERS.read();
let sync_handlers = CONDITION_HANDLERS.read();
for (key, value) in conditions {
let handler = if let Some(h) = async_handlers.get(key.as_str()) {
h.clone()
} else if let Some(h) = sync_handlers.get(key.as_str()) {
h.clone()
} else {
tracing::warn!("Unknown ACL condition '{}' — treated as unsatisfied", key);
return false;
};
to_evaluate.push((key.clone(), handler, value.clone()));
}
}
for (key, handler, value) in &to_evaluate {
use futures_util::FutureExt;
let fut = std::panic::AssertUnwindSafe(handler.evaluate(value, ctx)).catch_unwind();
match fut.await {
Ok(true) => {}
Ok(false) => return false,
Err(payload) => {
let msg = panic_message(payload.as_ref());
tracing::error!(
condition = %key,
panic = %msg,
"ACL condition handler panicked — denying (fail-closed)"
);
report_handler_error(format!("{key}: handler panicked: {msg}"));
return false;
}
}
}
true
}
pub struct IdentityTypesHandler;
#[async_trait]
impl ACLConditionHandler for IdentityTypesHandler {
async fn evaluate(&self, value: &Value, ctx: &Context<Value>) -> bool {
let Some(arr) = value.as_array() else {
return false;
};
let Some(identity) = &ctx.identity else {
return false;
};
arr.iter()
.any(|v| v.as_str().is_some_and(|s| s == identity.identity_type()))
}
}
pub struct RolesHandler;
#[async_trait]
impl ACLConditionHandler for RolesHandler {
async fn evaluate(&self, value: &Value, ctx: &Context<Value>) -> bool {
let Some(arr) = value.as_array() else {
return false;
};
let Some(identity) = &ctx.identity else {
return false;
};
arr.iter().any(|v| {
v.as_str()
.is_some_and(|s| identity.roles().contains(&s.to_string()))
})
}
}
pub struct MaxCallDepthHandler;
#[async_trait]
impl ACLConditionHandler for MaxCallDepthHandler {
async fn evaluate(&self, value: &Value, ctx: &Context<Value>) -> bool {
let threshold = match value {
Value::Number(n) => n.as_u64(),
Value::Object(map) => map.get("lte").and_then(serde_json::Value::as_u64),
_ => None,
};
match threshold {
Some(max) => (ctx.call_chain.len() as u64) <= max,
None => false,
}
}
}
pub(crate) struct OrHandler;
impl OrHandler {
pub(crate) fn new() -> Self {
Self
}
}
#[async_trait]
impl ACLConditionHandler for OrHandler {
async fn evaluate(&self, value: &Value, ctx: &Context<Value>) -> bool {
let Some(arr) = value.as_array() else {
return false;
};
for sub in arr {
if let Some(obj) = sub.as_object() {
let map: HashMap<String, Value> =
obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
if evaluate_conditions_async(&map, ctx).await {
return true;
}
}
}
false
}
}
pub(crate) struct NotHandler;
impl NotHandler {
pub(crate) fn new() -> Self {
Self
}
}
#[async_trait]
impl ACLConditionHandler for NotHandler {
async fn evaluate(&self, value: &Value, ctx: &Context<Value>) -> bool {
match value.as_object() {
Some(obj) => {
let map: HashMap<String, Value> =
obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
!evaluate_conditions_async(&map, ctx).await
}
None => false,
}
}
}
pub fn register_builtin_handlers() {
register_condition("identity_types", Arc::new(IdentityTypesHandler));
register_condition("roles", Arc::new(RolesHandler));
register_condition("max_call_depth", Arc::new(MaxCallDepthHandler));
register_condition("$or", Arc::new(OrHandler::new()));
register_condition("$not", Arc::new(NotHandler::new()));
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::{Context, Identity};
fn make_ctx(identity_type: &str, roles: Vec<&str>, call_depth: usize) -> Context<Value> {
let identity = Identity::new(
"test-id".to_string(),
identity_type.to_string(),
roles.into_iter().map(String::from).collect(),
HashMap::new(),
);
let mut ctx = Context::new(identity);
for i in 0..call_depth {
ctx.call_chain.push(format!("module.{i}"));
}
ctx
}
fn anon_ctx() -> Context<Value> {
Context::<Value>::anonymous()
}
#[tokio::test]
async fn identity_types_matches_correct_type() {
let handler = IdentityTypesHandler;
let ctx = make_ctx("user", vec![], 0);
let value = serde_json::json!(["user", "service"]);
assert!(handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn identity_types_rejects_wrong_type() {
let handler = IdentityTypesHandler;
let ctx = make_ctx("agent", vec![], 0);
let value = serde_json::json!(["user", "service"]);
assert!(!handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn identity_types_rejects_non_array_value() {
let handler = IdentityTypesHandler;
let ctx = make_ctx("user", vec![], 0);
let value = serde_json::json!("user"); assert!(!handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn identity_types_rejects_no_identity() {
let handler = IdentityTypesHandler;
let ctx = anon_ctx();
let value = serde_json::json!(["user"]);
assert!(!handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn roles_matches_overlapping_role() {
let handler = RolesHandler;
let ctx = make_ctx("user", vec!["admin", "viewer"], 0);
let value = serde_json::json!(["admin"]);
assert!(handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn roles_rejects_no_overlap() {
let handler = RolesHandler;
let ctx = make_ctx("user", vec!["viewer"], 0);
let value = serde_json::json!(["admin"]);
assert!(!handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn roles_rejects_no_identity() {
let handler = RolesHandler;
let ctx = anon_ctx();
let value = serde_json::json!(["admin"]);
assert!(!handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn max_call_depth_allows_under_limit() {
let handler = MaxCallDepthHandler;
let ctx = make_ctx("user", vec![], 3);
let value = serde_json::json!(5u64);
assert!(handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn max_call_depth_allows_at_limit() {
let handler = MaxCallDepthHandler;
let ctx = make_ctx("user", vec![], 5);
let value = serde_json::json!(5u64);
assert!(handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn max_call_depth_rejects_over_limit() {
let handler = MaxCallDepthHandler;
let ctx = make_ctx("user", vec![], 6);
let value = serde_json::json!(5u64);
assert!(!handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn max_call_depth_rejects_non_numeric_value() {
let handler = MaxCallDepthHandler;
let ctx = make_ctx("user", vec![], 0);
let value = serde_json::json!("five"); assert!(!handler.evaluate(&value, &ctx).await);
}
struct PassHandler;
#[async_trait]
impl ACLConditionHandler for PassHandler {
async fn evaluate(&self, value: &Value, _ctx: &Context<Value>) -> bool {
value.as_bool().unwrap_or(false)
}
}
fn setup_compound_test_handlers() {
register_condition("pass", Arc::new(PassHandler));
register_builtin_handlers();
}
#[tokio::test]
async fn or_handler_true_if_any_sub_passes() {
setup_compound_test_handlers();
let handler = OrHandler::new();
let ctx = anon_ctx();
let value = serde_json::json!([
{"pass": false},
{"pass": true},
]);
assert!(handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn or_handler_false_if_none_pass() {
setup_compound_test_handlers();
let handler = OrHandler::new();
let ctx = anon_ctx();
let value = serde_json::json!([
{"pass": false},
{"pass": false},
]);
assert!(!handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn or_handler_rejects_non_array_value() {
let handler = OrHandler::new();
let ctx = anon_ctx();
let value = serde_json::json!({"pass": true}); assert!(!handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn not_handler_inverts_passing_condition() {
setup_compound_test_handlers();
let handler = NotHandler::new();
let ctx = anon_ctx();
let value = serde_json::json!({"pass": true});
assert!(!handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn not_handler_inverts_failing_condition() {
setup_compound_test_handlers();
let handler = NotHandler::new();
let ctx = anon_ctx();
let value = serde_json::json!({"pass": false});
assert!(handler.evaluate(&value, &ctx).await);
}
#[tokio::test]
async fn not_handler_rejects_non_object_value() {
let handler = NotHandler::new();
let ctx = anon_ctx();
let value = serde_json::json!([{"pass": true}]); assert!(!handler.evaluate(&value, &ctx).await);
}
#[test]
fn register_condition_stores_and_overwrites() {
register_condition("_test_handler", Arc::new(MaxCallDepthHandler));
register_condition("_test_handler", Arc::new(MaxCallDepthHandler));
let map = CONDITION_HANDLERS.read();
assert!(map.contains_key("_test_handler"));
}
struct AsyncOnlyTrue;
#[async_trait]
impl ACLConditionHandler for AsyncOnlyTrue {
async fn evaluate(&self, _value: &Value, _ctx: &Context<Value>) -> bool {
true
}
}
#[tokio::test]
async fn register_async_condition_uses_separate_registry() {
struct SyncDeny;
#[async_trait]
impl ACLConditionHandler for SyncDeny {
async fn evaluate(&self, _value: &Value, _ctx: &Context<Value>) -> bool {
false
}
}
let key = "_test_async_vs_sync";
register_condition(key, Arc::new(SyncDeny));
register_async_condition(key, Arc::new(AsyncOnlyTrue));
let mut conditions: HashMap<String, Value> = HashMap::new();
conditions.insert(key.to_string(), Value::Null);
let ctx = anon_ctx();
assert!(evaluate_conditions_async(&conditions, &ctx).await);
let sync_map = CONDITION_HANDLERS.read();
assert!(sync_map.contains_key(key));
let async_map = ASYNC_CONDITION_HANDLERS.read();
assert!(async_map.contains_key(key));
}
#[tokio::test]
async fn async_check_falls_back_to_sync_registry_when_no_async_handler() {
let key = "_test_async_fallback";
register_condition(key, Arc::new(AsyncOnlyTrue));
let mut conditions: HashMap<String, Value> = HashMap::new();
conditions.insert(key.to_string(), Value::Null);
let ctx = anon_ctx();
assert!(evaluate_conditions_async(&conditions, &ctx).await);
}
#[tokio::test]
async fn handler_error_capture_returns_reported_message() {
let (decision, captured) = with_handler_error_capture(async {
report_handler_error("simulated handler failure");
false
})
.await;
assert!(!decision);
assert_eq!(captured.as_deref(), Some("simulated handler failure"));
}
#[tokio::test]
async fn handler_error_capture_isolates_per_scope() {
let ((), first) = with_handler_error_capture(async {
report_handler_error("first call");
})
.await;
let ((), second) = with_handler_error_capture(async {
})
.await;
assert_eq!(first.as_deref(), Some("first call"));
assert!(second.is_none());
}
#[test]
fn report_handler_error_outside_scope_is_noop() {
report_handler_error("dropped on the floor");
}
}