use moka::future::Cache;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CacheKey {
Page(String),
UserPage { path: String, user_id: String },
Content { content_type: String, id: String },
UserContent {
user_id: String,
content_type: String,
id: String,
},
External { url: String },
Custom(String),
}
impl CacheKey {
pub fn page(path: impl Into<String>) -> Self {
Self::Page(path.into())
}
pub fn user_page(path: impl Into<String>, user_id: impl Into<String>) -> Self {
Self::UserPage {
path: path.into(),
user_id: user_id.into(),
}
}
pub fn content(content_type: impl Into<String>, id: impl Into<String>) -> Self {
Self::Content {
content_type: content_type.into(),
id: id.into(),
}
}
pub fn user_content(
user_id: impl Into<String>,
content_type: impl Into<String>,
id: impl Into<String>,
) -> Self {
Self::UserContent {
user_id: user_id.into(),
content_type: content_type.into(),
id: id.into(),
}
}
pub fn external(url: impl Into<String>) -> Self {
Self::External { url: url.into() }
}
pub fn custom(key: impl Into<String>) -> Self {
Self::Custom(key.into())
}
fn to_cache_key(&self) -> String {
match self {
Self::Page(path) => format!("page:{}", path),
Self::UserPage { path, user_id } => format!("user:{}:page:{}", user_id, path),
Self::Content { content_type, id } => format!("content:{}:{}", content_type, id),
Self::UserContent {
user_id,
content_type,
id,
} => format!("user:{}:content:{}:{}", user_id, content_type, id),
Self::External { url } => format!("external:{}", url),
Self::Custom(key) => format!("custom:{}", key),
}
}
}
#[derive(Debug, Clone)]
pub struct CachedValue {
pub content: String,
pub etag: Option<String>,
pub content_type: String,
}
impl CachedValue {
pub fn html(content: String) -> Self {
Self {
content,
etag: None,
content_type: "text/html".to_string(),
}
}
pub fn json(content: String) -> Self {
Self {
content,
etag: None,
content_type: "application/json".to_string(),
}
}
pub fn with_etag(mut self, etag: impl Into<String>) -> Self {
self.etag = Some(etag.into());
self
}
}
#[derive(Clone)]
pub struct WhatCache {
content_cache: Cache<String, CachedValue>,
api_cache: Cache<String, String>,
#[allow(dead_code)]
default_ttl: Duration,
#[allow(dead_code)]
api_ttl: Duration,
tag_index: Arc<RwLock<HashMap<String, HashSet<String>>>>,
}
impl WhatCache {
pub fn new() -> Self {
Self::with_config(CacheConfig::default())
}
pub fn with_config(config: CacheConfig) -> Self {
let content_cache = Cache::builder()
.max_capacity(config.max_content_entries)
.time_to_live(Duration::from_secs(config.content_ttl_secs))
.build();
let api_cache = Cache::builder()
.max_capacity(config.max_api_entries)
.time_to_live(Duration::from_secs(config.api_ttl_secs))
.build();
Self {
content_cache,
api_cache,
default_ttl: Duration::from_secs(config.content_ttl_secs),
api_ttl: Duration::from_secs(config.api_ttl_secs),
tag_index: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn get(&self, key: &CacheKey) -> Option<CachedValue> {
self.content_cache.get(&key.to_cache_key()).await
}
pub async fn set(&self, key: &CacheKey, value: CachedValue) {
self.content_cache.insert(key.to_cache_key(), value).await;
}
pub async fn set_with_tags(&self, key: &CacheKey, value: CachedValue, tags: &[&str]) {
let cache_key = key.to_cache_key();
self.content_cache.insert(cache_key.clone(), value).await;
if !tags.is_empty() {
let mut index = self.tag_index.write().await;
for tag in tags {
index
.entry(tag.to_string())
.or_default()
.insert(cache_key.clone());
}
}
}
pub async fn set_with_ttl(&self, key: &CacheKey, value: CachedValue, _ttl: Duration) {
self.content_cache.insert(key.to_cache_key(), value).await;
}
pub async fn get_api(&self, url: &str) -> Option<String> {
self.api_cache.get(&format!("api:{}", url)).await
}
pub async fn set_api(&self, url: &str, response: String) {
self.api_cache
.insert(format!("api:{}", url), response)
.await;
}
pub async fn invalidate(&self, key: &CacheKey) {
self.content_cache.invalidate(&key.to_cache_key()).await;
}
pub async fn invalidate_by_tag(&self, tag: &str) {
let mut index = self.tag_index.write().await;
if let Some(keys) = index.remove(tag) {
for key in &keys {
self.content_cache.invalidate(key).await;
}
self.content_cache.run_pending_tasks().await;
tracing::debug!(
"Cache: invalidated {} entries for tag '{}'",
keys.len(),
tag
);
}
}
pub async fn invalidate_content_type(&self, content_type: &str) {
let index = self.tag_index.read().await;
if index.contains_key(content_type) {
drop(index); self.invalidate_by_tag(content_type).await;
} else {
drop(index);
self.content_cache.invalidate_all();
self.content_cache.run_pending_tasks().await;
}
}
pub async fn invalidate_user(&self, user_id: &str) {
let prefix = format!("user:{}:", user_id);
let index = self.tag_index.read().await;
if let Some(keys) = index.get(user_id) {
let keys_to_remove: Vec<String> = keys.iter().cloned().collect();
drop(index);
for key in &keys_to_remove {
self.content_cache.invalidate(key).await;
}
} else {
drop(index);
let _ = prefix; self.content_cache.invalidate_all();
}
self.content_cache.run_pending_tasks().await;
}
pub async fn clear_all(&self) {
self.content_cache.invalidate_all();
self.api_cache.invalidate_all();
self.content_cache.run_pending_tasks().await;
self.api_cache.run_pending_tasks().await;
self.tag_index.write().await.clear();
}
pub fn stats(&self) -> CacheStats {
CacheStats {
content_entries: self.content_cache.entry_count(),
api_entries: self.api_cache.entry_count(),
}
}
}
impl Default for WhatCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_content_entries: u64,
pub max_api_entries: u64,
pub content_ttl_secs: u64,
pub api_ttl_secs: u64,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_content_entries: 10_000,
max_api_entries: 1_000,
content_ttl_secs: 300, api_ttl_secs: 60, }
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub content_entries: u64,
pub api_entries: u64,
}
#[derive(Debug, Clone, Default)]
pub struct CacheControl {
pub scope: CacheScope,
pub max_age: Option<u64>,
pub must_revalidate: bool,
pub immutable: bool,
}
#[derive(Debug, Clone, Default)]
pub enum CacheScope {
#[default]
Public,
Private,
NoCache,
NoStore,
}
impl CacheControl {
pub fn public(max_age: u64) -> Self {
Self {
scope: CacheScope::Public,
max_age: Some(max_age),
must_revalidate: false,
immutable: false,
}
}
pub fn private(max_age: u64) -> Self {
Self {
scope: CacheScope::Private,
max_age: Some(max_age),
must_revalidate: false,
immutable: false,
}
}
pub fn no_cache() -> Self {
Self {
scope: CacheScope::NoCache,
max_age: None,
must_revalidate: true,
immutable: false,
}
}
pub fn immutable() -> Self {
Self {
scope: CacheScope::Public,
max_age: Some(31536000), must_revalidate: false,
immutable: true,
}
}
pub fn to_header_value(&self) -> String {
let mut parts = Vec::new();
match self.scope {
CacheScope::Public => parts.push("public".to_string()),
CacheScope::Private => parts.push("private".to_string()),
CacheScope::NoCache => parts.push("no-cache".to_string()),
CacheScope::NoStore => parts.push("no-store".to_string()),
}
if let Some(max_age) = self.max_age {
parts.push(format!("max-age={}", max_age));
}
if self.must_revalidate {
parts.push("must-revalidate".to_string());
}
if self.immutable {
parts.push("immutable".to_string());
}
parts.join(", ")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_cache_basic() {
let cache = WhatCache::new();
let key = CacheKey::page("/about");
let value = CachedValue::html("<h1>About</h1>".to_string());
cache.set(&key, value.clone()).await;
let retrieved = cache.get(&key).await.unwrap();
assert_eq!(retrieved.content, "<h1>About</h1>");
}
#[tokio::test]
async fn test_user_page_cache() {
let cache = WhatCache::new();
let key1 = CacheKey::user_page("/dashboard", "user1");
let key2 = CacheKey::user_page("/dashboard", "user2");
cache
.set(&key1, CachedValue::html("User 1 Dashboard".to_string()))
.await;
cache
.set(&key2, CachedValue::html("User 2 Dashboard".to_string()))
.await;
assert_eq!(cache.get(&key1).await.unwrap().content, "User 1 Dashboard");
assert_eq!(cache.get(&key2).await.unwrap().content, "User 2 Dashboard");
}
#[tokio::test]
async fn test_set_with_tags_and_invalidate_by_tag() {
let cache = WhatCache::new();
let key1 = CacheKey::page("/blog");
let key2 = CacheKey::page("/blog/post-1");
let key3 = CacheKey::page("/about");
cache
.set_with_tags(&key1, CachedValue::html("Blog list".into()), &["posts"])
.await;
cache
.set_with_tags(&key2, CachedValue::html("Post 1".into()), &["posts"])
.await;
cache
.set_with_tags(&key3, CachedValue::html("About page".into()), &["pages"])
.await;
assert!(cache.get(&key1).await.is_some());
assert!(cache.get(&key2).await.is_some());
assert!(cache.get(&key3).await.is_some());
cache.invalidate_by_tag("posts").await;
assert!(cache.get(&key1).await.is_none());
assert!(cache.get(&key2).await.is_none());
assert!(cache.get(&key3).await.is_some()); }
#[tokio::test]
async fn test_invalidate_content_type_targeted() {
let cache = WhatCache::new();
let key1 = CacheKey::page("/products");
let key2 = CacheKey::page("/cart");
cache
.set_with_tags(&key1, CachedValue::html("Products".into()), &["products"])
.await;
cache
.set_with_tags(&key2, CachedValue::html("Cart".into()), &["cart"])
.await;
cache.invalidate_content_type("products").await;
assert!(cache.get(&key1).await.is_none());
assert!(cache.get(&key2).await.is_some());
}
#[tokio::test]
async fn test_invalidate_content_type_fallback() {
let cache = WhatCache::new();
let key = CacheKey::page("/test");
cache.set(&key, CachedValue::html("Test".into())).await;
cache.invalidate_content_type("unknown").await;
assert!(cache.get(&key).await.is_none());
}
#[tokio::test]
async fn test_clear_all_clears_tag_index() {
let cache = WhatCache::new();
let key = CacheKey::page("/blog");
cache
.set_with_tags(&key, CachedValue::html("Blog".into()), &["posts"])
.await;
cache.clear_all().await;
assert!(cache.get(&key).await.is_none());
assert!(cache.tag_index.read().await.is_empty());
}
#[test]
fn test_cache_control_header() {
let cc = CacheControl::public(3600);
assert_eq!(cc.to_header_value(), "public, max-age=3600");
let cc = CacheControl::private(600);
assert_eq!(cc.to_header_value(), "private, max-age=600");
let cc = CacheControl::no_cache();
assert_eq!(cc.to_header_value(), "no-cache, must-revalidate");
let cc = CacheControl::immutable();
assert_eq!(cc.to_header_value(), "public, max-age=31536000, immutable");
}
}