1use std::collections::{HashMap, HashSet};
25use std::sync::atomic::{AtomicU64, Ordering};
26use std::sync::RwLock;
27use std::time::{Duration, Instant};
28
29use super::Role;
30
31pub const DEFAULT_TTL: Duration = Duration::from_secs(60);
33
34#[derive(Debug, Clone, Hash, Eq, PartialEq)]
36pub struct ScopeKey {
37 pub tenant: Option<String>,
38 pub role: Role,
39}
40
41impl ScopeKey {
42 pub fn new(tenant: Option<&str>, role: Role) -> Self {
43 Self {
44 tenant: tenant.map(|s| s.to_string()),
45 role,
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
53struct ScopeEntry {
54 collections: HashSet<String>,
55 inserted_at: Instant,
56}
57
58#[derive(Debug, Default, Clone, Copy)]
60pub struct AuthCacheStats {
61 pub hits: u64,
62 pub misses: u64,
63 pub invalidations: u64,
64}
65
66impl AuthCacheStats {
67 pub fn hit_rate(&self) -> f64 {
70 let total = self.hits + self.misses;
71 if total == 0 {
72 0.0
73 } else {
74 self.hits as f64 / total as f64
75 }
76 }
77}
78
79#[derive(Debug, Default)]
83pub struct AuthCache {
84 entries: RwLock<HashMap<ScopeKey, ScopeEntry>>,
85 ttl: Duration,
86 hits: AtomicU64,
87 misses: AtomicU64,
88 invalidations: AtomicU64,
89}
90
91impl AuthCache {
92 pub fn new(ttl: Duration) -> Self {
93 Self {
94 entries: RwLock::new(HashMap::new()),
95 ttl,
96 hits: AtomicU64::new(0),
97 misses: AtomicU64::new(0),
98 invalidations: AtomicU64::new(0),
99 }
100 }
101
102 pub fn get(&self, key: &ScopeKey) -> Option<HashSet<String>> {
106 let guard = self.entries.read().ok()?;
107 let entry = guard.get(key)?;
108 if entry.inserted_at.elapsed() >= self.ttl {
109 self.misses.fetch_add(1, Ordering::Relaxed);
111 tracing::trace!(
112 target: "auth_cache",
113 tenant = ?key.tenant,
114 role = ?key.role,
115 "scope_cache miss (TTL expired)"
116 );
117 return None;
118 }
119 self.hits.fetch_add(1, Ordering::Relaxed);
120 tracing::trace!(
121 target: "auth_cache",
122 tenant = ?key.tenant,
123 role = ?key.role,
124 "scope_cache hit"
125 );
126 Some(entry.collections.clone())
127 }
128
129 pub fn insert(&self, key: ScopeKey, collections: HashSet<String>) {
133 self.misses.fetch_add(1, Ordering::Relaxed);
134 tracing::trace!(
135 target: "auth_cache",
136 tenant = ?key.tenant,
137 role = ?key.role,
138 n = collections.len(),
139 "scope_cache miss → insert"
140 );
141 if let Ok(mut guard) = self.entries.write() {
142 guard.insert(
143 key,
144 ScopeEntry {
145 collections,
146 inserted_at: Instant::now(),
147 },
148 );
149 }
150 }
151
152 pub fn invalidate_all(&self) {
156 if let Ok(mut guard) = self.entries.write() {
157 guard.clear();
158 }
159 self.invalidations.fetch_add(1, Ordering::Relaxed);
160 tracing::debug!(target: "auth_cache", "scope_cache invalidate_all");
161 }
162
163 pub fn invalidate_tenant(&self, tenant: Option<&str>) {
166 if let Ok(mut guard) = self.entries.write() {
167 guard.retain(|k, _| k.tenant.as_deref() != tenant);
168 }
169 self.invalidations.fetch_add(1, Ordering::Relaxed);
170 tracing::debug!(target: "auth_cache", tenant = ?tenant, "scope_cache invalidate_tenant");
171 }
172
173 pub fn stats(&self) -> AuthCacheStats {
174 AuthCacheStats {
175 hits: self.hits.load(Ordering::Relaxed),
176 misses: self.misses.load(Ordering::Relaxed),
177 invalidations: self.invalidations.load(Ordering::Relaxed),
178 }
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use std::thread::sleep;
186
187 fn key(tenant: &str, role: Role) -> ScopeKey {
188 ScopeKey::new(Some(tenant), role)
189 }
190
191 fn set(items: &[&str]) -> HashSet<String> {
192 items.iter().map(|s| s.to_string()).collect()
193 }
194
195 #[test]
196 fn miss_then_hit() {
197 let cache = AuthCache::new(DEFAULT_TTL);
198 let k = key("acme", Role::Read);
199 assert!(cache.get(&k).is_none(), "first lookup is a miss");
200 cache.insert(k.clone(), set(&["orders", "customers"]));
201 let hit = cache.get(&k).expect("post-insert hit");
202 assert_eq!(hit, set(&["orders", "customers"]));
203 let stats = cache.stats();
204 assert_eq!(stats.hits, 1);
207 assert!(stats.misses >= 1);
208 }
209
210 #[test]
211 fn ttl_evicts() {
212 let cache = AuthCache::new(Duration::from_millis(20));
213 let k = key("acme", Role::Read);
214 cache.insert(k.clone(), set(&["x"]));
215 sleep(Duration::from_millis(40));
216 assert!(
217 cache.get(&k).is_none(),
218 "TTL'd entry must be treated as a miss"
219 );
220 }
221
222 #[test]
223 fn invalidate_tenant_drops_only_matching() {
224 let cache = AuthCache::new(DEFAULT_TTL);
225 cache.insert(key("acme", Role::Read), set(&["a"]));
226 cache.insert(key("globex", Role::Read), set(&["b"]));
227 cache.invalidate_tenant(Some("acme"));
228 assert!(cache.get(&key("acme", Role::Read)).is_none());
229 assert!(cache.get(&key("globex", Role::Read)).is_some());
230 assert_eq!(cache.stats().invalidations, 1);
231 }
232
233 #[test]
234 fn invalidate_all_drops_every_entry() {
235 let cache = AuthCache::new(DEFAULT_TTL);
236 cache.insert(key("acme", Role::Read), set(&["a"]));
237 cache.insert(key("globex", Role::Write), set(&["b"]));
238 cache.invalidate_all();
239 assert!(cache.get(&key("acme", Role::Read)).is_none());
240 assert!(cache.get(&key("globex", Role::Write)).is_none());
241 }
242
243 #[test]
244 fn hit_rate_handles_zero_lookups() {
245 let cache = AuthCache::new(DEFAULT_TTL);
246 assert_eq!(cache.stats().hit_rate(), 0.0);
247 }
248}