Skip to main content

chainrpc_core/policy/
rate_limiter.rs

1//! Token bucket rate limiter.
2//!
3//! Models a token bucket: tokens accrue at `refill_rate` tokens/second up to
4//! `capacity`. Each request consumes `cost` tokens. If insufficient tokens
5//! are available, `try_acquire` returns `false` and the caller should back off.
6
7use std::sync::Mutex;
8use std::time::{Duration, Instant};
9
10use crate::cu_tracker::CuCostTable;
11
12/// Rate limiter configuration.
13#[derive(Debug, Clone)]
14pub struct RateLimiterConfig {
15    /// Maximum tokens in the bucket.
16    pub capacity: f64,
17    /// Token refill rate (tokens per second).
18    pub refill_rate: f64,
19}
20
21impl Default for RateLimiterConfig {
22    fn default() -> Self {
23        Self {
24            capacity: 300.0,    // 300 CU capacity (Alchemy default)
25            refill_rate: 300.0, // 300 CU/s
26        }
27    }
28}
29
30struct BucketState {
31    tokens: f64,
32    last_refill: Instant,
33}
34
35/// Thread-safe token bucket rate limiter.
36pub struct TokenBucket {
37    config: RateLimiterConfig,
38    state: Mutex<BucketState>,
39}
40
41impl TokenBucket {
42    pub fn new(config: RateLimiterConfig) -> Self {
43        Self {
44            state: Mutex::new(BucketState {
45                tokens: config.capacity,
46                last_refill: Instant::now(),
47            }),
48            config,
49        }
50    }
51
52    /// Try to acquire `cost` tokens.
53    ///
54    /// Returns `true` if tokens were available and consumed.
55    /// Returns `false` if the bucket is empty (rate limit exceeded).
56    pub fn try_acquire(&self, cost: f64) -> bool {
57        let mut state = self.state.lock().unwrap();
58        self.refill(&mut state);
59
60        if state.tokens >= cost {
61            state.tokens -= cost;
62            true
63        } else {
64            false
65        }
66    }
67
68    /// Returns the estimated wait time before `cost` tokens are available.
69    pub fn wait_time(&self, cost: f64) -> Duration {
70        let state = self.state.lock().unwrap();
71        let deficit = cost - state.tokens;
72        if deficit <= 0.0 {
73            Duration::ZERO
74        } else {
75            Duration::from_secs_f64(deficit / self.config.refill_rate)
76        }
77    }
78
79    /// Returns currently available tokens.
80    pub fn available(&self) -> f64 {
81        let mut state = self.state.lock().unwrap();
82        self.refill(&mut state);
83        state.tokens
84    }
85
86    fn refill(&self, state: &mut BucketState) {
87        let now = Instant::now();
88        let elapsed = now.duration_since(state.last_refill).as_secs_f64();
89        let new_tokens = elapsed * self.config.refill_rate;
90        state.tokens = (state.tokens + new_tokens).min(self.config.capacity);
91        state.last_refill = now;
92    }
93}
94
95/// A rate limiter wrapping the token bucket.
96pub struct RateLimiter {
97    bucket: TokenBucket,
98    /// Cost per standard request (can be overridden per method).
99    pub default_cost: f64,
100}
101
102impl RateLimiter {
103    pub fn new(config: RateLimiterConfig) -> Self {
104        Self {
105            bucket: TokenBucket::new(config),
106            default_cost: 1.0,
107        }
108    }
109
110    /// Try to acquire the default cost.
111    pub fn try_acquire(&self) -> bool {
112        self.bucket.try_acquire(self.default_cost)
113    }
114
115    /// Try to acquire a specific cost (for expensive methods like eth_getLogs).
116    pub fn try_acquire_cost(&self, cost: f64) -> bool {
117        self.bucket.try_acquire(cost)
118    }
119
120    /// Wait time before the default cost is available.
121    pub fn wait_time(&self) -> Duration {
122        self.bucket.wait_time(self.default_cost)
123    }
124}
125
126/// A method-aware rate limiter that automatically looks up CU costs per RPC method.
127///
128/// Wraps a [`TokenBucket`] with a [`CuCostTable`] so callers only need to
129/// supply the method name — the correct compute-unit cost is resolved
130/// internally.
131pub struct MethodAwareRateLimiter {
132    bucket: TokenBucket,
133    cost_table: CuCostTable,
134}
135
136impl MethodAwareRateLimiter {
137    /// Create a new method-aware rate limiter.
138    pub fn new(config: RateLimiterConfig, cost_table: CuCostTable) -> Self {
139        Self {
140            bucket: TokenBucket::new(config),
141            cost_table,
142        }
143    }
144
145    /// Acquire tokens for a specific RPC method, using its CU cost.
146    ///
147    /// Returns `true` if the method's cost was successfully consumed from the
148    /// bucket, `false` if the bucket has insufficient tokens (rate limited).
149    pub fn try_acquire_method(&self, method: &str) -> bool {
150        let cost = self.cost_table.cost_for(method) as f64;
151        self.bucket.try_acquire(cost)
152    }
153
154    /// Wait time before the given method can be called.
155    pub fn wait_time_for_method(&self, method: &str) -> Duration {
156        let cost = self.cost_table.cost_for(method) as f64;
157        self.bucket.wait_time(cost)
158    }
159
160    /// Access the underlying bucket for manual control.
161    pub fn bucket(&self) -> &TokenBucket {
162        &self.bucket
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn acquire_within_capacity() {
172        let rl = RateLimiter::new(RateLimiterConfig {
173            capacity: 10.0,
174            refill_rate: 1.0,
175        });
176        for _ in 0..10 {
177            assert!(rl.try_acquire(), "should succeed within capacity");
178        }
179    }
180
181    #[test]
182    fn reject_when_empty() {
183        let rl = RateLimiter::new(RateLimiterConfig {
184            capacity: 3.0,
185            refill_rate: 0.0001, // almost no refill
186        });
187        rl.try_acquire();
188        rl.try_acquire();
189        rl.try_acquire();
190        // Now empty
191        assert!(!rl.try_acquire(), "should be rate limited");
192    }
193
194    #[test]
195    fn wait_time_when_empty() {
196        let rl = RateLimiter::new(RateLimiterConfig {
197            capacity: 1.0,
198            refill_rate: 10.0, // 10 tokens/sec
199        });
200        rl.try_acquire(); // drain
201        let wait = rl.wait_time();
202        // Should be ~100ms (1 token / 10 tokens per sec)
203        assert!(
204            wait.as_millis() >= 50 && wait.as_millis() <= 200,
205            "unexpected wait time: {wait:?}"
206        );
207    }
208
209    // ---- MethodAwareRateLimiter tests ----
210
211    #[test]
212    fn method_aware_uses_cu_costs() {
213        // eth_getLogs = 75 CU, eth_blockNumber = 10 CU.
214        // With capacity 150, eth_getLogs can be called 2 times (2*75=150),
215        // while eth_blockNumber can be called 15 times (15*10=150).
216        let table = CuCostTable::alchemy_defaults();
217
218        // Test expensive method: eth_getLogs (75 CU each)
219        let rl_expensive = MethodAwareRateLimiter::new(
220            RateLimiterConfig {
221                capacity: 150.0,
222                refill_rate: 0.0001, // near-zero refill so bucket drains
223            },
224            table.clone(),
225        );
226        assert!(rl_expensive.try_acquire_method("eth_getLogs")); // 75 consumed, 75 left
227        assert!(rl_expensive.try_acquire_method("eth_getLogs")); // 150 consumed, 0 left
228        assert!(
229            !rl_expensive.try_acquire_method("eth_getLogs"),
230            "should be rate limited after 2 expensive calls"
231        );
232
233        // Test cheap method: eth_blockNumber (10 CU each)
234        let rl_cheap = MethodAwareRateLimiter::new(
235            RateLimiterConfig {
236                capacity: 150.0,
237                refill_rate: 0.0001,
238            },
239            CuCostTable::alchemy_defaults(),
240        );
241        let mut count = 0;
242        while rl_cheap.try_acquire_method("eth_blockNumber") {
243            count += 1;
244            if count > 20 {
245                break; // safety valve
246            }
247        }
248        assert_eq!(
249            count, 15,
250            "cheap method (10 CU) should fit 15 times in 150 capacity"
251        );
252    }
253
254    #[test]
255    fn method_aware_wait_time() {
256        // refill_rate = 100 tokens/sec.
257        // Drain the bucket, then check wait times scale with method cost.
258        let table = CuCostTable::alchemy_defaults();
259        let rl = MethodAwareRateLimiter::new(
260            RateLimiterConfig {
261                capacity: 300.0,
262                refill_rate: 100.0, // 100 CU/sec
263            },
264            table,
265        );
266        // Drain the bucket completely.
267        while rl.bucket().try_acquire(100.0) {}
268
269        // eth_blockNumber = 10 CU → ~100ms wait at 100 CU/sec
270        let wait_cheap = rl.wait_time_for_method("eth_blockNumber");
271        // eth_getLogs = 75 CU → ~750ms wait at 100 CU/sec
272        let wait_expensive = rl.wait_time_for_method("eth_getLogs");
273
274        assert!(
275            wait_expensive > wait_cheap,
276            "expensive method should have longer wait: expensive={wait_expensive:?}, cheap={wait_cheap:?}"
277        );
278
279        // Verify approximate scale: expensive wait should be roughly 7.5x the cheap wait
280        let ratio = wait_expensive.as_secs_f64() / wait_cheap.as_secs_f64();
281        assert!(
282            ratio > 5.0 && ratio < 10.0,
283            "wait time ratio should be ~7.5, got {ratio:.2}"
284        );
285    }
286
287    #[test]
288    fn method_aware_unknown_method_uses_default() {
289        // Default CU cost for Alchemy table = 50.
290        // Capacity 100 → unknown method (50 CU) fits exactly 2 times.
291        let table = CuCostTable::alchemy_defaults();
292        let rl = MethodAwareRateLimiter::new(
293            RateLimiterConfig {
294                capacity: 100.0,
295                refill_rate: 0.0001,
296            },
297            table,
298        );
299
300        assert!(rl.try_acquire_method("some_unknown_rpc_method")); // 50 consumed
301        assert!(rl.try_acquire_method("another_unknown_method")); // 100 consumed
302        assert!(
303            !rl.try_acquire_method("yet_another_unknown"),
304            "unknown method should use default cost (50) and be rate limited"
305        );
306    }
307}