use super::new_backend::CacheBackend;
use crate::error::{CacheError, Result};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone)]
pub struct TieredBackend {
l1: Arc<dyn CacheBackend>,
l2: Arc<dyn CacheBackend>,
auto_promote: bool,
degraded: Arc<tokio::sync::RwLock<bool>>,
}
impl TieredBackend {
pub fn new(l1: impl CacheBackend + 'static, l2: impl CacheBackend + 'static) -> Self {
Self {
l1: Arc::new(l1),
l2: Arc::new(l2),
auto_promote: true,
degraded: Arc::new(tokio::sync::RwLock::new(false)),
}
}
pub fn from_arc(l1: Arc<dyn CacheBackend>, l2: Arc<dyn CacheBackend>) -> Self {
Self {
l1,
l2,
auto_promote: true,
degraded: Arc::new(tokio::sync::RwLock::new(false)),
}
}
pub fn builder() -> TieredBackendBuilder {
TieredBackendBuilder::default()
}
pub async fn is_degraded(&self) -> bool {
*self.degraded.read().await
}
async fn set_degraded(&self, degraded: bool) {
*self.degraded.write().await = degraded;
}
}
#[derive(Default)]
pub struct TieredBackendBuilder {
l1: Option<Arc<dyn CacheBackend>>,
l2: Option<Arc<dyn CacheBackend>>,
auto_promote: bool,
}
impl TieredBackendBuilder {
pub fn l1(mut self, l1: impl CacheBackend + 'static) -> Self {
self.l1 = Some(Arc::new(l1));
self
}
pub fn l2(mut self, l2: impl CacheBackend + 'static) -> Self {
self.l2 = Some(Arc::new(l2));
self
}
pub fn auto_promote(mut self, auto_promote: bool) -> Self {
self.auto_promote = auto_promote;
self
}
pub fn build(self) -> Result<TieredBackend> {
let l1 = self
.l1
.ok_or_else(|| CacheError::ConfigError("L1 backend is required".to_string()))?;
let l2 = self
.l2
.ok_or_else(|| CacheError::ConfigError("L2 backend is required".to_string()))?;
Ok(TieredBackend {
l1,
l2,
auto_promote: self.auto_promote,
degraded: Arc::new(tokio::sync::RwLock::new(false)),
})
}
}
#[async_trait]
impl CacheBackend for TieredBackend {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
if let Some(value) = self.l1.get(key).await? {
return Ok(Some(value));
}
if let Some(value) = self.l2.get(key).await? {
if self.auto_promote {
let value_clone = value.clone();
if let Err(e) = self.l1.set(key, value_clone, None).await {
tracing::warn!("Failed to promote value to L1: {}", e);
}
}
return Ok(Some(value));
}
Ok(None)
}
async fn set(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
let value_clone = value.clone();
if let Err(e) = self.l1.set(key, value, ttl).await {
tracing::warn!("Failed to set value in L1: {}", e);
}
match self.l2.set(key, value_clone, ttl).await {
Ok(_) => {
self.set_degraded(false).await;
Ok(())
}
Err(e) => {
tracing::warn!("L2 backend failed, entering degraded mode: {}", e);
self.set_degraded(true).await;
Err(CacheError::Degraded(
"L2 backend unavailable, operating in L1-only mode".to_string(),
))
}
}
}
async fn delete(&self, key: &str) -> Result<()> {
let l1_result = self.l1.delete(key).await;
let l2_result = self.l2.delete(key).await;
if l2_result.is_err() {
self.set_degraded(true).await;
}
l1_result.or(l2_result)
}
async fn exists(&self, key: &str) -> Result<bool> {
if self.l1.exists(key).await? {
return Ok(true);
}
self.l2.exists(key).await
}
async fn clear(&self) -> Result<()> {
let l1_result = self.l1.clear().await;
let l2_result = self.l2.clear().await;
if l2_result.is_err() {
self.set_degraded(true).await;
}
l1_result.or(l2_result)
}
async fn close(&self) -> Result<()> {
let l1_result = self.l1.close().await;
let l2_result = self.l2.close().await;
l1_result.or(l2_result)
}
async fn ttl(&self, key: &str) -> Result<Option<Duration>> {
if let Ok(Some(ttl)) = self.l1.ttl(key).await {
return Ok(Some(ttl));
}
self.l2.ttl(key).await
}
async fn expire(&self, key: &str, ttl: Duration) -> Result<bool> {
let l1_result = self.l1.expire(key, ttl).await;
let l2_result = self.l2.expire(key, ttl).await;
if l2_result.is_err() {
self.set_degraded(true).await;
}
Ok(l1_result.unwrap_or(false) || l2_result.unwrap_or(false))
}
async fn health_check(&self) -> Result<bool> {
let l1_healthy = self.l1.health_check().await.unwrap_or(false);
let l2_healthy = self.l2.health_check().await.unwrap_or(false);
if !l2_healthy && l1_healthy {
self.set_degraded(true).await;
} else {
self.set_degraded(false).await;
}
Ok(l1_healthy)
}
async fn stats(&self) -> Result<HashMap<String, String>> {
let mut stats = HashMap::new();
stats.insert("type".to_string(), "tiered".to_string());
stats.insert("degraded".to_string(), self.is_degraded().await.to_string());
stats.insert("auto_promote".to_string(), self.auto_promote.to_string());
if let Ok(l1_stats) = self.l1.stats().await {
for (key, value) in l1_stats {
stats.insert(format!("l1_{}", key), value);
}
}
if let Ok(l2_stats) = self.l2.stats().await {
for (key, value) in l2_stats {
stats.insert(format!("l2_{}", key), value);
}
}
Ok(stats)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::memory::MemoryBackend;
#[tokio::test]
async fn test_tiered_backend_basic() {
let l1 = MemoryBackend::new();
let l2 = MemoryBackend::new(); let backend = TieredBackend::new(l1, l2);
backend.set("key1", b"value1".to_vec(), None).await.unwrap();
let value = backend.get("key1").await.unwrap();
assert_eq!(value, Some(b"value1".to_vec()));
assert!(backend.exists("key1").await.unwrap());
assert!(!backend.exists("key2").await.unwrap());
backend.delete("key1").await.unwrap();
assert!(!backend.exists("key1").await.unwrap());
}
#[tokio::test]
async fn test_tiered_backend_l1_miss_l2_hit() {
let l1 = MemoryBackend::new();
let l2 = MemoryBackend::new();
let backend = TieredBackend::new(l1.clone(), l2.clone());
l2.set("key1", b"value1".to_vec(), None).await.unwrap();
let value: Option<Vec<u8>> = backend.get("key1").await.unwrap();
assert_eq!(value, Some(b"value1".to_vec()));
let exists: bool = l1.exists("key1").await.unwrap();
assert!(exists);
}
#[tokio::test]
async fn test_tiered_backend_stats() {
let l1 = MemoryBackend::new();
let l2 = MemoryBackend::new();
let backend = TieredBackend::new(l1.clone(), l2.clone());
let stats = backend.stats().await.unwrap();
assert_eq!(stats.get("type"), Some(&"tiered".to_string()));
assert_eq!(stats.get("degraded"), Some(&"false".to_string()));
assert!(stats.contains_key("l1_type"));
assert!(stats.contains_key("l2_type"));
}
}