1use std::collections::HashMap;
4use std::fmt;
5use std::time::{Duration, Instant};
6
7use moka::future::Cache;
8use tokio::sync::RwLock;
9
10use ans_types::{Badge, Fqdn, Version};
11
12#[derive(Debug, Clone)]
14#[non_exhaustive]
15pub struct CacheConfig {
16 pub max_entries: u64,
18 pub default_ttl: Duration,
20 pub refresh_threshold: Duration,
22}
23
24impl Default for CacheConfig {
25 fn default() -> Self {
26 Self {
27 max_entries: 1000,
28 default_ttl: Duration::from_secs(300), refresh_threshold: Duration::from_secs(60), }
31 }
32}
33
34impl CacheConfig {
35 pub fn with_ttl(ttl: Duration) -> Self {
37 Self {
38 default_ttl: ttl,
39 refresh_threshold: Duration::from_secs(ttl.as_secs() / 5),
40 ..Default::default()
41 }
42 }
43}
44
45#[derive(Debug, Clone, Hash, PartialEq, Eq)]
47#[non_exhaustive]
48pub enum CacheKey {
49 FqdnVersion(String, Version),
51 Url(String),
53}
54
55impl CacheKey {
56 pub fn fqdn_version(fqdn: &Fqdn, version: &Version) -> Self {
58 Self::FqdnVersion(fqdn.as_str().to_lowercase(), version.clone())
59 }
60
61 pub fn url(url: &str) -> Self {
63 Self::Url(url.to_string())
64 }
65}
66
67#[derive(Debug, Clone)]
69#[non_exhaustive]
70pub struct CachedBadge {
71 pub badge: Badge,
73 pub fetched_at: Instant,
75 pub ttl: Duration,
77}
78
79impl CachedBadge {
80 pub fn new(badge: Badge, ttl: Duration) -> Self {
82 Self {
83 badge,
84 fetched_at: Instant::now(),
85 ttl,
86 }
87 }
88
89 pub fn is_valid(&self) -> bool {
91 self.fetched_at.elapsed() < self.ttl
92 }
93
94 pub fn should_refresh(&self, threshold: Duration) -> bool {
96 let remaining = self.ttl.saturating_sub(self.fetched_at.elapsed());
97 remaining < threshold
98 }
99
100 pub fn remaining_ttl(&self) -> Duration {
102 self.ttl.saturating_sub(self.fetched_at.elapsed())
103 }
104}
105
106pub struct BadgeCache {
112 cache: Cache<CacheKey, CachedBadge>,
113 config: CacheConfig,
114 version_index: RwLock<HashMap<String, Vec<Version>>>,
117}
118
119impl fmt::Debug for BadgeCache {
120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 f.debug_struct("BadgeCache")
122 .field("config", &self.config)
123 .finish_non_exhaustive()
124 }
125}
126
127impl BadgeCache {
128 pub fn new(config: CacheConfig) -> Self {
130 let cache = Cache::builder()
131 .max_capacity(config.max_entries)
132 .time_to_live(config.default_ttl)
133 .build();
134
135 Self {
136 cache,
137 config,
138 version_index: RwLock::new(HashMap::new()),
139 }
140 }
141
142 pub fn with_defaults() -> Self {
144 Self::new(CacheConfig::default())
145 }
146
147 pub async fn get(&self, key: &CacheKey) -> Option<CachedBadge> {
149 self.cache.get(key).await.filter(CachedBadge::is_valid)
150 }
151
152 pub async fn insert(&self, key: CacheKey, badge: Badge) {
154 let cached = CachedBadge::new(badge, self.config.default_ttl);
155 self.cache.insert(key, cached).await;
156 }
157
158 pub async fn insert_with_ttl(&self, key: CacheKey, badge: Badge, ttl: Duration) {
165 let cached = CachedBadge::new(badge, ttl);
166 self.cache.insert(key, cached).await;
167 }
168
169 pub async fn invalidate(&self, key: &CacheKey) {
171 self.cache.invalidate(key).await;
172 }
173
174 pub async fn clear(&self) {
176 self.cache.invalidate_all();
177 self.cache.run_pending_tasks().await;
178 self.version_index.write().await.clear();
179 }
180
181 pub fn should_refresh(&self, cached: &CachedBadge) -> bool {
183 cached.should_refresh(self.config.refresh_threshold)
184 }
185
186 pub fn entry_count(&self) -> u64 {
188 self.cache.entry_count()
189 }
190
191 pub async fn get_by_fqdn_version(&self, fqdn: &Fqdn, version: &Version) -> Option<CachedBadge> {
193 self.get(&CacheKey::fqdn_version(fqdn, version)).await
194 }
195
196 pub async fn insert_for_fqdn_version(&self, fqdn: &Fqdn, version: &Version, badge: Badge) {
201 self.insert(CacheKey::fqdn_version(fqdn, version), badge)
202 .await;
203
204 let key = fqdn.as_str().to_lowercase();
205 let mut index = self.version_index.write().await;
206 let versions = index.entry(key).or_default();
207 if !versions.contains(version) {
208 versions.push(version.clone());
209 }
210 }
211
212 pub async fn get_all_for_fqdn(&self, fqdn: &Fqdn) -> Vec<CachedBadge> {
217 let key = fqdn.as_str().to_lowercase();
218 let index = self.version_index.read().await;
219 let versions = match index.get(&key) {
220 Some(v) => v.clone(),
221 None => return Vec::new(),
222 };
223 drop(index); let mut results = Vec::new();
226 for version in &versions {
227 if let Some(cached) = self.get(&CacheKey::fqdn_version(fqdn, version)).await {
228 results.push(cached);
229 }
230 }
231 results
232 }
233
234 pub async fn invalidate_fqdn(&self, fqdn: &Fqdn) {
236 let key = fqdn.as_str().to_lowercase();
237 let mut index = self.version_index.write().await;
238 if let Some(versions) = index.remove(&key) {
239 for version in &versions {
240 self.cache
241 .invalidate(&CacheKey::fqdn_version(fqdn, version))
242 .await;
243 }
244 }
245 }
246
247 pub async fn set_version_index(&self, fqdn: &Fqdn, versions: Vec<Version>) {
252 let key = fqdn.as_str().to_lowercase();
253 let mut index = self.version_index.write().await;
254 index.insert(key, versions);
255 }
256}
257
258impl Default for BadgeCache {
259 fn default() -> Self {
260 Self::with_defaults()
261 }
262}
263
264#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use chrono::Utc;
269 use uuid::Uuid;
270
271 fn create_test_badge() -> Badge {
272 test_badge_from_json("test.example.com", "v1.0.0", "SHA256:bbb", "SHA256:aaa")
273 }
274
275 fn test_badge_from_json(
276 host: &str,
277 version: &str,
278 server_fp: &str,
279 identity_fp: &str,
280 ) -> Badge {
281 serde_json::from_value(serde_json::json!({
282 "status": "ACTIVE",
283 "schemaVersion": "V1",
284 "payload": {
285 "logId": Uuid::new_v4().to_string(),
286 "producer": {
287 "event": {
288 "ansId": Uuid::new_v4().to_string(),
289 "ansName": format!("ans://{version}.{host}"),
290 "eventType": "AGENT_REGISTERED",
291 "agent": { "host": host, "name": "Test Agent", "version": version },
292 "attestations": {
293 "domainValidation": "ACME-DNS-01",
294 "identityCert": { "fingerprint": identity_fp, "type": "X509-OV-CLIENT" },
295 "serverCert": { "fingerprint": server_fp, "type": "X509-DV-SERVER" }
296 },
297 "expiresAt": (Utc::now() + chrono::Duration::days(365)).to_rfc3339(),
298 "issuedAt": Utc::now().to_rfc3339(),
299 "raId": "test-ra",
300 "timestamp": Utc::now().to_rfc3339()
301 },
302 "keyId": "test-key",
303 "signature": "test-sig"
304 }
305 }
306 })).expect("test badge JSON should be valid")
307 }
308
309 #[tokio::test]
310 async fn test_cache_insert_and_get() {
311 let cache = BadgeCache::with_defaults();
312 let badge = create_test_badge();
313 let fqdn = Fqdn::new("test.example.com").unwrap();
314 let version = Version::new(1, 0, 0);
315
316 cache
317 .insert_for_fqdn_version(&fqdn, &version, badge.clone())
318 .await;
319
320 let cached = cache.get_by_fqdn_version(&fqdn, &version).await;
321 assert!(cached.is_some());
322 assert_eq!(cached.unwrap().badge.agent_host(), "test.example.com");
323 }
324
325 #[tokio::test]
326 async fn test_cache_miss() {
327 let cache = BadgeCache::with_defaults();
328 let fqdn = Fqdn::new("unknown.example.com").unwrap();
329
330 let cached = cache
331 .get_by_fqdn_version(&fqdn, &Version::new(1, 0, 0))
332 .await;
333 assert!(cached.is_none());
334 }
335
336 #[tokio::test]
337 async fn test_cache_invalidate() {
338 let cache = BadgeCache::with_defaults();
339 let badge = create_test_badge();
340 let fqdn = Fqdn::new("test.example.com").unwrap();
341 let version = Version::new(1, 0, 0);
342
343 cache.insert_for_fqdn_version(&fqdn, &version, badge).await;
344 assert!(cache.get_by_fqdn_version(&fqdn, &version).await.is_some());
345
346 cache
347 .invalidate(&CacheKey::fqdn_version(&fqdn, &version))
348 .await;
349 assert!(cache.get_by_fqdn_version(&fqdn, &version).await.is_none());
350 }
351
352 #[tokio::test]
353 async fn test_cache_by_version() {
354 let cache = BadgeCache::with_defaults();
355 let badge = create_test_badge();
356 let fqdn = Fqdn::new("test.example.com").unwrap();
357 let version = Version::new(1, 0, 0);
358
359 cache.insert_for_fqdn_version(&fqdn, &version, badge).await;
360
361 let cached = cache.get_by_fqdn_version(&fqdn, &version).await;
362 assert!(cached.is_some());
363
364 let cached = cache
366 .get_by_fqdn_version(&fqdn, &Version::new(2, 0, 0))
367 .await;
368 assert!(cached.is_none());
369 }
370
371 #[test]
372 fn test_cached_badge_validity() {
373 let badge = create_test_badge();
374 let cached = CachedBadge::new(badge, Duration::from_secs(60));
375
376 assert!(cached.is_valid());
377 assert!(!cached.should_refresh(Duration::from_secs(10)));
379 }
380
381 #[test]
382 fn test_cached_badge_should_refresh() {
383 let badge = create_test_badge();
384 let cached = CachedBadge::new(badge, Duration::from_secs(30));
385
386 assert!(cached.should_refresh(Duration::from_secs(60)));
388
389 assert!(!cached.should_refresh(Duration::from_secs(10)));
391 }
392
393 fn create_test_badge_versioned(version: &str) -> Badge {
394 test_badge_from_json(
395 "test.example.com",
396 version,
397 &format!("SHA256:{version}-server-fp"),
398 "SHA256:aaa",
399 )
400 }
401
402 #[tokio::test]
403 async fn test_version_index_populated_on_tracked_insert() {
404 let cache = BadgeCache::with_defaults();
405 let fqdn = Fqdn::new("test.example.com").unwrap();
406 let v1 = Version::new(1, 0, 0);
407 let v2 = Version::new(1, 0, 1);
408
409 cache
410 .insert_for_fqdn_version(&fqdn, &v1, create_test_badge_versioned("v1.0.0"))
411 .await;
412 cache
413 .insert_for_fqdn_version(&fqdn, &v2, create_test_badge_versioned("v1.0.1"))
414 .await;
415
416 let index = cache.version_index.read().await;
418 let versions = index.get("test.example.com").unwrap();
419 assert_eq!(versions.len(), 2);
420 assert!(versions.contains(&v1));
421 assert!(versions.contains(&v2));
422 }
423
424 #[tokio::test]
425 async fn test_get_all_for_fqdn_returns_all_versions() {
426 let cache = BadgeCache::with_defaults();
427 let fqdn = Fqdn::new("test.example.com").unwrap();
428 let v1 = Version::new(1, 0, 0);
429 let v2 = Version::new(1, 0, 1);
430
431 cache
432 .insert_for_fqdn_version(&fqdn, &v1, create_test_badge_versioned("v1.0.0"))
433 .await;
434 cache
435 .insert_for_fqdn_version(&fqdn, &v2, create_test_badge_versioned("v1.0.1"))
436 .await;
437
438 let all = cache.get_all_for_fqdn(&fqdn).await;
439 assert_eq!(all.len(), 2);
440
441 let versions: Vec<String> = all
442 .iter()
443 .map(|c| c.badge.agent_version().to_string())
444 .collect();
445 assert!(versions.contains(&"v1.0.0".to_string()));
446 assert!(versions.contains(&"v1.0.1".to_string()));
447 }
448
449 #[tokio::test]
450 async fn test_get_all_for_fqdn_empty_for_unknown() {
451 let cache = BadgeCache::with_defaults();
452 let fqdn = Fqdn::new("unknown.example.com").unwrap();
453
454 let all = cache.get_all_for_fqdn(&fqdn).await;
455 assert!(all.is_empty());
456 }
457
458 #[tokio::test]
459 async fn test_invalidate_fqdn_clears_all_versions() {
460 let cache = BadgeCache::with_defaults();
461 let fqdn = Fqdn::new("test.example.com").unwrap();
462 let v1 = Version::new(1, 0, 0);
463 let v2 = Version::new(1, 0, 1);
464
465 cache
466 .insert_for_fqdn_version(&fqdn, &v1, create_test_badge_versioned("v1.0.0"))
467 .await;
468 cache
469 .insert_for_fqdn_version(&fqdn, &v2, create_test_badge_versioned("v1.0.1"))
470 .await;
471
472 assert_eq!(cache.get_all_for_fqdn(&fqdn).await.len(), 2);
474
475 cache.invalidate_fqdn(&fqdn).await;
477
478 assert!(cache.get_all_for_fqdn(&fqdn).await.is_empty());
480 assert!(cache.get_by_fqdn_version(&fqdn, &v1).await.is_none());
481 assert!(cache.get_by_fqdn_version(&fqdn, &v2).await.is_none());
482 }
483
484 #[tokio::test]
485 async fn test_set_version_index() {
486 let cache = BadgeCache::with_defaults();
487 let fqdn = Fqdn::new("test.example.com").unwrap();
488
489 cache
490 .set_version_index(&fqdn, vec![Version::new(1, 0, 0), Version::new(2, 0, 0)])
491 .await;
492
493 let all = cache.get_all_for_fqdn(&fqdn).await;
495 assert!(all.is_empty());
496
497 cache
499 .insert_for_fqdn_version(&fqdn, &Version::new(1, 0, 0), create_test_badge())
500 .await;
501
502 let all = cache.get_all_for_fqdn(&fqdn).await;
503 assert_eq!(all.len(), 1);
504 }
505
506 #[tokio::test]
507 async fn test_tracked_insert_idempotent() {
508 let cache = BadgeCache::with_defaults();
509 let fqdn = Fqdn::new("test.example.com").unwrap();
510 let v1 = Version::new(1, 0, 0);
511
512 cache
514 .insert_for_fqdn_version(&fqdn, &v1, create_test_badge())
515 .await;
516 cache
517 .insert_for_fqdn_version(&fqdn, &v1, create_test_badge())
518 .await;
519
520 let index = cache.version_index.read().await;
522 let versions = index.get("test.example.com").unwrap();
523 assert_eq!(versions.len(), 1);
524 }
525
526 #[tokio::test]
527 async fn test_clear_resets_version_index() {
528 let cache = BadgeCache::with_defaults();
529 let fqdn = Fqdn::new("test.example.com").unwrap();
530
531 cache
532 .insert_for_fqdn_version(&fqdn, &Version::new(1, 0, 0), create_test_badge())
533 .await;
534 assert!(!cache.get_all_for_fqdn(&fqdn).await.is_empty());
535
536 cache.clear().await;
537 assert!(cache.get_all_for_fqdn(&fqdn).await.is_empty());
538 }
539}