ai_lib/rate_limiter/
token_bucket.rs

1//! Token bucket rate limiter implementation
2
3use crate::metrics::Metrics;
4use crate::rate_limiter::RateLimiterConfig;
5use serde::{Deserialize, Serialize};
6use std::sync::{
7    atomic::{AtomicBool, AtomicU64, Ordering},
8    Arc,
9};
10use std::time::{Duration, Instant};
11use tokio::time::sleep;
12
13/// Rate limiter metrics for monitoring
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct RateLimiterMetrics {
16    pub current_tokens: u64,
17    pub capacity: u64,
18    pub refill_rate: u64,
19    pub total_requests: u64,
20    pub successful_requests: u64,
21    pub rejected_requests: u64,
22    pub adaptive_rate: Option<u64>,
23    pub is_adaptive: bool,
24    pub uptime: Duration,
25}
26
27/// Token bucket rate limiter
28pub struct TokenBucket {
29    capacity: u64,
30    tokens: AtomicU64,
31    refill_rate: u64,
32    last_refill: AtomicU64,
33    // Adaptive rate limiting
34    adaptive: bool,
35    adaptive_rate: AtomicU64,
36    min_rate: u64,
37    max_rate: u64,
38    // Metrics
39    total_requests: AtomicU64,
40    successful_requests: AtomicU64,
41    rejected_requests: AtomicU64,
42    start_time: Instant,
43    // Optional metrics collector
44    metrics: Option<Arc<dyn Metrics>>,
45    // Rate limiter enabled flag
46    enabled: AtomicBool,
47}
48
49impl TokenBucket {
50    /// Create a new token bucket with the given configuration
51    pub fn new(config: RateLimiterConfig) -> Self {
52        let capacity = config.burst_capacity;
53        let refill_rate = config.requests_per_second;
54        let adaptive = config.adaptive;
55        let initial_rate = config.initial_rate.unwrap_or(refill_rate);
56
57        Self {
58            capacity,
59            tokens: AtomicU64::new(capacity),
60            refill_rate,
61            last_refill: AtomicU64::new(Instant::now().elapsed().as_millis() as u64),
62            adaptive,
63            adaptive_rate: AtomicU64::new(initial_rate),
64            min_rate: (refill_rate / 4).max(1), // Minimum 25% of original rate
65            max_rate: refill_rate * 2,          // Maximum 200% of original rate
66            total_requests: AtomicU64::new(0),
67            successful_requests: AtomicU64::new(0),
68            rejected_requests: AtomicU64::new(0),
69            start_time: Instant::now(),
70            metrics: None,
71            enabled: AtomicBool::new(true),
72        }
73    }
74
75    /// Create a new token bucket with metrics collection
76    pub fn with_metrics(config: RateLimiterConfig, metrics: Arc<dyn Metrics>) -> Self {
77        let mut bucket = Self::new(config);
78        bucket.metrics = Some(metrics);
79        bucket
80    }
81
82    /// Enable or disable the rate limiter
83    pub fn set_enabled(&self, enabled: bool) {
84        self.enabled.store(enabled, Ordering::Release);
85    }
86
87    /// Acquire the specified number of tokens
88    pub async fn acquire(&self, tokens: u64) -> Result<(), RateLimitError> {
89        // Check if rate limiter is enabled
90        if !self.enabled.load(Ordering::Acquire) {
91            return Ok(());
92        }
93
94        // Increment total requests counter
95        self.total_requests.fetch_add(1, Ordering::Relaxed);
96
97        if tokens > self.capacity {
98            self.rejected_requests.fetch_add(1, Ordering::Relaxed);
99            return Err(RateLimitError::RequestTooLarge {
100                requested: tokens,
101                max_allowed: self.capacity,
102            });
103        }
104
105        loop {
106            self.refill_tokens();
107
108            let current = self.tokens.load(Ordering::Acquire);
109            if current >= tokens {
110                if self
111                    .tokens
112                    .compare_exchange_weak(
113                        current,
114                        current - tokens,
115                        Ordering::Release,
116                        Ordering::Relaxed,
117                    )
118                    .is_ok()
119                {
120                    self.successful_requests.fetch_add(1, Ordering::Relaxed);
121
122                    // Record metrics
123                    if let Some(metrics) = &self.metrics {
124                        metrics
125                            .incr_counter("rate_limiter.requests_successful", 1)
126                            .await;
127                    }
128
129                    return Ok(());
130                }
131            } else {
132                // Calculate wait time based on current rate
133                let current_rate = if self.adaptive {
134                    self.adaptive_rate.load(Ordering::Acquire)
135                } else {
136                    self.refill_rate
137                };
138
139                let wait_time = (tokens - current) * 1000 / current_rate;
140                if wait_time > 0 {
141                    sleep(Duration::from_millis(wait_time)).await;
142                }
143            }
144        }
145    }
146
147    /// Refill tokens based on elapsed time
148    fn refill_tokens(&self) {
149        let now = Instant::now().elapsed().as_millis() as u64;
150        let last_refill = self.last_refill.load(Ordering::Acquire);
151        let elapsed = now - last_refill;
152
153        if elapsed > 0 {
154            let current_rate = if self.adaptive {
155                self.adaptive_rate.load(Ordering::Acquire)
156            } else {
157                self.refill_rate
158            };
159
160            let tokens_to_add = (elapsed * current_rate) / 1000;
161            if tokens_to_add > 0 {
162                self.last_refill.store(now, Ordering::Release);
163
164                let current = self.tokens.load(Ordering::Acquire);
165                let new_tokens = (current + tokens_to_add).min(self.capacity);
166                self.tokens.store(new_tokens, Ordering::Release);
167            }
168        }
169    }
170
171    /// Get current token count
172    pub fn tokens(&self) -> u64 {
173        self.tokens.load(Ordering::Acquire)
174    }
175
176    /// Get comprehensive metrics
177    pub fn get_metrics(&self) -> RateLimiterMetrics {
178        RateLimiterMetrics {
179            current_tokens: self.tokens(),
180            capacity: self.capacity,
181            refill_rate: self.refill_rate,
182            total_requests: self.total_requests.load(Ordering::Relaxed),
183            successful_requests: self.successful_requests.load(Ordering::Relaxed),
184            rejected_requests: self.rejected_requests.load(Ordering::Relaxed),
185            adaptive_rate: if self.adaptive {
186                Some(self.adaptive_rate.load(Ordering::Relaxed))
187            } else {
188                None
189            },
190            is_adaptive: self.adaptive,
191            uptime: self.start_time.elapsed(),
192        }
193    }
194
195    /// Get success rate as a percentage
196    pub fn success_rate(&self) -> f64 {
197        let total = self.total_requests.load(Ordering::Relaxed);
198        if total == 0 {
199            return 100.0;
200        }
201        let successful = self.successful_requests.load(Ordering::Relaxed);
202        (successful as f64 / total as f64) * 100.0
203    }
204
205    /// Get rejection rate as a percentage
206    pub fn rejection_rate(&self) -> f64 {
207        let total = self.total_requests.load(Ordering::Relaxed);
208        if total == 0 {
209            return 0.0;
210        }
211        let rejected = self.rejected_requests.load(Ordering::Relaxed);
212        (rejected as f64 / total as f64) * 100.0
213    }
214
215    /// Adjust adaptive rate based on success/failure patterns
216    pub fn adjust_rate(&self, success: bool) {
217        if !self.adaptive {
218            return;
219        }
220
221        let current_rate = self.adaptive_rate.load(Ordering::Acquire);
222        let new_rate = if success {
223            // Increase rate gradually on success
224            (current_rate * 11) / 10 // 10% increase
225        } else {
226            // Decrease rate more aggressively on failure
227            (current_rate * 9) / 10 // 10% decrease
228        };
229
230        let clamped_rate = new_rate.clamp(self.min_rate, self.max_rate);
231        self.adaptive_rate.store(clamped_rate, Ordering::Release);
232
233        // Record metrics
234        if let Some(metrics) = &self.metrics {
235            tokio::spawn({
236                let metrics = metrics.clone();
237                async move {
238                    metrics
239                        .record_gauge("rate_limiter.adaptive_rate", clamped_rate as f64)
240                        .await;
241                }
242            });
243        }
244    }
245
246    /// Reset adaptive rate to initial value
247    pub fn reset_adaptive_rate(&self) {
248        if self.adaptive {
249            self.adaptive_rate
250                .store(self.refill_rate, Ordering::Release);
251        }
252    }
253
254    /// Set adaptive rate manually
255    pub fn set_adaptive_rate(&self, rate: u64) {
256        if self.adaptive {
257            let clamped_rate = rate.clamp(self.min_rate, self.max_rate);
258            self.adaptive_rate.store(clamped_rate, Ordering::Release);
259        }
260    }
261
262    /// Check if rate limiter is healthy
263    pub fn is_healthy(&self) -> bool {
264        self.success_rate() > 80.0 && self.rejection_rate() < 20.0
265    }
266
267    /// Reset all counters
268    pub fn reset(&self) {
269        self.total_requests.store(0, Ordering::Relaxed);
270        self.successful_requests.store(0, Ordering::Relaxed);
271        self.rejected_requests.store(0, Ordering::Relaxed);
272        self.tokens.store(self.capacity, Ordering::Relaxed);
273        self.reset_adaptive_rate();
274    }
275}
276
277/// Rate limiter error types
278#[derive(Debug, thiserror::Error)]
279pub enum RateLimitError {
280    #[error("Request size {requested} exceeds maximum allowed {max_allowed}")]
281    RequestTooLarge { requested: u64, max_allowed: u64 },
282    #[error("Rate limiter is disabled")]
283    Disabled,
284}