use std::collections::HashMap;
use std::fmt;
use std::time::{Duration, Instant};
use moka::future::Cache;
use tokio::sync::RwLock;
use ans_types::{Badge, Fqdn, Version};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct CacheConfig {
pub max_entries: u64,
pub default_ttl: Duration,
pub refresh_threshold: Duration,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
default_ttl: Duration::from_secs(300), refresh_threshold: Duration::from_secs(60), }
}
}
impl CacheConfig {
pub fn with_ttl(ttl: Duration) -> Self {
Self {
default_ttl: ttl,
refresh_threshold: Duration::from_secs(ttl.as_secs() / 5),
..Default::default()
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
#[non_exhaustive]
pub enum CacheKey {
FqdnVersion(String, Version),
Url(String),
}
impl CacheKey {
pub fn fqdn_version(fqdn: &Fqdn, version: &Version) -> Self {
Self::FqdnVersion(fqdn.as_str().to_lowercase(), version.clone())
}
pub fn url(url: &str) -> Self {
Self::Url(url.to_string())
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct CachedBadge {
pub badge: Badge,
pub fetched_at: Instant,
pub ttl: Duration,
}
impl CachedBadge {
pub fn new(badge: Badge, ttl: Duration) -> Self {
Self {
badge,
fetched_at: Instant::now(),
ttl,
}
}
pub fn is_valid(&self) -> bool {
self.fetched_at.elapsed() < self.ttl
}
pub fn should_refresh(&self, threshold: Duration) -> bool {
let remaining = self.ttl.saturating_sub(self.fetched_at.elapsed());
remaining < threshold
}
pub fn remaining_ttl(&self) -> Duration {
self.ttl.saturating_sub(self.fetched_at.elapsed())
}
}
pub struct BadgeCache {
cache: Cache<CacheKey, CachedBadge>,
config: CacheConfig,
version_index: RwLock<HashMap<String, Vec<Version>>>,
}
impl fmt::Debug for BadgeCache {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BadgeCache")
.field("config", &self.config)
.finish_non_exhaustive()
}
}
impl BadgeCache {
pub fn new(config: CacheConfig) -> Self {
let cache = Cache::builder()
.max_capacity(config.max_entries)
.time_to_live(config.default_ttl)
.build();
Self {
cache,
config,
version_index: RwLock::new(HashMap::new()),
}
}
pub fn with_defaults() -> Self {
Self::new(CacheConfig::default())
}
pub async fn get(&self, key: &CacheKey) -> Option<CachedBadge> {
self.cache.get(key).await.filter(CachedBadge::is_valid)
}
pub async fn insert(&self, key: CacheKey, badge: Badge) {
let cached = CachedBadge::new(badge, self.config.default_ttl);
self.cache.insert(key, cached).await;
}
pub async fn insert_with_ttl(&self, key: CacheKey, badge: Badge, ttl: Duration) {
let cached = CachedBadge::new(badge, ttl);
self.cache.insert(key, cached).await;
}
pub async fn invalidate(&self, key: &CacheKey) {
self.cache.invalidate(key).await;
}
pub async fn clear(&self) {
self.cache.invalidate_all();
self.cache.run_pending_tasks().await;
self.version_index.write().await.clear();
}
pub fn should_refresh(&self, cached: &CachedBadge) -> bool {
cached.should_refresh(self.config.refresh_threshold)
}
pub fn entry_count(&self) -> u64 {
self.cache.entry_count()
}
pub async fn get_by_fqdn_version(&self, fqdn: &Fqdn, version: &Version) -> Option<CachedBadge> {
self.get(&CacheKey::fqdn_version(fqdn, version)).await
}
pub async fn insert_for_fqdn_version(&self, fqdn: &Fqdn, version: &Version, badge: Badge) {
self.insert(CacheKey::fqdn_version(fqdn, version), badge)
.await;
let key = fqdn.as_str().to_lowercase();
let mut index = self.version_index.write().await;
let versions = index.entry(key).or_default();
if !versions.contains(version) {
versions.push(version.clone());
}
}
pub async fn get_all_for_fqdn(&self, fqdn: &Fqdn) -> Vec<CachedBadge> {
let key = fqdn.as_str().to_lowercase();
let index = self.version_index.read().await;
let versions = match index.get(&key) {
Some(v) => v.clone(),
None => return Vec::new(),
};
drop(index);
let mut results = Vec::new();
for version in &versions {
if let Some(cached) = self.get(&CacheKey::fqdn_version(fqdn, version)).await {
results.push(cached);
}
}
results
}
pub async fn invalidate_fqdn(&self, fqdn: &Fqdn) {
let key = fqdn.as_str().to_lowercase();
let mut index = self.version_index.write().await;
if let Some(versions) = index.remove(&key) {
for version in &versions {
self.cache
.invalidate(&CacheKey::fqdn_version(fqdn, version))
.await;
}
}
}
pub async fn set_version_index(&self, fqdn: &Fqdn, versions: Vec<Version>) {
let key = fqdn.as_str().to_lowercase();
let mut index = self.version_index.write().await;
index.insert(key, versions);
}
}
impl Default for BadgeCache {
fn default() -> Self {
Self::with_defaults()
}
}
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use uuid::Uuid;
fn create_test_badge() -> Badge {
test_badge_from_json("test.example.com", "v1.0.0", "SHA256:bbb", "SHA256:aaa")
}
fn test_badge_from_json(
host: &str,
version: &str,
server_fp: &str,
identity_fp: &str,
) -> Badge {
serde_json::from_value(serde_json::json!({
"status": "ACTIVE",
"schemaVersion": "V1",
"payload": {
"logId": Uuid::new_v4().to_string(),
"producer": {
"event": {
"ansId": Uuid::new_v4().to_string(),
"ansName": format!("ans://{version}.{host}"),
"eventType": "AGENT_REGISTERED",
"agent": { "host": host, "name": "Test Agent", "version": version },
"attestations": {
"domainValidation": "ACME-DNS-01",
"identityCert": { "fingerprint": identity_fp, "type": "X509-OV-CLIENT" },
"serverCert": { "fingerprint": server_fp, "type": "X509-DV-SERVER" }
},
"expiresAt": (Utc::now() + chrono::Duration::days(365)).to_rfc3339(),
"issuedAt": Utc::now().to_rfc3339(),
"raId": "test-ra",
"timestamp": Utc::now().to_rfc3339()
},
"keyId": "test-key",
"signature": "test-sig"
}
}
})).expect("test badge JSON should be valid")
}
#[tokio::test]
async fn test_cache_insert_and_get() {
let cache = BadgeCache::with_defaults();
let badge = create_test_badge();
let fqdn = Fqdn::new("test.example.com").unwrap();
let version = Version::new(1, 0, 0);
cache
.insert_for_fqdn_version(&fqdn, &version, badge.clone())
.await;
let cached = cache.get_by_fqdn_version(&fqdn, &version).await;
assert!(cached.is_some());
assert_eq!(cached.unwrap().badge.agent_host(), "test.example.com");
}
#[tokio::test]
async fn test_cache_miss() {
let cache = BadgeCache::with_defaults();
let fqdn = Fqdn::new("unknown.example.com").unwrap();
let cached = cache
.get_by_fqdn_version(&fqdn, &Version::new(1, 0, 0))
.await;
assert!(cached.is_none());
}
#[tokio::test]
async fn test_cache_invalidate() {
let cache = BadgeCache::with_defaults();
let badge = create_test_badge();
let fqdn = Fqdn::new("test.example.com").unwrap();
let version = Version::new(1, 0, 0);
cache.insert_for_fqdn_version(&fqdn, &version, badge).await;
assert!(cache.get_by_fqdn_version(&fqdn, &version).await.is_some());
cache
.invalidate(&CacheKey::fqdn_version(&fqdn, &version))
.await;
assert!(cache.get_by_fqdn_version(&fqdn, &version).await.is_none());
}
#[tokio::test]
async fn test_cache_by_version() {
let cache = BadgeCache::with_defaults();
let badge = create_test_badge();
let fqdn = Fqdn::new("test.example.com").unwrap();
let version = Version::new(1, 0, 0);
cache.insert_for_fqdn_version(&fqdn, &version, badge).await;
let cached = cache.get_by_fqdn_version(&fqdn, &version).await;
assert!(cached.is_some());
let cached = cache
.get_by_fqdn_version(&fqdn, &Version::new(2, 0, 0))
.await;
assert!(cached.is_none());
}
#[test]
fn test_cached_badge_validity() {
let badge = create_test_badge();
let cached = CachedBadge::new(badge, Duration::from_secs(60));
assert!(cached.is_valid());
assert!(!cached.should_refresh(Duration::from_secs(10)));
}
#[test]
fn test_cached_badge_should_refresh() {
let badge = create_test_badge();
let cached = CachedBadge::new(badge, Duration::from_secs(30));
assert!(cached.should_refresh(Duration::from_secs(60)));
assert!(!cached.should_refresh(Duration::from_secs(10)));
}
fn create_test_badge_versioned(version: &str) -> Badge {
test_badge_from_json(
"test.example.com",
version,
&format!("SHA256:{version}-server-fp"),
"SHA256:aaa",
)
}
#[tokio::test]
async fn test_version_index_populated_on_tracked_insert() {
let cache = BadgeCache::with_defaults();
let fqdn = Fqdn::new("test.example.com").unwrap();
let v1 = Version::new(1, 0, 0);
let v2 = Version::new(1, 0, 1);
cache
.insert_for_fqdn_version(&fqdn, &v1, create_test_badge_versioned("v1.0.0"))
.await;
cache
.insert_for_fqdn_version(&fqdn, &v2, create_test_badge_versioned("v1.0.1"))
.await;
let index = cache.version_index.read().await;
let versions = index.get("test.example.com").unwrap();
assert_eq!(versions.len(), 2);
assert!(versions.contains(&v1));
assert!(versions.contains(&v2));
}
#[tokio::test]
async fn test_get_all_for_fqdn_returns_all_versions() {
let cache = BadgeCache::with_defaults();
let fqdn = Fqdn::new("test.example.com").unwrap();
let v1 = Version::new(1, 0, 0);
let v2 = Version::new(1, 0, 1);
cache
.insert_for_fqdn_version(&fqdn, &v1, create_test_badge_versioned("v1.0.0"))
.await;
cache
.insert_for_fqdn_version(&fqdn, &v2, create_test_badge_versioned("v1.0.1"))
.await;
let all = cache.get_all_for_fqdn(&fqdn).await;
assert_eq!(all.len(), 2);
let versions: Vec<String> = all
.iter()
.map(|c| c.badge.agent_version().to_string())
.collect();
assert!(versions.contains(&"v1.0.0".to_string()));
assert!(versions.contains(&"v1.0.1".to_string()));
}
#[tokio::test]
async fn test_get_all_for_fqdn_empty_for_unknown() {
let cache = BadgeCache::with_defaults();
let fqdn = Fqdn::new("unknown.example.com").unwrap();
let all = cache.get_all_for_fqdn(&fqdn).await;
assert!(all.is_empty());
}
#[tokio::test]
async fn test_invalidate_fqdn_clears_all_versions() {
let cache = BadgeCache::with_defaults();
let fqdn = Fqdn::new("test.example.com").unwrap();
let v1 = Version::new(1, 0, 0);
let v2 = Version::new(1, 0, 1);
cache
.insert_for_fqdn_version(&fqdn, &v1, create_test_badge_versioned("v1.0.0"))
.await;
cache
.insert_for_fqdn_version(&fqdn, &v2, create_test_badge_versioned("v1.0.1"))
.await;
assert_eq!(cache.get_all_for_fqdn(&fqdn).await.len(), 2);
cache.invalidate_fqdn(&fqdn).await;
assert!(cache.get_all_for_fqdn(&fqdn).await.is_empty());
assert!(cache.get_by_fqdn_version(&fqdn, &v1).await.is_none());
assert!(cache.get_by_fqdn_version(&fqdn, &v2).await.is_none());
}
#[tokio::test]
async fn test_set_version_index() {
let cache = BadgeCache::with_defaults();
let fqdn = Fqdn::new("test.example.com").unwrap();
cache
.set_version_index(&fqdn, vec![Version::new(1, 0, 0), Version::new(2, 0, 0)])
.await;
let all = cache.get_all_for_fqdn(&fqdn).await;
assert!(all.is_empty());
cache
.insert_for_fqdn_version(&fqdn, &Version::new(1, 0, 0), create_test_badge())
.await;
let all = cache.get_all_for_fqdn(&fqdn).await;
assert_eq!(all.len(), 1);
}
#[tokio::test]
async fn test_tracked_insert_idempotent() {
let cache = BadgeCache::with_defaults();
let fqdn = Fqdn::new("test.example.com").unwrap();
let v1 = Version::new(1, 0, 0);
cache
.insert_for_fqdn_version(&fqdn, &v1, create_test_badge())
.await;
cache
.insert_for_fqdn_version(&fqdn, &v1, create_test_badge())
.await;
let index = cache.version_index.read().await;
let versions = index.get("test.example.com").unwrap();
assert_eq!(versions.len(), 1);
}
#[tokio::test]
async fn test_clear_resets_version_index() {
let cache = BadgeCache::with_defaults();
let fqdn = Fqdn::new("test.example.com").unwrap();
cache
.insert_for_fqdn_version(&fqdn, &Version::new(1, 0, 0), create_test_badge())
.await;
assert!(!cache.get_all_for_fqdn(&fqdn).await.is_empty());
cache.clear().await;
assert!(cache.get_all_for_fqdn(&fqdn).await.is_empty());
}
}