1use std::future::Future;
60use std::num::NonZeroUsize;
61use std::sync::Arc;
62use std::time::Duration;
63
64use chrono::{DateTime, Utc};
65
66use axess_cache::ClockTtlCache;
67use axess_clock::{Clock, SystemClock};
68
69use crate::authn::ids::{DeviceId, TenantId, UserId};
70use crate::device::store::{DeviceStore, SweepCounts};
71use crate::device::types::{Device, DeviceTrustLevel, FingerprintHash};
72
73const DEFAULT_CAPACITY: usize = 10_000;
75
76const DEFAULT_TTL_SECS: u64 = 60;
84
85type CacheKey = (TenantId, DeviceId);
89
90pub struct CachedDeviceStore<S>
96where
97 S: DeviceStore,
98{
99 inner: S,
100 cache: Arc<ClockTtlCache<CacheKey, Device>>,
101}
102
103impl<S> Clone for CachedDeviceStore<S>
104where
105 S: DeviceStore,
106{
107 fn clone(&self) -> Self {
108 Self {
109 inner: self.inner.clone(),
110 cache: self.cache.clone(),
111 }
112 }
113}
114
115impl<S> CachedDeviceStore<S>
116where
117 S: DeviceStore,
118{
119 pub fn new(inner: S) -> Self {
122 Self::with_options(
123 inner,
124 DEFAULT_CAPACITY,
125 Duration::from_secs(DEFAULT_TTL_SECS),
126 Arc::new(SystemClock),
127 )
128 }
129
130 pub fn with_options(inner: S, capacity: usize, ttl: Duration, clock: Arc<dyn Clock>) -> Self {
132 let capacity = NonZeroUsize::new(capacity.max(1)).expect("capacity >= 1");
133 let cache = Arc::new(ClockTtlCache::new(capacity, ttl, clock));
134 Self { inner, cache }
135 }
136
137 pub fn with_capacity(mut self, capacity: usize) -> Self {
139 let cap = NonZeroUsize::new(capacity.max(1)).expect("capacity >= 1");
140 let ttl = Duration::from_secs(DEFAULT_TTL_SECS);
144 self.cache = Arc::new(ClockTtlCache::new(
145 cap,
146 ttl,
147 Arc::new(SystemClock) as Arc<dyn Clock>,
148 ));
149 self
150 }
151
152 pub fn with_ttl(self, ttl: Duration) -> Self {
154 let cap = self.cache.capacity();
155 let cache = Arc::new(ClockTtlCache::new(
156 cap,
157 ttl,
158 Arc::new(SystemClock) as Arc<dyn Clock>,
159 ));
160 Self {
161 inner: self.inner,
162 cache,
163 }
164 }
165
166 pub fn with_clock(self, clock: Arc<dyn Clock>) -> Self {
168 let cap = self.cache.capacity();
169 let ttl = Duration::from_secs(DEFAULT_TTL_SECS);
171 let cache = Arc::new(ClockTtlCache::new(cap, ttl, clock));
172 Self {
173 inner: self.inner,
174 cache,
175 }
176 }
177
178 pub fn stats(&self) -> axess_cache::CacheStats {
181 self.cache.stats()
182 }
183
184 pub fn invalidate_all(&self) {
187 self.cache.invalidate_all();
188 }
189
190 pub fn invalidate_tenant(&self, tenant_id: &TenantId) {
193 let target = *tenant_id;
194 self.cache.invalidate_by(|k| k.0 == target);
195 }
196}
197
198impl<S> DeviceStore for CachedDeviceStore<S>
199where
200 S: DeviceStore,
201{
202 type Error = S::Error;
203
204 fn load(
205 &self,
206 tenant_id: &TenantId,
207 id: &DeviceId,
208 ) -> impl Future<Output = Result<Option<Device>, Self::Error>> + Send {
209 let key = (*tenant_id, *id);
210 let cache = self.cache.clone();
211 let inner = self.inner.clone();
212 let tenant = *tenant_id;
213 let device = *id;
214 async move {
215 if let Some(d) = cache.get(&key) {
217 return Ok(Some(d));
218 }
219 let result = inner.load(&tenant, &device).await?;
225 if let Some(ref d) = result {
226 cache.insert(key, d.clone());
227 }
228 Ok(result)
229 }
230 }
231
232 fn find_by_fingerprint(
233 &self,
234 tenant_id: &TenantId,
235 hash: &FingerprintHash,
236 ) -> impl Future<Output = Result<Option<Device>, Self::Error>> + Send {
237 let cache = self.cache.clone();
238 let inner = self.inner.clone();
239 let tenant = *tenant_id;
240 let hash = *hash;
241 async move {
242 let result = inner.find_by_fingerprint(&tenant, &hash).await?;
243 if let Some(ref d) = result {
246 cache.insert((tenant, d.id), d.clone());
247 }
248 Ok(result)
249 }
250 }
251
252 fn find_for_user(
253 &self,
254 tenant_id: &TenantId,
255 user_id: &UserId,
256 limit: usize,
257 ) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send {
258 self.inner.find_for_user(tenant_id, user_id, limit)
263 }
264
265 fn find_by_refresh_family(
266 &self,
267 tenant_id: &TenantId,
268 family_id: &str,
269 ) -> impl Future<Output = Result<Vec<Device>, Self::Error>> + Send {
270 self.inner.find_by_refresh_family(tenant_id, family_id)
273 }
274
275 fn save(&self, device: &Device) -> impl Future<Output = Result<(), Self::Error>> + Send {
276 let key = (device.tenant_id, device.id);
277 let cache = self.cache.clone();
278 let inner = self.inner.clone();
279 let device = device.clone();
280 async move {
281 cache.invalidate(&key);
287 inner.save(&device).await?;
288 cache.insert(key, device);
292 Ok(())
293 }
294 }
295
296 fn record_sighting(
297 &self,
298 tenant_id: &TenantId,
299 id: &DeviceId,
300 now: DateTime<Utc>,
301 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
302 self.inner.record_sighting(tenant_id, id, now)
314 }
315
316 fn set_trust_level(
317 &self,
318 tenant_id: &TenantId,
319 id: &DeviceId,
320 level: DeviceTrustLevel,
321 now: DateTime<Utc>,
322 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
323 let key = (*tenant_id, *id);
324 let cache = self.cache.clone();
325 let inner = self.inner.clone();
326 let tenant = *tenant_id;
327 let device = *id;
328 async move {
329 cache.invalidate(&key);
330 inner.set_trust_level(&tenant, &device, level, now).await
331 }
332 }
333
334 fn delete(
335 &self,
336 tenant_id: &TenantId,
337 id: &DeviceId,
338 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
339 let key = (*tenant_id, *id);
340 let cache = self.cache.clone();
341 let inner = self.inner.clone();
342 let tenant = *tenant_id;
343 let device = *id;
344 async move {
345 cache.invalidate(&key);
346 inner.delete(&tenant, &device).await
347 }
348 }
349
350 fn sweep(
351 &self,
352 tenant_id: &TenantId,
353 now: DateTime<Utc>,
354 ) -> impl Future<Output = Result<SweepCounts, Self::Error>> + Send {
355 self.inner.sweep(tenant_id, now)
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use crate::device::store::MemoryDeviceStore;
367 use crate::device::types::{Device, FingerprintHash};
368 use axess_clock::testing::MockClock;
369 use chrono::TimeZone;
370
371 fn fixed_clock() -> Arc<MockClock> {
372 Arc::new(MockClock::at(
373 Utc.with_ymd_and_hms(2026, 1, 1, 0, 0, 0).unwrap(),
374 ))
375 }
376
377 fn ids() -> (TenantId, UserId, DeviceId) {
378 (
379 crate::authn::ids::testing::tenant("tenant-1"),
380 crate::authn::ids::testing::user("user-1"),
381 crate::authn::ids::testing::device("device-1"),
382 )
383 }
384
385 fn build_device(t: &TenantId, u: &UserId, d: &DeviceId) -> Device {
386 Device {
387 id: *d,
388 tenant_id: *t,
389 user_id: Some(*u),
390 trust_level: DeviceTrustLevel::Seen,
391 fingerprint_hash: FingerprintHash::from_bytes([0u8; 32]),
392 first_seen_at: Utc.with_ymd_and_hms(2026, 1, 1, 0, 0, 0).unwrap(),
393 last_seen_at: Utc.with_ymd_and_hms(2026, 1, 1, 0, 0, 0).unwrap(),
394 revoked_at: None,
395 bindings: Vec::new(),
396 }
397 }
398
399 #[tokio::test]
400 async fn load_caches_after_first_hit() {
401 let inner = MemoryDeviceStore::new();
402 let (t, u, d) = ids();
403 inner.save(&build_device(&t, &u, &d)).await.unwrap();
404
405 let cached = CachedDeviceStore::new(inner.clone()).with_clock(fixed_clock() as _);
406 drop(cached.load(&t, &d).await.unwrap().expect("first load"));
408 let stats_after_miss = cached.stats();
409 assert_eq!(stats_after_miss.misses, 1);
410 assert_eq!(stats_after_miss.hits, 0);
411
412 drop(cached.load(&t, &d).await.unwrap().expect("second load"));
414 let stats_after_hit = cached.stats();
415 assert_eq!(stats_after_hit.hits, 1, "second load must hit cache");
416 }
417
418 #[tokio::test]
419 async fn load_does_not_cache_none_results() {
420 let inner = MemoryDeviceStore::new();
421 let (t, _u, d) = ids();
422 let cached = CachedDeviceStore::new(inner).with_clock(fixed_clock() as _);
423
424 assert!(cached.load(&t, &d).await.unwrap().is_none());
426 let stats = cached.stats();
428 assert_eq!(stats.inserts, 0, "None results must not be cached");
429 }
430
431 #[tokio::test]
432 async fn save_invalidates_and_repopulates() {
433 let inner = MemoryDeviceStore::new();
434 let (t, u, d) = ids();
435 inner.save(&build_device(&t, &u, &d)).await.unwrap();
436
437 let cached = CachedDeviceStore::new(inner.clone()).with_clock(fixed_clock() as _);
438 drop(cached.load(&t, &d).await.unwrap());
440
441 let mut updated = build_device(&t, &u, &d);
443 updated.trust_level = DeviceTrustLevel::Trusted;
444 cached.save(&updated).await.unwrap();
445
446 let loaded = cached.load(&t, &d).await.unwrap().unwrap();
448 assert_eq!(
449 loaded.trust_level,
450 DeviceTrustLevel::Trusted,
451 "save must invalidate the cached row so load sees the update"
452 );
453 }
454
455 #[tokio::test]
456 async fn set_trust_level_invalidates_cached_row() {
457 let inner = MemoryDeviceStore::new();
458 let (t, u, d) = ids();
459 inner.save(&build_device(&t, &u, &d)).await.unwrap();
460
461 let cached = CachedDeviceStore::new(inner.clone()).with_clock(fixed_clock() as _);
462 drop(cached.load(&t, &d).await.unwrap()); let now = Utc.with_ymd_and_hms(2026, 1, 1, 0, 5, 0).unwrap();
465 cached
466 .set_trust_level(&t, &d, DeviceTrustLevel::Revoked, now)
467 .await
468 .unwrap();
469
470 let loaded = cached.load(&t, &d).await.unwrap().unwrap();
471 assert_eq!(
472 loaded.trust_level,
473 DeviceTrustLevel::Revoked,
474 "set_trust_level must invalidate the cached row"
475 );
476 }
477
478 #[tokio::test]
479 async fn delete_invalidates_and_subsequent_load_is_none() {
480 let inner = MemoryDeviceStore::new();
481 let (t, u, d) = ids();
482 inner.save(&build_device(&t, &u, &d)).await.unwrap();
483
484 let cached = CachedDeviceStore::new(inner.clone()).with_clock(fixed_clock() as _);
485 drop(cached.load(&t, &d).await.unwrap()); cached.delete(&t, &d).await.unwrap();
488 assert!(
489 cached.load(&t, &d).await.unwrap().is_none(),
490 "delete must invalidate so the next load reflects absence"
491 );
492 }
493
494 #[tokio::test]
495 async fn record_sighting_does_not_invalidate() {
496 let inner = MemoryDeviceStore::new();
500 let (t, u, d) = ids();
501 inner.save(&build_device(&t, &u, &d)).await.unwrap();
502
503 let cached = CachedDeviceStore::new(inner.clone()).with_clock(fixed_clock() as _);
504 drop(cached.load(&t, &d).await.unwrap()); let stats_before = cached.stats();
506
507 let now = Utc.with_ymd_and_hms(2026, 1, 1, 0, 5, 0).unwrap();
508 cached.record_sighting(&t, &d, now).await.unwrap();
509 drop(cached.load(&t, &d).await.unwrap()); let stats_after = cached.stats();
512 assert_eq!(
513 stats_after.hits,
514 stats_before.hits + 1,
515 "record_sighting must not invalidate the cache"
516 );
517 }
518
519 #[tokio::test]
520 async fn find_by_fingerprint_primes_by_id_cache() {
521 let inner = MemoryDeviceStore::new();
522 let (t, u, d) = ids();
523 let device = build_device(&t, &u, &d);
524 let fp = device.fingerprint_hash;
525 inner.save(&device).await.unwrap();
526
527 let cached = CachedDeviceStore::new(inner.clone()).with_clock(fixed_clock() as _);
528 drop(
531 cached
532 .find_by_fingerprint(&t, &fp)
533 .await
534 .unwrap()
535 .expect("device found by fingerprint"),
536 );
537
538 drop(cached.load(&t, &d).await.unwrap());
540 let stats = cached.stats();
541 assert_eq!(
542 stats.hits, 1,
543 "find_by_fingerprint must prime the by-id cache so load is warm"
544 );
545 }
546
547 #[tokio::test]
555 async fn refresh_cascade_revocation_propagates_through_cache() {
556 use crate::device::cascade::cascade_revoke_by_refresh_family;
557 use crate::device::types::DeviceBinding;
558
559 let inner = MemoryDeviceStore::new();
560 let tenant = crate::authn::ids::testing::tenant("tenant-1");
561 let user = crate::authn::ids::testing::user("user-1");
562 let dev_a = crate::authn::ids::testing::device("dev-a");
563 let dev_b = crate::authn::ids::testing::device("dev-b");
564 let now = Utc.with_ymd_and_hms(2026, 1, 1, 0, 0, 0).unwrap();
565
566 for (id, fp_byte) in [(&dev_a, 0xa1u8), (&dev_b, 0xb2u8)] {
567 let device = Device {
568 id: *id,
569 tenant_id: tenant,
570 user_id: Some(user),
571 trust_level: DeviceTrustLevel::Trusted,
572 fingerprint_hash: FingerprintHash::from_bytes([fp_byte; 32]),
573 first_seen_at: now,
574 last_seen_at: now,
575 revoked_at: None,
576 bindings: vec![DeviceBinding::Refresh {
577 family_id: "fam-stolen".to_string(),
578 issued_at: now,
579 last_used_at: now,
580 }],
581 };
582 inner.save(&device).await.unwrap();
583 }
584
585 let cached = CachedDeviceStore::new(inner.clone()).with_clock(fixed_clock() as _);
586
587 let warm_a = cached.load(&tenant, &dev_a).await.unwrap().unwrap();
589 let warm_b = cached.load(&tenant, &dev_b).await.unwrap().unwrap();
590 assert_eq!(warm_a.trust_level, DeviceTrustLevel::Trusted);
591 assert_eq!(warm_b.trust_level, DeviceTrustLevel::Trusted);
592
593 let revoked_at = Utc.with_ymd_and_hms(2026, 1, 1, 0, 5, 0).unwrap();
597 let count = cascade_revoke_by_refresh_family(&cached, &tenant, "fam-stolen", revoked_at)
598 .await
599 .unwrap();
600 assert_eq!(count, 2, "both refresh-bound devices must be revoked");
601
602 let after_a = cached.load(&tenant, &dev_a).await.unwrap().unwrap();
604 let after_b = cached.load(&tenant, &dev_b).await.unwrap().unwrap();
605 assert_eq!(
606 after_a.trust_level,
607 DeviceTrustLevel::Revoked,
608 "cache must not serve stale Trusted after cascade revocation"
609 );
610 assert_eq!(
611 after_b.trust_level,
612 DeviceTrustLevel::Revoked,
613 "cache must not serve stale Trusted after cascade revocation"
614 );
615 }
616
617 #[tokio::test]
622 async fn invalidate_all_drops_every_entry() {
623 let inner = MemoryDeviceStore::new();
624 let t1 = crate::authn::ids::testing::tenant("t1");
625 let t2 = crate::authn::ids::testing::tenant("t2");
626 let u = crate::authn::ids::testing::user("u1");
627 let d1 = crate::authn::ids::testing::device("d1");
628 let d2 = crate::authn::ids::testing::device("d2");
629 inner.save(&build_device(&t1, &u, &d1)).await.unwrap();
630 inner.save(&build_device(&t2, &u, &d2)).await.unwrap();
631
632 let cached = CachedDeviceStore::new(inner.clone()).with_clock(fixed_clock() as _);
633 drop(cached.load(&t1, &d1).await.unwrap());
634 drop(cached.load(&t2, &d2).await.unwrap());
635 let warm = cached.stats();
636 assert_eq!(warm.misses, 2, "two cold loads landed two misses");
637
638 cached.invalidate_all();
639
640 drop(cached.load(&t1, &d1).await.unwrap());
642 drop(cached.load(&t2, &d2).await.unwrap());
643 let after = cached.stats();
644 assert_eq!(
645 after.misses,
646 warm.misses + 2,
647 "invalidate_all must drop every entry; a no-op mutant would \
648 let the second pair of loads hit cache"
649 );
650 }
651
652 #[tokio::test]
653 async fn invalidate_tenant_drops_only_matching_entries() {
654 let inner = MemoryDeviceStore::new();
655 let t1 = crate::authn::ids::testing::tenant("t1");
656 let t2 = crate::authn::ids::testing::tenant("t2");
657 let u = crate::authn::ids::testing::user("u1");
658 let d1 = crate::authn::ids::testing::device("d1");
659 let d2 = crate::authn::ids::testing::device("d2");
660 inner.save(&build_device(&t1, &u, &d1)).await.unwrap();
661 inner.save(&build_device(&t2, &u, &d2)).await.unwrap();
662
663 let cached = CachedDeviceStore::new(inner.clone()).with_clock(fixed_clock() as _);
664 drop(cached.load(&t1, &d1).await.unwrap());
665 drop(cached.load(&t2, &d2).await.unwrap());
666
667 cached.invalidate_tenant(&t1);
668
669 let stats_before = cached.stats();
671 drop(cached.load(&t1, &d1).await.unwrap());
672 let stats_after = cached.stats();
673 assert_eq!(
674 stats_after.misses,
675 stats_before.misses + 1,
676 "t1 entry should have been invalidated"
677 );
678
679 let stats_before2 = cached.stats();
681 drop(cached.load(&t2, &d2).await.unwrap());
682 let stats_after2 = cached.stats();
683 assert_eq!(
684 stats_after2.hits,
685 stats_before2.hits + 1,
686 "t2 entry must survive invalidate_tenant(t1)"
687 );
688 }
689}