use super::tiered_cache::TieredCacheControl;
use super::ttl_control::TtlControl;
use super::CacheOps;
use crate::backend::l2::L2Backend;
use crate::config::TwoLevelConfig;
use crate::error::Result;
use crate::metrics::GLOBAL_METRICS;
use crate::recovery::{
health::{HealthChecker, HealthState},
wal::{Operation, WalEntry, WalManager},
};
use crate::serialization::SerializerEnum;
use crate::sync::invalidation::InvalidationPublisher;
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{instrument, warn};
pub struct L2Client {
service_name: String,
l2: Arc<L2Backend>,
serializer: SerializerEnum,
health_state: Arc<RwLock<HealthState>>,
wal: Arc<WalManager>,
publisher: Option<Arc<InvalidationPublisher>>,
timestamp_cache: Arc<dashmap::DashMap<String, i64>>,
}
impl L2Client {
pub async fn new(
service_name: String,
l2: Arc<L2Backend>,
serializer: SerializerEnum,
) -> Result<Self> {
let health_state = Arc::new(RwLock::new(HealthState::Healthy));
let wal = Arc::new(WalManager::new(&service_name).await?);
let command_timeout_ms = l2.command_timeout_ms();
let checker = HealthChecker::new(
l2.clone(),
health_state.clone(),
wal.clone(),
service_name.clone(),
command_timeout_ms,
);
tokio::spawn(async move { checker.start().await });
let config = TwoLevelConfig::default();
let channel_name = Self::resolve_channel_name(&service_name, &config);
let publisher = match l2.get_raw_client() {
Ok(client) => match client.get_connection_manager().await {
Ok(manager) => Some(Arc::new(InvalidationPublisher::new(manager, channel_name))),
Err(e) => {
tracing::warn!(
"Failed to create connection manager for invalidation publisher: {}",
e
);
None
}
},
Err(crate::error::CacheError::NotSupported(_)) => {
tracing::warn!("Invalidation publisher not supported for this backend mode (likely Cluster), skipping");
None
}
Err(e) => return Err(e),
};
Ok(Self {
service_name,
l2,
serializer,
health_state,
wal,
publisher,
timestamp_cache: Arc::new(dashmap::DashMap::new()),
})
}
fn resolve_channel_name(service_name: &str, config: &TwoLevelConfig) -> String {
use crate::config::InvalidationChannelConfig;
match &config.invalidation_channel {
Some(InvalidationChannelConfig::Custom(name)) => name.clone(),
Some(InvalidationChannelConfig::Structured {
prefix,
use_service_name,
}) => {
let prefix = prefix.as_deref().unwrap_or("cache:invalidate");
if *use_service_name {
format!("{}:{}", prefix, service_name)
} else {
prefix.to_string()
}
}
None => format!("cache:invalidate:{}", service_name),
}
}
async fn handle_l2_failure(&self) {
tracing::warn!("L2 failure detected for service: {}", self.service_name);
let mut state = self.health_state.write().await;
*state = crate::recovery::health::HealthState::Degraded {
since: std::time::Instant::now(),
failure_count: 1,
};
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
pub async fn ping(&self) -> Result<()> {
self.l2.ping().await
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
pub async fn clear(&self) -> Result<()> {
self.l2.clear(&self.service_name).await
}
pub fn backend(&self) -> &L2Backend {
&self.l2
}
}
#[async_trait]
impl CacheOps for L2Client {
fn serializer(&self) -> &SerializerEnum {
&self.serializer
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn into_any_arc(self: Arc<Self>) -> Arc<dyn std::any::Any + Send + Sync> {
self
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn get_bytes(&self, key: &str) -> Result<Option<Vec<u8>>> {
GLOBAL_METRICS.record_request(&self.service_name, "L2", "get", "attempt");
let start = std::time::Instant::now();
match self.l2.get_with_version(key).await {
Ok(Some((value, _))) => {
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L2", "get", duration);
GLOBAL_METRICS.record_request(&self.service_name, "L2", "get", "hit");
Ok(Some(value))
}
Ok(None) => {
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L2", "get", duration);
GLOBAL_METRICS.record_request(&self.service_name, "L2", "get", "miss");
Ok(None)
}
Err(e) => {
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L2", "get", duration);
self.handle_l2_failure().await;
Err(e)
}
}
}
#[instrument(skip(self, value), level = "debug", fields(service = %self.service_name))]
async fn set_bytes(&self, key: &str, value: Vec<u8>, ttl: Option<u64>) -> Result<()> {
let state = self.health_state.read().await;
tracing::info!("set_bytes: current health state = {:?}", *state);
match *state {
HealthState::Healthy | HealthState::Recovering { .. } => {
drop(state);
let start = std::time::Instant::now();
let key_exists = match self.l2.get_with_version(key).await {
Ok(Some(_)) => true,
Ok(None) => false,
Err(_) => true, };
match self.l2.set_with_version(key, value.clone(), ttl).await {
Ok(_) => {
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L2", "set", duration);
if key_exists {
if let Some(publisher) = &self.publisher {
let _ = publisher.publish(key).await;
}
}
Ok(())
}
Err(e) => {
let duration = start.elapsed().as_secs_f64();
GLOBAL_METRICS.record_duration(&self.service_name, "L2", "set", duration);
tracing::warn!("L2 set failed during set_bytes, writing to WAL: {}", e);
self.handle_l2_failure().await;
self.wal
.append(WalEntry {
timestamp: std::time::SystemTime::now(),
operation: Operation::Set,
key: key.to_string(),
value: Some(value),
ttl: ttl.map(|t| t as i64),
})
.await?;
Ok(())
}
}
}
HealthState::Degraded { .. } => {
tracing::info!("set_bytes: L2 is degraded, writing to WAL and returning success");
drop(state);
self.wal
.append(WalEntry {
timestamp: std::time::SystemTime::now(),
operation: Operation::Set,
key: key.to_string(),
value: Some(value),
ttl: ttl.map(|t| t as i64),
})
.await?;
Ok(())
}
HealthState::WalReplaying { .. } => {
tracing::info!(
"set_bytes: L2 is replaying WAL, writing to WAL and returning success"
);
drop(state);
self.wal
.append(WalEntry {
timestamp: std::time::SystemTime::now(),
operation: Operation::Set,
key: key.to_string(),
value: Some(value),
ttl: ttl.map(|t| t as i64),
})
.await?;
Ok(())
}
}
}
#[instrument(skip(self, value), level = "debug", fields(service = %self.service_name))]
async fn set_l2_bytes(&self, key: &str, value: Vec<u8>, ttl: Option<u64>) -> Result<()> {
self.set_bytes(key, value, ttl).await
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn get_l1_bytes(&self, _key: &str) -> Result<Option<Vec<u8>>> {
Ok(None)
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn get_l2_bytes(&self, key: &str) -> Result<Option<Vec<u8>>> {
self.get_bytes(key).await
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn delete(&self, key: &str) -> Result<()> {
let state = self.health_state.read().await;
match *state {
HealthState::Healthy | HealthState::Recovering { .. } => {
drop(state);
match self.l2.delete(key).await {
Ok(_) => {
if let Some(publisher) = &self.publisher {
let _ = publisher.publish(key).await;
}
Ok(())
}
Err(e) => {
self.handle_l2_failure().await;
Err(e)
}
}
}
HealthState::Degraded { .. } => {
drop(state);
self.wal
.append(WalEntry {
timestamp: std::time::SystemTime::now(),
operation: Operation::Delete,
key: key.to_string(),
value: None,
ttl: None,
})
.await
}
HealthState::WalReplaying { .. } => {
drop(state);
self.wal
.append(WalEntry {
timestamp: std::time::SystemTime::now(),
operation: Operation::Delete,
key: key.to_string(),
value: None,
ttl: None,
})
.await
}
}
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn lock(&self, key: &str, ttl: u64) -> Result<Option<String>> {
let state = self.health_state.read().await;
match *state {
HealthState::Healthy | HealthState::Recovering { .. } => {
drop(state);
match self.l2.lock(key, ttl).await {
Ok(result) => {
if result.is_some() {
GLOBAL_METRICS.record_request(&self.service_name, "L2", "lock", "hit");
} else {
GLOBAL_METRICS.record_request(&self.service_name, "L2", "lock", "miss");
}
Ok(result)
}
Err(e) => {
self.handle_l2_failure().await;
Err(e)
}
}
}
HealthState::Degraded { .. } => {
drop(state);
warn!(
"Cannot acquire lock in degraded state, service={}",
self.service_name
);
Ok(None)
}
HealthState::WalReplaying { .. } => {
drop(state);
warn!(
"Cannot acquire lock during WAL replay, service={}",
self.service_name
);
Ok(None)
}
}
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn unlock(&self, key: &str, value: &str) -> Result<bool> {
let state = self.health_state.read().await;
match *state {
HealthState::Healthy | HealthState::Recovering { .. } => {
drop(state);
match self.l2.unlock(key, value).await {
Ok(result) => {
if result {
GLOBAL_METRICS.record_request(
&self.service_name,
"L2",
"unlock",
"hit",
);
} else {
GLOBAL_METRICS.record_request(
&self.service_name,
"L2",
"unlock",
"miss",
);
}
Ok(result)
}
Err(e) => {
self.handle_l2_failure().await;
Err(e)
}
}
}
HealthState::Degraded { .. } => {
drop(state);
warn!(
"Cannot release lock in degraded state, service={}",
self.service_name
);
Ok(false)
}
HealthState::WalReplaying { .. } => {
drop(state);
warn!(
"Cannot release lock during WAL replay, service={}",
self.service_name
);
Ok(false)
}
}
}
#[instrument(skip(self), level = "debug", fields(service = %self.service_name))]
async fn clear_l2(&self) -> Result<()> {
self.l2.clear(&self.service_name).await?;
GLOBAL_METRICS.record_request(&self.service_name, "L2", "clear", "success");
Ok(())
}
}
#[async_trait]
impl TtlControl for L2Client {
async fn get_l1_ttl(&self, _key: &str) -> Result<Option<u64>> {
Ok(None)
}
async fn get_l2_ttl(&self, key: &str) -> Result<Option<u64>> {
match self.l2.ttl(key).await {
Ok(Some(ttl)) if ttl > 0 => Ok(Some(ttl)),
Ok(Some(_)) => Ok(Some(0)), Ok(None) => Ok(None), Err(e) => Err(e),
}
}
async fn get_ttl(&self, key: &str) -> Result<Option<u64>> {
self.get_l2_ttl(key).await
}
async fn refresh_l1_ttl(&self, _key: &str, _ttl: u64) -> Result<bool> {
Ok(false)
}
async fn refresh_l2_ttl(&self, key: &str, ttl: u64) -> Result<bool> {
let state = self.health_state.read().await;
match *state {
HealthState::Healthy | HealthState::Recovering { .. } => {
drop(state);
let current_timestamp = chrono::Utc::now().timestamp();
if let Some(last_timestamp) = self.timestamp_cache.get(key) {
let last_ts = *last_timestamp;
if current_timestamp < last_ts {
warn!(
"Detected time rollback attack for key '{}': current={}, last={}",
key, current_timestamp, last_ts
);
return Err(crate::error::CacheError::InvalidInput(
"Time rollback detected".to_string(),
));
}
}
let result = self.l2.expire(key, ttl).await?;
if result {
GLOBAL_METRICS.record_request(&self.service_name, "L2", "expire", "success");
self.timestamp_cache
.insert(key.to_string(), current_timestamp);
}
Ok(result)
}
HealthState::Degraded { .. } | HealthState::WalReplaying { .. } => {
drop(state);
Ok(false)
}
}
}
async fn refresh_ttl(&self, key: &str, ttl: u64) -> Result<bool> {
self.refresh_l2_ttl(key, ttl).await
}
async fn touch(&self, key: &str) -> Result<bool> {
let state = self.health_state.read().await;
match *state {
HealthState::Healthy | HealthState::Recovering { .. } => {
drop(state);
let current_timestamp = chrono::Utc::now().timestamp();
if let Some(last_timestamp) = self.timestamp_cache.get(key) {
let last_ts = *last_timestamp;
if current_timestamp < last_ts {
warn!(
"Detected time rollback attack for key '{}': current={}, last={}",
key, current_timestamp, last_ts
);
return Err(crate::error::CacheError::InvalidInput(
"Time rollback detected".to_string(),
));
}
}
if let Some(current_ttl) = self.get_l2_ttl(key).await? {
if current_ttl == 0 {
return Ok(false);
}
let result = self.l2.expire(key, current_ttl).await?;
if result {
self.timestamp_cache
.insert(key.to_string(), current_timestamp);
}
return Ok(result);
}
Ok(false)
}
HealthState::Degraded { .. } | HealthState::WalReplaying { .. } => {
drop(state);
Ok(false)
}
}
}
}
#[async_trait]
impl TieredCacheControl for L2Client {
async fn get_l1_direct(&self, _key: &str) -> Result<Option<Vec<u8>>> {
Ok(None)
}
async fn set_l1_direct(&self, _key: &str, _value: Vec<u8>, _ttl: Option<u64>) -> Result<()> {
Err(crate::error::CacheError::NotSupported(
"L1 operations not supported in L2-only mode".to_string(),
))
}
async fn delete_l1_direct(&self, _key: &str) -> Result<bool> {
Ok(false)
}
async fn get_l2_direct(&self, key: &str) -> Result<Option<Vec<u8>>> {
self.get_bytes(key).await
}
async fn set_l2_direct(&self, key: &str, value: Vec<u8>, ttl: Option<u64>) -> Result<()> {
self.set_bytes(key, value, ttl).await
}
async fn delete_l2_direct(&self, key: &str) -> Result<bool> {
self.delete(key).await?;
Ok(true)
}
async fn promote_to_l1(&self, _key: &str) -> Result<bool> {
Ok(false)
}
async fn demote_to_l2(&self, _key: &str, _ttl: Option<u64>) -> Result<bool> {
Ok(false)
}
async fn evict_all(&self, key: &str) -> Result<bool> {
self.delete(key).await?;
Ok(true)
}
}