1use std::num::NonZeroU32;
2use std::sync::Arc;
3use std::time::Duration;
4
5use governor::Quota;
6use reqwest::Method;
7
8type DirectLimiter = governor::RateLimiter<
9 governor::state::NotKeyed,
10 governor::state::InMemoryState,
11 governor::clock::DefaultClock,
12>;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[allow(dead_code)]
17enum MatchMode {
18 Prefix,
22 Exact,
24}
25
26struct EndpointLimit {
28 path_prefix: &'static str,
29 method: Option<Method>,
30 match_mode: MatchMode,
31 burst: DirectLimiter,
32 sustained: Option<DirectLimiter>,
33}
34
35#[derive(Clone)]
40pub struct RateLimiter {
41 inner: Arc<RateLimiterInner>,
42}
43
44impl std::fmt::Debug for RateLimiter {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("RateLimiter")
47 .field("endpoints", &self.inner.limits.len())
48 .finish()
49 }
50}
51
52struct RateLimiterInner {
53 limits: Vec<EndpointLimit>,
54 default: DirectLimiter,
55}
56
57fn quota(count: u32, period: Duration) -> Quota {
62 let count = count.max(1);
63 let interval = period / count;
64 Quota::with_period(interval)
65 .expect("quota interval must be non-zero")
66 .allow_burst(NonZeroU32::new(count).unwrap())
67}
68
69fn endpoint_limit(
71 path_prefix: &'static str,
72 method: Option<Method>,
73 match_mode: MatchMode,
74 burst_count: u32,
75 burst_period: Duration,
76 sustained: Option<(u32, Duration)>,
77) -> EndpointLimit {
78 EndpointLimit {
79 path_prefix,
80 method,
81 match_mode,
82 burst: DirectLimiter::direct(quota(burst_count, burst_period)),
83 sustained: sustained.map(|(count, period)| DirectLimiter::direct(quota(count, period))),
84 }
85}
86
87impl RateLimiter {
88 pub async fn acquire(&self, path: &str, method: Option<&Method>) {
93 self.inner.default.until_ready().await;
94
95 for limit in &self.inner.limits {
96 let matched = match limit.match_mode {
97 MatchMode::Exact => path == limit.path_prefix,
98 MatchMode::Prefix => {
99 match path.strip_prefix(limit.path_prefix) {
102 Some(rest) => {
103 rest.is_empty() || rest.starts_with('/') || rest.starts_with('?')
104 }
105 None => false,
106 }
107 }
108 };
109 if !matched {
110 continue;
111 }
112 if let Some(ref m) = limit.method {
113 if method != Some(m) {
114 continue;
115 }
116 }
117 limit.burst.until_ready().await;
118 if let Some(ref sustained) = limit.sustained {
119 sustained.until_ready().await;
120 }
121 break;
122 }
123 }
124
125 pub fn clob_default() -> Self {
134 let ten_sec = Duration::from_secs(10);
135 let ten_min = Duration::from_secs(600);
136 let p = MatchMode::Prefix;
137
138 Self {
139 inner: Arc::new(RateLimiterInner {
140 default: DirectLimiter::direct(quota(9_000, ten_sec)),
141 limits: vec![
142 endpoint_limit(
144 "/order",
145 Some(Method::POST),
146 p,
147 3_500,
148 ten_sec,
149 Some((36_000, ten_min)),
150 ),
151 endpoint_limit("/order", Some(Method::DELETE), p, 3_000, ten_sec, None),
153 endpoint_limit("/auth", None, p, 100, ten_sec, None),
155 endpoint_limit("/trades", None, p, 900, ten_sec, None),
157 endpoint_limit("/data/", None, p, 900, ten_sec, None),
158 endpoint_limit("/prices-history", None, p, 1_500, ten_sec, None),
160 endpoint_limit("/markets", None, p, 1_500, ten_sec, None),
161 endpoint_limit("/book", None, p, 1_500, ten_sec, None),
162 endpoint_limit("/price", None, p, 1_500, ten_sec, None),
163 endpoint_limit("/midpoint", None, p, 1_500, ten_sec, None),
164 endpoint_limit("/neg-risk", None, p, 1_500, ten_sec, None),
165 endpoint_limit("/tick-size", None, p, 1_500, ten_sec, None),
166 ],
167 }),
168 }
169 }
170
171 pub fn gamma_default() -> Self {
180 let ten_sec = Duration::from_secs(10);
181 let p = MatchMode::Prefix;
182
183 Self {
184 inner: Arc::new(RateLimiterInner {
185 default: DirectLimiter::direct(quota(4_000, ten_sec)),
186 limits: vec![
187 endpoint_limit("/comments", None, p, 200, ten_sec, None),
188 endpoint_limit("/tags", None, p, 200, ten_sec, None),
189 endpoint_limit("/markets", None, p, 300, ten_sec, None),
190 endpoint_limit("/public-search", None, p, 350, ten_sec, None),
191 endpoint_limit("/events", None, p, 500, ten_sec, None),
192 ],
193 }),
194 }
195 }
196
197 pub fn data_default() -> Self {
203 let ten_sec = Duration::from_secs(10);
204 let p = MatchMode::Prefix;
205
206 Self {
207 inner: Arc::new(RateLimiterInner {
208 default: DirectLimiter::direct(quota(1_000, ten_sec)),
209 limits: vec![
210 endpoint_limit("/closed-positions", None, p, 150, ten_sec, None),
211 endpoint_limit("/positions", None, p, 150, ten_sec, None),
212 endpoint_limit("/trades", None, p, 200, ten_sec, None),
213 ],
214 }),
215 }
216 }
217
218 pub fn relay_default() -> Self {
222 Self {
223 inner: Arc::new(RateLimiterInner {
224 default: DirectLimiter::direct(quota(25, Duration::from_secs(60))),
225 limits: vec![],
226 }),
227 }
228 }
229}
230
231#[derive(Debug, Clone)]
233pub struct RetryConfig {
234 pub max_retries: u32,
236 pub initial_backoff_ms: u64,
238 pub max_backoff_ms: u64,
240}
241
242impl Default for RetryConfig {
243 fn default() -> Self {
244 Self {
245 max_retries: 3,
246 initial_backoff_ms: 500,
247 max_backoff_ms: 10_000,
248 }
249 }
250}
251
252impl RetryConfig {
253 pub fn backoff(&self, attempt: u32) -> Duration {
258 let base = self
259 .initial_backoff_ms
260 .saturating_mul(1u64 << attempt.min(10));
261 let capped = base.min(self.max_backoff_ms);
262 let jitter_factor = 0.75 + (fastrand::f64() * 0.5);
264 let ms = (capped as f64 * jitter_factor) as u64;
265 Duration::from_millis(ms.max(1))
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
276 fn test_retry_config_default() {
277 let cfg = RetryConfig::default();
278 assert_eq!(cfg.max_retries, 3);
279 assert_eq!(cfg.initial_backoff_ms, 500);
280 assert_eq!(cfg.max_backoff_ms, 10_000);
281 }
282
283 #[test]
284 fn test_backoff_attempt_zero() {
285 let cfg = RetryConfig::default();
286 let d = cfg.backoff(0);
287 let ms = d.as_millis() as u64;
290 assert!(
291 (375..=625).contains(&ms),
292 "attempt 0: {ms}ms not in [375, 625]"
293 );
294 }
295
296 #[test]
297 fn test_backoff_exponential_growth() {
298 let cfg = RetryConfig::default();
299 let d0 = cfg.backoff(0);
300 let d1 = cfg.backoff(1);
301 let d2 = cfg.backoff(2);
302 assert!(d0 < d1, "d0={d0:?} should be < d1={d1:?}");
303 assert!(d1 < d2, "d1={d1:?} should be < d2={d2:?}");
304 }
305
306 #[test]
307 fn test_backoff_jitter_bounds() {
308 let cfg = RetryConfig::default();
309 for attempt in 0..20 {
310 let d = cfg.backoff(attempt);
311 let base = cfg
312 .initial_backoff_ms
313 .saturating_mul(1u64 << attempt.min(10));
314 let capped = base.min(cfg.max_backoff_ms);
315 let lower = (capped as f64 * 0.75) as u64;
316 let upper = (capped as f64 * 1.25) as u64;
317 let ms = d.as_millis() as u64;
318 assert!(
319 ms >= lower.max(1) && ms <= upper,
320 "attempt {attempt}: {ms}ms not in [{lower}, {upper}]"
321 );
322 }
323 }
324
325 #[test]
326 fn test_backoff_max_capping() {
327 let cfg = RetryConfig::default();
328 for attempt in 5..=10 {
329 let d = cfg.backoff(attempt);
330 let ceiling = (cfg.max_backoff_ms as f64 * 1.25) as u64;
331 assert!(
332 d.as_millis() as u64 <= ceiling,
333 "attempt {attempt}: {:?} exceeded ceiling {ceiling}ms",
334 d
335 );
336 }
337 }
338
339 #[test]
340 fn test_backoff_very_high_attempt() {
341 let cfg = RetryConfig::default();
342 let d = cfg.backoff(100);
343 let ceiling = (cfg.max_backoff_ms as f64 * 1.25) as u64;
344 assert!(d.as_millis() as u64 <= ceiling);
345 assert!(d.as_millis() >= 1);
346 }
347
348 #[test]
349 fn test_backoff_jitter_distribution() {
350 let cfg = RetryConfig::default();
353 let midpoint = cfg.initial_backoff_ms; let (mut below, mut above) = (0u32, 0u32);
355 for _ in 0..200 {
356 let ms = cfg.backoff(0).as_millis() as u64;
357 if ms < midpoint {
358 below += 1;
359 } else {
360 above += 1;
361 }
362 }
363 assert!(
364 below >= 20 && above >= 20,
365 "jitter looks degenerate: {below} below midpoint, {above} above"
366 );
367 }
368
369 #[test]
372 fn test_quota_creation() {
373 let _ = quota(100, Duration::from_secs(10));
375 let _ = quota(1, Duration::from_secs(60));
376 let _ = quota(9_000, Duration::from_secs(10));
377 }
378
379 #[test]
380 fn test_quota_edge_zero_count() {
381 let _ = quota(0, Duration::from_secs(10));
383 }
384
385 #[test]
388 fn test_clob_default_construction() {
389 let rl = RateLimiter::clob_default();
390 assert_eq!(rl.inner.limits.len(), 12);
391 assert!(format!("{:?}", rl).contains("endpoints"));
392 }
393
394 #[test]
395 fn test_gamma_default_construction() {
396 let rl = RateLimiter::gamma_default();
397 assert_eq!(rl.inner.limits.len(), 5);
398 }
399
400 #[test]
401 fn test_data_default_construction() {
402 let rl = RateLimiter::data_default();
403 assert_eq!(rl.inner.limits.len(), 3);
404 }
405
406 #[test]
407 fn test_relay_default_construction() {
408 let rl = RateLimiter::relay_default();
409 assert_eq!(rl.inner.limits.len(), 0);
410 }
411
412 #[test]
413 fn test_rate_limiter_debug_format() {
414 let rl = RateLimiter::clob_default();
415 let dbg = format!("{:?}", rl);
416 assert!(dbg.contains("RateLimiter"), "missing struct name: {dbg}");
417 assert!(dbg.contains("endpoints: 12"), "missing count: {dbg}");
418 }
419
420 #[test]
423 fn test_clob_endpoint_order_and_methods() {
424 let rl = RateLimiter::clob_default();
425 let limits = &rl.inner.limits;
426
427 assert_eq!(limits[0].path_prefix, "/order");
429 assert_eq!(limits[0].method, Some(Method::POST));
430 assert!(limits[0].sustained.is_some());
431
432 assert_eq!(limits[1].path_prefix, "/order");
434 assert_eq!(limits[1].method, Some(Method::DELETE));
435 assert!(limits[1].sustained.is_none());
436
437 assert_eq!(limits[2].path_prefix, "/auth");
439 assert!(limits[2].method.is_none());
440 }
441
442 #[tokio::test]
445 async fn test_acquire_single_completes_immediately() {
446 let rl = RateLimiter::clob_default();
447 let start = std::time::Instant::now();
448 rl.acquire("/order", Some(&Method::POST)).await;
449 assert!(start.elapsed() < Duration::from_millis(50));
450 }
451
452 #[tokio::test]
453 async fn test_acquire_matches_endpoint_by_prefix() {
454 let rl = RateLimiter::clob_default();
455 let start = std::time::Instant::now();
456 rl.acquire("/order/123", Some(&Method::POST)).await;
458 assert!(start.elapsed() < Duration::from_millis(50));
459 }
460
461 #[tokio::test]
462 async fn test_acquire_prefix_respects_segment_boundary() {
463 let rl = RateLimiter::clob_default();
464 let limits = &rl.inner.limits;
465
466 let price_idx = limits
468 .iter()
469 .position(|l| l.path_prefix == "/price")
470 .expect("/price endpoint exists");
471
472 let prices_history_idx = limits
474 .iter()
475 .position(|l| l.path_prefix == "/prices-history")
476 .expect("/prices-history endpoint exists");
477
478 assert!(
480 prices_history_idx < price_idx,
481 "/prices-history (idx {prices_history_idx}) should come before /price (idx {price_idx})"
482 );
483 }
484
485 #[test]
486 fn test_match_mode_prefix_segment_boundary() {
487 let pattern = "/price";
489
490 let check = |path: &str| -> bool {
491 match path.strip_prefix(pattern) {
492 Some(rest) => rest.is_empty() || rest.starts_with('/') || rest.starts_with('?'),
493 None => false,
494 }
495 };
496
497 assert!(check("/price"), "exact match");
499 assert!(check("/price/foo"), "sub-path");
500 assert!(check("/price?token=abc"), "query params");
501
502 assert!(!check("/prices-history"), "partial word /prices-history");
504 assert!(!check("/pricelist"), "partial word /pricelist");
505 assert!(!check("/pricing"), "partial word /pricing");
506
507 assert!(!check("/midpoint"), "different prefix");
509 }
510
511 #[test]
512 fn test_match_mode_exact() {
513 let pattern = "/trades";
515
516 let check = |path: &str| -> bool { path == pattern };
517
518 assert!(check("/trades"), "exact match");
519 assert!(!check("/trades/123"), "sub-path should not match");
520 assert!(!check("/trades?limit=10"), "query params should not match");
521 assert!(!check("/traded"), "different word should not match");
522 }
523
524 #[tokio::test]
525 async fn test_acquire_method_filtering() {
526 let rl = RateLimiter::clob_default();
527 let start = std::time::Instant::now();
528 rl.acquire("/order", Some(&Method::GET)).await;
530 assert!(start.elapsed() < Duration::from_millis(50));
531 }
532
533 #[tokio::test]
534 async fn test_acquire_no_endpoint_match_uses_default_only() {
535 let rl = RateLimiter::clob_default();
536 let start = std::time::Instant::now();
537 rl.acquire("/unknown/path", None).await;
538 assert!(start.elapsed() < Duration::from_millis(50));
539 }
540
541 #[tokio::test]
542 async fn test_acquire_method_none_matches_any_method() {
543 let rl = RateLimiter::gamma_default();
544 let start = std::time::Instant::now();
545 rl.acquire("/events", Some(&Method::GET)).await;
547 rl.acquire("/events", Some(&Method::POST)).await;
548 rl.acquire("/events", None).await;
549 assert!(start.elapsed() < Duration::from_millis(50));
550 }
551
552 #[test]
555 fn test_clob_price_and_prices_history_are_distinct() {
556 let rl = RateLimiter::clob_default();
557 let limits = &rl.inner.limits;
558
559 let price = limits.iter().find(|l| l.path_prefix == "/price").unwrap();
560 let prices_history = limits
561 .iter()
562 .find(|l| l.path_prefix == "/prices-history")
563 .unwrap();
564
565 assert_eq!(price.match_mode, MatchMode::Prefix);
567 assert_eq!(prices_history.match_mode, MatchMode::Prefix);
568
569 if let Some(rest) = "/prices-history".strip_prefix(price.path_prefix) {
571 assert!(
572 !rest.is_empty() && !rest.starts_with('/') && !rest.starts_with('?'),
573 "/prices-history must not match /price pattern, rest = '{rest}'"
574 );
575 }
576 }
577
578 #[test]
579 fn test_data_positions_and_closed_positions_are_distinct() {
580 let rl = RateLimiter::data_default();
581 let limits = &rl.inner.limits;
582
583 let positions = limits
584 .iter()
585 .find(|l| l.path_prefix == "/positions")
586 .unwrap();
587 let closed = limits
588 .iter()
589 .find(|l| l.path_prefix == "/closed-positions")
590 .unwrap();
591
592 assert_eq!(positions.match_mode, MatchMode::Prefix);
593 assert_eq!(closed.match_mode, MatchMode::Prefix);
594
595 assert!(
597 !"/closed-positions".starts_with(positions.path_prefix),
598 "/closed-positions should not match /positions prefix"
599 );
600 }
601
602 #[test]
603 fn test_all_clob_endpoints_have_match_mode() {
604 let rl = RateLimiter::clob_default();
605 for limit in &rl.inner.limits {
606 assert!(
608 limit.match_mode == MatchMode::Prefix || limit.match_mode == MatchMode::Exact,
609 "endpoint {} has no valid match mode",
610 limit.path_prefix
611 );
612 }
613 }
614
615 #[tokio::test]
618 async fn test_acquire_concurrent_tasks_all_complete() {
619 let rl = RateLimiter::clob_default(); let rl = std::sync::Arc::new(rl);
622
623 let mut handles = Vec::new();
624 for _ in 0..10 {
625 let rl = rl.clone();
626 handles.push(tokio::spawn(async move {
627 rl.acquire("/markets", None).await;
628 }));
629 }
630
631 let start = std::time::Instant::now();
632 for handle in handles {
633 handle.await.unwrap();
634 }
635 assert!(
637 start.elapsed() < Duration::from_millis(100),
638 "concurrent acquires took too long: {:?}",
639 start.elapsed()
640 );
641 }
642
643 #[tokio::test]
644 async fn test_acquire_concurrent_different_endpoints() {
645 let rl = std::sync::Arc::new(RateLimiter::clob_default());
647
648 let rl1 = rl.clone();
649 let rl2 = rl.clone();
650 let rl3 = rl.clone();
651
652 let start = std::time::Instant::now();
653 let (r1, r2, r3) = tokio::join!(
654 tokio::spawn(async move { rl1.acquire("/markets", None).await }),
655 tokio::spawn(async move { rl2.acquire("/auth", None).await }),
656 tokio::spawn(async move { rl3.acquire("/order", Some(&Method::POST)).await }),
657 );
658 r1.unwrap();
659 r2.unwrap();
660 r3.unwrap();
661
662 assert!(
663 start.elapsed() < Duration::from_millis(50),
664 "different endpoints should not block: {:?}",
665 start.elapsed()
666 );
667 }
668
669 #[test]
672 fn test_clob_post_order_has_dual_window() {
673 let rl = RateLimiter::clob_default();
674 let post_order = rl
675 .inner
676 .limits
677 .iter()
678 .find(|l| l.path_prefix == "/order" && l.method == Some(Method::POST))
679 .expect("POST /order endpoint should exist");
680
681 assert!(
682 post_order.sustained.is_some(),
683 "POST /order should have a sustained (10-min) window"
684 );
685 }
686
687 #[test]
688 fn test_clob_delete_order_has_no_sustained_window() {
689 let rl = RateLimiter::clob_default();
690 let delete_order = rl
691 .inner
692 .limits
693 .iter()
694 .find(|l| l.path_prefix == "/order" && l.method == Some(Method::DELETE))
695 .expect("DELETE /order endpoint should exist");
696
697 assert!(
698 delete_order.sustained.is_none(),
699 "DELETE /order should only have a burst window"
700 );
701 }
702
703 #[tokio::test]
704 async fn test_dual_window_both_burst_and_sustained_are_awaited() {
705 let rl = RateLimiter::clob_default();
708 let start = std::time::Instant::now();
709 rl.acquire("/order", Some(&Method::POST)).await;
710 assert!(
711 start.elapsed() < Duration::from_millis(50),
712 "dual window single acquire should be fast: {:?}",
713 start.elapsed()
714 );
715 }
716
717 #[test]
720 fn test_should_retry_exhaustion() {
721 let client = crate::HttpClientBuilder::new("https://example.com")
723 .with_retry_config(RetryConfig {
724 max_retries: 3,
725 ..RetryConfig::default()
726 })
727 .build()
728 .unwrap();
729
730 for attempt in 0..3 {
732 assert!(
733 client
734 .should_retry(reqwest::StatusCode::TOO_MANY_REQUESTS, attempt, None)
735 .is_some(),
736 "attempt {attempt} should allow retry"
737 );
738 }
739 assert!(
741 client
742 .should_retry(reqwest::StatusCode::TOO_MANY_REQUESTS, 3, None)
743 .is_none(),
744 "attempt 3 should exhaust retries"
745 );
746 }
747
748 #[test]
749 fn test_should_retry_zero_max_retries_never_retries() {
750 let client = crate::HttpClientBuilder::new("https://example.com")
751 .with_retry_config(RetryConfig {
752 max_retries: 0,
753 ..RetryConfig::default()
754 })
755 .build()
756 .unwrap();
757
758 assert!(
759 client
760 .should_retry(reqwest::StatusCode::TOO_MANY_REQUESTS, 0, None)
761 .is_none(),
762 "max_retries=0 should never retry"
763 );
764 }
765}