mod config;
mod error;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{OnceCell, RwLock};
use crate::{LocalClient, RedisClient};
pub use config::{CacheMode, MultilevelConfig};
pub use error::{MultilevelError, Result};
pub static MULTILEVEL_CACHE: OnceCell<Arc<MultilevelClient>> = OnceCell::const_new();
#[derive(Debug, Clone)]
struct CircuitBreaker {
is_open: bool,
opened_at: Option<Instant>,
failure_count: u32,
first_failure: Option<Instant>,
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self {
is_open: false,
opened_at: None,
failure_count: 0,
first_failure: None,
}
}
}
impl CircuitBreaker {
fn reset(&mut self) {
self.is_open = false;
self.failure_count = 0;
self.first_failure = None;
self.opened_at = None;
}
}
#[derive(Clone)]
pub struct MultilevelClient {
local: Arc<LocalClient>,
redis: Arc<RedisClient>,
config: MultilevelConfig,
circuit_breaker: Arc<RwLock<CircuitBreaker>>,
}
impl MultilevelClient {
pub fn new(local: Arc<LocalClient>, redis: Arc<RedisClient>, config: MultilevelConfig) -> Self {
Self {
local,
redis,
config,
circuit_breaker: Arc::new(RwLock::new(CircuitBreaker::default())),
}
}
async fn is_redis_available(&self) -> bool {
if !self.config.enable_circuit_breaker {
return true;
}
let mut breaker = self.circuit_breaker.write().await;
if breaker.is_open {
if let Some(opened_at) = breaker.opened_at {
if opened_at.elapsed() >= Duration::from_secs(self.config.reset_timeout_secs) {
log::info!("Attempting to reset circuit breaker after timeout");
breaker.reset();
return true;
}
}
return false;
}
true
}
async fn record_redis_failure(&self) {
if !self.config.enable_circuit_breaker {
return;
}
let mut breaker = self.circuit_breaker.write().await;
let now = Instant::now();
if let Some(first_failure) = breaker.first_failure {
if first_failure.elapsed() >= Duration::from_secs(self.config.failure_window_secs) {
breaker.reset();
return;
}
}
breaker.failure_count += 1;
if breaker.first_failure.is_none() {
breaker.first_failure = Some(now);
}
if breaker.failure_count >= self.config.failure_threshold {
breaker.is_open = true;
breaker.opened_at = Some(now);
log::warn!(
"Circuit breaker opened due to Redis failures: count={}",
breaker.failure_count
);
}
}
async fn record_redis_success(&self) {
if !self.config.enable_circuit_breaker {
return;
}
let mut breaker = self.circuit_breaker.write().await;
breaker.reset();
log::info!("Circuit breaker reset after successful Redis operation");
}
pub async fn set<T: ToString + Clone>(&self, key: &str, value: T) -> Result<()> {
match self.config.mode {
CacheMode::Local => self
.local
.set(key, value)
.await
.map_err(MultilevelError::Local),
CacheMode::Redis => {
if self.is_redis_available().await {
match self.redis.set(key, value).await {
Ok(_) => {
self.record_redis_success().await;
Ok(())
}
Err(e) => {
self.record_redis_failure().await;
Err(MultilevelError::Redis(e))
}
}
} else {
Err(MultilevelError::RedisUnavailable)
}
}
CacheMode::Multi => {
let mut redis_success = true;
if self.is_redis_available().await {
match self.redis.set(key, value.clone()).await {
Ok(_) => {
self.record_redis_success().await;
}
Err(e) => {
redis_success = false;
self.record_redis_failure().await;
log::warn!("Redis set failed, falling back to local cache: {}", e);
}
}
} else {
redis_success = false;
}
if self.config.write_through || !redis_success {
self.local
.set(key, value)
.await
.map_err(MultilevelError::Local)?;
}
Ok(())
}
}
}
pub async fn get(&self, key: &str) -> Result<Option<String>> {
match self.config.mode {
CacheMode::Local => self.local.get(key).await.map_err(MultilevelError::Local),
CacheMode::Redis => {
if self.is_redis_available().await {
match self.redis.get(key).await {
Ok(value) => {
self.record_redis_success().await;
Ok(value)
}
Err(e) => {
self.record_redis_failure().await;
Err(MultilevelError::Redis(e))
}
}
} else {
Err(MultilevelError::RedisUnavailable)
}
}
CacheMode::Multi => {
if let Some(value) = self.local.get(key).await.map_err(MultilevelError::Local)? {
return Ok(Some(value));
}
if self.is_redis_available().await {
match self.redis.get(key).await {
Ok(Some(value)) => {
self.record_redis_success().await;
if self.config.read_through {
if let Err(e) = self.local.set(key, &value).await {
log::warn!("Failed to read-through to local cache: {}", e);
}
}
Ok(Some(value))
}
Ok(None) => {
self.record_redis_success().await;
Ok(None)
}
Err(e) => {
self.record_redis_failure().await;
log::warn!("Redis get failed: {}", e);
Ok(None)
}
}
} else {
Ok(None)
}
}
}
}
pub async fn del(&self, key: &str) -> Result<bool> {
match self.config.mode {
CacheMode::Local => self.local.del(key).await.map_err(MultilevelError::Local),
CacheMode::Redis => {
if self.is_redis_available().await {
match self.redis.del(key).await {
Ok(deleted) => {
self.record_redis_success().await;
Ok(deleted)
}
Err(e) => {
self.record_redis_failure().await;
Err(MultilevelError::Redis(e))
}
}
} else {
Err(MultilevelError::RedisUnavailable)
}
}
CacheMode::Multi => {
let mut deleted = false;
if self.is_redis_available().await {
match self.redis.del(key).await {
Ok(redis_deleted) => {
self.record_redis_success().await;
deleted = redis_deleted;
}
Err(e) => {
self.record_redis_failure().await;
log::warn!("Redis delete failed: {}", e);
}
}
}
match self.local.del(key).await {
Ok(local_deleted) => deleted = deleted || local_deleted,
Err(e) => log::warn!("Local cache delete failed: {}", e),
}
Ok(deleted)
}
}
}
pub async fn clear(&self) -> Result<()> {
match self.config.mode {
CacheMode::Local => {
self.local.clear().await;
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(())
}
CacheMode::Redis => {
if self.is_redis_available().await {
Ok(())
} else {
Err(MultilevelError::RedisUnavailable)
}
}
CacheMode::Multi => {
if self.is_redis_available().await {
if let Err(e) = self.redis.clear().await {
self.record_redis_failure().await;
log::warn!("Failed to clear Redis cache: {}", e);
} else {
self.record_redis_success().await;
}
}
self.local.clear().await;
Ok(())
}
}
}
pub async fn keys(&self, pattern: &str) -> Result<Vec<String>> {
match self.config.mode {
CacheMode::Local => self.local.keys(pattern).await.map_err(MultilevelError::Local),
CacheMode::Redis => self.redis.keys(pattern).await.map_err(MultilevelError::Redis),
CacheMode::Multi => self.redis.keys(pattern).await.map_err(MultilevelError::Redis),
}
}
pub fn stats(&self) -> MultilevelCacheStats {
let local_stats = self.local.stats();
MultilevelCacheStats {
mode: self.config.mode,
local_hits: local_stats.hits,
local_misses: local_stats.misses,
local_size: local_stats.size,
circuit_breaker_enabled: self.config.enable_circuit_breaker,
}
}
pub async fn circuit_breaker_status(&self) -> CircuitBreakerStatus {
let breaker = self.circuit_breaker.read().await;
CircuitBreakerStatus {
enabled: self.config.enable_circuit_breaker,
is_open: breaker.is_open,
failure_count: breaker.failure_count,
opened_at: breaker.opened_at,
}
}
}
#[derive(Debug, Clone)]
pub struct MultilevelCacheStats {
pub mode: CacheMode,
pub local_hits: u64,
pub local_misses: u64,
pub local_size: u64,
pub circuit_breaker_enabled: bool,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerStatus {
pub enabled: bool,
pub is_open: bool,
pub failure_count: u32,
pub opened_at: Option<Instant>,
}
pub async fn get_client() -> Arc<MultilevelClient> {
MULTILEVEL_CACHE
.get()
.expect("Multilevel cache not initialized")
.clone()
}
pub async fn init(
local: Arc<LocalClient>,
redis: Arc<RedisClient>,
config: MultilevelConfig,
) -> Result<Arc<MultilevelClient>> {
let client = MultilevelClient::new(local, redis, config);
let client = Arc::new(client);
if MULTILEVEL_CACHE.set(client.clone()).is_err() {
return Err(MultilevelError::AlreadyInitialized);
}
Ok(client)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{LocalConfig, PoolConfig, RedisConfig, RedisMode, RedisNode};
async fn create_test_clients() -> (Arc<LocalClient>, Arc<RedisClient>) {
let redis_config = RedisConfig {
mode: RedisMode::Standalone,
node: Some(RedisNode {
host: "localhost".to_string(),
port: 6379,
}),
sentinel: None,
cluster: None,
pool: PoolConfig::default(),
password: Some("1qaz!QAZ".to_string()),
database: None,
};
let local_config = LocalConfig {
max_capacity: 100,
ttl: 60,
tti: 30,
};
let redis_client = Arc::new(RedisClient::new(redis_config).await.unwrap());
let local_client = Arc::new(LocalClient::new(local_config));
(local_client, redis_client)
}
#[tokio::test]
async fn test_local_mode() {
let (local_client, redis_client) = create_test_clients().await;
let config = MultilevelConfig {
mode: CacheMode::Local,
write_through: true,
read_through: true,
enable_circuit_breaker: true,
failure_threshold: 5,
failure_window_secs: 60,
reset_timeout_secs: 300,
};
let client = MultilevelClient::new(local_client, redis_client, config);
client.set("local_key", "local_value").await.unwrap();
let value = client.get("local_key").await.unwrap();
assert_eq!(value, Some("local_value".to_string()));
let redis_value = client.redis.get("local_key").await.unwrap();
assert_eq!(redis_value, None);
let deleted = client.del("local_key").await.unwrap();
assert!(deleted);
let value = client.get("local_key").await.unwrap();
assert_eq!(value, None);
}
#[tokio::test]
async fn test_redis_mode() {
let (local_client, redis_client) = create_test_clients().await;
let config = MultilevelConfig {
mode: CacheMode::Redis,
write_through: true,
read_through: true,
enable_circuit_breaker: true,
failure_threshold: 5,
failure_window_secs: 60,
reset_timeout_secs: 300,
};
let client = MultilevelClient::new(local_client, redis_client, config);
client.set("redis_key", "redis_value").await.unwrap();
let value = client.get("redis_key").await.unwrap();
assert_eq!(value, Some("redis_value".to_string()));
let local_value = client.local.get("redis_key").await.unwrap();
assert_eq!(local_value, None);
let deleted = client.del("redis_key").await.unwrap();
assert!(deleted);
let value = client.get("redis_key").await.unwrap();
assert_eq!(value, None);
}
#[tokio::test]
async fn test_multi_mode() {
let (local_client, redis_client) = create_test_clients().await;
let config = MultilevelConfig {
mode: CacheMode::Multi,
write_through: true,
read_through: true,
enable_circuit_breaker: true,
failure_threshold: 5,
failure_window_secs: 60,
reset_timeout_secs: 300,
};
let client = MultilevelClient::new(local_client, redis_client, config);
client.set("multi_key", "multi_value").await.unwrap();
let local_value = client.local.get("multi_key").await.unwrap();
let redis_value = client.redis.get("multi_key").await.unwrap();
assert_eq!(local_value, Some("multi_value".to_string()));
assert_eq!(redis_value, Some("multi_value".to_string()));
client.local.clear().await;
client.set("local_hit", "hit_value").await.unwrap();
let value = client.get("local_hit").await.unwrap();
assert_eq!(value, Some("hit_value".to_string()));
client.local.clear().await;
let value = client.get("multi_key").await.unwrap();
assert_eq!(value, Some("multi_value".to_string()));
let local_value = client.local.get("multi_key").await.unwrap();
assert_eq!(local_value, Some("multi_value".to_string()));
}
#[tokio::test]
async fn test_stats_and_clear() {
let (local_client, redis_client) = create_test_clients().await;
let config = MultilevelConfig {
mode: CacheMode::Multi,
write_through: true,
read_through: true,
enable_circuit_breaker: true,
failure_threshold: 5,
failure_window_secs: 60,
reset_timeout_secs: 300,
};
let client = MultilevelClient::new(local_client, redis_client, config);
for i in 0..5 {
let key = format!("key_{}", i);
let value = format!("value_{}", i);
client.set(&key, &value).await.unwrap();
let local_value = client.local.get(&key).await.unwrap();
assert_eq!(
local_value,
Some(value.clone()),
"Data should be in local cache immediately after set: key={}",
key
);
let _ = client.get(&key).await.unwrap();
let _ = client.get(&key).await.unwrap();
}
client.local.run_pending_tasks().await;
let local_stats = client.local.stats();
assert_eq!(local_stats.hits, 15);
assert_eq!(local_stats.misses, 0);
assert_eq!(local_stats.size, 5);
client.clear().await.unwrap();
for i in 0..5 {
let key = format!("key_{}", i);
let value = client.get(&key).await.unwrap();
assert_eq!(
value, None,
"Cache should be empty after clear: key={}",
key
);
}
let final_stats = client.local.stats();
println!("final_stats: {:?}", final_stats);
assert_eq!(final_stats.hits, 15);
assert_eq!(final_stats.misses, 5);
assert_eq!(final_stats.size, 0);
}
#[tokio::test]
async fn test_circuit_breaker() {
let (local_client, redis_client) = create_test_clients().await;
let config = MultilevelConfig {
mode: CacheMode::Multi,
write_through: true,
read_through: true,
enable_circuit_breaker: true,
failure_threshold: 2, failure_window_secs: 60,
reset_timeout_secs: 1, };
let client =
MultilevelClient::new(local_client.clone(), redis_client.clone(), config.clone());
client.set("cb_key", "cb_value").await.unwrap();
let bad_redis_config = RedisConfig {
mode: RedisMode::Standalone,
node: Some(RedisNode {
host: "localhost".to_string(),
port: 6379,
}),
sentinel: None,
cluster: None,
pool: PoolConfig::default(),
password: Some("wrong_password".to_string()),
database: None,
};
let bad_redis_client = Arc::new(RedisClient::new(bad_redis_config).await.unwrap());
let client = MultilevelClient::new(local_client, bad_redis_client, config);
for i in 0..3 {
let result = client.set("trigger_key", "trigger_value").await;
assert!(
result.is_ok(),
"Set operation should succeed with local fallback"
);
let status = client.circuit_breaker_status().await;
log::info!(
"Circuit breaker status after failure {}: open={}, count={}, opened_at={:?}",
i + 1,
status.is_open,
status.failure_count,
status.opened_at
);
}
let status = client.circuit_breaker_status().await;
assert!(status.is_open, "Expected circuit breaker to be open");
assert!(status.failure_count >= 2, "Expected at least 2 failures");
log::info!(
"Circuit breaker status before reset: open={}, count={}, opened_at={:?}",
status.is_open,
status.failure_count,
status.opened_at
);
client.set("local_key", "local_value").await.unwrap();
let value = client.get("local_key").await.unwrap();
assert_eq!(value, Some("local_value".to_string()));
client.local.run_pending_tasks().await;
let status = client.circuit_breaker_status().await;
log::info!(
"Circuit breaker status after reset: open={}, count={}, opened_at={:?}",
status.is_open,
status.failure_count,
status.opened_at
);
tokio::time::sleep(Duration::from_secs(2)).await;
let _ = client.is_redis_available().await;
let final_status = client.circuit_breaker_status().await;
log::info!(
"Circuit breaker final status: open={}, count={}, opened_at={:?}",
final_status.is_open,
final_status.failure_count,
final_status.opened_at
);
assert!(
!final_status.is_open,
"Expected circuit breaker to be closed after reset timeout"
);
assert_eq!(
final_status.failure_count, 0,
"Expected failure count to be reset to 0"
);
client.set("after_reset", "value").await.unwrap();
let value = client.get("after_reset").await.unwrap();
assert_eq!(value, Some("value".to_string()));
}
#[tokio::test]
async fn test_write_through_and_read_through() {
let (local_client, redis_client) = create_test_clients().await;
let config = MultilevelConfig {
mode: CacheMode::Multi,
write_through: false, read_through: false, enable_circuit_breaker: true,
failure_threshold: 5,
failure_window_secs: 60,
reset_timeout_secs: 300,
};
let client = MultilevelClient::new(local_client, redis_client, config);
client.set("wt_key", "wt_value").await.unwrap();
let redis_value = client.redis.get("wt_key").await.unwrap();
let local_value = client.local.get("wt_key").await.unwrap();
assert_eq!(redis_value, Some("wt_value".to_string()));
assert_eq!(local_value, None);
let value = client.get("wt_key").await.unwrap();
assert_eq!(value, Some("wt_value".to_string()));
let local_value = client.local.get("wt_key").await.unwrap();
assert_eq!(local_value, None);
}
}