Skip to main content

ans_verify/
cache.rs

1//! Badge caching with TTL and background refresh support.
2
3use std::collections::HashMap;
4use std::fmt;
5use std::time::{Duration, Instant};
6
7use moka::future::Cache;
8use tokio::sync::RwLock;
9
10use ans_types::{Badge, Fqdn, Version};
11
12/// Cache configuration.
13#[derive(Debug, Clone)]
14#[non_exhaustive]
15pub struct CacheConfig {
16    /// Maximum number of entries in the cache.
17    pub max_entries: u64,
18    /// Default time-to-live for cached badges.
19    pub default_ttl: Duration,
20    /// Time before TTL when refresh is recommended.
21    pub refresh_threshold: Duration,
22}
23
24impl Default for CacheConfig {
25    fn default() -> Self {
26        Self {
27            max_entries: 1000,
28            default_ttl: Duration::from_secs(300), // 5 minutes
29            refresh_threshold: Duration::from_secs(60), // 1 minute before expiry
30        }
31    }
32}
33
34impl CacheConfig {
35    /// Create a new configuration with custom TTL.
36    pub fn with_ttl(ttl: Duration) -> Self {
37        Self {
38            default_ttl: ttl,
39            refresh_threshold: Duration::from_secs(ttl.as_secs() / 5),
40            ..Default::default()
41        }
42    }
43}
44
45/// Cache key for badge lookups.
46#[derive(Debug, Clone, Hash, PartialEq, Eq)]
47#[non_exhaustive]
48pub enum CacheKey {
49    /// Key by FQDN and version.
50    FqdnVersion(String, Version),
51    /// Key by badge URL.
52    Url(String),
53}
54
55impl CacheKey {
56    /// Create a key for FQDN and version.
57    pub fn fqdn_version(fqdn: &Fqdn, version: &Version) -> Self {
58        Self::FqdnVersion(fqdn.as_str().to_lowercase(), version.clone())
59    }
60
61    /// Create a key for URL.
62    pub fn url(url: &str) -> Self {
63        Self::Url(url.to_string())
64    }
65}
66
67/// Cached badge entry with metadata.
68#[derive(Debug, Clone)]
69#[non_exhaustive]
70pub struct CachedBadge {
71    /// The cached badge.
72    pub badge: Badge,
73    /// When the badge was fetched.
74    pub fetched_at: Instant,
75    /// TTL for this entry.
76    pub ttl: Duration,
77}
78
79impl CachedBadge {
80    /// Create a new cached badge.
81    pub fn new(badge: Badge, ttl: Duration) -> Self {
82        Self {
83            badge,
84            fetched_at: Instant::now(),
85            ttl,
86        }
87    }
88
89    /// Check if the cached badge is still valid.
90    pub fn is_valid(&self) -> bool {
91        self.fetched_at.elapsed() < self.ttl
92    }
93
94    /// Check if the badge should be refreshed soon.
95    pub fn should_refresh(&self, threshold: Duration) -> bool {
96        let remaining = self.ttl.saturating_sub(self.fetched_at.elapsed());
97        remaining < threshold
98    }
99
100    /// Get the remaining TTL.
101    pub fn remaining_ttl(&self) -> Duration {
102        self.ttl.saturating_sub(self.fetched_at.elapsed())
103    }
104}
105
106/// Badge cache with TTL support.
107///
108/// All badges are cached by `FqdnVersion`. A secondary version index tracks which
109/// versions are cached per FQDN, enabling `get_all_for_fqdn()` to scan all cached
110/// badges for a given host during rolling deployments.
111pub struct BadgeCache {
112    cache: Cache<CacheKey, CachedBadge>,
113    config: CacheConfig,
114    /// Secondary index: FQDN (lowercased) → set of cached versions.
115    /// Enables `get_all_for_fqdn()` without requiring moka prefix scans.
116    version_index: RwLock<HashMap<String, Vec<Version>>>,
117}
118
119impl fmt::Debug for BadgeCache {
120    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121        f.debug_struct("BadgeCache")
122            .field("config", &self.config)
123            .finish_non_exhaustive()
124    }
125}
126
127impl BadgeCache {
128    /// Create a new cache with the given configuration.
129    pub fn new(config: CacheConfig) -> Self {
130        let cache = Cache::builder()
131            .max_capacity(config.max_entries)
132            .time_to_live(config.default_ttl)
133            .build();
134
135        Self {
136            cache,
137            config,
138            version_index: RwLock::new(HashMap::new()),
139        }
140    }
141
142    /// Create a new cache with default configuration.
143    pub fn with_defaults() -> Self {
144        Self::new(CacheConfig::default())
145    }
146
147    /// Get a cached badge by key.
148    pub async fn get(&self, key: &CacheKey) -> Option<CachedBadge> {
149        self.cache.get(key).await.filter(CachedBadge::is_valid)
150    }
151
152    /// Insert a badge into the cache.
153    pub async fn insert(&self, key: CacheKey, badge: Badge) {
154        let cached = CachedBadge::new(badge, self.config.default_ttl);
155        self.cache.insert(key, cached).await;
156    }
157
158    /// Insert a badge with a custom soft TTL.
159    ///
160    /// The soft TTL controls when [`CachedBadge::is_valid`] returns false (i.e., when
161    /// reads treat the entry as stale). The underlying moka cache still uses the
162    /// global `default_ttl` for hard eviction. This means entries may be filtered out
163    /// by `is_valid()` before moka evicts them.
164    pub async fn insert_with_ttl(&self, key: CacheKey, badge: Badge, ttl: Duration) {
165        let cached = CachedBadge::new(badge, ttl);
166        self.cache.insert(key, cached).await;
167    }
168
169    /// Invalidate a cache entry.
170    pub async fn invalidate(&self, key: &CacheKey) {
171        self.cache.invalidate(key).await;
172    }
173
174    /// Clear all entries from the cache.
175    pub async fn clear(&self) {
176        self.cache.invalidate_all();
177        self.cache.run_pending_tasks().await;
178        self.version_index.write().await.clear();
179    }
180
181    /// Check if a cached badge should be refreshed.
182    pub fn should_refresh(&self, cached: &CachedBadge) -> bool {
183        cached.should_refresh(self.config.refresh_threshold)
184    }
185
186    /// Get the number of entries in the cache.
187    pub fn entry_count(&self) -> u64 {
188        self.cache.entry_count()
189    }
190
191    /// Get a badge by FQDN and version.
192    pub async fn get_by_fqdn_version(&self, fqdn: &Fqdn, version: &Version) -> Option<CachedBadge> {
193        self.get(&CacheKey::fqdn_version(fqdn, version)).await
194    }
195
196    /// Insert a badge keyed by FQDN and version, updating the version index.
197    ///
198    /// The version index enables `get_all_for_fqdn()` to discover all cached
199    /// badges for a given host.
200    pub async fn insert_for_fqdn_version(&self, fqdn: &Fqdn, version: &Version, badge: Badge) {
201        self.insert(CacheKey::fqdn_version(fqdn, version), badge)
202            .await;
203
204        let key = fqdn.as_str().to_lowercase();
205        let mut index = self.version_index.write().await;
206        let versions = index.entry(key).or_default();
207        if !versions.contains(version) {
208            versions.push(version.clone());
209        }
210    }
211
212    /// Get all cached badges for an FQDN across all known versions.
213    ///
214    /// Reads the version index to find which versions are cached, then fetches
215    /// each one. Filters out expired entries.
216    pub async fn get_all_for_fqdn(&self, fqdn: &Fqdn) -> Vec<CachedBadge> {
217        let key = fqdn.as_str().to_lowercase();
218        let index = self.version_index.read().await;
219        let versions = match index.get(&key) {
220            Some(v) => v.clone(),
221            None => return Vec::new(),
222        };
223        drop(index); // Release read lock before async cache lookups
224
225        let mut results = Vec::new();
226        for version in &versions {
227            if let Some(cached) = self.get(&CacheKey::fqdn_version(fqdn, version)).await {
228                results.push(cached);
229            }
230        }
231        results
232    }
233
234    /// Invalidate all cached badges for an FQDN (all versions).
235    pub async fn invalidate_fqdn(&self, fqdn: &Fqdn) {
236        let key = fqdn.as_str().to_lowercase();
237        let mut index = self.version_index.write().await;
238        if let Some(versions) = index.remove(&key) {
239            for version in &versions {
240                self.cache
241                    .invalidate(&CacheKey::fqdn_version(fqdn, version))
242                    .await;
243            }
244        }
245    }
246
247    /// Set the known versions for an FQDN from DNS records.
248    ///
249    /// Called after DNS lookup to pre-populate the version index with all
250    /// discovered versions, even before badges are fetched.
251    pub async fn set_version_index(&self, fqdn: &Fqdn, versions: Vec<Version>) {
252        let key = fqdn.as_str().to_lowercase();
253        let mut index = self.version_index.write().await;
254        index.insert(key, versions);
255    }
256}
257
258impl Default for BadgeCache {
259    fn default() -> Self {
260        Self::with_defaults()
261    }
262}
263
264#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use chrono::Utc;
269    use uuid::Uuid;
270
271    fn create_test_badge() -> Badge {
272        test_badge_from_json("test.example.com", "v1.0.0", "SHA256:bbb", "SHA256:aaa")
273    }
274
275    fn test_badge_from_json(
276        host: &str,
277        version: &str,
278        server_fp: &str,
279        identity_fp: &str,
280    ) -> Badge {
281        serde_json::from_value(serde_json::json!({
282            "status": "ACTIVE",
283            "schemaVersion": "V1",
284            "payload": {
285                "logId": Uuid::new_v4().to_string(),
286                "producer": {
287                    "event": {
288                        "ansId": Uuid::new_v4().to_string(),
289                        "ansName": format!("ans://{version}.{host}"),
290                        "eventType": "AGENT_REGISTERED",
291                        "agent": { "host": host, "name": "Test Agent", "version": version },
292                        "attestations": {
293                            "domainValidation": "ACME-DNS-01",
294                            "identityCert": { "fingerprint": identity_fp, "type": "X509-OV-CLIENT" },
295                            "serverCert": { "fingerprint": server_fp, "type": "X509-DV-SERVER" }
296                        },
297                        "expiresAt": (Utc::now() + chrono::Duration::days(365)).to_rfc3339(),
298                        "issuedAt": Utc::now().to_rfc3339(),
299                        "raId": "test-ra",
300                        "timestamp": Utc::now().to_rfc3339()
301                    },
302                    "keyId": "test-key",
303                    "signature": "test-sig"
304                }
305            }
306        })).expect("test badge JSON should be valid")
307    }
308
309    #[tokio::test]
310    async fn test_cache_insert_and_get() {
311        let cache = BadgeCache::with_defaults();
312        let badge = create_test_badge();
313        let fqdn = Fqdn::new("test.example.com").unwrap();
314        let version = Version::new(1, 0, 0);
315
316        cache
317            .insert_for_fqdn_version(&fqdn, &version, badge.clone())
318            .await;
319
320        let cached = cache.get_by_fqdn_version(&fqdn, &version).await;
321        assert!(cached.is_some());
322        assert_eq!(cached.unwrap().badge.agent_host(), "test.example.com");
323    }
324
325    #[tokio::test]
326    async fn test_cache_miss() {
327        let cache = BadgeCache::with_defaults();
328        let fqdn = Fqdn::new("unknown.example.com").unwrap();
329
330        let cached = cache
331            .get_by_fqdn_version(&fqdn, &Version::new(1, 0, 0))
332            .await;
333        assert!(cached.is_none());
334    }
335
336    #[tokio::test]
337    async fn test_cache_invalidate() {
338        let cache = BadgeCache::with_defaults();
339        let badge = create_test_badge();
340        let fqdn = Fqdn::new("test.example.com").unwrap();
341        let version = Version::new(1, 0, 0);
342
343        cache.insert_for_fqdn_version(&fqdn, &version, badge).await;
344        assert!(cache.get_by_fqdn_version(&fqdn, &version).await.is_some());
345
346        cache
347            .invalidate(&CacheKey::fqdn_version(&fqdn, &version))
348            .await;
349        assert!(cache.get_by_fqdn_version(&fqdn, &version).await.is_none());
350    }
351
352    #[tokio::test]
353    async fn test_cache_by_version() {
354        let cache = BadgeCache::with_defaults();
355        let badge = create_test_badge();
356        let fqdn = Fqdn::new("test.example.com").unwrap();
357        let version = Version::new(1, 0, 0);
358
359        cache.insert_for_fqdn_version(&fqdn, &version, badge).await;
360
361        let cached = cache.get_by_fqdn_version(&fqdn, &version).await;
362        assert!(cached.is_some());
363
364        // Different version should not be found
365        let cached = cache
366            .get_by_fqdn_version(&fqdn, &Version::new(2, 0, 0))
367            .await;
368        assert!(cached.is_none());
369    }
370
371    #[test]
372    fn test_cached_badge_validity() {
373        let badge = create_test_badge();
374        let cached = CachedBadge::new(badge, Duration::from_secs(60));
375
376        assert!(cached.is_valid());
377        // With 60s TTL just created (remaining ~60s) and 10s threshold, should not refresh yet
378        assert!(!cached.should_refresh(Duration::from_secs(10)));
379    }
380
381    #[test]
382    fn test_cached_badge_should_refresh() {
383        let badge = create_test_badge();
384        let cached = CachedBadge::new(badge, Duration::from_secs(30));
385
386        // With 30s TTL and 60s threshold, should recommend refresh
387        assert!(cached.should_refresh(Duration::from_secs(60)));
388
389        // With 30s TTL and 10s threshold, should not recommend refresh yet
390        assert!(!cached.should_refresh(Duration::from_secs(10)));
391    }
392
393    fn create_test_badge_versioned(version: &str) -> Badge {
394        test_badge_from_json(
395            "test.example.com",
396            version,
397            &format!("SHA256:{version}-server-fp"),
398            "SHA256:aaa",
399        )
400    }
401
402    #[tokio::test]
403    async fn test_version_index_populated_on_tracked_insert() {
404        let cache = BadgeCache::with_defaults();
405        let fqdn = Fqdn::new("test.example.com").unwrap();
406        let v1 = Version::new(1, 0, 0);
407        let v2 = Version::new(1, 0, 1);
408
409        cache
410            .insert_for_fqdn_version(&fqdn, &v1, create_test_badge_versioned("v1.0.0"))
411            .await;
412        cache
413            .insert_for_fqdn_version(&fqdn, &v2, create_test_badge_versioned("v1.0.1"))
414            .await;
415
416        // Version index should contain both versions
417        let index = cache.version_index.read().await;
418        let versions = index.get("test.example.com").unwrap();
419        assert_eq!(versions.len(), 2);
420        assert!(versions.contains(&v1));
421        assert!(versions.contains(&v2));
422    }
423
424    #[tokio::test]
425    async fn test_get_all_for_fqdn_returns_all_versions() {
426        let cache = BadgeCache::with_defaults();
427        let fqdn = Fqdn::new("test.example.com").unwrap();
428        let v1 = Version::new(1, 0, 0);
429        let v2 = Version::new(1, 0, 1);
430
431        cache
432            .insert_for_fqdn_version(&fqdn, &v1, create_test_badge_versioned("v1.0.0"))
433            .await;
434        cache
435            .insert_for_fqdn_version(&fqdn, &v2, create_test_badge_versioned("v1.0.1"))
436            .await;
437
438        let all = cache.get_all_for_fqdn(&fqdn).await;
439        assert_eq!(all.len(), 2);
440
441        let versions: Vec<String> = all
442            .iter()
443            .map(|c| c.badge.agent_version().to_string())
444            .collect();
445        assert!(versions.contains(&"v1.0.0".to_string()));
446        assert!(versions.contains(&"v1.0.1".to_string()));
447    }
448
449    #[tokio::test]
450    async fn test_get_all_for_fqdn_empty_for_unknown() {
451        let cache = BadgeCache::with_defaults();
452        let fqdn = Fqdn::new("unknown.example.com").unwrap();
453
454        let all = cache.get_all_for_fqdn(&fqdn).await;
455        assert!(all.is_empty());
456    }
457
458    #[tokio::test]
459    async fn test_invalidate_fqdn_clears_all_versions() {
460        let cache = BadgeCache::with_defaults();
461        let fqdn = Fqdn::new("test.example.com").unwrap();
462        let v1 = Version::new(1, 0, 0);
463        let v2 = Version::new(1, 0, 1);
464
465        cache
466            .insert_for_fqdn_version(&fqdn, &v1, create_test_badge_versioned("v1.0.0"))
467            .await;
468        cache
469            .insert_for_fqdn_version(&fqdn, &v2, create_test_badge_versioned("v1.0.1"))
470            .await;
471
472        // Verify all entries exist
473        assert_eq!(cache.get_all_for_fqdn(&fqdn).await.len(), 2);
474
475        // Invalidate all
476        cache.invalidate_fqdn(&fqdn).await;
477
478        // All entries should be gone
479        assert!(cache.get_all_for_fqdn(&fqdn).await.is_empty());
480        assert!(cache.get_by_fqdn_version(&fqdn, &v1).await.is_none());
481        assert!(cache.get_by_fqdn_version(&fqdn, &v2).await.is_none());
482    }
483
484    #[tokio::test]
485    async fn test_set_version_index() {
486        let cache = BadgeCache::with_defaults();
487        let fqdn = Fqdn::new("test.example.com").unwrap();
488
489        cache
490            .set_version_index(&fqdn, vec![Version::new(1, 0, 0), Version::new(2, 0, 0)])
491            .await;
492
493        // Index set, but no badges cached yet — get_all should return empty
494        let all = cache.get_all_for_fqdn(&fqdn).await;
495        assert!(all.is_empty());
496
497        // Now insert a badge for one version
498        cache
499            .insert_for_fqdn_version(&fqdn, &Version::new(1, 0, 0), create_test_badge())
500            .await;
501
502        let all = cache.get_all_for_fqdn(&fqdn).await;
503        assert_eq!(all.len(), 1);
504    }
505
506    #[tokio::test]
507    async fn test_tracked_insert_idempotent() {
508        let cache = BadgeCache::with_defaults();
509        let fqdn = Fqdn::new("test.example.com").unwrap();
510        let v1 = Version::new(1, 0, 0);
511
512        // Insert same version twice
513        cache
514            .insert_for_fqdn_version(&fqdn, &v1, create_test_badge())
515            .await;
516        cache
517            .insert_for_fqdn_version(&fqdn, &v1, create_test_badge())
518            .await;
519
520        // Version should appear only once in the index
521        let index = cache.version_index.read().await;
522        let versions = index.get("test.example.com").unwrap();
523        assert_eq!(versions.len(), 1);
524    }
525
526    #[tokio::test]
527    async fn test_clear_resets_version_index() {
528        let cache = BadgeCache::with_defaults();
529        let fqdn = Fqdn::new("test.example.com").unwrap();
530
531        cache
532            .insert_for_fqdn_version(&fqdn, &Version::new(1, 0, 0), create_test_badge())
533            .await;
534        assert!(!cache.get_all_for_fqdn(&fqdn).await.is_empty());
535
536        cache.clear().await;
537        assert!(cache.get_all_for_fqdn(&fqdn).await.is_empty());
538    }
539}