Skip to main content

car_engine/
rate_limit.rs

1//! Token bucket rate limiter for tool calls with backpressure support.
2
3use std::collections::HashMap;
4use std::time::Instant;
5use tokio::sync::Mutex;
6
7/// Per-tool rate limit configuration.
8#[derive(Debug, Clone)]
9pub struct RateLimit {
10    pub max_calls: u32,
11    pub interval_secs: f64,
12}
13
14/// Token bucket rate limiter for tool calls.
15///
16/// Each tool can have an independent rate limit. When a tool's bucket is empty,
17/// `acquire()` blocks until a token becomes available (backpressure).
18pub struct RateLimiter {
19    limits: Mutex<HashMap<String, RateLimit>>,
20    buckets: Mutex<HashMap<String, TokenBucket>>,
21}
22
23struct TokenBucket {
24    tokens: f64,
25    max_tokens: f64,
26    refill_rate: f64, // tokens per second
27    last_refill: Instant,
28}
29
30impl TokenBucket {
31    fn new(max_tokens: f64, refill_rate: f64) -> Self {
32        Self {
33            tokens: max_tokens,
34            max_tokens,
35            refill_rate,
36            last_refill: Instant::now(),
37        }
38    }
39
40    /// Refill tokens based on elapsed time since last refill.
41    fn refill(&mut self) {
42        let now = Instant::now();
43        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
44        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
45        self.last_refill = now;
46    }
47
48    /// Try to consume one token. Returns true if successful.
49    fn try_consume(&mut self) -> bool {
50        self.refill();
51        if self.tokens >= 1.0 {
52            self.tokens -= 1.0;
53            true
54        } else {
55            false
56        }
57    }
58
59    /// Seconds until one token is available (0.0 if already available).
60    fn time_until_available(&mut self) -> f64 {
61        self.refill();
62        if self.tokens >= 1.0 {
63            return 0.0;
64        }
65        let deficit = 1.0 - self.tokens;
66        deficit / self.refill_rate
67    }
68}
69
70impl RateLimiter {
71    /// Create an empty rate limiter with no limits configured.
72    pub fn new() -> Self {
73        Self {
74            limits: Mutex::new(HashMap::new()),
75            buckets: Mutex::new(HashMap::new()),
76        }
77    }
78
79    /// Configure a rate limit for a specific tool.
80    ///
81    /// `max_calls` tokens over `interval_secs` seconds. The refill rate is
82    /// `max_calls / interval_secs` tokens per second.
83    pub async fn set_limit(&self, tool: &str, limit: RateLimit) {
84        let max_tokens = limit.max_calls as f64;
85        let refill_rate = max_tokens / limit.interval_secs;
86
87        self.limits
88            .lock()
89            .await
90            .insert(tool.to_string(), limit);
91
92        self.buckets
93            .lock()
94            .await
95            .insert(tool.to_string(), TokenBucket::new(max_tokens, refill_rate));
96    }
97
98    /// Wait until a token is available for the given tool, then consume it.
99    ///
100    /// If no rate limit is configured for the tool, returns immediately.
101    /// This provides backpressure: callers block until capacity is available.
102    pub async fn acquire(&self, tool: &str) {
103        loop {
104            let wait_time = {
105                let mut buckets = self.buckets.lock().await;
106                let bucket = match buckets.get_mut(tool) {
107                    Some(b) => b,
108                    None => return, // no limit configured
109                };
110
111                if bucket.try_consume() {
112                    return;
113                }
114
115                bucket.time_until_available()
116            };
117
118            // Sleep outside the lock to allow other tasks to proceed.
119            tokio::time::sleep(std::time::Duration::from_secs_f64(wait_time)).await;
120        }
121    }
122
123    /// Non-blocking attempt to acquire a token for the given tool.
124    ///
125    /// Returns `true` if a token was consumed, `false` if the bucket is empty.
126    /// Returns `true` if no rate limit is configured for the tool.
127    pub async fn try_acquire(&self, tool: &str) -> bool {
128        let mut buckets = self.buckets.lock().await;
129        match buckets.get_mut(tool) {
130            Some(bucket) => bucket.try_consume(),
131            None => true, // no limit configured
132        }
133    }
134}
135
136impl Default for RateLimiter {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[tokio::test]
147    async fn test_token_bucket_refills_correctly() {
148        let limiter = RateLimiter::new();
149        limiter
150            .set_limit(
151                "tool_a",
152                RateLimit {
153                    max_calls: 2,
154                    interval_secs: 1.0,
155                },
156            )
157            .await;
158
159        // Consume both tokens.
160        assert!(limiter.try_acquire("tool_a").await);
161        assert!(limiter.try_acquire("tool_a").await);
162        // Bucket is empty.
163        assert!(!limiter.try_acquire("tool_a").await);
164
165        // Wait for refill (0.5s should refill 1 token at rate 2/s).
166        tokio::time::sleep(std::time::Duration::from_millis(550)).await;
167        assert!(limiter.try_acquire("tool_a").await);
168    }
169
170    #[tokio::test]
171    async fn test_acquire_blocks_when_empty() {
172        let limiter = Arc::new(RateLimiter::new());
173        limiter
174            .set_limit(
175                "tool_b",
176                RateLimit {
177                    max_calls: 1,
178                    interval_secs: 0.2,
179                },
180            )
181            .await;
182
183        // Drain the single token.
184        assert!(limiter.try_acquire("tool_b").await);
185        assert!(!limiter.try_acquire("tool_b").await);
186
187        // acquire() should block then return after ~0.2s refill.
188        let start = Instant::now();
189        limiter.acquire("tool_b").await;
190        let elapsed = start.elapsed();
191
192        assert!(
193            elapsed.as_millis() >= 100,
194            "acquire should have blocked; elapsed={}ms",
195            elapsed.as_millis()
196        );
197    }
198
199    #[tokio::test]
200    async fn test_independent_tool_limits() {
201        let limiter = RateLimiter::new();
202        limiter
203            .set_limit(
204                "fast",
205                RateLimit {
206                    max_calls: 10,
207                    interval_secs: 1.0,
208                },
209            )
210            .await;
211        limiter
212            .set_limit(
213                "slow",
214                RateLimit {
215                    max_calls: 1,
216                    interval_secs: 1.0,
217                },
218            )
219            .await;
220
221        // Drain the slow bucket.
222        assert!(limiter.try_acquire("slow").await);
223        assert!(!limiter.try_acquire("slow").await);
224
225        // fast bucket should still have tokens.
226        for _ in 0..10 {
227            assert!(limiter.try_acquire("fast").await);
228        }
229        assert!(!limiter.try_acquire("fast").await);
230    }
231
232    #[tokio::test]
233    async fn test_no_limit_always_passes() {
234        let limiter = RateLimiter::new();
235        // No limit set for "unconfigured".
236        assert!(limiter.try_acquire("unconfigured").await);
237        limiter.acquire("unconfigured").await; // should return immediately
238    }
239
240    use std::sync::Arc;
241
242    #[tokio::test]
243    async fn test_max_tokens_cap() {
244        let limiter = RateLimiter::new();
245        limiter
246            .set_limit(
247                "capped",
248                RateLimit {
249                    max_calls: 2,
250                    interval_secs: 1.0,
251                },
252            )
253            .await;
254
255        // Wait extra time -- tokens should not exceed max_tokens.
256        tokio::time::sleep(std::time::Duration::from_millis(600)).await;
257
258        assert!(limiter.try_acquire("capped").await);
259        assert!(limiter.try_acquire("capped").await);
260        assert!(!limiter.try_acquire("capped").await);
261    }
262}