use cedar_policy::{Context, Entities, EntityUid};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tracing::warn;
use super::{
context::{BuildRequestContext, NoContext},
error::{AuthzDenied, AuthzError},
provider::AuthzEntityProvider,
store::{AuthzDecision, PolicyEvaluator, make_uid},
};
pub struct AuthzStore<P: AuthzEntityProvider> {
pub(super) evaluator: Arc<dyn PolicyEvaluator>,
pub(super) provider: Arc<P>,
pub(super) namespace: Arc<str>,
}
impl<P: AuthzEntityProvider> AuthzStore<P> {
pub fn new(
evaluator: Arc<dyn PolicyEvaluator>,
provider: Arc<P>,
namespace: impl Into<Arc<str>>,
) -> Self {
Self {
evaluator,
provider,
namespace: namespace.into(),
}
}
pub fn validate(&self) -> Result<(), AuthzError> {
if let Some(schema) = self.evaluator.schema() {
self.provider.validate_against_schema(schema)?;
}
Ok(())
}
pub fn user_uid(&self, id: &str) -> Result<EntityUid, AuthzError> {
make_uid(&self.namespace, "User", id)
}
pub fn role_uid(&self, name: &str) -> Result<EntityUid, AuthzError> {
make_uid(&self.namespace, "Role", name)
}
pub fn action_uid(&self, name: &str) -> Result<EntityUid, AuthzError> {
make_uid(&self.namespace, "Action", name)
}
pub fn tenant_uid(&self, id: &str) -> Result<EntityUid, AuthzError> {
make_uid(&self.namespace, "Tenant", id)
}
pub fn platform_uid(&self, id: &str) -> Result<EntityUid, AuthzError> {
make_uid(&self.namespace, "Platform", id)
}
pub fn entity_uid(&self, type_name: &str, id: &str) -> Result<EntityUid, AuthzError> {
make_uid(&self.namespace, type_name, id)
}
pub fn for_user_id(
self: &Arc<Self>,
user_id: &str,
) -> Result<AuthzSession<P, NoContext>, AuthzError> {
let principal = self.user_uid(user_id)?;
Ok(AuthzSession {
store: Arc::clone(self),
principal,
context: Context::empty(),
cache: Mutex::new(HashMap::new()),
_ctx: std::marker::PhantomData,
})
}
pub fn for_user_id_with_context<Ctx: BuildRequestContext>(
self: &Arc<Self>,
user_id: &str,
ctx: Ctx,
) -> Result<AuthzSession<P, Ctx>, AuthzError> {
let principal = self.user_uid(user_id)?;
let context = ctx.to_cedar_context()?;
Ok(AuthzSession {
store: Arc::clone(self),
principal,
context,
cache: Mutex::new(HashMap::new()),
_ctx: std::marker::PhantomData,
})
}
}
pub struct AuthzSession<P: AuthzEntityProvider, Ctx = NoContext> {
store: Arc<AuthzStore<P>>,
principal: EntityUid,
context: Context,
cache: Mutex<HashMap<(String, String), Arc<Entities>>>,
_ctx: std::marker::PhantomData<Ctx>,
}
impl<P: AuthzEntityProvider, Ctx> AuthzSession<P, Ctx> {
#[tracing::instrument(skip(self, resource))]
pub async fn require(&self, action: &str, resource: &P::ResourceId) -> Result<(), AuthzDenied> {
match self.check(action, resource).await {
AuthzDecision::Allow => Ok(()),
AuthzDecision::Deny => Err(AuthzDenied),
}
}
#[tracing::instrument(skip(self, resource))]
pub async fn is_permitted(&self, action: &str, resource: &P::ResourceId) -> bool {
matches!(self.check(action, resource).await, AuthzDecision::Allow)
}
#[tracing::instrument(skip(self, checks))]
pub async fn batch_check(
&self,
checks: &[(&str, &P::ResourceId)],
) -> Vec<(String, AuthzDecision)> {
let mut results = Vec::with_capacity(checks.len());
for (action, resource) in checks {
let decision = self.check(action, resource).await;
results.push(((*action).to_string(), decision));
}
results
}
pub fn principal(&self) -> &EntityUid {
&self.principal
}
async fn check(&self, action: &str, resource: &P::ResourceId) -> AuthzDecision {
let action_uid = match self.store.action_uid(action) {
Ok(uid) => uid,
Err(e) => {
warn!("authz: invalid action UID '{}': {e}", action);
return AuthzDecision::Deny;
}
};
let resource_uid = match self.store.provider.resource_uid(resource) {
Ok(uid) => uid,
Err(e) => {
warn!("authz: invalid resource UID: {e}");
return AuthzDecision::Deny;
}
};
let cache_key = (action_uid.to_string(), resource_uid.to_string());
let entities = {
let cached = self
.cache
.lock()
.unwrap_or_else(|e| e.into_inner())
.get(&cache_key)
.cloned();
if let Some(arc) = cached {
arc
} else {
match self
.store
.provider
.entities_for(&self.principal, resource, &action_uid)
.await
{
Ok(ent) => {
let arc = Arc::new(ent);
self.cache
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(cache_key, Arc::clone(&arc));
arc
}
Err(e) => {
warn!("authz: entity provider error: {e}");
return AuthzDecision::Deny;
}
}
}
};
self.store.evaluator.is_authorized(
&entities,
self.principal.clone(),
action_uid,
resource_uid,
self.context.clone(),
)
}
}
#[cfg(test)]
mod authz_session_tests {
use super::*;
use crate::authz::store::PolicyStore;
use cedar_policy::{Entities, EntityUid, Schema};
const SCHEMA_JSON: &str = r#"{
"TestApp": {
"entityTypes": {
"User": { "memberOfTypes": [] },
"Resource": { "memberOfTypes": [] }
},
"actions": {
"View": {
"appliesTo": {
"principalTypes": ["User"],
"resourceTypes": ["Resource"]
}
}
}
}
}"#;
const POLICY_TEXT: &str = r#"permit(
principal == TestApp::User::"alice",
action == TestApp::Action::"View",
resource == TestApp::Resource::"doc1"
);"#;
struct ErroringProvider;
impl AuthzEntityProvider for ErroringProvider {
type ResourceId = String;
type Error = std::convert::Infallible;
async fn entities_for(
&self,
principal: &EntityUid,
resource_id: &Self::ResourceId,
action: &EntityUid,
) -> Result<Entities, Self::Error> {
tracing::trace!(
target: "axess::authz::test_stub",
?principal,
?resource_id,
?action,
"ErroringProvider::entities_for: returning empty entities",
);
Ok(Entities::empty())
}
fn resource_uid(&self, id: &Self::ResourceId) -> Result<EntityUid, AuthzError> {
super::super::store::make_uid("TestApp", "Resource", id)
}
fn validate_against_schema(&self, schema: &Schema) -> Result<(), AuthzError> {
tracing::trace!(
target: "axess::authz::test_stub",
?schema,
"ErroringProvider::validate_against_schema: synthetic failure",
);
Err(AuthzError::SchemaParse(
"synthetic schema validation failure".to_string(),
))
}
}
#[test]
fn validate_propagates_provider_validation_error() {
let evaluator: Arc<dyn PolicyEvaluator> =
Arc::new(PolicyStore::from_text(POLICY_TEXT, SCHEMA_JSON).unwrap());
let store = AuthzStore::new(evaluator, Arc::new(ErroringProvider), "TestApp");
let result = store.validate();
assert!(
matches!(result, Err(AuthzError::SchemaParse(_))),
"validate() must propagate provider's validate_against_schema error, got {result:?}"
);
}
}