riglr_core/util/
token_bucket.rs

1//! Token bucket rate limiting strategy
2//!
3//! This module implements a time-based token bucket algorithm where tokens
4//! are replenished continuously based on elapsed time.
5
6use super::rate_limit_strategy::{ClientRateInfo, RateLimitStrategy};
7use crate::ToolError;
8use dashmap::DashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12/// Token bucket rate limiting strategy
13///
14/// This strategy implements a time-based token bucket algorithm where:
15/// - Tokens are replenished continuously based on elapsed time
16/// - The replenishment rate is `max_requests` per `time_window`
17/// - Burst capacity allows temporary spikes up to `burst_size` tokens
18#[derive(Debug, Clone)]
19pub struct TokenBucketStrategy {
20    /// Maximum number of requests allowed in the time window
21    max_requests: usize,
22    /// Time window for rate limiting
23    time_window: Duration,
24    /// Optional burst size for allowing temporary spikes
25    burst_size: Option<usize>,
26    /// Map of client ID to their request history
27    clients: Arc<DashMap<String, ClientRateInfo>>,
28}
29
30impl TokenBucketStrategy {
31    /// Create a new token bucket strategy
32    pub fn new(max_requests: usize, time_window: Duration) -> Self {
33        Self {
34            max_requests,
35            time_window,
36            burst_size: None,
37            clients: Arc::new(DashMap::new()),
38        }
39    }
40
41    /// Create with burst capacity
42    pub fn with_burst(max_requests: usize, time_window: Duration, burst_size: usize) -> Self {
43        Self {
44            max_requests,
45            time_window,
46            burst_size: Some(burst_size),
47            clients: Arc::new(DashMap::new()),
48        }
49    }
50}
51
52impl RateLimitStrategy for TokenBucketStrategy {
53    fn check_rate_limit(&self, client_id: &str) -> Result<(), ToolError> {
54        let now = Instant::now();
55        let mut entry = self
56            .clients
57            .entry(client_id.to_string())
58            .or_insert_with(|| {
59                ClientRateInfo::new(self.burst_size.unwrap_or(self.max_requests) as f64)
60            });
61
62        // Time-based token replenishment
63        let elapsed = now.duration_since(entry.last_refill);
64        let elapsed_seconds = elapsed.as_secs_f64();
65
66        // Calculate the refill rate: tokens per second
67        let refill_rate = self.max_requests as f64 / self.time_window.as_secs_f64();
68
69        // Calculate tokens to add based on elapsed time
70        let tokens_to_add = elapsed_seconds * refill_rate;
71
72        // Add tokens up to the burst size limit (or max_requests if no burst size)
73        let max_tokens = self.burst_size.unwrap_or(self.max_requests) as f64;
74        entry.burst_tokens = (entry.burst_tokens + tokens_to_add).min(max_tokens);
75        entry.last_refill = now;
76
77        // Remove old requests outside the time window (for sustained rate tracking)
78        entry
79            .request_times
80            .retain(|&time| now.duration_since(time) < self.time_window);
81
82        // Check if we have tokens available
83        let has_burst_token = entry.burst_tokens >= 1.0;
84
85        if !has_burst_token {
86            // No tokens available, calculate retry time
87            let tokens_needed = 1.0 - entry.burst_tokens;
88            let seconds_until_token = tokens_needed / refill_rate;
89            let retry_after = Duration::from_secs_f64(seconds_until_token);
90
91            return Err(ToolError::RateLimited {
92                source: None,
93                source_message: format!(
94                    "Token bucket rate limit: {} requests per {:?}",
95                    self.max_requests, self.time_window
96                ),
97                context: format!("User exceeded rate limit of {} requests", self.max_requests),
98                retry_after: Some(retry_after),
99            });
100        }
101
102        // We have a token, consume it and allow the request
103        entry.burst_tokens -= 1.0;
104        entry.request_times.push(now);
105        Ok(())
106    }
107
108    fn reset_client(&self, client_id: &str) {
109        self.clients.remove(client_id);
110    }
111
112    fn clear_all(&self) {
113        self.clients.clear();
114    }
115
116    fn get_request_count(&self, client_id: &str) -> usize {
117        let now = Instant::now();
118        self.clients
119            .get(client_id)
120            .map(|entry| {
121                entry
122                    .request_times
123                    .iter()
124                    .filter(|&&time| now.duration_since(time) < self.time_window)
125                    .count()
126            })
127            .unwrap_or(0)
128    }
129
130    fn strategy_name(&self) -> &str {
131        "TokenBucket"
132    }
133}