riglr_core/util/
token_bucket.rs1use super::rate_limit_strategy::{ClientRateInfo, RateLimitStrategy};
7use crate::ToolError;
8use dashmap::DashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone)]
19pub struct TokenBucketStrategy {
20 max_requests: usize,
22 time_window: Duration,
24 burst_size: Option<usize>,
26 clients: Arc<DashMap<String, ClientRateInfo>>,
28}
29
30impl TokenBucketStrategy {
31 pub fn new(max_requests: usize, time_window: Duration) -> Self {
33 Self {
34 max_requests,
35 time_window,
36 burst_size: None,
37 clients: Arc::new(DashMap::new()),
38 }
39 }
40
41 pub fn with_burst(max_requests: usize, time_window: Duration, burst_size: usize) -> Self {
43 Self {
44 max_requests,
45 time_window,
46 burst_size: Some(burst_size),
47 clients: Arc::new(DashMap::new()),
48 }
49 }
50}
51
52impl RateLimitStrategy for TokenBucketStrategy {
53 fn check_rate_limit(&self, client_id: &str) -> Result<(), ToolError> {
54 let now = Instant::now();
55 let mut entry = self
56 .clients
57 .entry(client_id.to_string())
58 .or_insert_with(|| {
59 ClientRateInfo::new(self.burst_size.unwrap_or(self.max_requests) as f64)
60 });
61
62 let elapsed = now.duration_since(entry.last_refill);
64 let elapsed_seconds = elapsed.as_secs_f64();
65
66 let refill_rate = self.max_requests as f64 / self.time_window.as_secs_f64();
68
69 let tokens_to_add = elapsed_seconds * refill_rate;
71
72 let max_tokens = self.burst_size.unwrap_or(self.max_requests) as f64;
74 entry.burst_tokens = (entry.burst_tokens + tokens_to_add).min(max_tokens);
75 entry.last_refill = now;
76
77 entry
79 .request_times
80 .retain(|&time| now.duration_since(time) < self.time_window);
81
82 let has_burst_token = entry.burst_tokens >= 1.0;
84
85 if !has_burst_token {
86 let tokens_needed = 1.0 - entry.burst_tokens;
88 let seconds_until_token = tokens_needed / refill_rate;
89 let retry_after = Duration::from_secs_f64(seconds_until_token);
90
91 return Err(ToolError::RateLimited {
92 source: None,
93 source_message: format!(
94 "Token bucket rate limit: {} requests per {:?}",
95 self.max_requests, self.time_window
96 ),
97 context: format!("User exceeded rate limit of {} requests", self.max_requests),
98 retry_after: Some(retry_after),
99 });
100 }
101
102 entry.burst_tokens -= 1.0;
104 entry.request_times.push(now);
105 Ok(())
106 }
107
108 fn reset_client(&self, client_id: &str) {
109 self.clients.remove(client_id);
110 }
111
112 fn clear_all(&self) {
113 self.clients.clear();
114 }
115
116 fn get_request_count(&self, client_id: &str) -> usize {
117 let now = Instant::now();
118 self.clients
119 .get(client_id)
120 .map(|entry| {
121 entry
122 .request_times
123 .iter()
124 .filter(|&&time| now.duration_since(time) < self.time_window)
125 .count()
126 })
127 .unwrap_or(0)
128 }
129
130 fn strategy_name(&self) -> &str {
131 "TokenBucket"
132 }
133}