Skip to main content

polyoxide_core/
rate_limit.rs

1use std::num::NonZeroU32;
2use std::sync::Arc;
3use std::time::Duration;
4
5use governor::Quota;
6use reqwest::Method;
7
8type DirectLimiter = governor::RateLimiter<
9    governor::state::NotKeyed,
10    governor::state::InMemoryState,
11    governor::clock::DefaultClock,
12>;
13
14/// Rate limit configuration for a specific endpoint pattern.
15struct EndpointLimit {
16    path_prefix: &'static str,
17    method: Option<Method>,
18    burst: DirectLimiter,
19    sustained: Option<DirectLimiter>,
20}
21
22/// Holds all rate limiters for one API surface.
23///
24/// Created via factory methods like [`RateLimiter::clob_default()`] which
25/// configure hardcoded limits matching Polymarket's documented rate limits.
26#[derive(Clone)]
27pub struct RateLimiter {
28    inner: Arc<RateLimiterInner>,
29}
30
31impl std::fmt::Debug for RateLimiter {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("RateLimiter")
34            .field("endpoints", &self.inner.limits.len())
35            .finish()
36    }
37}
38
39struct RateLimiterInner {
40    limits: Vec<EndpointLimit>,
41    default: DirectLimiter,
42}
43
44/// Helper to create a quota: `count` requests per `period`.
45///
46/// Uses `Quota::with_period` for exact rate enforcement rather than
47/// ceiling-based `per_second`, which can over-permit for non-round windows.
48fn quota(count: u32, period: Duration) -> Quota {
49    let count = count.max(1);
50    let interval = period / count;
51    Quota::with_period(interval)
52        .expect("quota interval must be non-zero")
53        .allow_burst(NonZeroU32::new(count).unwrap())
54}
55
56impl RateLimiter {
57    /// Await the appropriate limiter(s) for this endpoint.
58    ///
59    /// Always awaits the default (general) limiter, then additionally awaits
60    /// the first matching endpoint-specific limiter (burst + sustained).
61    pub async fn acquire(&self, path: &str, method: Option<&Method>) {
62        self.inner.default.until_ready().await;
63
64        for limit in &self.inner.limits {
65            if !path.starts_with(limit.path_prefix) {
66                continue;
67            }
68            if let Some(ref m) = limit.method {
69                if method != Some(m) {
70                    continue;
71                }
72            }
73            limit.burst.until_ready().await;
74            if let Some(ref sustained) = limit.sustained {
75                sustained.until_ready().await;
76            }
77            break;
78        }
79    }
80
81    /// CLOB API rate limits.
82    ///
83    /// - General: 9,000/10s
84    /// - POST /order: 3,500/10s burst + 36,000/10min sustained
85    /// - DELETE /order: 3,000/10s
86    /// - Market data (/markets, /book, /price, /midpoint, /prices-history, /neg-risk, /tick-size): 1,500/10s
87    /// - Ledger (/trades, /data/): 900/10s
88    /// - Auth (/auth): 100/10s
89    pub fn clob_default() -> Self {
90        let ten_sec = Duration::from_secs(10);
91        let ten_min = Duration::from_secs(600);
92
93        Self {
94            inner: Arc::new(RateLimiterInner {
95                default: DirectLimiter::direct(quota(9_000, ten_sec)),
96                limits: vec![
97                    // POST /order — dual window
98                    EndpointLimit {
99                        path_prefix: "/order",
100                        method: Some(Method::POST),
101                        burst: DirectLimiter::direct(quota(3_500, ten_sec)),
102                        sustained: Some(DirectLimiter::direct(quota(36_000, ten_min))),
103                    },
104                    // DELETE /order
105                    EndpointLimit {
106                        path_prefix: "/order",
107                        method: Some(Method::DELETE),
108                        burst: DirectLimiter::direct(quota(3_000, ten_sec)),
109                        sustained: None,
110                    },
111                    // Auth
112                    EndpointLimit {
113                        path_prefix: "/auth",
114                        method: None,
115                        burst: DirectLimiter::direct(quota(100, ten_sec)),
116                        sustained: None,
117                    },
118                    // Ledger — must come before market data since /trades is more specific
119                    EndpointLimit {
120                        path_prefix: "/trades",
121                        method: None,
122                        burst: DirectLimiter::direct(quota(900, ten_sec)),
123                        sustained: None,
124                    },
125                    EndpointLimit {
126                        path_prefix: "/data/",
127                        method: None,
128                        burst: DirectLimiter::direct(quota(900, ten_sec)),
129                        sustained: None,
130                    },
131                    // Market data endpoints
132                    EndpointLimit {
133                        path_prefix: "/markets",
134                        method: None,
135                        burst: DirectLimiter::direct(quota(1_500, ten_sec)),
136                        sustained: None,
137                    },
138                    EndpointLimit {
139                        path_prefix: "/book",
140                        method: None,
141                        burst: DirectLimiter::direct(quota(1_500, ten_sec)),
142                        sustained: None,
143                    },
144                    EndpointLimit {
145                        path_prefix: "/price",
146                        method: None,
147                        burst: DirectLimiter::direct(quota(1_500, ten_sec)),
148                        sustained: None,
149                    },
150                    EndpointLimit {
151                        path_prefix: "/midpoint",
152                        method: None,
153                        burst: DirectLimiter::direct(quota(1_500, ten_sec)),
154                        sustained: None,
155                    },
156                    EndpointLimit {
157                        path_prefix: "/neg-risk",
158                        method: None,
159                        burst: DirectLimiter::direct(quota(1_500, ten_sec)),
160                        sustained: None,
161                    },
162                    EndpointLimit {
163                        path_prefix: "/tick-size",
164                        method: None,
165                        burst: DirectLimiter::direct(quota(1_500, ten_sec)),
166                        sustained: None,
167                    },
168                    EndpointLimit {
169                        path_prefix: "/prices-history",
170                        method: None,
171                        burst: DirectLimiter::direct(quota(1_500, ten_sec)),
172                        sustained: None,
173                    },
174                ],
175            }),
176        }
177    }
178
179    /// Gamma API rate limits.
180    ///
181    /// - General: 4,000/10s
182    /// - /events: 500/10s
183    /// - /markets: 300/10s
184    /// - /public-search: 350/10s
185    /// - /comments: 200/10s
186    /// - /tags: 200/10s
187    pub fn gamma_default() -> Self {
188        let ten_sec = Duration::from_secs(10);
189
190        Self {
191            inner: Arc::new(RateLimiterInner {
192                default: DirectLimiter::direct(quota(4_000, ten_sec)),
193                limits: vec![
194                    EndpointLimit {
195                        path_prefix: "/comments",
196                        method: None,
197                        burst: DirectLimiter::direct(quota(200, ten_sec)),
198                        sustained: None,
199                    },
200                    EndpointLimit {
201                        path_prefix: "/tags",
202                        method: None,
203                        burst: DirectLimiter::direct(quota(200, ten_sec)),
204                        sustained: None,
205                    },
206                    EndpointLimit {
207                        path_prefix: "/markets",
208                        method: None,
209                        burst: DirectLimiter::direct(quota(300, ten_sec)),
210                        sustained: None,
211                    },
212                    EndpointLimit {
213                        path_prefix: "/public-search",
214                        method: None,
215                        burst: DirectLimiter::direct(quota(350, ten_sec)),
216                        sustained: None,
217                    },
218                    EndpointLimit {
219                        path_prefix: "/events",
220                        method: None,
221                        burst: DirectLimiter::direct(quota(500, ten_sec)),
222                        sustained: None,
223                    },
224                ],
225            }),
226        }
227    }
228
229    /// Data API rate limits.
230    ///
231    /// - General: 1,000/10s
232    /// - /trades: 200/10s
233    /// - /positions and /closed-positions: 150/10s
234    pub fn data_default() -> Self {
235        let ten_sec = Duration::from_secs(10);
236
237        Self {
238            inner: Arc::new(RateLimiterInner {
239                default: DirectLimiter::direct(quota(1_000, ten_sec)),
240                limits: vec![
241                    EndpointLimit {
242                        path_prefix: "/positions",
243                        method: None,
244                        burst: DirectLimiter::direct(quota(150, ten_sec)),
245                        sustained: None,
246                    },
247                    EndpointLimit {
248                        path_prefix: "/closed-positions",
249                        method: None,
250                        burst: DirectLimiter::direct(quota(150, ten_sec)),
251                        sustained: None,
252                    },
253                    EndpointLimit {
254                        path_prefix: "/trades",
255                        method: None,
256                        burst: DirectLimiter::direct(quota(200, ten_sec)),
257                        sustained: None,
258                    },
259                ],
260            }),
261        }
262    }
263
264    /// Relay API rate limits.
265    ///
266    /// - 25 requests per 1 minute (single limiter, no endpoint-specific limits)
267    pub fn relay_default() -> Self {
268        Self {
269            inner: Arc::new(RateLimiterInner {
270                default: DirectLimiter::direct(quota(25, Duration::from_secs(60))),
271                limits: vec![],
272            }),
273        }
274    }
275}
276
277/// Configuration for retry-on-429 with exponential backoff.
278#[derive(Debug, Clone)]
279pub struct RetryConfig {
280    pub max_retries: u32,
281    pub initial_backoff_ms: u64,
282    pub max_backoff_ms: u64,
283}
284
285impl Default for RetryConfig {
286    fn default() -> Self {
287        Self {
288            max_retries: 3,
289            initial_backoff_ms: 500,
290            max_backoff_ms: 10_000,
291        }
292    }
293}
294
295impl RetryConfig {
296    /// Calculate backoff duration with jitter for attempt N.
297    ///
298    /// Uses system clock nanos for non-deterministic jitter (75%-125% of base delay)
299    /// to avoid thundering herd when multiple clients retry simultaneously.
300    pub fn backoff(&self, attempt: u32) -> Duration {
301        let base = self
302            .initial_backoff_ms
303            .saturating_mul(1u64 << attempt.min(10));
304        let capped = base.min(self.max_backoff_ms);
305        // Use subsecond nanos as cheap non-deterministic jitter source
306        let nanos = std::time::SystemTime::now()
307            .duration_since(std::time::UNIX_EPOCH)
308            .unwrap_or_default()
309            .subsec_nanos();
310        // Map nanos to 0.75..1.25 range
311        let jitter_factor = 0.75 + (nanos as f64 / u32::MAX as f64) * 0.5;
312        let ms = (capped as f64 * jitter_factor) as u64;
313        Duration::from_millis(ms.max(1))
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    // ── RetryConfig ──────────────────────────────────────────────
322
323    #[test]
324    fn test_retry_config_default() {
325        let cfg = RetryConfig::default();
326        assert_eq!(cfg.max_retries, 3);
327        assert_eq!(cfg.initial_backoff_ms, 500);
328        assert_eq!(cfg.max_backoff_ms, 10_000);
329    }
330
331    #[test]
332    fn test_backoff_attempt_zero() {
333        let cfg = RetryConfig::default();
334        let d = cfg.backoff(0);
335        // base = 500 * 2^0 = 500, capped = 500, jitter in [0.75, 1.25]
336        // ms in [375, 625]
337        let ms = d.as_millis() as u64;
338        assert!(
339            (375..=625).contains(&ms),
340            "attempt 0: {ms}ms not in [375, 625]"
341        );
342    }
343
344    #[test]
345    fn test_backoff_exponential_growth() {
346        let cfg = RetryConfig::default();
347        let d0 = cfg.backoff(0);
348        let d1 = cfg.backoff(1);
349        let d2 = cfg.backoff(2);
350        assert!(d0 < d1, "d0={d0:?} should be < d1={d1:?}");
351        assert!(d1 < d2, "d1={d1:?} should be < d2={d2:?}");
352    }
353
354    #[test]
355    fn test_backoff_jitter_bounds() {
356        let cfg = RetryConfig::default();
357        for attempt in 0..20 {
358            let d = cfg.backoff(attempt);
359            let base = cfg
360                .initial_backoff_ms
361                .saturating_mul(1u64 << attempt.min(10));
362            let capped = base.min(cfg.max_backoff_ms);
363            let lower = (capped as f64 * 0.75) as u64;
364            let upper = (capped as f64 * 1.25) as u64;
365            let ms = d.as_millis() as u64;
366            assert!(
367                ms >= lower.max(1) && ms <= upper,
368                "attempt {attempt}: {ms}ms not in [{lower}, {upper}]"
369            );
370        }
371    }
372
373    #[test]
374    fn test_backoff_max_capping() {
375        let cfg = RetryConfig::default();
376        for attempt in 5..=10 {
377            let d = cfg.backoff(attempt);
378            let ceiling = (cfg.max_backoff_ms as f64 * 1.25) as u64;
379            assert!(
380                d.as_millis() as u64 <= ceiling,
381                "attempt {attempt}: {:?} exceeded ceiling {ceiling}ms",
382                d
383            );
384        }
385    }
386
387    #[test]
388    fn test_backoff_very_high_attempt() {
389        let cfg = RetryConfig::default();
390        let d = cfg.backoff(100);
391        let ceiling = (cfg.max_backoff_ms as f64 * 1.25) as u64;
392        assert!(d.as_millis() as u64 <= ceiling);
393        assert!(d.as_millis() >= 1);
394    }
395
396    // ── quota() ──────────────────────────────────────────────────
397
398    #[test]
399    fn test_quota_creation() {
400        // Should not panic for representative values
401        let _ = quota(100, Duration::from_secs(10));
402        let _ = quota(1, Duration::from_secs(60));
403        let _ = quota(9_000, Duration::from_secs(10));
404    }
405
406    #[test]
407    fn test_quota_edge_zero_count() {
408        // count=0 is guarded by .max(1) — should not panic
409        let _ = quota(0, Duration::from_secs(10));
410    }
411
412    // ── Factory methods ──────────────────────────────────────────
413
414    #[test]
415    fn test_clob_default_construction() {
416        let rl = RateLimiter::clob_default();
417        assert_eq!(rl.inner.limits.len(), 12);
418        assert!(format!("{:?}", rl).contains("endpoints"));
419    }
420
421    #[test]
422    fn test_gamma_default_construction() {
423        let rl = RateLimiter::gamma_default();
424        assert_eq!(rl.inner.limits.len(), 5);
425    }
426
427    #[test]
428    fn test_data_default_construction() {
429        let rl = RateLimiter::data_default();
430        assert_eq!(rl.inner.limits.len(), 3);
431    }
432
433    #[test]
434    fn test_relay_default_construction() {
435        let rl = RateLimiter::relay_default();
436        assert_eq!(rl.inner.limits.len(), 0);
437    }
438
439    #[test]
440    fn test_rate_limiter_debug_format() {
441        let rl = RateLimiter::clob_default();
442        let dbg = format!("{:?}", rl);
443        assert!(dbg.contains("RateLimiter"), "missing struct name: {dbg}");
444        assert!(dbg.contains("endpoints: 12"), "missing count: {dbg}");
445    }
446
447    // ── Endpoint matching internals ──────────────────────────────
448
449    #[test]
450    fn test_clob_endpoint_order_and_methods() {
451        let rl = RateLimiter::clob_default();
452        let limits = &rl.inner.limits;
453
454        // First: POST /order with sustained
455        assert_eq!(limits[0].path_prefix, "/order");
456        assert_eq!(limits[0].method, Some(Method::POST));
457        assert!(limits[0].sustained.is_some());
458
459        // Second: DELETE /order without sustained
460        assert_eq!(limits[1].path_prefix, "/order");
461        assert_eq!(limits[1].method, Some(Method::DELETE));
462        assert!(limits[1].sustained.is_none());
463
464        // Third: /auth with method=None
465        assert_eq!(limits[2].path_prefix, "/auth");
466        assert!(limits[2].method.is_none());
467    }
468
469    // ── acquire() async behavior ─────────────────────────────────
470
471    #[tokio::test]
472    async fn test_acquire_single_completes_immediately() {
473        let rl = RateLimiter::clob_default();
474        let start = std::time::Instant::now();
475        rl.acquire("/order", Some(&Method::POST)).await;
476        assert!(start.elapsed() < Duration::from_millis(50));
477    }
478
479    #[tokio::test]
480    async fn test_acquire_matches_endpoint_by_prefix() {
481        let rl = RateLimiter::clob_default();
482        let start = std::time::Instant::now();
483        // /order/123 should match the /order prefix
484        rl.acquire("/order/123", Some(&Method::POST)).await;
485        assert!(start.elapsed() < Duration::from_millis(50));
486    }
487
488    #[tokio::test]
489    async fn test_acquire_method_filtering() {
490        let rl = RateLimiter::clob_default();
491        let start = std::time::Instant::now();
492        // GET /order shouldn't match POST or DELETE /order endpoints — falls to default only
493        rl.acquire("/order", Some(&Method::GET)).await;
494        assert!(start.elapsed() < Duration::from_millis(50));
495    }
496
497    #[tokio::test]
498    async fn test_acquire_no_endpoint_match_uses_default_only() {
499        let rl = RateLimiter::clob_default();
500        let start = std::time::Instant::now();
501        rl.acquire("/unknown/path", None).await;
502        assert!(start.elapsed() < Duration::from_millis(50));
503    }
504
505    #[tokio::test]
506    async fn test_acquire_method_none_matches_any_method() {
507        let rl = RateLimiter::gamma_default();
508        let start = std::time::Instant::now();
509        // /events has method: None — should match GET, POST, and None
510        rl.acquire("/events", Some(&Method::GET)).await;
511        rl.acquire("/events", Some(&Method::POST)).await;
512        rl.acquire("/events", None).await;
513        assert!(start.elapsed() < Duration::from_millis(50));
514    }
515}