use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
use std::time::{Duration, Instant};
use super::Role;
pub const DEFAULT_TTL: Duration = Duration::from_secs(60);
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct ScopeKey {
pub tenant: Option<String>,
pub principal: String,
pub role: Role,
}
impl ScopeKey {
pub fn new(tenant: Option<&str>, principal: &str, role: Role) -> Self {
Self {
tenant: tenant.map(|s| s.to_string()),
principal: principal.to_string(),
role,
}
}
}
#[derive(Debug, Clone)]
struct ScopeEntry {
collections: HashSet<String>,
inserted_at: Instant,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct AuthCacheStats {
pub hits: u64,
pub misses: u64,
pub invalidations: u64,
}
impl AuthCacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
#[derive(Debug, Default)]
pub struct AuthCache {
entries: RwLock<HashMap<ScopeKey, ScopeEntry>>,
ttl: Duration,
hits: AtomicU64,
misses: AtomicU64,
invalidations: AtomicU64,
}
impl AuthCache {
pub fn new(ttl: Duration) -> Self {
Self {
entries: RwLock::new(HashMap::new()),
ttl,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
invalidations: AtomicU64::new(0),
}
}
pub fn get(&self, key: &ScopeKey) -> Option<HashSet<String>> {
let guard = self.entries.read().ok()?;
let entry = guard.get(key)?;
if entry.inserted_at.elapsed() >= self.ttl {
self.misses.fetch_add(1, Ordering::Relaxed);
tracing::trace!(
target: "auth_cache",
tenant = ?key.tenant,
principal = %key.principal,
role = ?key.role,
"scope_cache miss (TTL expired)"
);
return None;
}
self.hits.fetch_add(1, Ordering::Relaxed);
tracing::trace!(
target: "auth_cache",
tenant = ?key.tenant,
principal = %key.principal,
role = ?key.role,
"scope_cache hit"
);
Some(entry.collections.clone())
}
pub fn insert(&self, key: ScopeKey, collections: HashSet<String>) {
self.misses.fetch_add(1, Ordering::Relaxed);
tracing::trace!(
target: "auth_cache",
tenant = ?key.tenant,
principal = %key.principal,
role = ?key.role,
n = collections.len(),
"scope_cache miss → insert"
);
if let Ok(mut guard) = self.entries.write() {
guard.insert(
key,
ScopeEntry {
collections,
inserted_at: Instant::now(),
},
);
}
}
pub fn invalidate_all(&self) {
if let Ok(mut guard) = self.entries.write() {
guard.clear();
}
self.invalidations.fetch_add(1, Ordering::Relaxed);
tracing::debug!(target: "auth_cache", "scope_cache invalidate_all");
}
pub fn invalidate_tenant(&self, tenant: Option<&str>) {
if let Ok(mut guard) = self.entries.write() {
guard.retain(|k, _| k.tenant.as_deref() != tenant);
}
self.invalidations.fetch_add(1, Ordering::Relaxed);
tracing::debug!(target: "auth_cache", tenant = ?tenant, "scope_cache invalidate_tenant");
}
pub fn stats(&self) -> AuthCacheStats {
AuthCacheStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
invalidations: self.invalidations.load(Ordering::Relaxed),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
fn key(tenant: &str, principal: &str, role: Role) -> ScopeKey {
ScopeKey::new(Some(tenant), principal, role)
}
fn set(items: &[&str]) -> HashSet<String> {
items.iter().map(|s| s.to_string()).collect()
}
#[test]
fn miss_then_hit() {
let cache = AuthCache::new(DEFAULT_TTL);
let k = key("acme", "alice", Role::Read);
assert!(cache.get(&k).is_none(), "first lookup is a miss");
cache.insert(k.clone(), set(&["orders", "customers"]));
let hit = cache.get(&k).expect("post-insert hit");
assert_eq!(hit, set(&["orders", "customers"]));
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert!(stats.misses >= 1);
}
#[test]
fn ttl_evicts() {
let cache = AuthCache::new(Duration::from_millis(20));
let k = key("acme", "alice", Role::Read);
cache.insert(k.clone(), set(&["x"]));
sleep(Duration::from_millis(40));
assert!(
cache.get(&k).is_none(),
"TTL'd entry must be treated as a miss"
);
}
#[test]
fn invalidate_tenant_drops_only_matching() {
let cache = AuthCache::new(DEFAULT_TTL);
cache.insert(key("acme", "alice", Role::Read), set(&["a"]));
cache.insert(key("globex", "alice", Role::Read), set(&["b"]));
cache.invalidate_tenant(Some("acme"));
assert!(cache.get(&key("acme", "alice", Role::Read)).is_none());
assert!(cache.get(&key("globex", "alice", Role::Read)).is_some());
assert_eq!(cache.stats().invalidations, 1);
}
#[test]
fn same_tenant_and_role_do_not_share_between_principals() {
let cache = AuthCache::new(DEFAULT_TTL);
cache.insert(key("acme", "alice", Role::Read), set(&["orders"]));
assert!(
cache.get(&key("acme", "bob", Role::Read)).is_none(),
"direct grants are principal-specific"
);
}
#[test]
fn invalidate_all_drops_every_entry() {
let cache = AuthCache::new(DEFAULT_TTL);
cache.insert(key("acme", "alice", Role::Read), set(&["a"]));
cache.insert(key("globex", "alice", Role::Write), set(&["b"]));
cache.invalidate_all();
assert!(cache.get(&key("acme", "alice", Role::Read)).is_none());
assert!(cache.get(&key("globex", "alice", Role::Write)).is_none());
}
#[test]
fn hit_rate_handles_zero_lookups() {
let cache = AuthCache::new(DEFAULT_TTL);
assert_eq!(cache.stats().hit_rate(), 0.0);
}
}