use super::tiered_cache::TieredCacheControl;
use super::ttl_control::TtlControl;
use super::{db_loader::DbFallbackManager, l2::L2Client, CacheOps};
use crate::backend::l1::L1Backend;
use crate::bloom_filter::{BloomFilterManager, BloomFilterOptions, BloomFilterShared};
use crate::config::TwoLevelConfig;
use crate::error::Result;
use crate::metrics::GLOBAL_METRICS;
use crate::recovery::{
health::{HealthChecker, HealthState},
wal::{Operation, WalEntry, WalManager},
};
use crate::serialization::{Serializer, SerializerEnum};
use crate::smart_strategy::SmartStrategyManager;
use crate::sync::{
common::{BatchOperation, BatchWriterConfig},
invalidation::{InvalidationPublisher, InvalidationSubscriber},
optimized_batch_writer::OptimizedBatchWriter,
promotion::PromotionManager,
warmup::WarmupManager,
};
use crate::utils::{validate_cache_key, validate_key_length, validate_value_size};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tracing::{debug, info, instrument, warn};
pub struct TwoLevelClient {
service_name: String,
config: TwoLevelConfig,
l1: Option<Arc<L1Backend>>,
l2: Option<Arc<L2Client>>,
serializer: SerializerEnum,
health_state: Arc<RwLock<HealthState>>,
wal: Arc<WalManager>,
promotion_mgr: Option<Arc<PromotionManager>>,
batch_writer: Option<Arc<OptimizedBatchWriter>>,
publisher: Option<Arc<InvalidationPublisher>>,
db_fallback_mgr: Option<Arc<DbFallbackManager>>,
bloom_filter: Option<BloomFilterShared>,
bloom_filter_mgr: Option<Arc<BloomFilterManager>>,
warmup_mgr: Option<Arc<WarmupManager>>,
#[cfg(feature = "smart-strategy")]
smart_strategy: Option<Arc<SmartStrategyManager>>,
#[allow(dead_code)]
health_checker_handle: Option<JoinHandle<()>>,
#[allow(dead_code)]
batch_writer_handle: Option<JoinHandle<()>>,
}
impl Clone for TwoLevelClient {
fn clone(&self) -> Self {
Self {
service_name: self.service_name.clone(),
config: self.config.clone(),
l1: self.l1.clone(),
l2: self.l2.clone(),
serializer: self.serializer.clone(),
health_state: self.health_state.clone(),
wal: self.wal.clone(),
promotion_mgr: self.promotion_mgr.clone(),
batch_writer: self.batch_writer.clone(),
publisher: self.publisher.clone(),
db_fallback_mgr: self.db_fallback_mgr.clone(),
bloom_filter: self.bloom_filter.clone(),
bloom_filter_mgr: self.bloom_filter_mgr.clone(),
warmup_mgr: self.warmup_mgr.clone(),
#[cfg(feature = "smart-strategy")]
smart_strategy: self.smart_strategy.clone(),
health_checker_handle: None,
batch_writer_handle: None,
}
}
}
impl TwoLevelClient {
#[allow(clippy::too_many_arguments)]
#[instrument(
skip(config, l1, l2_backend, serializer),
level = "info",
name = "init_two_level_client"
)]
pub async fn new(
service_name: String,
config: TwoLevelConfig,
l1: Arc<L1Backend>,
l2_backend: Arc<crate::backend::l2::L2Backend>,
serializer: SerializerEnum,
) -> Result<Self> {
let health_state = Arc::new(RwLock::new(HealthState::Healthy));
let wal = Arc::new(WalManager::new(&service_name).await?);
let l2 = Arc::new(
L2Client::new(service_name.clone(), l2_backend.clone(), serializer.clone()).await?,
);
let command_timeout_ms = l2_backend.command_timeout_ms();
let checker = HealthChecker::new(
l2_backend.clone(),
health_state.clone(),
wal.clone(),
service_name.clone(),
command_timeout_ms,
);
let health_checker_handle = tokio::spawn(async move { checker.start().await });
let channel_name = Self::resolve_channel_name(&service_name, &config);
let (publisher, _sub_handle) = match l2_backend.get_raw_client() {
Ok(client) => {
let sub = InvalidationSubscriber::new(
client.clone(),
l1.clone(),
channel_name.clone(),
health_state.clone(),
);
let publisher = Arc::new(InvalidationPublisher::new(
client.get_connection_manager().await?,
channel_name,
));
let sub_clone = sub;
tokio::spawn(async move {
if let Err(e) = sub_clone.start().await {
warn!("Invalidation subscriber error: {}", e);
}
});
(Some(publisher), Option::<tokio::task::JoinHandle<()>>::None) }
Err(crate::error::CacheError::NotSupported(_)) => {
warn!(
"Invalidation not supported for this backend mode (likely Cluster), skipping"
);
(None, Option::<tokio::task::JoinHandle<()>>::None)
}
Err(e) => return Err(e),
};
let promotion_mgr = if config.promote_on_hit {
Some(Arc::new(PromotionManager::new(
l1.clone(),
l2_backend.clone(),
health_state.clone(),
)))
} else {
None
};
let (batch_writer, batch_writer_handle) = if config.enable_batch_write {
let batch_config = crate::sync::optimized_batch_writer::OptimizedBatchWriterConfig {
base: BatchWriterConfig {
max_batch_size: config.batch_size,
flush_interval_ms: config.batch_interval_ms,
max_buffer_size: config.batch_size * 10,
},
max_retry_count: 3,
retry_delay_ms: 1000,
max_buffer_size: config.batch_size * 10,
high_water_mark: (config.batch_size as f64 * 0.8) as usize,
low_water_mark: (config.batch_size as f64 * 0.2) as usize,
enable_wal: true,
enable_compression: true,
compression_threshold: 1024,
};
let bw = Arc::new(OptimizedBatchWriter::new(
service_name.clone(),
l2_backend.clone(),
batch_config,
wal.clone(),
));
let bw_clone = bw.clone();
let handle = tokio::spawn(async move { bw_clone.start().await });
(Some(bw), Some(handle))
} else {
(None, None)
};
let (bloom_filter, bloom_filter_mgr) = if let Some(bloom_config) = &config.bloom_filter {
let options = BloomFilterOptions::new(
bloom_config.name.clone(),
bloom_config.expected_elements,
bloom_config.false_positive_rate,
);
let mgr = Arc::new(BloomFilterManager::new());
let filter = mgr.get_or_create(options).await?;
(Some(filter), Some(mgr))
} else {
(None, None)
};
let warmup_mgr = config.warmup.as_ref().map(|warmup_config| {
Arc::new(WarmupManager::new(
service_name.clone(),
warmup_config.clone(),
))
});
Ok(Self {
service_name: service_name.to_string(),
config,
l1: Some(l1),
l2: Some(l2),
serializer,
health_state,
wal,
promotion_mgr,
batch_writer,
publisher,
db_fallback_mgr: None,
bloom_filter,
bloom_filter_mgr,
warmup_mgr,
#[cfg(feature = "smart-strategy")]
smart_strategy: None,
health_checker_handle: Some(health_checker_handle),
batch_writer_handle,
})
}
#[instrument(skip(self), level = "warn")]
async fn handle_l2_failure(&self) {
warn!("L2 failure detected for service: {}", self.service_name);
let mut state_guard = self.health_state.write().await;
let current_state = *state_guard;
match current_state {
HealthState::Healthy => {
warn!(
"Service {} transitioning from Healthy to Degraded",
self.service_name
);
*state_guard = HealthState::Degraded {
since: std::time::Instant::now(),
failure_count: 1,
};
crate::metrics::GLOBAL_METRICS.set_health(&self.service_name, 0);
}
HealthState::Degraded {
since,
failure_count,
} => {
let new_failure_count = failure_count + 1;
warn!(
"Service {} remains Degraded, failure count increased: {} -> {}",
self.service_name, failure_count, new_failure_count
);
*state_guard = HealthState::Degraded {
since,
failure_count: new_failure_count,
};
}
HealthState::Recovering {
since: _,
success_count: _,
} => {
warn!(
"Service {} recovery failed, transitioning back to Degraded from Recovering",
self.service_name
);
*state_guard = HealthState::Degraded {
since: std::time::Instant::now(),
failure_count: 1,
};
crate::metrics::GLOBAL_METRICS.set_health(&self.service_name, 0);
}
HealthState::WalReplaying { .. } => {
warn!(
"Service {} is replaying WAL, ignoring failure during replay",
self.service_name
);
}
}
info!(
"Service {} degradation strategy applied, current state: {:?}",
self.service_name, *state_guard
);
}
pub async fn get_health_state(&self) -> HealthState {
*self.health_state.read().await
}
fn resolve_channel_name(service_name: &str, config: &TwoLevelConfig) -> String {
use crate::config::InvalidationChannelConfig;
match &config.invalidation_channel {
Some(InvalidationChannelConfig::Custom(name)) => name.clone(),
Some(InvalidationChannelConfig::Structured {
prefix,
use_service_name,
}) => {
let prefix = prefix.as_deref().unwrap_or("cache:invalidate");
if *use_service_name {
format!("{}:{}", prefix, service_name)
} else {
prefix.to_string()
}
}
None => format!("cache:invalidate:{}", service_name),
}
}
pub async fn check_bloom_filter(&self, key: &str) -> Result<Option<bool>> {
if let Some(bloom_filter) = &self.bloom_filter {
let key_bytes = key.as_bytes();
let contains = bloom_filter.contains(key_bytes)?;
Ok(Some(contains))
} else {
Ok(None)
}
}
#[instrument(skip(self, loader), level = "info", fields(key_count = keys.len()))]
pub async fn warmup<T, F, Fut>(
&self,
keys: Vec<String>,
loader: F,
ttl: Option<u64>,
) -> Result<()>
where
T: serde::Serialize + Send + Sync,
F: Fn(Vec<String>) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<Vec<(String, T)>>> + Send,
{
if keys.is_empty() {
return Ok(());
}
let data = loader(keys).await?;
for (key, value) in data {
self.set(&key, &value, ttl).await?;
}
Ok(())
}
pub async fn run_warmup(&self) -> Result<()> {
if let Some(warmup_mgr) = &self.warmup_mgr {
let client: Arc<Self> = Arc::new(self.clone());
let result = warmup_mgr
.run_warmup(move |keys: Vec<String>| {
let client = Arc::clone(&client);
Box::pin(async move {
let mut result = HashMap::new();
for key in keys {
match client.get_bytes(&key).await {
Ok(Some(value)) => {
result.insert(key, value);
}
Ok(None) => {
debug!("Warmup: key not found in L2: {}", key);
}
Err(e) => {
warn!("Warmup: failed to get key {} from L2: {}", key, e);
}
}
}
Ok(result)
})
})
.await?;
if result.success {
info!(
"Cache warmup completed successfully, loaded {} items",
result.loaded
);
} else {
warn!(
"Cache warmup completed with some failures, loaded: {}, failed: {}",
result.loaded, result.failed
);
}
}
Ok(())
}
#[doc(hidden)]
pub fn warmup_manager(&self) -> Option<&Arc<WarmupManager>> {
self.warmup_mgr.as_ref()
}
#[instrument(skip(self), level = "info", fields(service = %self.service_name))]
pub async fn shutdown(&self) -> Result<()> {
info!("正在关闭TwoLevelClient...");
if let Some(handle) = &self.health_checker_handle {
info!("停止健康检查器");
handle.abort();
}
if let Some(handle) = &self.batch_writer_handle {
info!("停止批处理写入器");
handle.abort();
}
if let Some(_l1) = &self.l1 {
info!("关闭L1缓存");
}
if let Some(l2) = &self.l2 {
info!("关闭L2缓存");
l2.shutdown().await?;
}
info!("WAL日志已处理");
info!("TwoLevelClient已关闭");
Ok(())
}
}
#[async_trait]
impl CacheOps for TwoLevelClient {
fn serializer(&self) -> &SerializerEnum {
&self.serializer
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn into_any_arc(self: Arc<Self>) -> Arc<dyn std::any::Any + Send + Sync> {
self
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn lock(&self, key: &str, ttl: u64) -> Result<Option<String>> {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
if ttl == 0 {
return Err(crate::error::CacheError::InvalidInput(
"Lock TTL must be greater than 0".to_string(),
));
}
debug!("TwoLevelClient lock called: key={}, ttl={}", key, ttl);
if let Some(l2) = &self.l2 {
debug!("L2 backend available, attempting lock acquisition");
let result = l2.lock(key, ttl).await;
debug!("L2 lock result: {:?}", result);
return result;
}
warn!("Cannot acquire lock, L2 unavailable or not configured");
Ok(None)
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn unlock(&self, key: &str, value: &str) -> Result<bool> {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
if value.is_empty() {
return Err(crate::error::CacheError::InvalidInput(
"Unlock value cannot be empty".to_string(),
));
}
if let Some(l2) = &self.l2 {
return l2.unlock(key, value).await;
}
warn!("Cannot release lock, L2 unavailable or not configured");
Ok(false)
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn get_bytes(&self, key: &str) -> Result<Option<Vec<u8>>> {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
if let (Some(l1), Some(l2)) = (&self.l1, &self.l2) {
if let Some(bloom_filter) = &self.bloom_filter {
let key_bytes = key.as_bytes();
match bloom_filter.contains(key_bytes) {
Ok(contains) => {
if !contains {
GLOBAL_METRICS.record_request(
&self.service_name,
"BloomFilter",
"get",
"miss",
);
return Ok(None);
}
GLOBAL_METRICS.record_request(
&self.service_name,
"BloomFilter",
"get",
"hit",
);
}
Err(e) => {
warn!("BloomFilter check failed: {}", e);
}
}
}
GLOBAL_METRICS.record_request(&self.service_name, "L1", "get", "attempt");
let start = std::time::Instant::now();
if let Some((bytes, _)) = l1.get_with_metadata(key).await? {
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L1", "get", duration);
GLOBAL_METRICS.record_request(&self.service_name, "L1", "get", "hit");
#[cfg(feature = "smart-strategy")]
if let Some(strategy) = &self.smart_strategy {
strategy.record_access(true);
}
return Ok(Some(bytes));
}
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L1", "get", duration);
GLOBAL_METRICS.record_request(&self.service_name, "L1", "get", "miss");
let state = self.health_state.read().await;
let is_degraded = matches!(*state, HealthState::Degraded { .. });
drop(state);
if !is_degraded {
GLOBAL_METRICS.record_request(&self.service_name, "L2", "get", "attempt");
let start = std::time::Instant::now();
match l2.get_bytes(key).await {
Ok(Some(value)) => {
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L2", "get", duration);
GLOBAL_METRICS.record_request(&self.service_name, "L2", "get", "hit");
#[cfg(feature = "smart-strategy")]
if let Some(strategy) = &self.smart_strategy {
strategy.record_access(true);
}
if self.config.promote_on_hit {
if let Some(promotion_mgr) = &self.promotion_mgr {
let promo = promotion_mgr.clone();
let k = key.to_string();
let v = value.clone();
tokio::spawn(async move {
let _ = promo.promote(k, v, 0).await;
});
}
}
return Ok(Some(value));
}
Ok(None) => {
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L2", "get", duration);
GLOBAL_METRICS.record_request(&self.service_name, "L2", "get", "miss");
}
Err(_e) => {
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L2", "get", duration);
self.handle_l2_failure().await;
}
}
}
if let Some(db_fallback_mgr) = &self.db_fallback_mgr {
GLOBAL_METRICS.record_request(&self.service_name, "DB", "fallback", "attempt");
let start = std::time::Instant::now();
match db_fallback_mgr.fallback_load(key).await {
Ok(Some(data)) => {
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(
&self.service_name,
"DB",
"fallback",
duration,
);
GLOBAL_METRICS.record_request(&self.service_name, "DB", "fallback", "hit");
#[cfg(feature = "smart-strategy")]
if let Some(strategy) = &self.smart_strategy {
strategy.record_access(true);
}
if let Err(e) = self.set_bytes(key, data.clone(), None).await {
warn!("Failed to write fallback data to cache: {}", e);
}
return Ok(Some(data));
}
Ok(None) => {
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(
&self.service_name,
"DB",
"fallback",
duration,
);
GLOBAL_METRICS.record_request(&self.service_name, "DB", "fallback", "miss");
debug!("Database fallback miss for key: {}", key);
}
Err(e) => {
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(
&self.service_name,
"DB",
"fallback",
duration,
);
warn!("Database fallback failed for key {}: {}", key, e);
}
}
}
}
#[cfg(feature = "smart-strategy")]
if let Some(strategy) = &self.smart_strategy {
strategy.record_access(false);
}
Ok(None)
}
#[instrument(skip(self, value), level = "debug", fields(service = %self.service_name))]
async fn set_bytes(&self, key: &str, value: Vec<u8>, ttl: Option<u64>) -> Result<()> {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
let max_value_size = self.config.max_value_size.unwrap_or(10 * 1024 * 1024);
validate_value_size(&value, max_value_size)?;
let bytes = value;
if let Some(bloom_filter) = &self.bloom_filter {
let key_bytes = key.as_bytes().to_vec();
match bloom_filter.add(&key_bytes).await {
Ok(_) => {
GLOBAL_METRICS.record_request(&self.service_name, "BloomFilter", "set", "add");
}
Err(e) => {
warn!("Failed to add key to BloomFilter: {}", e);
}
}
}
if let (Some(l1), Some(l2)) = (&self.l1, &self.l2) {
let start = std::time::Instant::now();
debug!("Writing to L1: key={}", key);
l1.set_bytes(key, bytes.clone(), ttl).await?;
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L1", "set", duration);
debug!("L1 write successful: key={}", key);
let state = self.health_state.read().await;
let current_state = *state;
debug!("Current health state: {:?}", current_state);
match current_state {
HealthState::Healthy | HealthState::Recovering { .. } => {
drop(state);
if self.config.enable_batch_write {
if let Some(batch_writer) = &self.batch_writer {
batch_writer
.enqueue_operation(
BatchOperation::Set {
key: key.to_string(),
value: bytes,
ttl,
},
100, )
.await?;
}
} else {
l2.set_bytes(key, bytes, ttl).await?;
}
}
HealthState::Degraded { .. } => {
drop(state);
debug!("L2 is degraded, writing to WAL: key={}", key);
self.wal
.append(WalEntry {
timestamp: std::time::SystemTime::now(),
operation: Operation::Set,
key: key.to_string(),
value: Some(bytes),
ttl: ttl.map(|t| t as i64),
})
.await?;
debug!("WAL write successful: key={}", key);
}
HealthState::WalReplaying { .. } => {
drop(state);
debug!("L2 is replaying WAL, writing to WAL: key={}", key);
self.wal
.append(WalEntry {
timestamp: std::time::SystemTime::now(),
operation: Operation::Set,
key: key.to_string(),
value: Some(bytes),
ttl: ttl.map(|t| t as i64),
})
.await?;
debug!("WAL write successful: key={}", key);
}
}
}
Ok(())
}
#[instrument(skip(self, value), level = "debug", fields(service = %self.service_name))]
async fn set_l1_bytes(&self, key: &str, value: Vec<u8>, ttl: Option<u64>) -> Result<()> {
if let Some(l1) = &self.l1 {
let start = std::time::Instant::now();
l1.set_bytes(key, value, ttl).await?;
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L1", "set", duration);
}
Ok(())
}
#[instrument(skip(self, value), level = "debug", fields(service = %self.service_name))]
async fn set_l2_bytes(&self, key: &str, value: Vec<u8>, ttl: Option<u64>) -> Result<()> {
if let Some(l2) = &self.l2 {
let state = self.health_state.read().await;
match *state {
HealthState::Healthy | HealthState::Recovering { .. } => {
drop(state);
l2.set_bytes(key, value, ttl).await?;
}
HealthState::Degraded { .. } => {
return Err(crate::error::CacheError::L2Error(
"L2 is degraded".to_string(),
));
}
HealthState::WalReplaying { .. } => {
return Err(crate::error::CacheError::L2Error(
"L2 is replaying WAL".to_string(),
));
}
}
}
Ok(())
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn get_l1_bytes(&self, key: &str) -> Result<Option<Vec<u8>>> {
if let Some(l1) = &self.l1 {
let start = std::time::Instant::now();
let result = l1.get_bytes(key).await?;
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L1", "get", duration);
Ok(result)
} else {
Ok(None)
}
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn get_l2_bytes(&self, key: &str) -> Result<Option<Vec<u8>>> {
if let Some(l2) = &self.l2 {
let start = std::time::Instant::now();
let result = l2.get_bytes(key).await?;
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L2", "get", duration);
Ok(result)
} else {
Ok(None)
}
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn delete(&self, key: &str) -> Result<()> {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
if let (Some(l1), Some(l2)) = (&self.l1, &self.l2) {
l1.delete(key).await?;
let state = self.health_state.read().await;
match *state {
HealthState::Healthy | HealthState::Recovering { .. } => {
drop(state);
match l2.delete(key).await {
Ok(_) => {
if let Some(publisher) = &self.publisher {
let _ = publisher.publish(key).await;
}
}
Err(e) => {
self.handle_l2_failure().await;
return Err(e);
}
}
}
HealthState::Degraded { .. } => {
drop(state);
self.wal
.append(WalEntry {
timestamp: std::time::SystemTime::now(),
operation: Operation::Delete,
key: key.to_string(),
value: None,
ttl: None,
})
.await?;
}
HealthState::WalReplaying { .. } => {
drop(state);
tracing::warn!(
"Cannot delete during WAL replay, service={}",
self.service_name
);
return Err(crate::error::CacheError::L2Error(
"L2 is replaying WAL".to_string(),
));
}
}
}
Ok(())
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn clear_l1(&self) -> Result<()> {
if let Some(l1) = &self.l1 {
l1.clear()?;
GLOBAL_METRICS.record_request(&self.service_name, "L1", "clear", "success");
}
Ok(())
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn clear_l2(&self) -> Result<()> {
if let Some(l2) = &self.l2 {
l2.clear().await?;
GLOBAL_METRICS.record_request(&self.service_name, "L2", "clear", "success");
}
Ok(())
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn clear_wal(&self) -> Result<()> {
self.wal.clear().await?;
GLOBAL_METRICS.record_request(&self.service_name, "WAL", "clear", "success");
Ok(())
}
}
impl TwoLevelClient {
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
pub async fn get<T: serde::de::DeserializeOwned + Send>(&self, key: &str) -> Result<Option<T>> {
if let Some(bytes) = self.get_bytes(key).await? {
return Ok(Some(self.serializer.deserialize(&bytes)?));
}
Ok(None)
}
#[instrument(skip(self, value), level = "debug", fields(service = %self.service_name))]
pub async fn set<T: serde::Serialize + Send + Sync>(
&self,
key: &str,
value: &T,
ttl: Option<u64>,
) -> Result<()> {
let bytes = self.serializer.serialize(value)?;
self.set_bytes(key, bytes, ttl).await
}
#[instrument(skip(self, value), level = "debug", fields(service = %self.service_name))]
pub async fn set_l1_only<T: serde::Serialize + Send + Sync>(
&self,
key: &str,
value: &T,
ttl: Option<u64>,
) -> Result<()> {
let bytes = self.serializer.serialize(value)?;
CacheOps::set_l1_bytes(self, key, bytes, ttl).await
}
#[instrument(skip(self, value), level = "debug", fields(service = %self.service_name))]
pub async fn set_l2_only<T: serde::Serialize + Send + Sync>(
&self,
key: &str,
value: &T,
ttl: Option<u64>,
) -> Result<()> {
let bytes = self.serializer.serialize(value)?;
CacheOps::set_l2_bytes(self, key, bytes, ttl).await
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
pub async fn get_l1_only<T: serde::de::DeserializeOwned + Send>(
&self,
key: &str,
) -> Result<Option<T>> {
if let Some(bytes) = CacheOps::get_l1_bytes(self, key).await? {
return Ok(Some(self.serializer.deserialize(&bytes)?));
}
Ok(None)
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
pub async fn get_l2_only<T: serde::de::DeserializeOwned + Send>(
&self,
key: &str,
) -> Result<Option<T>> {
if let Some(bytes) = CacheOps::get_l2_bytes(self, key).await? {
return Ok(Some(self.serializer.deserialize(&bytes)?));
}
Ok(None)
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
pub async fn ping_l2(&self) -> Result<()> {
if let Some(l2) = &self.l2 {
l2.ping().await
} else {
Err(crate::error::CacheError::L2Error(
"L2 client not available".to_string(),
))
}
}
#[instrument(skip(self), level = "info", fields(service = %self.service_name))]
pub fn set_db_fallback_manager(&mut self, db_fallback_mgr: Arc<DbFallbackManager>) {
info!(
"Setting database fallback manager for service: {}",
self.service_name
);
self.db_fallback_mgr = Some(db_fallback_mgr);
}
#[doc(hidden)]
pub fn get_db_fallback_manager(&self) -> Option<Arc<DbFallbackManager>> {
self.db_fallback_mgr.clone()
}
#[cfg(feature = "smart-strategy")]
pub fn enable_smart_strategy(
&mut self,
config: Option<crate::smart_strategy::SmartStrategyConfig>,
) {
self.smart_strategy = Some(Arc::new(crate::smart_strategy::SmartStrategyManager::new(
config,
)));
}
#[cfg(feature = "smart-strategy")]
pub fn is_smart_strategy_enabled(&self) -> bool {
self.smart_strategy.is_some()
}
#[cfg(feature = "smart-strategy")]
pub fn smart_strategy(&self) -> Option<&Arc<crate::smart_strategy::SmartStrategyManager>> {
self.smart_strategy.as_ref()
}
pub async fn del_pattern(&self, pattern: &str) -> Result<u64> {
if let Some(l2) = &self.l2 {
let deleted = l2.backend().del_pattern(pattern).await?;
GLOBAL_METRICS.record_request(&self.service_name, "L2", "del_pattern", "success");
return Ok(deleted);
}
Ok(0)
}
}
#[async_trait]
impl TtlControl for TwoLevelClient {
async fn get_l1_ttl(&self, key: &str) -> Result<Option<u64>> {
if let Some(l1) = &self.l1 {
l1.ttl(key).await
} else {
Ok(None)
}
}
async fn get_l2_ttl(&self, key: &str) -> Result<Option<u64>> {
if let Some(l2) = &self.l2 {
match l2.backend().ttl(key).await {
Ok(Some(ttl)) if ttl > 0 => Ok(Some(ttl)),
Ok(Some(_)) => Ok(Some(0)), Ok(None) => Ok(None), Err(e) => Err(e),
}
} else {
Ok(None)
}
}
async fn get_ttl(&self, key: &str) -> Result<Option<u64>> {
self.get_l2_ttl(key).await
}
async fn refresh_l1_ttl(&self, key: &str, ttl: u64) -> Result<bool> {
if let Some(l1) = &self.l1 {
l1.refresh_ttl(key, ttl).await
} else {
Ok(false)
}
}
async fn refresh_l2_ttl(&self, key: &str, ttl: u64) -> Result<bool> {
if let Some(l2) = &self.l2 {
let state = self.health_state.read().await;
match *state {
HealthState::Healthy | HealthState::Recovering { .. } => {
drop(state);
let result = l2.backend().expire(key, ttl).await?;
if result {
GLOBAL_METRICS.record_request(
&self.service_name,
"L2",
"expire",
"success",
);
}
Ok(result)
}
HealthState::Degraded { .. } | HealthState::WalReplaying { .. } => {
drop(state);
Ok(false)
}
}
} else {
Ok(false)
}
}
async fn refresh_ttl(&self, key: &str, ttl: u64) -> Result<bool> {
let l1_result = self.refresh_l1_ttl(key, ttl).await?;
let l2_result = self.refresh_l2_ttl(key, ttl).await?;
Ok(l1_result || l2_result)
}
async fn touch(&self, key: &str) -> Result<bool> {
let state = self.health_state.read().await;
match *state {
HealthState::Healthy | HealthState::Recovering { .. } => {
drop(state);
if let Some(current_ttl) = self.get_l2_ttl(key).await? {
if current_ttl == 0 {
return Ok(false);
}
if let Some(l2) = &self.l2 {
return l2.backend().expire(key, current_ttl).await;
}
}
Ok(true)
}
HealthState::Degraded { .. } | HealthState::WalReplaying { .. } => {
drop(state);
Ok(false)
}
}
}
}
#[async_trait]
impl TieredCacheControl for TwoLevelClient {
async fn get_l1_direct(&self, key: &str) -> Result<Option<Vec<u8>>> {
if let Some(l1) = &self.l1 {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
GLOBAL_METRICS.record_request(&self.service_name, "L1", "get_direct", "attempt");
let start = std::time::Instant::now();
let result = l1.get_bytes(key).await;
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L1", "get_direct", duration);
match result {
Ok(Some(value)) => {
GLOBAL_METRICS.record_request(&self.service_name, "L1", "get_direct", "hit");
Ok(Some(value))
}
Ok(None) => {
GLOBAL_METRICS.record_request(&self.service_name, "L1", "get_direct", "miss");
Ok(None)
}
Err(e) => Err(e),
}
} else {
Ok(None)
}
}
async fn set_l1_direct(&self, key: &str, value: Vec<u8>, ttl: Option<u64>) -> Result<()> {
if let Some(l1) = &self.l1 {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
let max_value_size = self.config.max_value_size.unwrap_or(10 * 1024 * 1024);
validate_value_size(&value, max_value_size)?;
let start = std::time::Instant::now();
l1.set_bytes(key, value, ttl).await?;
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L1", "set_direct", duration);
GLOBAL_METRICS.record_request(&self.service_name, "L1", "set_direct", "success");
Ok(())
} else {
Err(crate::error::CacheError::NotSupported(
"L1 not available".to_string(),
))
}
}
async fn delete_l1_direct(&self, key: &str) -> Result<bool> {
if let Some(l1) = &self.l1 {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
l1.delete(key).await?;
GLOBAL_METRICS.record_request(&self.service_name, "L1", "delete_direct", "success");
Ok(true)
} else {
Ok(false)
}
}
async fn get_l2_direct(&self, key: &str) -> Result<Option<Vec<u8>>> {
if let Some(l2) = &self.l2 {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
GLOBAL_METRICS.record_request(&self.service_name, "L2", "get_direct", "attempt");
let start = std::time::Instant::now();
let result = l2.get_bytes(key).await;
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L2", "get_direct", duration);
match result {
Ok(Some(value)) => {
GLOBAL_METRICS.record_request(&self.service_name, "L2", "get_direct", "hit");
Ok(Some(value))
}
Ok(None) => {
GLOBAL_METRICS.record_request(&self.service_name, "L2", "get_direct", "miss");
Ok(None)
}
Err(e) => Err(e),
}
} else {
Ok(None)
}
}
async fn set_l2_direct(&self, key: &str, value: Vec<u8>, ttl: Option<u64>) -> Result<()> {
if let Some(l2) = &self.l2 {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
let max_value_size = self.config.max_value_size.unwrap_or(10 * 1024 * 1024);
validate_value_size(&value, max_value_size)?;
l2.set_bytes(key, value, ttl).await?;
GLOBAL_METRICS.record_request(&self.service_name, "L2", "set_direct", "success");
Ok(())
} else {
Err(crate::error::CacheError::NotSupported(
"L2 not available".to_string(),
))
}
}
async fn delete_l2_direct(&self, key: &str) -> Result<bool> {
if let Some(l2) = &self.l2 {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
l2.delete(key).await?;
GLOBAL_METRICS.record_request(&self.service_name, "L2", "delete_direct", "success");
Ok(true)
} else {
Ok(false)
}
}
async fn promote_to_l1(&self, key: &str) -> Result<bool> {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
if let (Some(l1), Some(l2)) = (&self.l1, &self.l2) {
match l2.get_bytes(key).await {
Ok(Some(value)) => {
let ttl = l2.backend().ttl(key).await?;
let start = std::time::Instant::now();
l1.set_bytes(key, value.clone(), ttl).await?;
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L1", "promote", duration);
GLOBAL_METRICS.record_request(&self.service_name, "L2", "promote", "success");
Ok(true)
}
Ok(None) => {
GLOBAL_METRICS.record_request(&self.service_name, "L2", "promote", "miss");
Ok(false)
}
Err(e) => Err(e),
}
} else {
Ok(false)
}
}
async fn demote_to_l2(&self, key: &str, ttl: Option<u64>) -> Result<bool> {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
if let (Some(l1), Some(l2)) = (&self.l1, &self.l2) {
match l1.get_bytes(key).await {
Ok(Some(value)) => {
let start = std::time::Instant::now();
l2.set_bytes(key, value.clone(), ttl).await?;
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L2", "demote", duration);
GLOBAL_METRICS.record_request(&self.service_name, "L1", "demote", "success");
Ok(true)
}
Ok(None) => {
GLOBAL_METRICS.record_request(&self.service_name, "L1", "demote", "miss");
Ok(false)
}
Err(e) => Err(e),
}
} else {
Ok(false)
}
}
async fn evict_all(&self, key: &str) -> Result<bool> {
validate_cache_key(key)?;
let max_key_length = self.config.max_key_length.unwrap_or(256);
validate_key_length(key, max_key_length)?;
let mut success = false;
if let Some(l1) = &self.l1 {
match l1.delete(key).await {
Ok(_) => {
GLOBAL_METRICS.record_request(&self.service_name, "L1", "evict", "success");
success = true;
}
Err(e) => warn!("Failed to evict from L1: {}", e),
}
}
if let Some(l2) = &self.l2 {
match l2.delete(key).await {
Ok(_) => {
GLOBAL_METRICS.record_request(&self.service_name, "L2", "evict", "success");
success = true;
}
Err(e) => warn!("Failed to evict from L2: {}", e),
}
}
Ok(success)
}
}