celers_core/
rate_limit.rs

1#![allow(
2    clippy::cast_possible_truncation,
3    clippy::cast_sign_loss,
4    clippy::cast_precision_loss,
5    clippy::cast_possible_wrap
6)]
7//! Rate Limiting for Task Execution
8//!
9//! This module provides rate limiting capabilities for controlling task execution rates.
10//! It supports multiple algorithms:
11//!
12//! - **Token Bucket**: Classic algorithm that allows bursts up to bucket capacity
13//! - **Sliding Window**: Tracks actual execution counts within a time window
14//!
15//! # Example
16//!
17//! ```rust
18//! use celers_core::rate_limit::{RateLimiter, TokenBucket, RateLimitConfig};
19//! use std::time::Duration;
20//!
21//! // Create a rate limiter allowing 10 tasks per second with burst capacity of 20
22//! let config = RateLimitConfig::new(10.0).with_burst(20);
23//! let mut limiter = TokenBucket::new(config);
24//!
25//! // Check if we can execute a task
26//! if limiter.try_acquire() {
27//!     println!("Task can execute");
28//! } else {
29//!     println!("Rate limited, wait {:?}", limiter.time_until_available());
30//! }
31//! ```
32
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::sync::{Arc, RwLock};
36use std::time::{Duration, Instant};
37
38/// Configuration for rate limiting
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct RateLimitConfig {
41    /// Maximum tasks per second
42    pub rate: f64,
43    /// Burst capacity (max tokens in bucket)
44    /// If None, defaults to rate
45    pub burst: Option<u32>,
46    /// Whether to use sliding window algorithm instead of token bucket
47    pub sliding_window: bool,
48    /// Window size for sliding window algorithm (in seconds)
49    pub window_size: u64,
50}
51
52impl RateLimitConfig {
53    /// Create a new rate limit configuration
54    ///
55    /// # Arguments
56    ///
57    /// * `rate` - Maximum tasks per second
58    #[must_use]
59    pub fn new(rate: f64) -> Self {
60        Self {
61            rate,
62            burst: None,
63            sliding_window: false,
64            window_size: 1,
65        }
66    }
67
68    /// Set the burst capacity (max tokens in bucket)
69    #[must_use]
70    pub fn with_burst(mut self, burst: u32) -> Self {
71        self.burst = Some(burst);
72        self
73    }
74
75    /// Use sliding window algorithm instead of token bucket
76    #[must_use]
77    pub fn with_sliding_window(mut self, window_size: u64) -> Self {
78        self.sliding_window = true;
79        self.window_size = window_size;
80        self
81    }
82
83    /// Get the effective burst capacity
84    #[must_use]
85    #[inline]
86    pub fn effective_burst(&self) -> u32 {
87        self.burst.unwrap_or(self.rate.ceil() as u32)
88    }
89}
90
91impl Default for RateLimitConfig {
92    fn default() -> Self {
93        Self {
94            rate: 100.0, // 100 tasks per second default
95            burst: None,
96            sliding_window: false,
97            window_size: 1,
98        }
99    }
100}
101
102/// Trait for rate limiter implementations
103pub trait RateLimiter: Send + Sync {
104    /// Try to acquire a permit to execute a task
105    ///
106    /// Returns `true` if the task can execute immediately, `false` if rate limited
107    fn try_acquire(&mut self) -> bool;
108
109    /// Acquire a permit, blocking until available
110    ///
111    /// Returns the time waited
112    fn acquire(&mut self) -> Duration;
113
114    /// Get the time until a permit will be available
115    fn time_until_available(&self) -> Duration;
116
117    /// Get the current number of available permits
118    fn available_permits(&self) -> u32;
119
120    /// Reset the rate limiter to its initial state
121    fn reset(&mut self);
122
123    /// Update the rate limit configuration
124    fn set_rate(&mut self, rate: f64);
125
126    /// Get current configuration
127    fn config(&self) -> &RateLimitConfig;
128}
129
130/// Token bucket rate limiter
131///
132/// The token bucket algorithm works by:
133/// - Adding tokens at a fixed rate (rate per second)
134/// - Consuming one token per task execution
135/// - Allowing bursts up to the bucket capacity
136///
137/// This is the default and recommended rate limiter for most use cases.
138#[derive(Debug)]
139pub struct TokenBucket {
140    config: RateLimitConfig,
141    /// Current number of tokens in the bucket
142    tokens: f64,
143    /// Last time tokens were refilled
144    last_refill: Instant,
145}
146
147impl TokenBucket {
148    /// Create a new token bucket rate limiter
149    #[must_use]
150    pub fn new(config: RateLimitConfig) -> Self {
151        let tokens = f64::from(config.effective_burst());
152        Self {
153            config,
154            tokens,
155            last_refill: Instant::now(),
156        }
157    }
158
159    /// Create a token bucket with default configuration
160    #[must_use]
161    pub fn with_rate(rate: f64) -> Self {
162        Self::new(RateLimitConfig::new(rate))
163    }
164
165    /// Refill tokens based on elapsed time
166    #[inline]
167    fn refill(&mut self) {
168        let now = Instant::now();
169        let elapsed = now.duration_since(self.last_refill);
170        let new_tokens = elapsed.as_secs_f64() * self.config.rate;
171        let max_tokens = f64::from(self.config.effective_burst());
172        self.tokens = (self.tokens + new_tokens).min(max_tokens);
173        self.last_refill = now;
174    }
175}
176
177impl RateLimiter for TokenBucket {
178    fn try_acquire(&mut self) -> bool {
179        self.refill();
180        if self.tokens >= 1.0 {
181            self.tokens -= 1.0;
182            true
183        } else {
184            false
185        }
186    }
187
188    fn acquire(&mut self) -> Duration {
189        let start = Instant::now();
190        while !self.try_acquire() {
191            let wait_time = self.time_until_available();
192            if wait_time > Duration::ZERO {
193                std::thread::sleep(wait_time);
194            }
195        }
196        start.elapsed()
197    }
198
199    fn time_until_available(&self) -> Duration {
200        if self.tokens >= 1.0 {
201            Duration::ZERO
202        } else {
203            let tokens_needed = 1.0 - self.tokens;
204            let seconds = tokens_needed / self.config.rate;
205            Duration::from_secs_f64(seconds)
206        }
207    }
208
209    fn available_permits(&self) -> u32 {
210        self.tokens.floor() as u32
211    }
212
213    fn reset(&mut self) {
214        self.tokens = f64::from(self.config.effective_burst());
215        self.last_refill = Instant::now();
216    }
217
218    fn set_rate(&mut self, rate: f64) {
219        self.config.rate = rate;
220    }
221
222    fn config(&self) -> &RateLimitConfig {
223        &self.config
224    }
225}
226
227/// Sliding window rate limiter
228///
229/// Tracks actual execution timestamps and counts executions within a sliding window.
230/// More accurate than token bucket but uses more memory.
231#[derive(Debug)]
232pub struct SlidingWindow {
233    config: RateLimitConfig,
234    /// Timestamps of recent executions
235    timestamps: Vec<Instant>,
236}
237
238impl SlidingWindow {
239    /// Create a new sliding window rate limiter
240    #[must_use]
241    pub fn new(config: RateLimitConfig) -> Self {
242        Self {
243            config,
244            timestamps: Vec::new(),
245        }
246    }
247
248    /// Create a sliding window limiter with default configuration
249    #[must_use]
250    pub fn with_rate(rate: f64, window_size: u64) -> Self {
251        let config = RateLimitConfig::new(rate).with_sliding_window(window_size);
252        Self::new(config)
253    }
254
255    /// Clean up old timestamps outside the window
256    #[inline]
257    fn cleanup(&mut self) {
258        let window = Duration::from_secs(self.config.window_size);
259        let cutoff = Instant::now()
260            .checked_sub(window)
261            .expect("window duration should be valid for subtraction");
262        self.timestamps.retain(|&t| t > cutoff);
263    }
264
265    /// Get the maximum allowed executions in the window
266    #[inline]
267    fn max_executions(&self) -> usize {
268        (self.config.rate * self.config.window_size as f64).ceil() as usize
269    }
270}
271
272impl RateLimiter for SlidingWindow {
273    fn try_acquire(&mut self) -> bool {
274        self.cleanup();
275        if self.timestamps.len() < self.max_executions() {
276            self.timestamps.push(Instant::now());
277            true
278        } else {
279            false
280        }
281    }
282
283    fn acquire(&mut self) -> Duration {
284        let start = Instant::now();
285        while !self.try_acquire() {
286            let wait_time = self.time_until_available();
287            if wait_time > Duration::ZERO {
288                std::thread::sleep(wait_time);
289            }
290        }
291        start.elapsed()
292    }
293
294    fn time_until_available(&self) -> Duration {
295        if self.timestamps.len() < self.max_executions() {
296            Duration::ZERO
297        } else if let Some(&oldest) = self.timestamps.first() {
298            let window = Duration::from_secs(self.config.window_size);
299            let expires = oldest + window;
300            let now = Instant::now();
301            if expires > now {
302                expires - now
303            } else {
304                Duration::ZERO
305            }
306        } else {
307            Duration::ZERO
308        }
309    }
310
311    fn available_permits(&self) -> u32 {
312        let max = self.max_executions();
313        let current = self.timestamps.len();
314        (max.saturating_sub(current)) as u32
315    }
316
317    fn reset(&mut self) {
318        self.timestamps.clear();
319    }
320
321    fn set_rate(&mut self, rate: f64) {
322        self.config.rate = rate;
323    }
324
325    fn config(&self) -> &RateLimitConfig {
326        &self.config
327    }
328}
329
330/// Per-task rate limiter manager
331///
332/// Manages rate limiters for multiple task types, allowing different
333/// rate limits per task name.
334#[derive(Debug)]
335pub struct TaskRateLimiter {
336    /// Per-task rate limiters (`task_name` -> limiter)
337    limiters: HashMap<String, TokenBucket>,
338    /// Default rate limit for tasks without specific configuration
339    default_config: Option<RateLimitConfig>,
340}
341
342impl TaskRateLimiter {
343    /// Create a new task rate limiter manager
344    #[must_use]
345    pub fn new() -> Self {
346        Self {
347            limiters: HashMap::new(),
348            default_config: None,
349        }
350    }
351
352    /// Create with a default rate limit for all tasks
353    #[must_use]
354    pub fn with_default(config: RateLimitConfig) -> Self {
355        Self {
356            limiters: HashMap::new(),
357            default_config: Some(config),
358        }
359    }
360
361    /// Set rate limit for a specific task type
362    pub fn set_task_rate(&mut self, task_name: impl Into<String>, config: RateLimitConfig) {
363        let name = task_name.into();
364        self.limiters.insert(name, TokenBucket::new(config));
365    }
366
367    /// Remove rate limit for a specific task type
368    pub fn remove_task_rate(&mut self, task_name: &str) {
369        self.limiters.remove(task_name);
370    }
371
372    /// Try to acquire a permit for a specific task
373    ///
374    /// Returns `true` if the task can execute, `false` if rate limited
375    pub fn try_acquire(&mut self, task_name: &str) -> bool {
376        if let Some(limiter) = self.limiters.get_mut(task_name) {
377            limiter.try_acquire()
378        } else if let Some(ref config) = self.default_config {
379            // Create a limiter for this task using default config
380            let mut limiter = TokenBucket::new(config.clone());
381            let result = limiter.try_acquire();
382            self.limiters.insert(task_name.to_string(), limiter);
383            result
384        } else {
385            // No rate limit configured
386            true
387        }
388    }
389
390    /// Get time until a task can be executed
391    #[must_use]
392    pub fn time_until_available(&self, task_name: &str) -> Duration {
393        if let Some(limiter) = self.limiters.get(task_name) {
394            limiter.time_until_available()
395        } else {
396            Duration::ZERO
397        }
398    }
399
400    /// Check if a task type has a rate limit configured
401    #[inline]
402    #[must_use]
403    pub fn has_rate_limit(&self, task_name: &str) -> bool {
404        self.limiters.contains_key(task_name) || self.default_config.is_some()
405    }
406
407    /// Get the rate limit configuration for a task
408    #[inline]
409    pub fn get_rate_limit(&self, task_name: &str) -> Option<&RateLimitConfig> {
410        self.limiters
411            .get(task_name)
412            .map(RateLimiter::config)
413            .or(self.default_config.as_ref())
414    }
415
416    /// Reset all rate limiters
417    pub fn reset_all(&mut self) {
418        for limiter in self.limiters.values_mut() {
419            limiter.reset();
420        }
421    }
422}
423
424impl Default for TaskRateLimiter {
425    fn default() -> Self {
426        Self::new()
427    }
428}
429
430/// Thread-safe per-worker rate limiter
431///
432/// Wraps a rate limiter for safe concurrent access from multiple worker threads.
433#[derive(Debug, Clone)]
434pub struct WorkerRateLimiter {
435    inner: Arc<RwLock<TaskRateLimiter>>,
436}
437
438impl WorkerRateLimiter {
439    /// Create a new worker rate limiter
440    #[must_use]
441    pub fn new() -> Self {
442        Self {
443            inner: Arc::new(RwLock::new(TaskRateLimiter::new())),
444        }
445    }
446
447    /// Create with a default rate limit
448    #[must_use]
449    pub fn with_default(config: RateLimitConfig) -> Self {
450        Self {
451            inner: Arc::new(RwLock::new(TaskRateLimiter::with_default(config))),
452        }
453    }
454
455    /// Set rate limit for a specific task type
456    pub fn set_task_rate(&self, task_name: impl Into<String>, config: RateLimitConfig) {
457        if let Ok(mut guard) = self.inner.write() {
458            guard.set_task_rate(task_name, config);
459        }
460    }
461
462    /// Remove rate limit for a specific task type
463    pub fn remove_task_rate(&self, task_name: &str) {
464        if let Ok(mut guard) = self.inner.write() {
465            guard.remove_task_rate(task_name);
466        }
467    }
468
469    /// Try to acquire a permit for a specific task
470    #[must_use]
471    pub fn try_acquire(&self, task_name: &str) -> bool {
472        if let Ok(mut guard) = self.inner.write() {
473            guard.try_acquire(task_name)
474        } else {
475            // If lock is poisoned, allow execution
476            true
477        }
478    }
479
480    /// Get time until a task can be executed
481    #[must_use]
482    pub fn time_until_available(&self, task_name: &str) -> Duration {
483        if let Ok(guard) = self.inner.read() {
484            guard.time_until_available(task_name)
485        } else {
486            Duration::ZERO
487        }
488    }
489
490    /// Check if a task type has a rate limit configured
491    #[inline]
492    #[must_use]
493    pub fn has_rate_limit(&self, task_name: &str) -> bool {
494        if let Ok(guard) = self.inner.read() {
495            guard.has_rate_limit(task_name)
496        } else {
497            false
498        }
499    }
500
501    /// Reset all rate limiters
502    pub fn reset_all(&self) {
503        if let Ok(mut guard) = self.inner.write() {
504            guard.reset_all();
505        }
506    }
507}
508
509impl Default for WorkerRateLimiter {
510    fn default() -> Self {
511        Self::new()
512    }
513}
514
515/// Create a rate limiter from configuration
516#[must_use]
517pub fn create_rate_limiter(config: RateLimitConfig) -> Box<dyn RateLimiter> {
518    if config.sliding_window {
519        Box::new(SlidingWindow::new(config))
520    } else {
521        Box::new(TokenBucket::new(config))
522    }
523}
524
525/// Distributed rate limiting coordination
526///
527/// This module provides distributed rate limiting across multiple workers,
528/// allowing rate limits to be enforced cluster-wide rather than per-worker.
529///
530/// # Features
531///
532/// - **Cluster-wide rate limiting**: Coordinate rate limits across all workers
533/// - **Redis backend**: Use Redis for distributed state storage
534/// - **Token bucket algorithm**: Distributed token bucket with atomic operations
535/// - **Sliding window algorithm**: Distributed sliding window using sorted sets
536/// - **Fallback support**: Graceful degradation to local rate limiting on failure
537/// - **TTL support**: Automatic cleanup of stale data
538///
539/// # Example
540///
541/// ```rust,ignore
542/// use celers_core::rate_limit::{DistributedRateLimiter, RateLimitConfig};
543///
544/// // Create a distributed rate limiter backed by Redis
545/// let config = RateLimitConfig::new(100.0).with_burst(200);
546/// let limiter = DistributedRateLimiter::redis(
547///     "redis://localhost:6379",
548///     "my_task",
549///     config,
550/// ).await?;
551///
552/// // Try to acquire a permit across all workers
553/// if limiter.try_acquire().await? {
554///     println!("Task can execute");
555/// } else {
556///     println!("Rate limited cluster-wide");
557/// }
558/// ```
559use async_trait::async_trait;
560
561/// Trait for distributed rate limiter backends
562///
563/// Implementations should provide atomic operations for rate limiting
564/// across multiple processes/workers.
565#[async_trait]
566pub trait DistributedRateLimiter: Send + Sync {
567    /// Try to acquire a permit atomically
568    ///
569    /// Returns `Ok(true)` if acquired, `Ok(false)` if rate limited,
570    /// or an error if the backend is unavailable.
571    async fn try_acquire(&self) -> crate::Result<bool>;
572
573    /// Get the time until a permit will be available
574    ///
575    /// Returns `Ok(Duration)` with the wait time, or an error if unavailable.
576    async fn time_until_available(&self) -> crate::Result<Duration>;
577
578    /// Get the current number of available permits
579    ///
580    /// Returns `Ok(count)` or an error if unavailable.
581    async fn available_permits(&self) -> crate::Result<u32>;
582
583    /// Reset the distributed rate limiter
584    ///
585    /// Clears all state in the distributed backend.
586    async fn reset(&self) -> crate::Result<()>;
587
588    /// Update the rate limit configuration
589    ///
590    /// Note: This updates the local configuration. Distributed backends
591    /// may need additional coordination to sync configuration changes.
592    async fn set_rate(&self, rate: f64) -> crate::Result<()>;
593
594    /// Get current configuration
595    fn config(&self) -> &RateLimitConfig;
596
597    /// Get the backend name (for diagnostics)
598    fn backend_name(&self) -> &str;
599}
600
601/// Distributed rate limiter state
602///
603/// Stores rate limiting state in a distributed backend (e.g., Redis)
604/// for coordination across multiple workers.
605#[derive(Debug, Clone)]
606pub struct DistributedRateLimiterState {
607    /// Redis key for storing rate limit state
608    pub key: String,
609    /// Rate limit configuration
610    pub config: RateLimitConfig,
611    /// Local fallback limiter (used if distributed backend is unavailable)
612    pub fallback: Arc<RwLock<TokenBucket>>,
613}
614
615impl DistributedRateLimiterState {
616    /// Create a new distributed rate limiter state
617    ///
618    /// # Arguments
619    ///
620    /// * `key` - Redis key for storing rate limit state
621    /// * `config` - Rate limit configuration
622    #[must_use]
623    pub fn new(key: String, config: RateLimitConfig) -> Self {
624        let fallback = Arc::new(RwLock::new(TokenBucket::new(config.clone())));
625        Self {
626            key,
627            config,
628            fallback,
629        }
630    }
631
632    /// Get the Redis key for token count
633    ///
634    /// Used by backend implementations to store token count.
635    #[inline]
636    #[must_use]
637    pub fn token_key(&self) -> String {
638        format!("{}:tokens", self.key)
639    }
640
641    /// Get the Redis key for last refill timestamp
642    ///
643    /// Used by backend implementations to store last refill time.
644    #[inline]
645    #[must_use]
646    pub fn refill_key(&self) -> String {
647        format!("{}:refill", self.key)
648    }
649
650    /// Get the Redis key for sliding window
651    ///
652    /// Used by backend implementations to store sliding window data.
653    #[inline]
654    #[must_use]
655    pub fn window_key(&self) -> String {
656        format!("{}:window", self.key)
657    }
658
659    /// Try to acquire using local fallback
660    fn try_acquire_fallback(&self) -> bool {
661        if let Ok(mut guard) = self.fallback.write() {
662            guard.try_acquire()
663        } else {
664            // If lock is poisoned, allow execution
665            true
666        }
667    }
668}
669
670/// Distributed token bucket implementation
671///
672/// Uses atomic operations in a distributed backend (e.g., Redis Lua scripts)
673/// to implement token bucket algorithm across multiple workers.
674///
675/// # Redis Implementation
676///
677/// The token bucket is implemented using Redis with two keys:
678/// - `{key}:tokens` - Current token count (float)
679/// - `{key}:refill` - Last refill timestamp (integer, milliseconds since epoch)
680///
681/// A Lua script performs atomic token refill and acquisition:
682/// 1. Calculate elapsed time since last refill
683/// 2. Add tokens based on elapsed time and rate
684/// 3. Cap tokens at burst capacity
685/// 4. Attempt to consume 1 token
686/// 5. Update last refill timestamp
687///
688/// # Example Lua Script
689///
690/// ```lua
691/// local tokens_key = KEYS[1]
692/// local refill_key = KEYS[2]
693/// local rate = tonumber(ARGV[1])
694/// local burst = tonumber(ARGV[2])
695/// local now = tonumber(ARGV[3])
696///
697/// local last_refill = redis.call('GET', refill_key)
698/// local tokens = redis.call('GET', tokens_key)
699///
700/// if not tokens then
701///     tokens = burst
702/// else
703///     tokens = tonumber(tokens)
704/// end
705///
706/// if last_refill then
707///     local elapsed = (now - tonumber(last_refill)) / 1000.0
708///     tokens = math.min(tokens + elapsed * rate, burst)
709/// end
710///
711/// if tokens >= 1.0 then
712///     tokens = tokens - 1.0
713///     redis.call('SET', tokens_key, tostring(tokens))
714///     redis.call('SET', refill_key, tostring(now))
715///     return 1
716/// else
717///     redis.call('SET', tokens_key, tostring(tokens))
718///     redis.call('SET', refill_key, tostring(now))
719///     return 0
720/// end
721/// ```
722#[derive(Debug, Clone)]
723pub struct DistributedTokenBucketSpec {
724    state: DistributedRateLimiterState,
725}
726
727impl DistributedTokenBucketSpec {
728    /// Create a new distributed token bucket specification
729    ///
730    /// This creates the specification for a distributed token bucket.
731    /// Actual implementation requires a backend (e.g., Redis client).
732    #[must_use]
733    pub fn new(key: String, config: RateLimitConfig) -> Self {
734        Self {
735            state: DistributedRateLimiterState::new(key, config),
736        }
737    }
738
739    /// Get the Lua script for atomic token acquisition
740    ///
741    /// This script should be loaded into Redis using SCRIPT LOAD
742    /// and executed with EVALSHA for better performance.
743    #[must_use]
744    pub fn lua_acquire_script() -> &'static str {
745        r"
746        local tokens_key = KEYS[1]
747        local refill_key = KEYS[2]
748        local rate = tonumber(ARGV[1])
749        local burst = tonumber(ARGV[2])
750        local now = tonumber(ARGV[3])
751        local ttl = tonumber(ARGV[4])
752
753        local last_refill = redis.call('GET', refill_key)
754        local tokens = redis.call('GET', tokens_key)
755
756        if not tokens then
757            tokens = burst
758        else
759            tokens = tonumber(tokens)
760        end
761
762        if last_refill then
763            local elapsed = (now - tonumber(last_refill)) / 1000.0
764            tokens = math.min(tokens + elapsed * rate, burst)
765        end
766
767        if tokens >= 1.0 then
768            tokens = tokens - 1.0
769            redis.call('SET', tokens_key, tostring(tokens), 'EX', ttl)
770            redis.call('SET', refill_key, tostring(now), 'EX', ttl)
771            return {1, tokens}
772        else
773            redis.call('SET', tokens_key, tostring(tokens), 'EX', ttl)
774            redis.call('SET', refill_key, tostring(now), 'EX', ttl)
775            return {0, tokens}
776        end
777        "
778    }
779
780    /// Get the Lua script for querying available permits
781    #[must_use]
782    pub fn lua_available_script() -> &'static str {
783        r"
784        local tokens_key = KEYS[1]
785        local refill_key = KEYS[2]
786        local rate = tonumber(ARGV[1])
787        local burst = tonumber(ARGV[2])
788        local now = tonumber(ARGV[3])
789
790        local last_refill = redis.call('GET', refill_key)
791        local tokens = redis.call('GET', tokens_key)
792
793        if not tokens then
794            return burst
795        else
796            tokens = tonumber(tokens)
797        end
798
799        if last_refill then
800            local elapsed = (now - tonumber(last_refill)) / 1000.0
801            tokens = math.min(tokens + elapsed * rate, burst)
802        end
803
804        return math.floor(tokens)
805        "
806    }
807
808    /// Get the state for implementing the distributed backend
809    #[inline]
810    #[must_use]
811    pub fn state(&self) -> &DistributedRateLimiterState {
812        &self.state
813    }
814
815    /// Try to acquire using local fallback
816    #[must_use]
817    pub fn try_acquire_fallback(&self) -> bool {
818        self.state.try_acquire_fallback()
819    }
820}
821
822/// Distributed sliding window implementation
823///
824/// Uses sorted sets in a distributed backend (e.g., Redis ZSET)
825/// to implement sliding window algorithm across multiple workers.
826///
827/// # Redis Implementation
828///
829/// The sliding window is implemented using Redis sorted set:
830/// - `{key}:window` - Sorted set of timestamps (score = timestamp, member = UUID)
831///
832/// Operations:
833/// 1. **Acquire**: Add current timestamp to sorted set if count < limit
834/// 2. **Cleanup**: Remove timestamps outside the window using ZREMRANGEBYSCORE
835/// 3. **Count**: Count timestamps within window using ZCOUNT
836///
837/// # Example Lua Script
838///
839/// ```lua
840/// local window_key = KEYS[1]
841/// local now = tonumber(ARGV[1])
842/// local window_size = tonumber(ARGV[2])
843/// local max_count = tonumber(ARGV[3])
844/// local uuid = ARGV[4]
845///
846/// local cutoff = now - window_size * 1000
847/// redis.call('ZREMRANGEBYSCORE', window_key, '-inf', cutoff)
848///
849/// local count = redis.call('ZCARD', window_key)
850/// if count < max_count then
851///     redis.call('ZADD', window_key, now, uuid)
852///     redis.call('EXPIRE', window_key, window_size * 2)
853///     return 1
854/// else
855///     return 0
856/// end
857/// ```
858#[derive(Debug, Clone)]
859pub struct DistributedSlidingWindowSpec {
860    state: DistributedRateLimiterState,
861}
862
863impl DistributedSlidingWindowSpec {
864    /// Create a new distributed sliding window specification
865    #[must_use]
866    pub fn new(key: String, config: RateLimitConfig) -> Self {
867        Self {
868            state: DistributedRateLimiterState::new(key, config),
869        }
870    }
871
872    /// Get the Lua script for atomic window acquisition
873    #[must_use]
874    pub fn lua_acquire_script() -> &'static str {
875        r"
876        local window_key = KEYS[1]
877        local now = tonumber(ARGV[1])
878        local window_size = tonumber(ARGV[2])
879        local max_count = tonumber(ARGV[3])
880        local uuid = ARGV[4]
881
882        local cutoff = now - window_size * 1000
883        redis.call('ZREMRANGEBYSCORE', window_key, '-inf', cutoff)
884
885        local count = redis.call('ZCARD', window_key)
886        if count < max_count then
887            redis.call('ZADD', window_key, now, uuid)
888            redis.call('EXPIRE', window_key, window_size * 2)
889            return {1, max_count - count - 1}
890        else
891            return {0, 0}
892        end
893        "
894    }
895
896    /// Get the Lua script for querying available permits
897    #[must_use]
898    pub fn lua_available_script() -> &'static str {
899        r"
900        local window_key = KEYS[1]
901        local now = tonumber(ARGV[1])
902        local window_size = tonumber(ARGV[2])
903        local max_count = tonumber(ARGV[3])
904
905        local cutoff = now - window_size * 1000
906        redis.call('ZREMRANGEBYSCORE', window_key, '-inf', cutoff)
907
908        local count = redis.call('ZCARD', window_key)
909        return math.max(0, max_count - count)
910        "
911    }
912
913    /// Get the Lua script for querying time until available
914    #[must_use]
915    pub fn lua_time_until_script() -> &'static str {
916        r"
917        local window_key = KEYS[1]
918        local now = tonumber(ARGV[1])
919        local window_size = tonumber(ARGV[2])
920        local max_count = tonumber(ARGV[3])
921
922        local cutoff = now - window_size * 1000
923        redis.call('ZREMRANGEBYSCORE', window_key, '-inf', cutoff)
924
925        local count = redis.call('ZCARD', window_key)
926        if count < max_count then
927            return 0
928        else
929            local oldest = redis.call('ZRANGE', window_key, 0, 0, 'WITHSCORES')
930            if #oldest >= 2 then
931                local oldest_timestamp = tonumber(oldest[2])
932                local expires = oldest_timestamp + window_size * 1000
933                return math.max(0, expires - now)
934            else
935                return 0
936            end
937    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::cast_precision_loss)]
938        "
939    }
940
941    /// Get the maximum number of executions allowed in the window
942    #[must_use]
943    #[inline]
944    pub fn max_executions(&self) -> usize {
945        (self.state.config.rate * self.state.config.window_size as f64).ceil() as usize
946    }
947
948    /// Get the state for implementing the distributed backend
949    #[inline]
950    #[must_use]
951    pub fn state(&self) -> &DistributedRateLimiterState {
952        &self.state
953    }
954
955    /// Try to acquire using local fallback
956    #[must_use]
957    pub fn try_acquire_fallback(&self) -> bool {
958        self.state.try_acquire_fallback()
959    }
960}
961
962/// Distributed rate limiter coordinator
963///
964/// Manages distributed rate limiters for multiple task types,
965/// allowing cluster-wide rate limit enforcement.
966///
967/// # Features
968///
969/// - **Per-task rate limits**: Different rate limits for each task type
970/// - **Cluster-wide coordination**: Rate limits enforced across all workers
971/// - **Automatic fallback**: Gracefully degrade to local rate limiting on backend failure
972/// - **Configuration management**: Dynamic rate limit updates
973///
974/// # Example
975///
976/// ```rust,ignore
977/// use celers_core::rate_limit::{DistributedRateLimiterCoordinator, RateLimitConfig};
978///
979/// let coordinator = DistributedRateLimiterCoordinator::new("myapp");
980///
981/// // Set cluster-wide rate limit for a task
982/// coordinator.set_task_rate(
983///     "send_email",
984///     RateLimitConfig::new(100.0).with_burst(200),
985/// );
986///
987/// // Try to acquire across the cluster
988/// if coordinator.try_acquire("send_email").await? {
989///     send_email().await?;
990/// }
991/// ```
992#[derive(Debug, Clone)]
993pub struct DistributedRateLimiterCoordinator {
994    /// Application namespace for Redis keys
995    namespace: String,
996    /// Per-task token bucket specs
997    token_buckets: Arc<RwLock<HashMap<String, DistributedTokenBucketSpec>>>,
998    /// Per-task sliding window specs
999    sliding_windows: Arc<RwLock<HashMap<String, DistributedSlidingWindowSpec>>>,
1000    /// Default configuration for tasks without specific limits
1001    default_config: Option<RateLimitConfig>,
1002}
1003
1004impl DistributedRateLimiterCoordinator {
1005    /// Create a new distributed rate limiter coordinator
1006    ///
1007    /// # Arguments
1008    ///
1009    /// * `namespace` - Application namespace for Redis keys (e.g., "myapp")
1010    pub fn new(namespace: impl Into<String>) -> Self {
1011        Self {
1012            namespace: namespace.into(),
1013            token_buckets: Arc::new(RwLock::new(HashMap::new())),
1014            sliding_windows: Arc::new(RwLock::new(HashMap::new())),
1015            default_config: None,
1016        }
1017    }
1018
1019    /// Create with a default rate limit for all tasks
1020    pub fn with_default(namespace: impl Into<String>, config: RateLimitConfig) -> Self {
1021        Self {
1022            namespace: namespace.into(),
1023            token_buckets: Arc::new(RwLock::new(HashMap::new())),
1024            sliding_windows: Arc::new(RwLock::new(HashMap::new())),
1025            default_config: Some(config),
1026        }
1027    }
1028
1029    /// Set distributed rate limit for a specific task type
1030    ///
1031    /// Creates a distributed rate limiter spec that can be used by
1032    /// backend implementations (e.g., Redis).
1033    pub fn set_task_rate(&self, task_name: impl Into<String>, config: RateLimitConfig) {
1034        let name = task_name.into();
1035        let key = format!("{}:ratelimit:{}", self.namespace, name);
1036
1037        if config.sliding_window {
1038            if let Ok(mut guard) = self.sliding_windows.write() {
1039                guard.insert(name.clone(), DistributedSlidingWindowSpec::new(key, config));
1040            }
1041        } else if let Ok(mut guard) = self.token_buckets.write() {
1042            guard.insert(name.clone(), DistributedTokenBucketSpec::new(key, config));
1043        }
1044    }
1045
1046    /// Remove rate limit for a specific task type
1047    pub fn remove_task_rate(&self, task_name: &str) {
1048        if let Ok(mut guard) = self.token_buckets.write() {
1049            guard.remove(task_name);
1050        }
1051        if let Ok(mut guard) = self.sliding_windows.write() {
1052            guard.remove(task_name);
1053        }
1054    }
1055
1056    /// Get the token bucket spec for a task (if using token bucket)
1057    #[inline]
1058    #[must_use]
1059    pub fn get_token_bucket_spec(&self, task_name: &str) -> Option<DistributedTokenBucketSpec> {
1060        if let Ok(guard) = self.token_buckets.read() {
1061            guard.get(task_name).cloned()
1062        } else {
1063            None
1064        }
1065    }
1066
1067    /// Get the sliding window spec for a task (if using sliding window)
1068    #[inline]
1069    #[must_use]
1070    pub fn get_sliding_window_spec(&self, task_name: &str) -> Option<DistributedSlidingWindowSpec> {
1071        if let Ok(guard) = self.sliding_windows.read() {
1072            guard.get(task_name).cloned()
1073        } else {
1074            None
1075        }
1076    }
1077
1078    /// Check if a task has a distributed rate limit configured
1079    #[inline]
1080    #[must_use]
1081    pub fn has_rate_limit(&self, task_name: &str) -> bool {
1082        let has_bucket = if let Ok(guard) = self.token_buckets.read() {
1083            guard.contains_key(task_name)
1084        } else {
1085            false
1086        };
1087
1088        let has_window = if let Ok(guard) = self.sliding_windows.read() {
1089            guard.contains_key(task_name)
1090        } else {
1091            false
1092        };
1093
1094        has_bucket || has_window || self.default_config.is_some()
1095    }
1096
1097    /// Try to acquire using local fallback for a task
1098    ///
1099    /// This method is useful when the distributed backend is unavailable.
1100    #[must_use]
1101    pub fn try_acquire_fallback(&self, task_name: &str) -> bool {
1102        // Try token bucket first
1103        if let Some(spec) = self.get_token_bucket_spec(task_name) {
1104            return spec.try_acquire_fallback();
1105        }
1106
1107        // Try sliding window
1108        if let Some(spec) = self.get_sliding_window_spec(task_name) {
1109            return spec.try_acquire_fallback();
1110        }
1111
1112        // Use default config if available
1113        if let Some(ref config) = self.default_config {
1114            let key = format!("{}:ratelimit:{}", self.namespace, task_name);
1115            let spec = DistributedTokenBucketSpec::new(key, config.clone());
1116            return spec.try_acquire_fallback();
1117        }
1118
1119        // No rate limit configured
1120        true
1121    }
1122
1123    /// Get the Redis key for a task's rate limiter
1124    #[must_use]
1125    pub fn redis_key(&self, task_name: &str) -> String {
1126        format!("{}:ratelimit:{}", self.namespace, task_name)
1127    }
1128}
1129
1130#[cfg(test)]
1131mod tests {
1132    use super::*;
1133    use std::thread;
1134    use std::time::Duration;
1135
1136    #[test]
1137    fn test_token_bucket_basic() {
1138        let config = RateLimitConfig::new(10.0).with_burst(5);
1139        let mut limiter = TokenBucket::new(config);
1140
1141        // Should be able to acquire up to burst capacity immediately
1142        for _ in 0..5 {
1143            assert!(limiter.try_acquire());
1144        }
1145
1146        // Next acquisition should fail
1147        assert!(!limiter.try_acquire());
1148    }
1149
1150    #[test]
1151    fn test_token_bucket_refill() {
1152        let config = RateLimitConfig::new(100.0).with_burst(10);
1153        let mut limiter = TokenBucket::new(config);
1154
1155        // Exhaust all tokens
1156        for _ in 0..10 {
1157            assert!(limiter.try_acquire());
1158        }
1159        assert!(!limiter.try_acquire());
1160
1161        // Wait for refill (10ms should give us ~1 token at 100/sec)
1162        thread::sleep(Duration::from_millis(15));
1163
1164        // Should have at least 1 token now
1165        assert!(limiter.try_acquire());
1166    }
1167
1168    #[test]
1169    fn test_sliding_window_basic() {
1170        let config = RateLimitConfig::new(5.0).with_sliding_window(1);
1171        let mut limiter = SlidingWindow::new(config);
1172
1173        // Should be able to execute 5 tasks in 1 second window
1174        for _ in 0..5 {
1175            assert!(limiter.try_acquire());
1176        }
1177
1178        // Next acquisition should fail
1179        assert!(!limiter.try_acquire());
1180    }
1181
1182    #[test]
1183    fn test_task_rate_limiter() {
1184        let mut manager = TaskRateLimiter::new();
1185
1186        // Set rate limit for task_a
1187        manager.set_task_rate("task_a", RateLimitConfig::new(10.0).with_burst(2));
1188
1189        // task_a should be rate limited
1190        assert!(manager.try_acquire("task_a"));
1191        assert!(manager.try_acquire("task_a"));
1192        assert!(!manager.try_acquire("task_a"));
1193
1194        // task_b has no rate limit, should always pass
1195        assert!(manager.try_acquire("task_b"));
1196        assert!(manager.try_acquire("task_b"));
1197        assert!(manager.try_acquire("task_b"));
1198    }
1199
1200    #[test]
1201    fn test_task_rate_limiter_default() {
1202        let mut manager = TaskRateLimiter::with_default(RateLimitConfig::new(10.0).with_burst(2));
1203
1204        // All tasks should use default rate limit
1205        assert!(manager.try_acquire("task_a"));
1206        assert!(manager.try_acquire("task_a"));
1207        assert!(!manager.try_acquire("task_a"));
1208
1209        assert!(manager.try_acquire("task_b"));
1210        assert!(manager.try_acquire("task_b"));
1211        assert!(!manager.try_acquire("task_b"));
1212    }
1213
1214    #[test]
1215    fn test_worker_rate_limiter_thread_safe() {
1216        let limiter = WorkerRateLimiter::new();
1217        // Use a very low rate (0.1 tokens/sec = 1 token per 10 seconds) to prevent
1218        // token regeneration during test execution, which would cause flaky results
1219        limiter.set_task_rate("task_a", RateLimitConfig::new(0.1).with_burst(10));
1220
1221        let limiter_clone = limiter.clone();
1222
1223        // Spawn multiple threads to test thread safety
1224        let handles: Vec<_> = (0..4)
1225            .map(|_| {
1226                let l = limiter_clone.clone();
1227                thread::spawn(move || {
1228                    let mut count = 0;
1229                    for _ in 0..5 {
1230                        if l.try_acquire("task_a") {
1231                            count += 1;
1232                        }
1233                    }
1234                    count
1235                })
1236            })
1237            .collect();
1238
1239        let total: usize = handles.into_iter().map(|h| h.join().unwrap()).sum();
1240
1241        // Total acquisitions should not exceed burst capacity
1242        assert!(total <= 10);
1243    }
1244
1245    #[test]
1246    fn test_rate_limit_config_serialization() {
1247        let config = RateLimitConfig::new(50.0)
1248            .with_burst(100)
1249            .with_sliding_window(10);
1250
1251        let json = serde_json::to_string(&config).unwrap();
1252        let parsed: RateLimitConfig = serde_json::from_str(&json).unwrap();
1253
1254        assert!((parsed.rate - 50.0).abs() < f64::EPSILON);
1255        assert_eq!(parsed.burst, Some(100));
1256        assert!(parsed.sliding_window);
1257        assert_eq!(parsed.window_size, 10);
1258    }
1259
1260    #[test]
1261    fn test_time_until_available() {
1262        let config = RateLimitConfig::new(10.0).with_burst(1);
1263        let mut limiter = TokenBucket::new(config);
1264
1265        // Exhaust the token
1266        assert!(limiter.try_acquire());
1267
1268        // Time until available should be around 100ms (1 token at 10/sec)
1269        let wait_time = limiter.time_until_available();
1270        assert!(wait_time > Duration::ZERO);
1271        assert!(wait_time <= Duration::from_millis(150));
1272    }
1273
1274    #[test]
1275    fn test_reset() {
1276        let config = RateLimitConfig::new(10.0).with_burst(5);
1277        let mut limiter = TokenBucket::new(config);
1278
1279        // Exhaust all tokens
1280        for _ in 0..5 {
1281            limiter.try_acquire();
1282        }
1283        assert!(!limiter.try_acquire());
1284
1285        // Reset should restore tokens
1286        limiter.reset();
1287        assert!(limiter.try_acquire());
1288    }
1289
1290    #[test]
1291    fn test_set_rate() {
1292        let config = RateLimitConfig::new(10.0).with_burst(10);
1293        let mut limiter = TokenBucket::new(config);
1294
1295        // Update rate
1296        limiter.set_rate(100.0);
1297        assert!((limiter.config().rate - 100.0).abs() < f64::EPSILON);
1298    }
1299
1300    #[test]
1301    fn test_create_rate_limiter() {
1302        // Token bucket
1303        let config = RateLimitConfig::new(10.0);
1304        let mut limiter = create_rate_limiter(config);
1305        assert!(limiter.try_acquire());
1306
1307        // Sliding window
1308        let config = RateLimitConfig::new(10.0).with_sliding_window(1);
1309        let mut limiter = create_rate_limiter(config);
1310        assert!(limiter.try_acquire());
1311    }
1312}