use async_trait::async_trait;
use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use super::{CacheBackend, CacheEntry, CacheRead};
use crate::error::CacheError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PromotionStrategy {
HitCount { threshold: u64 },
HitRate { threshold_per_minute: u64 },
}
impl Default for PromotionStrategy {
fn default() -> Self {
Self::HitCount { threshold: 3 }
}
}
#[derive(Debug, Clone, Default)]
pub struct TierStats {
pub l1_hits: u64,
pub l2_hits: u64,
pub misses: u64,
pub promotions: u64,
}
#[derive(Debug, Clone)]
pub struct MultiTierConfig {
pub promotion_strategy: PromotionStrategy,
pub write_through: bool,
pub max_l1_entry_size: Option<usize>,
}
impl Default for MultiTierConfig {
fn default() -> Self {
Self {
promotion_strategy: PromotionStrategy::default(),
write_through: true,
max_l1_entry_size: Some(256 * 1024), }
}
}
struct KeyStats {
l2_hits: AtomicU64,
}
impl KeyStats {
fn new() -> Self {
Self {
l2_hits: AtomicU64::new(0),
}
}
fn record_hit(&self) -> u64 {
self.l2_hits.fetch_add(1, Ordering::Relaxed) + 1
}
fn reset(&self) {
self.l2_hits.store(0, Ordering::Relaxed);
}
fn hits(&self) -> u64 {
self.l2_hits.load(Ordering::Relaxed)
}
}
#[derive(Clone)]
pub struct MultiTierBackend<L1, L2> {
l1: L1,
l2: L2,
config: MultiTierConfig,
key_stats: Arc<DashMap<String, Arc<KeyStats>>>,
tier_stats: Arc<TierStats>,
}
impl<L1, L2> MultiTierBackend<L1, L2>
where
L1: CacheBackend,
L2: CacheBackend,
{
pub fn new(l1: L1, l2: L2) -> Self {
Self {
l1,
l2,
config: MultiTierConfig::default(),
key_stats: Arc::new(DashMap::new()),
tier_stats: Arc::new(TierStats::default()),
}
}
pub fn builder() -> MultiTierBuilder<L1, L2> {
MultiTierBuilder::new()
}
pub fn l1(&self) -> &L1 {
&self.l1
}
pub fn l2(&self) -> &L2 {
&self.l2
}
pub fn stats(&self) -> &TierStats {
&self.tier_stats
}
fn should_promote(&self, key: &str) -> bool {
let stats = self
.key_stats
.entry(key.to_string())
.or_insert_with(|| Arc::new(KeyStats::new()));
match self.config.promotion_strategy {
PromotionStrategy::HitCount { threshold } => stats.hits() >= threshold,
PromotionStrategy::HitRate {
threshold_per_minute: _,
} => {
stats.hits() >= 3
}
}
}
fn record_hit(&self, key: &str) -> u64 {
self.key_stats
.entry(key.to_string())
.or_insert_with(|| Arc::new(KeyStats::new()))
.record_hit()
}
#[allow(dead_code)]
async fn promote(
&self,
key: &str,
entry: CacheEntry,
ttl: Duration,
stale_for: Duration,
) -> Result<(), CacheError> {
self.l1.set(key.to_string(), entry, ttl, stale_for).await?;
if let Some(stats) = self.key_stats.get(key) {
stats.reset();
}
Ok(())
}
}
#[async_trait]
impl<L1, L2> CacheBackend for MultiTierBackend<L1, L2>
where
L1: CacheBackend,
L2: CacheBackend,
{
async fn get(&self, key: &str) -> Result<Option<CacheRead>, CacheError> {
if let Some(entry) = self.l1.get(key).await? {
#[cfg(feature = "metrics")]
metrics::counter!("tower_http_cache.tier.l1_hit").increment(1);
return Ok(Some(entry));
}
if let Some(read) = self.l2.get(key).await? {
#[cfg(feature = "metrics")]
metrics::counter!("tower_http_cache.tier.l2_hit").increment(1);
self.record_hit(key);
if self.should_promote(key) {
let entry_size = read.entry.body.len();
let should_promote_l1 = if let Some(max_size) = self.config.max_l1_entry_size {
entry_size <= max_size
} else {
true
};
if should_promote_l1 {
#[cfg(feature = "metrics")]
metrics::counter!("tower_http_cache.tier.promoted").increment(1);
let ttl = if let Some(expires_at) = read.expires_at {
expires_at
.duration_since(std::time::SystemTime::now())
.unwrap_or(Duration::from_secs(60))
} else {
Duration::from_secs(60)
};
let stale_for = if let (Some(stale_until), Some(expires_at)) =
(read.stale_until, read.expires_at)
{
stale_until.duration_since(expires_at).unwrap_or_default()
} else {
Duration::ZERO
};
let entry = read.entry.clone();
let key = key.to_string();
let l1 = self.l1.clone();
let key_stats = self.key_stats.clone();
tokio::spawn(async move {
let _ = l1.set(key.clone(), entry, ttl, stale_for).await;
if let Some(stats) = key_stats.get(&key) {
stats.reset();
}
});
} else {
#[cfg(feature = "metrics")]
metrics::counter!("tower_http_cache.tier.promotion_skipped_large").increment(1);
#[cfg(feature = "tracing")]
tracing::debug!(
key = %key,
size = entry_size,
max_l1_size = ?self.config.max_l1_entry_size,
"skipping promotion for large entry"
);
}
}
return Ok(Some(read));
}
Ok(None)
}
async fn set(
&self,
key: String,
entry: CacheEntry,
ttl: Duration,
stale_for: Duration,
) -> Result<(), CacheError> {
let entry_size = entry.body.len();
self.l2
.set(key.clone(), entry.clone(), ttl, stale_for)
.await?;
if self.config.write_through {
let should_write_l1 = if let Some(max_size) = self.config.max_l1_entry_size {
if entry_size <= max_size {
true
} else {
#[cfg(feature = "metrics")]
metrics::counter!("tower_http_cache.tier.l1_skipped_large").increment(1);
#[cfg(feature = "tracing")]
tracing::debug!(
key = %key,
size = entry_size,
max_l1_size = max_size,
"skipping L1 write for large entry"
);
false
}
} else {
true
};
if should_write_l1 {
let _ = self.l1.set(key.clone(), entry, ttl, stale_for).await;
}
}
Ok(())
}
async fn invalidate(&self, key: &str) -> Result<(), CacheError> {
let l1_result = self.l1.invalidate(key).await;
let l2_result = self.l2.invalidate(key).await;
self.key_stats.remove(key);
l1_result.and(l2_result)
}
async fn get_keys_by_tag(&self, tag: &str) -> Result<Vec<String>, CacheError> {
let mut keys = self.l1.get_keys_by_tag(tag).await?;
let l2_keys = self.l2.get_keys_by_tag(tag).await?;
keys.extend(l2_keys);
keys.sort();
keys.dedup();
Ok(keys)
}
async fn invalidate_by_tag(&self, tag: &str) -> Result<usize, CacheError> {
let l1_count = self.l1.invalidate_by_tag(tag).await?;
let l2_count = self.l2.invalidate_by_tag(tag).await?;
Ok(l1_count + l2_count)
}
async fn list_tags(&self) -> Result<Vec<String>, CacheError> {
let mut tags = self.l1.list_tags().await?;
let l2_tags = self.l2.list_tags().await?;
tags.extend(l2_tags);
tags.sort();
tags.dedup();
Ok(tags)
}
}
pub struct MultiTierBuilder<L1, L2> {
l1: Option<L1>,
l2: Option<L2>,
config: MultiTierConfig,
}
impl<L1, L2> MultiTierBuilder<L1, L2> {
pub fn new() -> Self {
Self {
l1: None,
l2: None,
config: MultiTierConfig::default(),
}
}
pub fn l1(mut self, backend: L1) -> Self {
self.l1 = Some(backend);
self
}
pub fn l2(mut self, backend: L2) -> Self {
self.l2 = Some(backend);
self
}
pub fn promotion_strategy(mut self, strategy: PromotionStrategy) -> Self {
self.config.promotion_strategy = strategy;
self
}
pub fn promotion_threshold(mut self, threshold: u64) -> Self {
self.config.promotion_strategy = PromotionStrategy::HitCount { threshold };
self
}
pub fn write_through(mut self, enabled: bool) -> Self {
self.config.write_through = enabled;
self
}
pub fn max_l1_entry_size(mut self, size: Option<usize>) -> Self {
self.config.max_l1_entry_size = size;
self
}
pub fn build(self) -> MultiTierBackend<L1, L2> {
MultiTierBackend {
l1: self.l1.expect("L1 backend is required"),
l2: self.l2.expect("L2 backend is required"),
config: self.config,
key_stats: Arc::new(DashMap::new()),
tier_stats: Arc::new(TierStats::default()),
}
}
}
impl<L1, L2> Default for MultiTierBuilder<L1, L2> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::memory::InMemoryBackend;
use bytes::Bytes;
use http::{StatusCode, Version};
fn test_entry() -> CacheEntry {
CacheEntry::new(
StatusCode::OK,
Version::HTTP_11,
Vec::new(),
Bytes::from_static(b"test"),
)
}
#[tokio::test]
async fn multi_tier_l1_hit() {
let l1 = InMemoryBackend::new(100);
let l2 = InMemoryBackend::new(1000);
let backend = MultiTierBackend::new(l1.clone(), l2);
l1.set(
"key".to_string(),
test_entry(),
Duration::from_secs(60),
Duration::ZERO,
)
.await
.unwrap();
let result = backend.get("key").await.unwrap();
assert!(result.is_some());
}
#[tokio::test]
async fn multi_tier_l2_hit_and_promote() {
let l1 = InMemoryBackend::new(100);
let l2 = InMemoryBackend::new(1000);
let backend = MultiTierBackend::builder()
.l1(l1.clone())
.l2(l2.clone())
.promotion_threshold(3)
.build();
l2.set(
"key".to_string(),
test_entry(),
Duration::from_secs(60),
Duration::ZERO,
)
.await
.unwrap();
for _ in 0..3 {
let result = backend.get("key").await.unwrap();
assert!(result.is_some());
}
tokio::time::sleep(Duration::from_millis(50)).await;
let l1_result = l1.get("key").await.unwrap();
assert!(l1_result.is_some());
}
#[tokio::test]
async fn multi_tier_set_writes_to_both_tiers() {
let l1 = InMemoryBackend::new(100);
let l2 = InMemoryBackend::new(1000);
let backend = MultiTierBackend::builder()
.l1(l1.clone())
.l2(l2.clone())
.write_through(true)
.build();
backend
.set(
"key".to_string(),
test_entry(),
Duration::from_secs(60),
Duration::ZERO,
)
.await
.unwrap();
assert!(l1.get("key").await.unwrap().is_some());
assert!(l2.get("key").await.unwrap().is_some());
}
#[tokio::test]
async fn multi_tier_invalidate_both_tiers() {
let l1 = InMemoryBackend::new(100);
let l2 = InMemoryBackend::new(1000);
let backend = MultiTierBackend::new(l1.clone(), l2.clone());
l1.set(
"key".to_string(),
test_entry(),
Duration::from_secs(60),
Duration::ZERO,
)
.await
.unwrap();
l2.set(
"key".to_string(),
test_entry(),
Duration::from_secs(60),
Duration::ZERO,
)
.await
.unwrap();
backend.invalidate("key").await.unwrap();
assert!(l1.get("key").await.unwrap().is_none());
assert!(l2.get("key").await.unwrap().is_none());
}
#[tokio::test]
async fn multi_tier_miss() {
let l1 = InMemoryBackend::new(100);
let l2 = InMemoryBackend::new(1000);
let backend = MultiTierBackend::new(l1, l2);
let result = backend.get("nonexistent").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn promotion_strategy_hit_count() {
let strategy = PromotionStrategy::HitCount { threshold: 5 };
let l1 = InMemoryBackend::new(100);
let l2 = InMemoryBackend::new(1000);
let backend = MultiTierBackend::builder()
.l1(l1.clone())
.l2(l2.clone())
.promotion_strategy(strategy)
.build();
l2.set(
"key".to_string(),
test_entry(),
Duration::from_secs(60),
Duration::ZERO,
)
.await
.unwrap();
for _ in 0..5 {
backend.get("key").await.unwrap();
}
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(l1.get("key").await.unwrap().is_some());
}
}