use std::time::Duration;
use super::backend::{BackendStats, CacheBackend, CacheResult};
use super::invalidation::EntityTag;
use super::key::{CacheKey, KeyPattern};
#[derive(Debug, Clone)]
pub struct TieredCacheConfig {
pub write_through_l1: bool,
pub write_through_l2: bool,
pub l1_ttl: Option<Duration>,
pub l2_ttl: Option<Duration>,
pub l1_required: bool,
pub l2_required: bool,
}
impl Default for TieredCacheConfig {
fn default() -> Self {
Self {
write_through_l1: true,
write_through_l2: true,
l1_ttl: Some(Duration::from_secs(60)), l2_ttl: Some(Duration::from_secs(300)), l1_required: false,
l2_required: false,
}
}
}
impl TieredCacheConfig {
pub fn with_l1_ttl(mut self, ttl: Duration) -> Self {
self.l1_ttl = Some(ttl);
self
}
pub fn with_l2_ttl(mut self, ttl: Duration) -> Self {
self.l2_ttl = Some(ttl);
self
}
pub fn require_l1(mut self) -> Self {
self.l1_required = true;
self
}
pub fn require_l2(mut self) -> Self {
self.l2_required = true;
self
}
pub fn no_write_l1(mut self) -> Self {
self.write_through_l1 = false;
self
}
pub fn no_write_l2(mut self) -> Self {
self.write_through_l2 = false;
self
}
}
pub struct TieredCache<L1, L2>
where
L1: CacheBackend,
L2: CacheBackend,
{
l1: L1,
l2: L2,
config: TieredCacheConfig,
}
impl<L1, L2> TieredCache<L1, L2>
where
L1: CacheBackend,
L2: CacheBackend,
{
pub fn new(l1: L1, l2: L2) -> Self {
Self {
l1,
l2,
config: TieredCacheConfig::default(),
}
}
pub fn with_config(l1: L1, l2: L2, config: TieredCacheConfig) -> Self {
Self { l1, l2, config }
}
pub fn l1(&self) -> &L1 {
&self.l1
}
pub fn l2(&self) -> &L2 {
&self.l2
}
pub fn config(&self) -> &TieredCacheConfig {
&self.config
}
}
impl<L1, L2> CacheBackend for TieredCache<L1, L2>
where
L1: CacheBackend,
L2: CacheBackend,
{
async fn get<T>(&self, key: &CacheKey) -> CacheResult<Option<T>>
where
T: serde::de::DeserializeOwned,
{
match self.l1.get::<T>(key).await {
Ok(Some(value)) => return Ok(Some(value)),
Ok(None) => {} Err(e) if self.config.l1_required => return Err(e),
Err(_) => {} }
match self.l2.get::<T>(key).await {
Ok(Some(value)) => {
Ok(Some(value))
}
Ok(None) => Ok(None),
Err(e) if self.config.l2_required => Err(e),
Err(_) => Ok(None),
}
}
async fn set<T>(&self, key: &CacheKey, value: &T, ttl: Option<Duration>) -> CacheResult<()>
where
T: serde::Serialize + Sync,
{
if self.config.write_through_l2 {
let l2_ttl = ttl.or(self.config.l2_ttl);
match self.l2.set(key, value, l2_ttl).await {
Ok(()) => {}
Err(e) if self.config.l2_required => return Err(e),
Err(_) => {} }
}
if self.config.write_through_l1 {
let l1_ttl = ttl
.map(|t| t.min(self.config.l1_ttl.unwrap_or(t)))
.or(self.config.l1_ttl);
match self.l1.set(key, value, l1_ttl).await {
Ok(()) => {}
Err(e) if self.config.l1_required => return Err(e),
Err(_) => {} }
}
Ok(())
}
async fn delete(&self, key: &CacheKey) -> CacheResult<bool> {
let l2_deleted = match self.l2.delete(key).await {
Ok(deleted) => deleted,
Err(e) if self.config.l2_required => return Err(e),
Err(_) => false,
};
let l1_deleted = match self.l1.delete(key).await {
Ok(deleted) => deleted,
Err(e) if self.config.l1_required => return Err(e),
Err(_) => false,
};
Ok(l1_deleted || l2_deleted)
}
async fn exists(&self, key: &CacheKey) -> CacheResult<bool> {
if let Ok(true) = self.l1.exists(key).await {
return Ok(true);
}
self.l2.exists(key).await
}
async fn invalidate_pattern(&self, pattern: &KeyPattern) -> CacheResult<u64> {
let l2_count = self.l2.invalidate_pattern(pattern).await.unwrap_or(0);
let l1_count = self.l1.invalidate_pattern(pattern).await.unwrap_or(0);
Ok(l1_count.max(l2_count))
}
async fn invalidate_tags(&self, tags: &[EntityTag]) -> CacheResult<u64> {
let l2_count = self.l2.invalidate_tags(tags).await.unwrap_or(0);
let l1_count = self.l1.invalidate_tags(tags).await.unwrap_or(0);
Ok(l1_count.max(l2_count))
}
async fn clear(&self) -> CacheResult<()> {
let l2_result = self.l2.clear().await;
let l1_result = self.l1.clear().await;
if self.config.l2_required {
l2_result?;
}
if self.config.l1_required {
l1_result?;
}
Ok(())
}
async fn len(&self) -> CacheResult<usize> {
self.l2.len().await
}
async fn stats(&self) -> CacheResult<BackendStats> {
let l1_stats = self.l1.stats().await.unwrap_or_default();
let l2_stats = self.l2.stats().await.unwrap_or_default();
Ok(BackendStats {
entries: l2_stats.entries, memory_bytes: l1_stats.memory_bytes, connections: l2_stats.connections, info: Some(format!(
"Tiered: L1={} entries, L2={} entries",
l1_stats.entries, l2_stats.entries
)),
})
}
}
pub struct TieredCacheBuilder<L1, L2>
where
L1: CacheBackend,
L2: CacheBackend,
{
l1: Option<L1>,
l2: Option<L2>,
config: TieredCacheConfig,
}
impl<L1, L2> Default for TieredCacheBuilder<L1, L2>
where
L1: CacheBackend,
L2: CacheBackend,
{
fn default() -> Self {
Self {
l1: None,
l2: None,
config: TieredCacheConfig::default(),
}
}
}
impl<L1, L2> TieredCacheBuilder<L1, L2>
where
L1: CacheBackend,
L2: CacheBackend,
{
pub fn new() -> Self {
Self::default()
}
pub fn l1(mut self, cache: L1) -> Self {
self.l1 = Some(cache);
self
}
pub fn l2(mut self, cache: L2) -> Self {
self.l2 = Some(cache);
self
}
pub fn config(mut self, config: TieredCacheConfig) -> Self {
self.config = config;
self
}
pub fn l1_ttl(mut self, ttl: Duration) -> Self {
self.config.l1_ttl = Some(ttl);
self
}
pub fn l2_ttl(mut self, ttl: Duration) -> Self {
self.config.l2_ttl = Some(ttl);
self
}
pub fn build(self) -> TieredCache<L1, L2> {
TieredCache {
l1: self.l1.expect("L1 cache must be set"),
l2: self.l2.expect("L2 cache must be set"),
config: self.config,
}
}
}
#[cfg(test)]
mod tests {
use super::super::backend::NoopCache;
use super::super::memory::{MemoryCache, MemoryCacheConfig};
use super::*;
#[tokio::test]
async fn test_tiered_cache_l1_hit() {
let l1 = MemoryCache::new(MemoryCacheConfig::new(100));
let l2 = MemoryCache::new(MemoryCacheConfig::new(100));
let cache = TieredCache::new(l1, l2);
let key = CacheKey::new("test", "key1");
cache.set(&key, &"hello", None).await.unwrap();
let value: Option<String> = cache.get(&key).await.unwrap();
assert_eq!(value, Some("hello".to_string()));
}
#[tokio::test]
async fn test_tiered_cache_l2_fallback() {
let l1 = MemoryCache::new(MemoryCacheConfig::new(100));
let l2 = MemoryCache::new(MemoryCacheConfig::new(100));
let key = CacheKey::new("test", "key1");
l2.set(&key, &"from l2", None).await.unwrap();
let cache = TieredCache::with_config(
l1,
l2,
TieredCacheConfig {
write_through_l1: true,
..Default::default()
},
);
let value: Option<String> = cache.get(&key).await.unwrap();
assert_eq!(value, Some("from l2".to_string()));
}
#[tokio::test]
async fn test_tiered_cache_invalidation() {
let l1 = MemoryCache::new(MemoryCacheConfig::new(100));
let l2 = MemoryCache::new(MemoryCacheConfig::new(100));
let cache = TieredCache::new(l1, l2);
let key = CacheKey::new("User", "id:1");
cache.set(&key, &"user data", None).await.unwrap();
let count = cache
.invalidate_pattern(&KeyPattern::entity("User"))
.await
.unwrap();
assert!(count >= 1);
let value: Option<String> = cache.get(&key).await.unwrap();
assert!(value.is_none());
}
#[tokio::test]
async fn test_tiered_cache_with_noop_l2() {
let l1 = MemoryCache::new(MemoryCacheConfig::new(100));
let l2 = NoopCache;
let cache = TieredCache::new(l1, l2);
let key = CacheKey::new("test", "key1");
cache.set(&key, &"hello", None).await.unwrap();
let value: Option<String> = cache.get(&key).await.unwrap();
assert_eq!(value, Some("hello".to_string()));
}
}