Skip to main content

car_engine/
rate_limit.rs

1//! Token bucket rate limiter for tool calls with backpressure support.
2
3use std::collections::HashMap;
4use std::time::Instant;
5use tokio::sync::Mutex;
6
7/// Per-tool rate limit configuration.
8#[derive(Debug, Clone)]
9pub struct RateLimit {
10    pub max_calls: u32,
11    pub interval_secs: f64,
12}
13
14/// Token bucket rate limiter for tool calls.
15///
16/// Each tool can have an independent rate limit. When a tool's bucket is empty,
17/// `acquire()` blocks until a token becomes available (backpressure).
18pub struct RateLimiter {
19    limits: Mutex<HashMap<String, RateLimit>>,
20    buckets: Mutex<HashMap<String, TokenBucket>>,
21}
22
23struct TokenBucket {
24    tokens: f64,
25    max_tokens: f64,
26    refill_rate: f64, // tokens per second
27    last_refill: Instant,
28}
29
30impl TokenBucket {
31    fn new(max_tokens: f64, refill_rate: f64) -> Self {
32        Self {
33            tokens: max_tokens,
34            max_tokens,
35            refill_rate,
36            last_refill: Instant::now(),
37        }
38    }
39
40    /// Refill tokens based on elapsed time since last refill.
41    fn refill(&mut self) {
42        let now = Instant::now();
43        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
44        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
45        self.last_refill = now;
46    }
47
48    /// Try to consume one token. Returns true if successful.
49    fn try_consume(&mut self) -> bool {
50        self.refill();
51        if self.tokens >= 1.0 {
52            self.tokens -= 1.0;
53            true
54        } else {
55            false
56        }
57    }
58
59    /// Seconds until one token is available (0.0 if already available).
60    fn time_until_available(&mut self) -> f64 {
61        self.refill();
62        if self.tokens >= 1.0 {
63            return 0.0;
64        }
65        let deficit = 1.0 - self.tokens;
66        deficit / self.refill_rate
67    }
68}
69
70impl RateLimiter {
71    /// Create an empty rate limiter with no limits configured.
72    pub fn new() -> Self {
73        Self {
74            limits: Mutex::new(HashMap::new()),
75            buckets: Mutex::new(HashMap::new()),
76        }
77    }
78
79    /// Configure a rate limit for a specific tool.
80    ///
81    /// `max_calls` tokens over `interval_secs` seconds. The refill rate is
82    /// `max_calls / interval_secs` tokens per second.
83    pub async fn set_limit(&self, tool: &str, limit: RateLimit) {
84        let max_tokens = limit.max_calls as f64;
85        let refill_rate = max_tokens / limit.interval_secs;
86
87        self.limits.lock().await.insert(tool.to_string(), limit);
88
89        self.buckets
90            .lock()
91            .await
92            .insert(tool.to_string(), TokenBucket::new(max_tokens, refill_rate));
93    }
94
95    /// Wait until a token is available for the given tool, then consume it.
96    ///
97    /// If no rate limit is configured for the tool, returns immediately.
98    /// This provides backpressure: callers block until capacity is available.
99    pub async fn acquire(&self, tool: &str) {
100        loop {
101            let wait_time = {
102                let mut buckets = self.buckets.lock().await;
103                let bucket = match buckets.get_mut(tool) {
104                    Some(b) => b,
105                    None => return, // no limit configured
106                };
107
108                if bucket.try_consume() {
109                    return;
110                }
111
112                bucket.time_until_available()
113            };
114
115            // Sleep outside the lock to allow other tasks to proceed.
116            tokio::time::sleep(std::time::Duration::from_secs_f64(wait_time)).await;
117        }
118    }
119
120    /// Non-blocking attempt to acquire a token for the given tool.
121    ///
122    /// Returns `true` if a token was consumed, `false` if the bucket is empty.
123    /// Returns `true` if no rate limit is configured for the tool.
124    pub async fn try_acquire(&self, tool: &str) -> bool {
125        let mut buckets = self.buckets.lock().await;
126        match buckets.get_mut(tool) {
127            Some(bucket) => bucket.try_consume(),
128            None => true, // no limit configured
129        }
130    }
131}
132
133impl Default for RateLimiter {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[tokio::test]
144    async fn test_token_bucket_refills_correctly() {
145        let limiter = RateLimiter::new();
146        limiter
147            .set_limit(
148                "tool_a",
149                RateLimit {
150                    max_calls: 2,
151                    interval_secs: 1.0,
152                },
153            )
154            .await;
155
156        // Consume both tokens.
157        assert!(limiter.try_acquire("tool_a").await);
158        assert!(limiter.try_acquire("tool_a").await);
159        // Bucket is empty.
160        assert!(!limiter.try_acquire("tool_a").await);
161
162        // Wait for refill (0.5s should refill 1 token at rate 2/s).
163        tokio::time::sleep(std::time::Duration::from_millis(550)).await;
164        assert!(limiter.try_acquire("tool_a").await);
165    }
166
167    #[tokio::test]
168    async fn test_acquire_blocks_when_empty() {
169        let limiter = Arc::new(RateLimiter::new());
170        limiter
171            .set_limit(
172                "tool_b",
173                RateLimit {
174                    max_calls: 1,
175                    interval_secs: 0.2,
176                },
177            )
178            .await;
179
180        // Drain the single token.
181        assert!(limiter.try_acquire("tool_b").await);
182        assert!(!limiter.try_acquire("tool_b").await);
183
184        // acquire() should block then return after ~0.2s refill.
185        let start = Instant::now();
186        limiter.acquire("tool_b").await;
187        let elapsed = start.elapsed();
188
189        assert!(
190            elapsed.as_millis() >= 100,
191            "acquire should have blocked; elapsed={}ms",
192            elapsed.as_millis()
193        );
194    }
195
196    #[tokio::test]
197    async fn test_independent_tool_limits() {
198        let limiter = RateLimiter::new();
199        limiter
200            .set_limit(
201                "fast",
202                RateLimit {
203                    max_calls: 10,
204                    interval_secs: 1.0,
205                },
206            )
207            .await;
208        limiter
209            .set_limit(
210                "slow",
211                RateLimit {
212                    max_calls: 1,
213                    interval_secs: 1.0,
214                },
215            )
216            .await;
217
218        // Drain the slow bucket.
219        assert!(limiter.try_acquire("slow").await);
220        assert!(!limiter.try_acquire("slow").await);
221
222        // fast bucket should still have tokens.
223        for _ in 0..10 {
224            assert!(limiter.try_acquire("fast").await);
225        }
226        assert!(!limiter.try_acquire("fast").await);
227    }
228
229    #[tokio::test]
230    async fn test_no_limit_always_passes() {
231        let limiter = RateLimiter::new();
232        // No limit set for "unconfigured".
233        assert!(limiter.try_acquire("unconfigured").await);
234        limiter.acquire("unconfigured").await; // should return immediately
235    }
236
237    use std::sync::Arc;
238
239    #[tokio::test]
240    async fn test_max_tokens_cap() {
241        let limiter = RateLimiter::new();
242        limiter
243            .set_limit(
244                "capped",
245                RateLimit {
246                    max_calls: 2,
247                    interval_secs: 1.0,
248                },
249            )
250            .await;
251
252        // Wait extra time -- tokens should not exceed max_tokens.
253        tokio::time::sleep(std::time::Duration::from_millis(600)).await;
254
255        assert!(limiter.try_acquire("capped").await);
256        assert!(limiter.try_acquire("capped").await);
257        assert!(!limiter.try_acquire("capped").await);
258    }
259}