use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "distributed-rate-limit-memcached")]
use async_memcached::AsciiProtocol;
use parking_lot::RwLock;
use tracing::{debug, error, trace, warn};
use zentinel_config::MemcachedBackendConfig;
use crate::rate_limit::{RateLimitConfig, RateLimitOutcome};
#[derive(Debug, Default)]
pub struct MemcachedRateLimitStats {
pub total_checks: AtomicU64,
pub allowed: AtomicU64,
pub limited: AtomicU64,
pub memcached_errors: AtomicU64,
pub local_fallbacks: AtomicU64,
}
impl MemcachedRateLimitStats {
pub fn record_check(&self, outcome: RateLimitOutcome) {
self.total_checks.fetch_add(1, Ordering::Relaxed);
match outcome {
RateLimitOutcome::Allowed => {
self.allowed.fetch_add(1, Ordering::Relaxed);
}
RateLimitOutcome::Limited => {
self.limited.fetch_add(1, Ordering::Relaxed);
}
}
}
pub fn record_memcached_error(&self) {
self.memcached_errors.fetch_add(1, Ordering::Relaxed);
}
pub fn record_local_fallback(&self) {
self.local_fallbacks.fetch_add(1, Ordering::Relaxed);
}
}
#[cfg(feature = "distributed-rate-limit-memcached")]
pub struct MemcachedRateLimiter {
client: RwLock<async_memcached::Client>,
config: RwLock<MemcachedConfig>,
healthy: AtomicBool,
pub stats: Arc<MemcachedRateLimitStats>,
}
#[cfg(feature = "distributed-rate-limit-memcached")]
#[derive(Debug, Clone)]
struct MemcachedConfig {
key_prefix: String,
max_rps: u32,
window_secs: u64,
timeout: Duration,
fallback_local: bool,
ttl_secs: u32,
}
#[cfg(feature = "distributed-rate-limit-memcached")]
impl MemcachedRateLimiter {
pub async fn new(
backend_config: &MemcachedBackendConfig,
rate_config: &RateLimitConfig,
) -> Result<Self, async_memcached::Error> {
let addr = backend_config
.url
.trim_start_matches("memcache://")
.trim_start_matches("memcached://");
let client = async_memcached::Client::new(addr).await?;
debug!(
url = %backend_config.url,
prefix = %backend_config.key_prefix,
max_rps = rate_config.max_rps,
"Memcached rate limiter initialized"
);
Ok(Self {
client: RwLock::new(client),
config: RwLock::new(MemcachedConfig {
key_prefix: backend_config.key_prefix.clone(),
max_rps: rate_config.max_rps,
window_secs: 1,
timeout: Duration::from_millis(backend_config.timeout_ms),
fallback_local: backend_config.fallback_local,
ttl_secs: backend_config.ttl_secs,
}),
healthy: AtomicBool::new(true),
stats: Arc::new(MemcachedRateLimitStats::default()),
})
}
pub async fn check(
&self,
key: &str,
) -> Result<(RateLimitOutcome, u64), async_memcached::Error> {
let config = self.config.read().clone();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let window_key = format!("{}{}:{}", config.key_prefix, key, now);
#[allow(clippy::await_holding_lock)]
let result = tokio::time::timeout(config.timeout, async {
let mut client = self.client.write();
match client.increment(&window_key, 1).await {
Ok(count) => Ok(count),
Err(async_memcached::Error::Protocol(async_memcached::Status::NotFound)) => {
client
.set(&window_key, &b"1"[..], Some(config.ttl_secs as i64), None)
.await
.map(|_| 1u64)
}
Err(e) => Err(e),
}
})
.await
.map_err(|_| {
async_memcached::Error::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Memcached operation timed out",
))
})??;
self.healthy.store(true, Ordering::Relaxed);
let outcome = if result > config.max_rps as u64 {
RateLimitOutcome::Limited
} else {
RateLimitOutcome::Allowed
};
trace!(
key = key,
count = result,
max_rps = config.max_rps,
outcome = ?outcome,
"Memcached rate limit check"
);
self.stats.record_check(outcome);
Ok((outcome, result))
}
pub fn update_config(
&self,
backend_config: &MemcachedBackendConfig,
rate_config: &RateLimitConfig,
) {
let mut config = self.config.write();
config.key_prefix = backend_config.key_prefix.clone();
config.max_rps = rate_config.max_rps;
config.timeout = Duration::from_millis(backend_config.timeout_ms);
config.fallback_local = backend_config.fallback_local;
config.ttl_secs = backend_config.ttl_secs;
}
pub fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Relaxed)
}
pub fn mark_unhealthy(&self) {
self.healthy.store(false, Ordering::Relaxed);
self.stats.record_memcached_error();
}
pub fn fallback_enabled(&self) -> bool {
self.config.read().fallback_local
}
}
#[cfg(not(feature = "distributed-rate-limit-memcached"))]
pub struct MemcachedRateLimiter;
#[cfg(not(feature = "distributed-rate-limit-memcached"))]
impl MemcachedRateLimiter {
pub async fn new(
_backend_config: &MemcachedBackendConfig,
_rate_config: &RateLimitConfig,
) -> Result<Self, String> {
Err(
"Memcached rate limiting requires the 'distributed-rate-limit-memcached' feature"
.to_string(),
)
}
}
#[cfg(feature = "distributed-rate-limit-memcached")]
pub async fn create_memcached_rate_limiter(
backend_config: &MemcachedBackendConfig,
rate_config: &RateLimitConfig,
) -> Option<MemcachedRateLimiter> {
match MemcachedRateLimiter::new(backend_config, rate_config).await {
Ok(limiter) => {
debug!(
url = %backend_config.url,
"Memcached rate limiter created successfully"
);
Some(limiter)
}
Err(e) => {
error!(
error = %e,
url = %backend_config.url,
"Failed to create Memcached rate limiter"
);
if backend_config.fallback_local {
warn!("Falling back to local rate limiting");
}
None
}
}
}
#[cfg(not(feature = "distributed-rate-limit-memcached"))]
pub async fn create_memcached_rate_limiter(
_backend_config: &MemcachedBackendConfig,
_rate_config: &RateLimitConfig,
) -> Option<MemcachedRateLimiter> {
warn!("Memcached rate limiting requested but feature is disabled. Using local rate limiting.");
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stats_recording() {
let stats = MemcachedRateLimitStats::default();
stats.record_check(RateLimitOutcome::Allowed);
stats.record_check(RateLimitOutcome::Allowed);
stats.record_check(RateLimitOutcome::Limited);
assert_eq!(stats.total_checks.load(Ordering::Relaxed), 3);
assert_eq!(stats.allowed.load(Ordering::Relaxed), 2);
assert_eq!(stats.limited.load(Ordering::Relaxed), 1);
}
#[test]
fn test_stats_memcached_errors() {
let stats = MemcachedRateLimitStats::default();
stats.record_memcached_error();
stats.record_memcached_error();
stats.record_local_fallback();
assert_eq!(stats.memcached_errors.load(Ordering::Relaxed), 2);
assert_eq!(stats.local_fallbacks.load(Ordering::Relaxed), 1);
}
}