use std::sync::Arc;
use std::time::Duration;
use cedar_policy::{Entities, EntityUid};
use moka::future::Cache;
use crate::authz::error::AuthzError;
use crate::authz::provider::RequestEntityProvider;
use crate::session::AuthSession;
#[derive(Hash, Eq, PartialEq, Clone)]
struct EntityCacheKey {
principal: EntityUid,
tenant: Option<String>,
resource: EntityUid,
action: EntityUid,
}
pub struct MokaEntityCache<P>
where
P: RequestEntityProvider,
{
inner: P,
cache: Cache<EntityCacheKey, Arc<Entities>>,
}
impl<P> MokaEntityCache<P>
where
P: RequestEntityProvider,
{
pub fn new(inner: P) -> Self {
Self {
inner,
cache: Cache::builder()
.max_capacity(10_000)
.time_to_live(Duration::from_secs(60))
.build(),
}
}
pub fn with_options(inner: P, capacity: u64, ttl: Duration) -> Self {
Self {
inner,
cache: Cache::builder()
.max_capacity(capacity)
.time_to_live(ttl)
.build(),
}
}
pub fn with_capacity(self, capacity: u64) -> Self {
let ttl = Duration::from_secs(60);
Self::with_options(self.inner, capacity, ttl)
}
pub fn with_ttl(self, ttl: Duration) -> Self {
let capacity = 10_000;
Self::with_options(self.inner, capacity, ttl)
}
pub async fn invalidate(
&self,
principal: &EntityUid,
tenant: Option<&str>,
resource: &EntityUid,
action: &EntityUid,
) {
let key = EntityCacheKey {
principal: principal.clone(),
tenant: tenant.map(str::to_string),
resource: resource.clone(),
action: action.clone(),
};
self.cache.invalidate(&key).await;
}
pub fn invalidate_all(&self) {
self.cache.invalidate_all();
}
pub fn invalidate_principal(&self, principal: &EntityUid) -> Result<(), moka::PredicateError> {
let principal = principal.clone();
self.cache
.invalidate_entries_if(move |k, _v| k.principal == principal)?;
Ok(())
}
pub fn invalidate_tenant(&self, tenant: &str) -> Result<(), moka::PredicateError> {
let tenant = tenant.to_string();
self.cache
.invalidate_entries_if(move |k, _v| k.tenant.as_deref() == Some(tenant.as_str()))?;
Ok(())
}
pub fn inner(&self) -> &P {
&self.inner
}
}
impl<P> super::invalidator::CacheInvalidator for MokaEntityCache<P>
where
P: RequestEntityProvider + 'static,
{
type Error = moka::PredicateError;
async fn invalidate_principal(&self, principal: &EntityUid) -> Result<(), Self::Error> {
MokaEntityCache::invalidate_principal(self, principal)
}
async fn invalidate_tenant(&self, tenant: &str) -> Result<(), Self::Error> {
MokaEntityCache::invalidate_tenant(self, tenant)
}
async fn invalidate_all(&self) -> Result<(), Self::Error> {
MokaEntityCache::invalidate_all(self);
Ok(())
}
}
impl<P> RequestEntityProvider for MokaEntityCache<P>
where
P: RequestEntityProvider,
{
fn entities_for<'a>(
&'a self,
session: &'a AuthSession,
principal: &'a EntityUid,
resource: &'a EntityUid,
action: &'a EntityUid,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Entities, AuthzError>> + Send + 'a>,
> {
Box::pin(async move {
let tenant = session.tenant_id().await.map(|t| t.to_string().to_string());
let key = EntityCacheKey {
principal: principal.clone(),
tenant,
resource: resource.clone(),
action: action.clone(),
};
if let Some(cached) = self.cache.get(&key).await {
return Ok((*cached).clone());
}
let entities = self
.inner
.entities_for(session, principal, resource, action)
.await?;
self.cache.insert(key, Arc::new(entities.clone())).await;
Ok(entities)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::str::FromStr;
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountingProvider {
calls: Arc<AtomicUsize>,
}
impl RequestEntityProvider for CountingProvider {
fn entities_for<'a>(
&'a self,
session: &'a AuthSession,
principal: &'a EntityUid,
resource: &'a EntityUid,
action: &'a EntityUid,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Entities, AuthzError>> + Send + 'a>,
> {
let _ = (session, action);
let calls = self.calls.clone();
let principal = principal.clone();
let resource = resource.clone();
Box::pin(async move {
calls.fetch_add(1, Ordering::SeqCst);
let p = cedar_policy::Entity::new(
principal,
std::collections::HashMap::new(),
HashSet::new(),
)
.unwrap();
let r = cedar_policy::Entity::new(
resource,
std::collections::HashMap::new(),
HashSet::new(),
)
.unwrap();
Ok(Entities::from_entities(vec![p, r], None).unwrap())
})
}
}
fn guest_session() -> AuthSession {
use crate::session::SessionData;
use crate::session::id::SessionId;
use crate::session::layer::{SessionHandle, SessionInner};
use tokio::sync::RwLock;
let inner = SessionInner {
id: SessionId::new(&axess_rng::SystemRng),
data: SessionData::default(),
modified: false,
regenerate: false,
pre_cycle_id: None,
pending_fingerprint: None,
max_custom_bytes: 64 * 1024,
};
AuthSession(SessionHandle(Arc::new(RwLock::new(inner))))
}
fn principal() -> EntityUid {
EntityUid::from_str("App::User::\"alice\"").unwrap()
}
fn action() -> EntityUid {
EntityUid::from_str("App::Action::\"View\"").unwrap()
}
fn doc(id: &str) -> EntityUid {
EntityUid::from_str(&format!("App::Doc::\"{id}\"")).unwrap()
}
#[tokio::test]
async fn first_call_misses_then_caches() {
let calls = Arc::new(AtomicUsize::new(0));
let cached = MokaEntityCache::new(CountingProvider {
calls: calls.clone(),
});
let s = guest_session();
let p = principal();
let a = action();
let r1 = doc("doc-1");
let _ = cached.entities_for(&s, &p, &r1, &a).await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1);
let _ = cached.entities_for(&s, &p, &r1, &a).await.unwrap();
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"cache hit should not invoke inner"
);
let r2 = doc("doc-2");
let _ = cached.entities_for(&s, &p, &r2, &a).await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn invalidate_evicts_cached_entry() {
let calls = Arc::new(AtomicUsize::new(0));
let cached = MokaEntityCache::new(CountingProvider {
calls: calls.clone(),
});
let s = guest_session();
let p = principal();
let a = action();
let r = doc("doc-1");
let _ = cached.entities_for(&s, &p, &r, &a).await.unwrap();
let _ = cached.entities_for(&s, &p, &r, &a).await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1);
cached.invalidate(&p, None, &r, &a).await;
let _ = cached.entities_for(&s, &p, &r, &a).await.unwrap();
assert_eq!(
calls.load(Ordering::SeqCst),
2,
"after invalidate, next call should re-invoke inner"
);
}
#[tokio::test]
async fn invalidate_all_evicts_every_entry() {
let calls = Arc::new(AtomicUsize::new(0));
let cached = MokaEntityCache::new(CountingProvider {
calls: calls.clone(),
});
let s = guest_session();
let p = principal();
let a = action();
let r1 = doc("doc-1");
let r2 = doc("doc-2");
let _ = cached.entities_for(&s, &p, &r1, &a).await.unwrap();
let _ = cached.entities_for(&s, &p, &r2, &a).await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 2);
let _ = cached.entities_for(&s, &p, &r1, &a).await.unwrap();
let _ = cached.entities_for(&s, &p, &r2, &a).await.unwrap();
assert_eq!(
calls.load(Ordering::SeqCst),
2,
"warmed entries must hit cache"
);
cached.invalidate_all();
cached.cache.run_pending_tasks().await;
let _ = cached.entities_for(&s, &p, &r1, &a).await.unwrap();
let _ = cached.entities_for(&s, &p, &r2, &a).await.unwrap();
assert_eq!(
calls.load(Ordering::SeqCst),
4,
"after invalidate_all, every entry must re-invoke inner; \
killing the `()` no-op mutation on invalidate_all"
);
}
}