1use std::collections::{HashMap, HashSet};
25use std::sync::atomic::{AtomicU64, Ordering};
26use std::sync::RwLock;
27use std::time::{Duration, Instant};
28
29use super::Role;
30
31pub const DEFAULT_TTL: Duration = Duration::from_secs(60);
33
34#[derive(Debug, Clone, Hash, Eq, PartialEq)]
36pub struct ScopeKey {
37 pub tenant: Option<String>,
38 pub principal: String,
39 pub role: Role,
40}
41
42impl ScopeKey {
43 pub fn new(tenant: Option<&str>, principal: &str, role: Role) -> Self {
44 Self {
45 tenant: tenant.map(|s| s.to_string()),
46 principal: principal.to_string(),
47 role,
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
55struct ScopeEntry {
56 collections: HashSet<String>,
57 inserted_at: Instant,
58}
59
60#[derive(Debug, Default, Clone, Copy)]
62pub struct AuthCacheStats {
63 pub hits: u64,
64 pub misses: u64,
65 pub invalidations: u64,
66}
67
68impl AuthCacheStats {
69 pub fn hit_rate(&self) -> f64 {
72 let total = self.hits + self.misses;
73 if total == 0 {
74 0.0
75 } else {
76 self.hits as f64 / total as f64
77 }
78 }
79}
80
81#[derive(Debug, Default)]
85pub struct AuthCache {
86 entries: RwLock<HashMap<ScopeKey, ScopeEntry>>,
87 ttl: Duration,
88 hits: AtomicU64,
89 misses: AtomicU64,
90 invalidations: AtomicU64,
91}
92
93impl AuthCache {
94 pub fn new(ttl: Duration) -> Self {
95 Self {
96 entries: RwLock::new(HashMap::new()),
97 ttl,
98 hits: AtomicU64::new(0),
99 misses: AtomicU64::new(0),
100 invalidations: AtomicU64::new(0),
101 }
102 }
103
104 pub fn get(&self, key: &ScopeKey) -> Option<HashSet<String>> {
108 let guard = self.entries.read().ok()?;
109 let entry = guard.get(key)?;
110 if entry.inserted_at.elapsed() >= self.ttl {
111 self.misses.fetch_add(1, Ordering::Relaxed);
113 tracing::trace!(
114 target: "auth_cache",
115 tenant = ?key.tenant,
116 principal = %key.principal,
117 role = ?key.role,
118 "scope_cache miss (TTL expired)"
119 );
120 return None;
121 }
122 self.hits.fetch_add(1, Ordering::Relaxed);
123 tracing::trace!(
124 target: "auth_cache",
125 tenant = ?key.tenant,
126 principal = %key.principal,
127 role = ?key.role,
128 "scope_cache hit"
129 );
130 Some(entry.collections.clone())
131 }
132
133 pub fn insert(&self, key: ScopeKey, collections: HashSet<String>) {
137 self.misses.fetch_add(1, Ordering::Relaxed);
138 tracing::trace!(
139 target: "auth_cache",
140 tenant = ?key.tenant,
141 principal = %key.principal,
142 role = ?key.role,
143 n = collections.len(),
144 "scope_cache miss → insert"
145 );
146 if let Ok(mut guard) = self.entries.write() {
147 guard.insert(
148 key,
149 ScopeEntry {
150 collections,
151 inserted_at: Instant::now(),
152 },
153 );
154 }
155 }
156
157 pub fn invalidate_all(&self) {
161 if let Ok(mut guard) = self.entries.write() {
162 guard.clear();
163 }
164 self.invalidations.fetch_add(1, Ordering::Relaxed);
165 tracing::debug!(target: "auth_cache", "scope_cache invalidate_all");
166 }
167
168 pub fn invalidate_tenant(&self, tenant: Option<&str>) {
171 if let Ok(mut guard) = self.entries.write() {
172 guard.retain(|k, _| k.tenant.as_deref() != tenant);
173 }
174 self.invalidations.fetch_add(1, Ordering::Relaxed);
175 tracing::debug!(target: "auth_cache", tenant = ?tenant, "scope_cache invalidate_tenant");
176 }
177
178 pub fn stats(&self) -> AuthCacheStats {
179 AuthCacheStats {
180 hits: self.hits.load(Ordering::Relaxed),
181 misses: self.misses.load(Ordering::Relaxed),
182 invalidations: self.invalidations.load(Ordering::Relaxed),
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use std::thread::sleep;
191
192 fn key(tenant: &str, principal: &str, role: Role) -> ScopeKey {
193 ScopeKey::new(Some(tenant), principal, role)
194 }
195
196 fn set(items: &[&str]) -> HashSet<String> {
197 items.iter().map(|s| s.to_string()).collect()
198 }
199
200 #[test]
201 fn miss_then_hit() {
202 let cache = AuthCache::new(DEFAULT_TTL);
203 let k = key("acme", "alice", Role::Read);
204 assert!(cache.get(&k).is_none(), "first lookup is a miss");
205 cache.insert(k.clone(), set(&["orders", "customers"]));
206 let hit = cache.get(&k).expect("post-insert hit");
207 assert_eq!(hit, set(&["orders", "customers"]));
208 let stats = cache.stats();
209 assert_eq!(stats.hits, 1);
212 assert!(stats.misses >= 1);
213 }
214
215 #[test]
216 fn ttl_evicts() {
217 let cache = AuthCache::new(Duration::from_millis(20));
218 let k = key("acme", "alice", Role::Read);
219 cache.insert(k.clone(), set(&["x"]));
220 sleep(Duration::from_millis(40));
221 assert!(
222 cache.get(&k).is_none(),
223 "TTL'd entry must be treated as a miss"
224 );
225 }
226
227 #[test]
228 fn invalidate_tenant_drops_only_matching() {
229 let cache = AuthCache::new(DEFAULT_TTL);
230 cache.insert(key("acme", "alice", Role::Read), set(&["a"]));
231 cache.insert(key("globex", "alice", Role::Read), set(&["b"]));
232 cache.invalidate_tenant(Some("acme"));
233 assert!(cache.get(&key("acme", "alice", Role::Read)).is_none());
234 assert!(cache.get(&key("globex", "alice", Role::Read)).is_some());
235 assert_eq!(cache.stats().invalidations, 1);
236 }
237
238 #[test]
239 fn same_tenant_and_role_do_not_share_between_principals() {
240 let cache = AuthCache::new(DEFAULT_TTL);
241 cache.insert(key("acme", "alice", Role::Read), set(&["orders"]));
242 assert!(
243 cache.get(&key("acme", "bob", Role::Read)).is_none(),
244 "direct grants are principal-specific"
245 );
246 }
247
248 #[test]
249 fn invalidate_all_drops_every_entry() {
250 let cache = AuthCache::new(DEFAULT_TTL);
251 cache.insert(key("acme", "alice", Role::Read), set(&["a"]));
252 cache.insert(key("globex", "alice", Role::Write), set(&["b"]));
253 cache.invalidate_all();
254 assert!(cache.get(&key("acme", "alice", Role::Read)).is_none());
255 assert!(cache.get(&key("globex", "alice", Role::Write)).is_none());
256 }
257
258 #[test]
259 fn hit_rate_handles_zero_lookups() {
260 let cache = AuthCache::new(DEFAULT_TTL);
261 assert_eq!(cache.stats().hit_rate(), 0.0);
262 }
263}