use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::Duration;
use cedar_policy::{Entities, EntityUid};
use axess_cache::ClockTtlCache;
use axess_clock::{Clock, SystemClock};
use crate::authz::error::AuthzError;
use crate::authz::provider::RequestEntityProvider;
use crate::session::AuthSession;
const DEFAULT_CAPACITY: usize = 10_000;
const DEFAULT_TTL_SECS: u64 = 60;
#[derive(Hash, Eq, PartialEq, Clone)]
struct EntityCacheKey {
principal: EntityUid,
tenant: Option<String>,
resource: EntityUid,
action: EntityUid,
}
pub struct EntityCache<P>
where
P: RequestEntityProvider,
{
inner: P,
cache: ClockTtlCache<EntityCacheKey, Arc<Entities>>,
}
impl<P> EntityCache<P>
where
P: RequestEntityProvider,
{
pub fn new(inner: P) -> Self {
Self::with_options(
inner,
DEFAULT_CAPACITY,
Duration::from_secs(DEFAULT_TTL_SECS),
Arc::new(SystemClock) as Arc<dyn Clock>,
)
}
pub fn with_options(inner: P, capacity: usize, ttl: Duration, clock: Arc<dyn Clock>) -> Self {
let cap = NonZeroUsize::new(capacity.max(1)).expect("capacity is at least 1");
Self {
inner,
cache: ClockTtlCache::new(cap, ttl, clock),
}
}
pub fn with_capacity(self, capacity: usize) -> Self {
let ttl = Duration::from_secs(DEFAULT_TTL_SECS);
Self::with_options(self.inner, capacity, ttl, Arc::new(SystemClock))
}
pub fn with_ttl(self, ttl: Duration) -> Self {
Self::with_options(self.inner, DEFAULT_CAPACITY, ttl, Arc::new(SystemClock))
}
pub fn with_clock(self, clock: Arc<dyn Clock>) -> Self {
Self::with_options(
self.inner,
DEFAULT_CAPACITY,
Duration::from_secs(DEFAULT_TTL_SECS),
clock,
)
}
pub fn invalidate(
&self,
principal: &EntityUid,
tenant: Option<&str>,
resource: &EntityUid,
action: &EntityUid,
) {
let key = EntityCacheKey {
principal: principal.clone(),
tenant: tenant.map(str::to_string),
resource: resource.clone(),
action: action.clone(),
};
self.cache.invalidate(&key);
}
pub fn invalidate_all(&self) {
self.cache.invalidate_all();
}
pub fn invalidate_principal(&self, principal: &EntityUid) -> usize {
self.cache.invalidate_by(|k| &k.principal == principal)
}
pub fn invalidate_tenant(&self, tenant: &str) -> usize {
self.cache
.invalidate_by(|k| k.tenant.as_deref() == Some(tenant))
}
pub fn inner(&self) -> &P {
&self.inner
}
pub fn stats(&self) -> axess_cache::CacheStats {
self.cache.stats()
}
pub fn reset_stats(&self) {
self.cache.reset_stats();
}
pub fn flush_metrics(&self, metrics: &dyn crate::metrics::AuthnMetrics) {
let snapshot = self.stats();
for _ in 0..snapshot.hits {
metrics.authz_cache_hit();
}
for _ in 0..snapshot.misses {
metrics.authz_cache_miss();
}
for _ in 0..snapshot.capacity_evictions {
metrics.authz_cache_eviction();
}
for _ in 0..snapshot.invalidations {
metrics.authz_cache_invalidation();
}
self.reset_stats();
}
}
impl<P> super::invalidator::CacheInvalidator for EntityCache<P>
where
P: RequestEntityProvider + 'static,
{
type Error = std::convert::Infallible;
async fn invalidate_principal(&self, principal: &EntityUid) -> Result<(), Self::Error> {
let _ = EntityCache::invalidate_principal(self, principal);
Ok(())
}
async fn invalidate_tenant(&self, tenant: &str) -> Result<(), Self::Error> {
let _ = EntityCache::invalidate_tenant(self, tenant);
Ok(())
}
async fn invalidate_all(&self) -> Result<(), Self::Error> {
EntityCache::invalidate_all(self);
Ok(())
}
}
impl<P> RequestEntityProvider for EntityCache<P>
where
P: RequestEntityProvider,
{
fn entities_for<'a>(
&'a self,
session: &'a AuthSession,
principal: &'a EntityUid,
resource: &'a EntityUid,
action: &'a EntityUid,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Entities, AuthzError>> + Send + 'a>,
> {
Box::pin(async move {
let tenant = session.tenant_id().await.map(|t| t.to_string().to_string());
let key = EntityCacheKey {
principal: principal.clone(),
tenant,
resource: resource.clone(),
action: action.clone(),
};
let arc = self
.cache
.get_or_try_insert_with(key, || async {
let entities = self
.inner
.entities_for(session, principal, resource, action)
.await?;
Ok::<Arc<Entities>, AuthzError>(Arc::new(entities))
})
.await?;
Ok((*arc).clone())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use axess_clock::testing::MockClock;
use std::collections::HashSet;
use std::str::FromStr;
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountingProvider {
calls: Arc<AtomicUsize>,
}
impl RequestEntityProvider for CountingProvider {
fn entities_for<'a>(
&'a self,
session: &'a AuthSession,
principal: &'a EntityUid,
resource: &'a EntityUid,
action: &'a EntityUid,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Entities, AuthzError>> + Send + 'a>,
> {
let _ = (session, action);
let calls = self.calls.clone();
let principal = principal.clone();
let resource = resource.clone();
Box::pin(async move {
calls.fetch_add(1, Ordering::SeqCst);
let p = cedar_policy::Entity::new(
principal,
std::collections::HashMap::new(),
HashSet::new(),
)
.unwrap();
let r = cedar_policy::Entity::new(
resource,
std::collections::HashMap::new(),
HashSet::new(),
)
.unwrap();
Ok(Entities::from_entities(vec![p, r], None).unwrap())
})
}
}
fn guest_session() -> AuthSession {
use crate::session::SessionData;
use crate::session::id::SessionId;
use crate::session::layer::{SessionHandle, SessionInner};
use tokio::sync::RwLock;
let inner = SessionInner {
id: SessionId::new(&axess_rng::SystemRng),
data: SessionData::default(),
modified: false,
regenerate: false,
pre_cycle_id: None,
pending_fingerprint: None,
max_custom_bytes: 64 * 1024,
};
AuthSession(SessionHandle(Arc::new(RwLock::new(inner))))
}
fn principal() -> EntityUid {
EntityUid::from_str("App::User::\"alice\"").unwrap()
}
fn action() -> EntityUid {
EntityUid::from_str("App::Action::\"View\"").unwrap()
}
fn doc(id: &str) -> EntityUid {
EntityUid::from_str(&format!("App::Doc::\"{id}\"")).unwrap()
}
#[tokio::test]
async fn first_call_misses_then_caches() {
let calls = Arc::new(AtomicUsize::new(0));
let cached = EntityCache::new(CountingProvider {
calls: calls.clone(),
});
let s = guest_session();
let p = principal();
let a = action();
let r1 = doc("doc-1");
let _ = cached.entities_for(&s, &p, &r1, &a).await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1);
let _ = cached.entities_for(&s, &p, &r1, &a).await.unwrap();
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"cache hit should not invoke inner"
);
let r2 = doc("doc-2");
let _ = cached.entities_for(&s, &p, &r2, &a).await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn invalidate_evicts_cached_entry() {
let calls = Arc::new(AtomicUsize::new(0));
let cached = EntityCache::new(CountingProvider {
calls: calls.clone(),
});
let s = guest_session();
let p = principal();
let a = action();
let r = doc("doc-1");
let _ = cached.entities_for(&s, &p, &r, &a).await.unwrap();
let _ = cached.entities_for(&s, &p, &r, &a).await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1);
cached.invalidate(&p, None, &r, &a);
let _ = cached.entities_for(&s, &p, &r, &a).await.unwrap();
assert_eq!(
calls.load(Ordering::SeqCst),
2,
"after invalidate, next call should re-invoke inner"
);
}
#[tokio::test]
async fn concurrent_cold_misses_share_one_inner_call() {
let calls = Arc::new(AtomicUsize::new(0));
let cached = Arc::new(EntityCache::new(CountingProvider {
calls: calls.clone(),
}));
let p = principal();
let a = action();
let r = doc("doc-1");
const N: usize = 8;
let mut handles = Vec::with_capacity(N);
for _ in 0..N {
let cached = cached.clone();
let p = p.clone();
let a = a.clone();
let r = r.clone();
let s = guest_session();
handles.push(tokio::spawn(async move {
cached.entities_for(&s, &p, &r, &a).await.map(|_| ())
}));
}
for h in handles {
h.await.unwrap().unwrap();
}
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"single-flight must collapse N concurrent cold misses into 1 inner call"
);
}
#[tokio::test]
async fn entries_expire_under_injected_clock() {
let clock = Arc::new(MockClock::now());
let calls = Arc::new(AtomicUsize::new(0));
let cached = EntityCache::with_options(
CountingProvider {
calls: calls.clone(),
},
DEFAULT_CAPACITY,
Duration::from_secs(60),
clock.clone() as Arc<dyn Clock>,
);
let s = guest_session();
let p = principal();
let a = action();
let r = doc("doc-1");
let _ = cached.entities_for(&s, &p, &r, &a).await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1);
clock.advance_secs(30);
let _ = cached.entities_for(&s, &p, &r, &a).await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1, "still inside TTL");
clock.advance_secs(31);
let _ = cached.entities_for(&s, &p, &r, &a).await.unwrap();
assert_eq!(
calls.load(Ordering::SeqCst),
2,
"TTL expired under MockClock; must re-fetch from inner"
);
}
}