Skip to main content

grapsus_proxy/
memcached_rate_limit.rs

1//! Distributed rate limiting with Memcached backend
2//!
3//! This module provides a Memcached-backed rate limiter for multi-instance deployments.
4//! Uses a counter-based sliding window algorithm.
5//!
6//! # Algorithm
7//!
8//! Uses a fixed window counter algorithm with Memcached:
9//! 1. Generate a time-windowed key (current second)
10//! 2. Increment the counter atomically
11//! 3. Allow if count <= max_rps
12//!
13//! Note: This is slightly less accurate than Redis sorted sets but more efficient
14//! for Memcached's simpler data model.
15
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::Duration;
19
20#[cfg(feature = "distributed-rate-limit-memcached")]
21use async_memcached::AsciiProtocol;
22use parking_lot::RwLock;
23use tracing::{debug, error, trace, warn};
24
25use grapsus_config::MemcachedBackendConfig;
26
27use crate::rate_limit::{RateLimitConfig, RateLimitOutcome};
28
29/// Statistics for Memcached-based distributed rate limiting
30#[derive(Debug, Default)]
31pub struct MemcachedRateLimitStats {
32    /// Total requests checked
33    pub total_checks: AtomicU64,
34    /// Requests allowed
35    pub allowed: AtomicU64,
36    /// Requests limited
37    pub limited: AtomicU64,
38    /// Memcached errors (fallback to local)
39    pub memcached_errors: AtomicU64,
40    /// Local fallback invocations
41    pub local_fallbacks: AtomicU64,
42}
43
44impl MemcachedRateLimitStats {
45    pub fn record_check(&self, outcome: RateLimitOutcome) {
46        self.total_checks.fetch_add(1, Ordering::Relaxed);
47        match outcome {
48            RateLimitOutcome::Allowed => {
49                self.allowed.fetch_add(1, Ordering::Relaxed);
50            }
51            RateLimitOutcome::Limited => {
52                self.limited.fetch_add(1, Ordering::Relaxed);
53            }
54        }
55    }
56
57    pub fn record_memcached_error(&self) {
58        self.memcached_errors.fetch_add(1, Ordering::Relaxed);
59    }
60
61    pub fn record_local_fallback(&self) {
62        self.local_fallbacks.fetch_add(1, Ordering::Relaxed);
63    }
64}
65
66/// Memcached-backed distributed rate limiter
67#[cfg(feature = "distributed-rate-limit-memcached")]
68pub struct MemcachedRateLimiter {
69    /// Memcached client
70    client: RwLock<async_memcached::Client>,
71    /// Configuration
72    config: RwLock<MemcachedConfig>,
73    /// Whether Memcached is currently healthy
74    healthy: AtomicBool,
75    /// Statistics
76    pub stats: Arc<MemcachedRateLimitStats>,
77}
78
79#[cfg(feature = "distributed-rate-limit-memcached")]
80#[derive(Debug, Clone)]
81struct MemcachedConfig {
82    key_prefix: String,
83    max_rps: u32,
84    window_secs: u64,
85    timeout: Duration,
86    fallback_local: bool,
87    ttl_secs: u32,
88}
89
90#[cfg(feature = "distributed-rate-limit-memcached")]
91impl MemcachedRateLimiter {
92    /// Create a new Memcached rate limiter
93    pub async fn new(
94        backend_config: &MemcachedBackendConfig,
95        rate_config: &RateLimitConfig,
96    ) -> Result<Self, async_memcached::Error> {
97        // Parse the URL to get host:port
98        let addr = backend_config
99            .url
100            .trim_start_matches("memcache://")
101            .trim_start_matches("memcached://");
102
103        let client = async_memcached::Client::new(addr).await?;
104
105        debug!(
106            url = %backend_config.url,
107            prefix = %backend_config.key_prefix,
108            max_rps = rate_config.max_rps,
109            "Memcached rate limiter initialized"
110        );
111
112        Ok(Self {
113            client: RwLock::new(client),
114            config: RwLock::new(MemcachedConfig {
115                key_prefix: backend_config.key_prefix.clone(),
116                max_rps: rate_config.max_rps,
117                window_secs: 1,
118                timeout: Duration::from_millis(backend_config.timeout_ms),
119                fallback_local: backend_config.fallback_local,
120                ttl_secs: backend_config.ttl_secs,
121            }),
122            healthy: AtomicBool::new(true),
123            stats: Arc::new(MemcachedRateLimitStats::default()),
124        })
125    }
126
127    /// Check if a request should be rate limited
128    ///
129    /// Returns the outcome and the current request count in the window.
130    pub async fn check(
131        &self,
132        key: &str,
133    ) -> Result<(RateLimitOutcome, u64), async_memcached::Error> {
134        let config = self.config.read().clone();
135
136        // Generate time-windowed key
137        let now = std::time::SystemTime::now()
138            .duration_since(std::time::UNIX_EPOCH)
139            .unwrap()
140            .as_secs();
141        let window_key = format!("{}{}:{}", config.key_prefix, key, now);
142
143        // Increment counter atomically
144        // The write guard must be held across awaits because async_memcached::Client
145        // requires &mut self for operations and is not internally synchronized.
146        #[allow(clippy::await_holding_lock)]
147        let result = tokio::time::timeout(config.timeout, async {
148            let mut client = self.client.write();
149            // Try to increment; if key doesn't exist, it will return an error
150            match client.increment(&window_key, 1).await {
151                Ok(count) => Ok(count),
152                Err(async_memcached::Error::Protocol(async_memcached::Status::NotFound)) => {
153                    // Key doesn't exist, set it to 1 with TTL
154                    client
155                        .set(&window_key, &b"1"[..], Some(config.ttl_secs as i64), None)
156                        .await
157                        .map(|_| 1u64)
158                }
159                Err(e) => Err(e),
160            }
161        })
162        .await
163        .map_err(|_| {
164            async_memcached::Error::Io(std::io::Error::new(
165                std::io::ErrorKind::TimedOut,
166                "Memcached operation timed out",
167            ))
168        })??;
169
170        self.healthy.store(true, Ordering::Relaxed);
171
172        let outcome = if result > config.max_rps as u64 {
173            RateLimitOutcome::Limited
174        } else {
175            RateLimitOutcome::Allowed
176        };
177
178        trace!(
179            key = key,
180            count = result,
181            max_rps = config.max_rps,
182            outcome = ?outcome,
183            "Memcached rate limit check"
184        );
185
186        self.stats.record_check(outcome);
187        Ok((outcome, result))
188    }
189
190    /// Update configuration
191    pub fn update_config(
192        &self,
193        backend_config: &MemcachedBackendConfig,
194        rate_config: &RateLimitConfig,
195    ) {
196        let mut config = self.config.write();
197        config.key_prefix = backend_config.key_prefix.clone();
198        config.max_rps = rate_config.max_rps;
199        config.timeout = Duration::from_millis(backend_config.timeout_ms);
200        config.fallback_local = backend_config.fallback_local;
201        config.ttl_secs = backend_config.ttl_secs;
202    }
203
204    /// Check if Memcached is currently healthy
205    pub fn is_healthy(&self) -> bool {
206        self.healthy.load(Ordering::Relaxed)
207    }
208
209    /// Mark Memcached as unhealthy (will trigger fallback)
210    pub fn mark_unhealthy(&self) {
211        self.healthy.store(false, Ordering::Relaxed);
212        self.stats.record_memcached_error();
213    }
214
215    /// Check if fallback to local is enabled
216    pub fn fallback_enabled(&self) -> bool {
217        self.config.read().fallback_local
218    }
219}
220
221/// Stub for when distributed-rate-limit-memcached feature is disabled
222#[cfg(not(feature = "distributed-rate-limit-memcached"))]
223pub struct MemcachedRateLimiter;
224
225#[cfg(not(feature = "distributed-rate-limit-memcached"))]
226impl MemcachedRateLimiter {
227    pub async fn new(
228        _backend_config: &MemcachedBackendConfig,
229        _rate_config: &RateLimitConfig,
230    ) -> Result<Self, String> {
231        Err(
232            "Memcached rate limiting requires the 'distributed-rate-limit-memcached' feature"
233                .to_string(),
234        )
235    }
236}
237
238/// Create a Memcached rate limiter from configuration
239#[cfg(feature = "distributed-rate-limit-memcached")]
240pub async fn create_memcached_rate_limiter(
241    backend_config: &MemcachedBackendConfig,
242    rate_config: &RateLimitConfig,
243) -> Option<MemcachedRateLimiter> {
244    match MemcachedRateLimiter::new(backend_config, rate_config).await {
245        Ok(limiter) => {
246            debug!(
247                url = %backend_config.url,
248                "Memcached rate limiter created successfully"
249            );
250            Some(limiter)
251        }
252        Err(e) => {
253            error!(
254                error = %e,
255                url = %backend_config.url,
256                "Failed to create Memcached rate limiter"
257            );
258            if backend_config.fallback_local {
259                warn!("Falling back to local rate limiting");
260            }
261            None
262        }
263    }
264}
265
266#[cfg(not(feature = "distributed-rate-limit-memcached"))]
267pub async fn create_memcached_rate_limiter(
268    _backend_config: &MemcachedBackendConfig,
269    _rate_config: &RateLimitConfig,
270) -> Option<MemcachedRateLimiter> {
271    warn!("Memcached rate limiting requested but feature is disabled. Using local rate limiting.");
272    None
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_stats_recording() {
281        let stats = MemcachedRateLimitStats::default();
282
283        stats.record_check(RateLimitOutcome::Allowed);
284        stats.record_check(RateLimitOutcome::Allowed);
285        stats.record_check(RateLimitOutcome::Limited);
286
287        assert_eq!(stats.total_checks.load(Ordering::Relaxed), 3);
288        assert_eq!(stats.allowed.load(Ordering::Relaxed), 2);
289        assert_eq!(stats.limited.load(Ordering::Relaxed), 1);
290    }
291
292    #[test]
293    fn test_stats_memcached_errors() {
294        let stats = MemcachedRateLimitStats::default();
295
296        stats.record_memcached_error();
297        stats.record_memcached_error();
298        stats.record_local_fallback();
299
300        assert_eq!(stats.memcached_errors.load(Ordering::Relaxed), 2);
301        assert_eq!(stats.local_fallbacks.load(Ordering::Relaxed), 1);
302    }
303}