1use std::num::NonZeroU32;
2use std::sync::Arc;
3use std::time::Duration;
4
5use governor::Quota;
6use reqwest::Method;
7
8type DirectLimiter = governor::RateLimiter<
9 governor::state::NotKeyed,
10 governor::state::InMemoryState,
11 governor::clock::DefaultClock,
12>;
13
14struct EndpointLimit {
16 path_prefix: &'static str,
17 method: Option<Method>,
18 burst: DirectLimiter,
19 sustained: Option<DirectLimiter>,
20}
21
22#[derive(Clone)]
27pub struct RateLimiter {
28 inner: Arc<RateLimiterInner>,
29}
30
31impl std::fmt::Debug for RateLimiter {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("RateLimiter")
34 .field("endpoints", &self.inner.limits.len())
35 .finish()
36 }
37}
38
39struct RateLimiterInner {
40 limits: Vec<EndpointLimit>,
41 default: DirectLimiter,
42}
43
44fn quota(count: u32, period: Duration) -> Quota {
49 let count = count.max(1);
50 let interval = period / count;
51 Quota::with_period(interval)
52 .expect("quota interval must be non-zero")
53 .allow_burst(NonZeroU32::new(count).unwrap())
54}
55
56impl RateLimiter {
57 pub async fn acquire(&self, path: &str, method: Option<&Method>) {
62 self.inner.default.until_ready().await;
63
64 for limit in &self.inner.limits {
65 if !path.starts_with(limit.path_prefix) {
66 continue;
67 }
68 if let Some(ref m) = limit.method {
69 if method != Some(m) {
70 continue;
71 }
72 }
73 limit.burst.until_ready().await;
74 if let Some(ref sustained) = limit.sustained {
75 sustained.until_ready().await;
76 }
77 break;
78 }
79 }
80
81 pub fn clob_default() -> Self {
90 let ten_sec = Duration::from_secs(10);
91 let ten_min = Duration::from_secs(600);
92
93 Self {
94 inner: Arc::new(RateLimiterInner {
95 default: DirectLimiter::direct(quota(9_000, ten_sec)),
96 limits: vec![
97 EndpointLimit {
99 path_prefix: "/order",
100 method: Some(Method::POST),
101 burst: DirectLimiter::direct(quota(3_500, ten_sec)),
102 sustained: Some(DirectLimiter::direct(quota(36_000, ten_min))),
103 },
104 EndpointLimit {
106 path_prefix: "/order",
107 method: Some(Method::DELETE),
108 burst: DirectLimiter::direct(quota(3_000, ten_sec)),
109 sustained: None,
110 },
111 EndpointLimit {
113 path_prefix: "/auth",
114 method: None,
115 burst: DirectLimiter::direct(quota(100, ten_sec)),
116 sustained: None,
117 },
118 EndpointLimit {
120 path_prefix: "/trades",
121 method: None,
122 burst: DirectLimiter::direct(quota(900, ten_sec)),
123 sustained: None,
124 },
125 EndpointLimit {
126 path_prefix: "/data/",
127 method: None,
128 burst: DirectLimiter::direct(quota(900, ten_sec)),
129 sustained: None,
130 },
131 EndpointLimit {
133 path_prefix: "/markets",
134 method: None,
135 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
136 sustained: None,
137 },
138 EndpointLimit {
139 path_prefix: "/book",
140 method: None,
141 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
142 sustained: None,
143 },
144 EndpointLimit {
145 path_prefix: "/price",
146 method: None,
147 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
148 sustained: None,
149 },
150 EndpointLimit {
151 path_prefix: "/midpoint",
152 method: None,
153 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
154 sustained: None,
155 },
156 EndpointLimit {
157 path_prefix: "/neg-risk",
158 method: None,
159 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
160 sustained: None,
161 },
162 EndpointLimit {
163 path_prefix: "/tick-size",
164 method: None,
165 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
166 sustained: None,
167 },
168 EndpointLimit {
169 path_prefix: "/prices-history",
170 method: None,
171 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
172 sustained: None,
173 },
174 ],
175 }),
176 }
177 }
178
179 pub fn gamma_default() -> Self {
188 let ten_sec = Duration::from_secs(10);
189
190 Self {
191 inner: Arc::new(RateLimiterInner {
192 default: DirectLimiter::direct(quota(4_000, ten_sec)),
193 limits: vec![
194 EndpointLimit {
195 path_prefix: "/comments",
196 method: None,
197 burst: DirectLimiter::direct(quota(200, ten_sec)),
198 sustained: None,
199 },
200 EndpointLimit {
201 path_prefix: "/tags",
202 method: None,
203 burst: DirectLimiter::direct(quota(200, ten_sec)),
204 sustained: None,
205 },
206 EndpointLimit {
207 path_prefix: "/markets",
208 method: None,
209 burst: DirectLimiter::direct(quota(300, ten_sec)),
210 sustained: None,
211 },
212 EndpointLimit {
213 path_prefix: "/public-search",
214 method: None,
215 burst: DirectLimiter::direct(quota(350, ten_sec)),
216 sustained: None,
217 },
218 EndpointLimit {
219 path_prefix: "/events",
220 method: None,
221 burst: DirectLimiter::direct(quota(500, ten_sec)),
222 sustained: None,
223 },
224 ],
225 }),
226 }
227 }
228
229 pub fn data_default() -> Self {
235 let ten_sec = Duration::from_secs(10);
236
237 Self {
238 inner: Arc::new(RateLimiterInner {
239 default: DirectLimiter::direct(quota(1_000, ten_sec)),
240 limits: vec![
241 EndpointLimit {
242 path_prefix: "/positions",
243 method: None,
244 burst: DirectLimiter::direct(quota(150, ten_sec)),
245 sustained: None,
246 },
247 EndpointLimit {
248 path_prefix: "/closed-positions",
249 method: None,
250 burst: DirectLimiter::direct(quota(150, ten_sec)),
251 sustained: None,
252 },
253 EndpointLimit {
254 path_prefix: "/trades",
255 method: None,
256 burst: DirectLimiter::direct(quota(200, ten_sec)),
257 sustained: None,
258 },
259 ],
260 }),
261 }
262 }
263
264 pub fn relay_default() -> Self {
268 Self {
269 inner: Arc::new(RateLimiterInner {
270 default: DirectLimiter::direct(quota(25, Duration::from_secs(60))),
271 limits: vec![],
272 }),
273 }
274 }
275}
276
277#[derive(Debug, Clone)]
279pub struct RetryConfig {
280 pub max_retries: u32,
281 pub initial_backoff_ms: u64,
282 pub max_backoff_ms: u64,
283}
284
285impl Default for RetryConfig {
286 fn default() -> Self {
287 Self {
288 max_retries: 3,
289 initial_backoff_ms: 500,
290 max_backoff_ms: 10_000,
291 }
292 }
293}
294
295impl RetryConfig {
296 pub fn backoff(&self, attempt: u32) -> Duration {
301 let base = self
302 .initial_backoff_ms
303 .saturating_mul(1u64 << attempt.min(10));
304 let capped = base.min(self.max_backoff_ms);
305 let nanos = std::time::SystemTime::now()
307 .duration_since(std::time::UNIX_EPOCH)
308 .unwrap_or_default()
309 .subsec_nanos();
310 let jitter_factor = 0.75 + (nanos as f64 / u32::MAX as f64) * 0.5;
312 let ms = (capped as f64 * jitter_factor) as u64;
313 Duration::from_millis(ms.max(1))
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
324 fn test_retry_config_default() {
325 let cfg = RetryConfig::default();
326 assert_eq!(cfg.max_retries, 3);
327 assert_eq!(cfg.initial_backoff_ms, 500);
328 assert_eq!(cfg.max_backoff_ms, 10_000);
329 }
330
331 #[test]
332 fn test_backoff_attempt_zero() {
333 let cfg = RetryConfig::default();
334 let d = cfg.backoff(0);
335 let ms = d.as_millis() as u64;
338 assert!(
339 (375..=625).contains(&ms),
340 "attempt 0: {ms}ms not in [375, 625]"
341 );
342 }
343
344 #[test]
345 fn test_backoff_exponential_growth() {
346 let cfg = RetryConfig::default();
347 let d0 = cfg.backoff(0);
348 let d1 = cfg.backoff(1);
349 let d2 = cfg.backoff(2);
350 assert!(d0 < d1, "d0={d0:?} should be < d1={d1:?}");
351 assert!(d1 < d2, "d1={d1:?} should be < d2={d2:?}");
352 }
353
354 #[test]
355 fn test_backoff_jitter_bounds() {
356 let cfg = RetryConfig::default();
357 for attempt in 0..20 {
358 let d = cfg.backoff(attempt);
359 let base = cfg
360 .initial_backoff_ms
361 .saturating_mul(1u64 << attempt.min(10));
362 let capped = base.min(cfg.max_backoff_ms);
363 let lower = (capped as f64 * 0.75) as u64;
364 let upper = (capped as f64 * 1.25) as u64;
365 let ms = d.as_millis() as u64;
366 assert!(
367 ms >= lower.max(1) && ms <= upper,
368 "attempt {attempt}: {ms}ms not in [{lower}, {upper}]"
369 );
370 }
371 }
372
373 #[test]
374 fn test_backoff_max_capping() {
375 let cfg = RetryConfig::default();
376 for attempt in 5..=10 {
377 let d = cfg.backoff(attempt);
378 let ceiling = (cfg.max_backoff_ms as f64 * 1.25) as u64;
379 assert!(
380 d.as_millis() as u64 <= ceiling,
381 "attempt {attempt}: {:?} exceeded ceiling {ceiling}ms",
382 d
383 );
384 }
385 }
386
387 #[test]
388 fn test_backoff_very_high_attempt() {
389 let cfg = RetryConfig::default();
390 let d = cfg.backoff(100);
391 let ceiling = (cfg.max_backoff_ms as f64 * 1.25) as u64;
392 assert!(d.as_millis() as u64 <= ceiling);
393 assert!(d.as_millis() >= 1);
394 }
395
396 #[test]
399 fn test_quota_creation() {
400 let _ = quota(100, Duration::from_secs(10));
402 let _ = quota(1, Duration::from_secs(60));
403 let _ = quota(9_000, Duration::from_secs(10));
404 }
405
406 #[test]
407 fn test_quota_edge_zero_count() {
408 let _ = quota(0, Duration::from_secs(10));
410 }
411
412 #[test]
415 fn test_clob_default_construction() {
416 let rl = RateLimiter::clob_default();
417 assert_eq!(rl.inner.limits.len(), 12);
418 assert!(format!("{:?}", rl).contains("endpoints"));
419 }
420
421 #[test]
422 fn test_gamma_default_construction() {
423 let rl = RateLimiter::gamma_default();
424 assert_eq!(rl.inner.limits.len(), 5);
425 }
426
427 #[test]
428 fn test_data_default_construction() {
429 let rl = RateLimiter::data_default();
430 assert_eq!(rl.inner.limits.len(), 3);
431 }
432
433 #[test]
434 fn test_relay_default_construction() {
435 let rl = RateLimiter::relay_default();
436 assert_eq!(rl.inner.limits.len(), 0);
437 }
438
439 #[test]
440 fn test_rate_limiter_debug_format() {
441 let rl = RateLimiter::clob_default();
442 let dbg = format!("{:?}", rl);
443 assert!(dbg.contains("RateLimiter"), "missing struct name: {dbg}");
444 assert!(dbg.contains("endpoints: 12"), "missing count: {dbg}");
445 }
446
447 #[test]
450 fn test_clob_endpoint_order_and_methods() {
451 let rl = RateLimiter::clob_default();
452 let limits = &rl.inner.limits;
453
454 assert_eq!(limits[0].path_prefix, "/order");
456 assert_eq!(limits[0].method, Some(Method::POST));
457 assert!(limits[0].sustained.is_some());
458
459 assert_eq!(limits[1].path_prefix, "/order");
461 assert_eq!(limits[1].method, Some(Method::DELETE));
462 assert!(limits[1].sustained.is_none());
463
464 assert_eq!(limits[2].path_prefix, "/auth");
466 assert!(limits[2].method.is_none());
467 }
468
469 #[tokio::test]
472 async fn test_acquire_single_completes_immediately() {
473 let rl = RateLimiter::clob_default();
474 let start = std::time::Instant::now();
475 rl.acquire("/order", Some(&Method::POST)).await;
476 assert!(start.elapsed() < Duration::from_millis(50));
477 }
478
479 #[tokio::test]
480 async fn test_acquire_matches_endpoint_by_prefix() {
481 let rl = RateLimiter::clob_default();
482 let start = std::time::Instant::now();
483 rl.acquire("/order/123", Some(&Method::POST)).await;
485 assert!(start.elapsed() < Duration::from_millis(50));
486 }
487
488 #[tokio::test]
489 async fn test_acquire_method_filtering() {
490 let rl = RateLimiter::clob_default();
491 let start = std::time::Instant::now();
492 rl.acquire("/order", Some(&Method::GET)).await;
494 assert!(start.elapsed() < Duration::from_millis(50));
495 }
496
497 #[tokio::test]
498 async fn test_acquire_no_endpoint_match_uses_default_only() {
499 let rl = RateLimiter::clob_default();
500 let start = std::time::Instant::now();
501 rl.acquire("/unknown/path", None).await;
502 assert!(start.elapsed() < Duration::from_millis(50));
503 }
504
505 #[tokio::test]
506 async fn test_acquire_method_none_matches_any_method() {
507 let rl = RateLimiter::gamma_default();
508 let start = std::time::Instant::now();
509 rl.acquire("/events", Some(&Method::GET)).await;
511 rl.acquire("/events", Some(&Method::POST)).await;
512 rl.acquire("/events", None).await;
513 assert!(start.elapsed() < Duration::from_millis(50));
514 }
515}