1use std::num::NonZeroU32;
2use std::sync::Arc;
3use std::time::Duration;
4
5use governor::Quota;
6use reqwest::Method;
7
8type DirectLimiter = governor::RateLimiter<
9 governor::state::NotKeyed,
10 governor::state::InMemoryState,
11 governor::clock::DefaultClock,
12>;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[allow(dead_code)]
17enum MatchMode {
18 Prefix,
22 Exact,
24}
25
26struct EndpointLimit {
28 path_prefix: &'static str,
29 method: Option<Method>,
30 match_mode: MatchMode,
31 burst: DirectLimiter,
32 sustained: Option<DirectLimiter>,
33}
34
35#[derive(Clone)]
40pub struct RateLimiter {
41 inner: Arc<RateLimiterInner>,
42}
43
44impl std::fmt::Debug for RateLimiter {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("RateLimiter")
47 .field("endpoints", &self.inner.limits.len())
48 .finish()
49 }
50}
51
52struct RateLimiterInner {
53 limits: Vec<EndpointLimit>,
54 default: DirectLimiter,
55}
56
57fn quota(count: u32, period: Duration) -> Quota {
62 let count = count.max(1);
63 let interval = period / count;
64 Quota::with_period(interval)
65 .expect("quota interval must be non-zero")
66 .allow_burst(NonZeroU32::new(count).unwrap())
67}
68
69impl RateLimiter {
70 pub async fn acquire(&self, path: &str, method: Option<&Method>) {
75 self.inner.default.until_ready().await;
76
77 for limit in &self.inner.limits {
78 let matched = match limit.match_mode {
79 MatchMode::Exact => path == limit.path_prefix,
80 MatchMode::Prefix => {
81 match path.strip_prefix(limit.path_prefix) {
84 Some(rest) => {
85 rest.is_empty() || rest.starts_with('/') || rest.starts_with('?')
86 }
87 None => false,
88 }
89 }
90 };
91 if !matched {
92 continue;
93 }
94 if let Some(ref m) = limit.method {
95 if method != Some(m) {
96 continue;
97 }
98 }
99 limit.burst.until_ready().await;
100 if let Some(ref sustained) = limit.sustained {
101 sustained.until_ready().await;
102 }
103 break;
104 }
105 }
106
107 pub fn clob_default() -> Self {
116 let ten_sec = Duration::from_secs(10);
117 let ten_min = Duration::from_secs(600);
118
119 Self {
120 inner: Arc::new(RateLimiterInner {
121 default: DirectLimiter::direct(quota(9_000, ten_sec)),
122 limits: vec![
123 EndpointLimit {
125 path_prefix: "/order",
126 method: Some(Method::POST),
127 match_mode: MatchMode::Prefix,
128 burst: DirectLimiter::direct(quota(3_500, ten_sec)),
129 sustained: Some(DirectLimiter::direct(quota(36_000, ten_min))),
130 },
131 EndpointLimit {
133 path_prefix: "/order",
134 method: Some(Method::DELETE),
135 match_mode: MatchMode::Prefix,
136 burst: DirectLimiter::direct(quota(3_000, ten_sec)),
137 sustained: None,
138 },
139 EndpointLimit {
141 path_prefix: "/auth",
142 method: None,
143 match_mode: MatchMode::Prefix,
144 burst: DirectLimiter::direct(quota(100, ten_sec)),
145 sustained: None,
146 },
147 EndpointLimit {
149 path_prefix: "/trades",
150 method: None,
151 match_mode: MatchMode::Prefix,
152 burst: DirectLimiter::direct(quota(900, ten_sec)),
153 sustained: None,
154 },
155 EndpointLimit {
156 path_prefix: "/data/",
157 method: None,
158 match_mode: MatchMode::Prefix,
159 burst: DirectLimiter::direct(quota(900, ten_sec)),
160 sustained: None,
161 },
162 EndpointLimit {
165 path_prefix: "/prices-history",
166 method: None,
167 match_mode: MatchMode::Prefix,
168 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
169 sustained: None,
170 },
171 EndpointLimit {
172 path_prefix: "/markets",
173 method: None,
174 match_mode: MatchMode::Prefix,
175 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
176 sustained: None,
177 },
178 EndpointLimit {
179 path_prefix: "/book",
180 method: None,
181 match_mode: MatchMode::Prefix,
182 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
183 sustained: None,
184 },
185 EndpointLimit {
186 path_prefix: "/price",
187 method: None,
188 match_mode: MatchMode::Prefix,
189 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
190 sustained: None,
191 },
192 EndpointLimit {
193 path_prefix: "/midpoint",
194 method: None,
195 match_mode: MatchMode::Prefix,
196 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
197 sustained: None,
198 },
199 EndpointLimit {
200 path_prefix: "/neg-risk",
201 method: None,
202 match_mode: MatchMode::Prefix,
203 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
204 sustained: None,
205 },
206 EndpointLimit {
207 path_prefix: "/tick-size",
208 method: None,
209 match_mode: MatchMode::Prefix,
210 burst: DirectLimiter::direct(quota(1_500, ten_sec)),
211 sustained: None,
212 },
213 ],
214 }),
215 }
216 }
217
218 pub fn gamma_default() -> Self {
227 let ten_sec = Duration::from_secs(10);
228
229 Self {
230 inner: Arc::new(RateLimiterInner {
231 default: DirectLimiter::direct(quota(4_000, ten_sec)),
232 limits: vec![
233 EndpointLimit {
234 path_prefix: "/comments",
235 method: None,
236 match_mode: MatchMode::Prefix,
237 burst: DirectLimiter::direct(quota(200, ten_sec)),
238 sustained: None,
239 },
240 EndpointLimit {
241 path_prefix: "/tags",
242 method: None,
243 match_mode: MatchMode::Prefix,
244 burst: DirectLimiter::direct(quota(200, ten_sec)),
245 sustained: None,
246 },
247 EndpointLimit {
248 path_prefix: "/markets",
249 method: None,
250 match_mode: MatchMode::Prefix,
251 burst: DirectLimiter::direct(quota(300, ten_sec)),
252 sustained: None,
253 },
254 EndpointLimit {
255 path_prefix: "/public-search",
256 method: None,
257 match_mode: MatchMode::Prefix,
258 burst: DirectLimiter::direct(quota(350, ten_sec)),
259 sustained: None,
260 },
261 EndpointLimit {
262 path_prefix: "/events",
263 method: None,
264 match_mode: MatchMode::Prefix,
265 burst: DirectLimiter::direct(quota(500, ten_sec)),
266 sustained: None,
267 },
268 ],
269 }),
270 }
271 }
272
273 pub fn data_default() -> Self {
279 let ten_sec = Duration::from_secs(10);
280
281 Self {
282 inner: Arc::new(RateLimiterInner {
283 default: DirectLimiter::direct(quota(1_000, ten_sec)),
284 limits: vec![
285 EndpointLimit {
286 path_prefix: "/closed-positions",
287 method: None,
288 match_mode: MatchMode::Prefix,
289 burst: DirectLimiter::direct(quota(150, ten_sec)),
290 sustained: None,
291 },
292 EndpointLimit {
293 path_prefix: "/positions",
294 method: None,
295 match_mode: MatchMode::Prefix,
296 burst: DirectLimiter::direct(quota(150, ten_sec)),
297 sustained: None,
298 },
299 EndpointLimit {
300 path_prefix: "/trades",
301 method: None,
302 match_mode: MatchMode::Prefix,
303 burst: DirectLimiter::direct(quota(200, ten_sec)),
304 sustained: None,
305 },
306 ],
307 }),
308 }
309 }
310
311 pub fn relay_default() -> Self {
315 Self {
316 inner: Arc::new(RateLimiterInner {
317 default: DirectLimiter::direct(quota(25, Duration::from_secs(60))),
318 limits: vec![],
319 }),
320 }
321 }
322}
323
324#[derive(Debug, Clone)]
326pub struct RetryConfig {
327 pub max_retries: u32,
328 pub initial_backoff_ms: u64,
329 pub max_backoff_ms: u64,
330}
331
332impl Default for RetryConfig {
333 fn default() -> Self {
334 Self {
335 max_retries: 3,
336 initial_backoff_ms: 500,
337 max_backoff_ms: 10_000,
338 }
339 }
340}
341
342impl RetryConfig {
343 pub fn backoff(&self, attempt: u32) -> Duration {
348 let base = self
349 .initial_backoff_ms
350 .saturating_mul(1u64 << attempt.min(10));
351 let capped = base.min(self.max_backoff_ms);
352 let jitter_factor = 0.75 + (fastrand::f64() * 0.5);
354 let ms = (capped as f64 * jitter_factor) as u64;
355 Duration::from_millis(ms.max(1))
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
366 fn test_retry_config_default() {
367 let cfg = RetryConfig::default();
368 assert_eq!(cfg.max_retries, 3);
369 assert_eq!(cfg.initial_backoff_ms, 500);
370 assert_eq!(cfg.max_backoff_ms, 10_000);
371 }
372
373 #[test]
374 fn test_backoff_attempt_zero() {
375 let cfg = RetryConfig::default();
376 let d = cfg.backoff(0);
377 let ms = d.as_millis() as u64;
380 assert!(
381 (375..=625).contains(&ms),
382 "attempt 0: {ms}ms not in [375, 625]"
383 );
384 }
385
386 #[test]
387 fn test_backoff_exponential_growth() {
388 let cfg = RetryConfig::default();
389 let d0 = cfg.backoff(0);
390 let d1 = cfg.backoff(1);
391 let d2 = cfg.backoff(2);
392 assert!(d0 < d1, "d0={d0:?} should be < d1={d1:?}");
393 assert!(d1 < d2, "d1={d1:?} should be < d2={d2:?}");
394 }
395
396 #[test]
397 fn test_backoff_jitter_bounds() {
398 let cfg = RetryConfig::default();
399 for attempt in 0..20 {
400 let d = cfg.backoff(attempt);
401 let base = cfg
402 .initial_backoff_ms
403 .saturating_mul(1u64 << attempt.min(10));
404 let capped = base.min(cfg.max_backoff_ms);
405 let lower = (capped as f64 * 0.75) as u64;
406 let upper = (capped as f64 * 1.25) as u64;
407 let ms = d.as_millis() as u64;
408 assert!(
409 ms >= lower.max(1) && ms <= upper,
410 "attempt {attempt}: {ms}ms not in [{lower}, {upper}]"
411 );
412 }
413 }
414
415 #[test]
416 fn test_backoff_max_capping() {
417 let cfg = RetryConfig::default();
418 for attempt in 5..=10 {
419 let d = cfg.backoff(attempt);
420 let ceiling = (cfg.max_backoff_ms as f64 * 1.25) as u64;
421 assert!(
422 d.as_millis() as u64 <= ceiling,
423 "attempt {attempt}: {:?} exceeded ceiling {ceiling}ms",
424 d
425 );
426 }
427 }
428
429 #[test]
430 fn test_backoff_very_high_attempt() {
431 let cfg = RetryConfig::default();
432 let d = cfg.backoff(100);
433 let ceiling = (cfg.max_backoff_ms as f64 * 1.25) as u64;
434 assert!(d.as_millis() as u64 <= ceiling);
435 assert!(d.as_millis() >= 1);
436 }
437
438 #[test]
439 fn test_backoff_jitter_distribution() {
440 let cfg = RetryConfig::default();
443 let midpoint = cfg.initial_backoff_ms; let (mut below, mut above) = (0u32, 0u32);
445 for _ in 0..200 {
446 let ms = cfg.backoff(0).as_millis() as u64;
447 if ms < midpoint {
448 below += 1;
449 } else {
450 above += 1;
451 }
452 }
453 assert!(
454 below >= 20 && above >= 20,
455 "jitter looks degenerate: {below} below midpoint, {above} above"
456 );
457 }
458
459 #[test]
462 fn test_quota_creation() {
463 let _ = quota(100, Duration::from_secs(10));
465 let _ = quota(1, Duration::from_secs(60));
466 let _ = quota(9_000, Duration::from_secs(10));
467 }
468
469 #[test]
470 fn test_quota_edge_zero_count() {
471 let _ = quota(0, Duration::from_secs(10));
473 }
474
475 #[test]
478 fn test_clob_default_construction() {
479 let rl = RateLimiter::clob_default();
480 assert_eq!(rl.inner.limits.len(), 12);
481 assert!(format!("{:?}", rl).contains("endpoints"));
482 }
483
484 #[test]
485 fn test_gamma_default_construction() {
486 let rl = RateLimiter::gamma_default();
487 assert_eq!(rl.inner.limits.len(), 5);
488 }
489
490 #[test]
491 fn test_data_default_construction() {
492 let rl = RateLimiter::data_default();
493 assert_eq!(rl.inner.limits.len(), 3);
494 }
495
496 #[test]
497 fn test_relay_default_construction() {
498 let rl = RateLimiter::relay_default();
499 assert_eq!(rl.inner.limits.len(), 0);
500 }
501
502 #[test]
503 fn test_rate_limiter_debug_format() {
504 let rl = RateLimiter::clob_default();
505 let dbg = format!("{:?}", rl);
506 assert!(dbg.contains("RateLimiter"), "missing struct name: {dbg}");
507 assert!(dbg.contains("endpoints: 12"), "missing count: {dbg}");
508 }
509
510 #[test]
513 fn test_clob_endpoint_order_and_methods() {
514 let rl = RateLimiter::clob_default();
515 let limits = &rl.inner.limits;
516
517 assert_eq!(limits[0].path_prefix, "/order");
519 assert_eq!(limits[0].method, Some(Method::POST));
520 assert!(limits[0].sustained.is_some());
521
522 assert_eq!(limits[1].path_prefix, "/order");
524 assert_eq!(limits[1].method, Some(Method::DELETE));
525 assert!(limits[1].sustained.is_none());
526
527 assert_eq!(limits[2].path_prefix, "/auth");
529 assert!(limits[2].method.is_none());
530 }
531
532 #[tokio::test]
535 async fn test_acquire_single_completes_immediately() {
536 let rl = RateLimiter::clob_default();
537 let start = std::time::Instant::now();
538 rl.acquire("/order", Some(&Method::POST)).await;
539 assert!(start.elapsed() < Duration::from_millis(50));
540 }
541
542 #[tokio::test]
543 async fn test_acquire_matches_endpoint_by_prefix() {
544 let rl = RateLimiter::clob_default();
545 let start = std::time::Instant::now();
546 rl.acquire("/order/123", Some(&Method::POST)).await;
548 assert!(start.elapsed() < Duration::from_millis(50));
549 }
550
551 #[tokio::test]
552 async fn test_acquire_prefix_respects_segment_boundary() {
553 let rl = RateLimiter::clob_default();
554 let limits = &rl.inner.limits;
555
556 let price_idx = limits
558 .iter()
559 .position(|l| l.path_prefix == "/price")
560 .expect("/price endpoint exists");
561
562 let prices_history_idx = limits
564 .iter()
565 .position(|l| l.path_prefix == "/prices-history")
566 .expect("/prices-history endpoint exists");
567
568 assert!(
570 prices_history_idx < price_idx,
571 "/prices-history (idx {prices_history_idx}) should come before /price (idx {price_idx})"
572 );
573 }
574
575 #[test]
576 fn test_match_mode_prefix_segment_boundary() {
577 let pattern = "/price";
579
580 let check = |path: &str| -> bool {
581 match path.strip_prefix(pattern) {
582 Some(rest) => rest.is_empty() || rest.starts_with('/') || rest.starts_with('?'),
583 None => false,
584 }
585 };
586
587 assert!(check("/price"), "exact match");
589 assert!(check("/price/foo"), "sub-path");
590 assert!(check("/price?token=abc"), "query params");
591
592 assert!(!check("/prices-history"), "partial word /prices-history");
594 assert!(!check("/pricelist"), "partial word /pricelist");
595 assert!(!check("/pricing"), "partial word /pricing");
596
597 assert!(!check("/midpoint"), "different prefix");
599 }
600
601 #[test]
602 fn test_match_mode_exact() {
603 let pattern = "/trades";
605
606 let check = |path: &str| -> bool { path == pattern };
607
608 assert!(check("/trades"), "exact match");
609 assert!(!check("/trades/123"), "sub-path should not match");
610 assert!(!check("/trades?limit=10"), "query params should not match");
611 assert!(!check("/traded"), "different word should not match");
612 }
613
614 #[tokio::test]
615 async fn test_acquire_method_filtering() {
616 let rl = RateLimiter::clob_default();
617 let start = std::time::Instant::now();
618 rl.acquire("/order", Some(&Method::GET)).await;
620 assert!(start.elapsed() < Duration::from_millis(50));
621 }
622
623 #[tokio::test]
624 async fn test_acquire_no_endpoint_match_uses_default_only() {
625 let rl = RateLimiter::clob_default();
626 let start = std::time::Instant::now();
627 rl.acquire("/unknown/path", None).await;
628 assert!(start.elapsed() < Duration::from_millis(50));
629 }
630
631 #[tokio::test]
632 async fn test_acquire_method_none_matches_any_method() {
633 let rl = RateLimiter::gamma_default();
634 let start = std::time::Instant::now();
635 rl.acquire("/events", Some(&Method::GET)).await;
637 rl.acquire("/events", Some(&Method::POST)).await;
638 rl.acquire("/events", None).await;
639 assert!(start.elapsed() < Duration::from_millis(50));
640 }
641
642 #[test]
645 fn test_clob_price_and_prices_history_are_distinct() {
646 let rl = RateLimiter::clob_default();
647 let limits = &rl.inner.limits;
648
649 let price = limits.iter().find(|l| l.path_prefix == "/price").unwrap();
650 let prices_history = limits
651 .iter()
652 .find(|l| l.path_prefix == "/prices-history")
653 .unwrap();
654
655 assert_eq!(price.match_mode, MatchMode::Prefix);
657 assert_eq!(prices_history.match_mode, MatchMode::Prefix);
658
659 if let Some(rest) = "/prices-history".strip_prefix(price.path_prefix) {
661 assert!(
662 !rest.is_empty() && !rest.starts_with('/') && !rest.starts_with('?'),
663 "/prices-history must not match /price pattern, rest = '{rest}'"
664 );
665 }
666 }
667
668 #[test]
669 fn test_data_positions_and_closed_positions_are_distinct() {
670 let rl = RateLimiter::data_default();
671 let limits = &rl.inner.limits;
672
673 let positions = limits
674 .iter()
675 .find(|l| l.path_prefix == "/positions")
676 .unwrap();
677 let closed = limits
678 .iter()
679 .find(|l| l.path_prefix == "/closed-positions")
680 .unwrap();
681
682 assert_eq!(positions.match_mode, MatchMode::Prefix);
683 assert_eq!(closed.match_mode, MatchMode::Prefix);
684
685 assert!(
687 !"/closed-positions".starts_with(positions.path_prefix),
688 "/closed-positions should not match /positions prefix"
689 );
690 }
691
692 #[test]
693 fn test_all_clob_endpoints_have_match_mode() {
694 let rl = RateLimiter::clob_default();
695 for limit in &rl.inner.limits {
696 assert!(
698 limit.match_mode == MatchMode::Prefix || limit.match_mode == MatchMode::Exact,
699 "endpoint {} has no valid match mode",
700 limit.path_prefix
701 );
702 }
703 }
704
705 #[tokio::test]
708 async fn test_acquire_concurrent_tasks_all_complete() {
709 let rl = RateLimiter::clob_default(); let rl = std::sync::Arc::new(rl);
712
713 let mut handles = Vec::new();
714 for _ in 0..10 {
715 let rl = rl.clone();
716 handles.push(tokio::spawn(async move {
717 rl.acquire("/markets", None).await;
718 }));
719 }
720
721 let start = std::time::Instant::now();
722 for handle in handles {
723 handle.await.unwrap();
724 }
725 assert!(
727 start.elapsed() < Duration::from_millis(100),
728 "concurrent acquires took too long: {:?}",
729 start.elapsed()
730 );
731 }
732
733 #[tokio::test]
734 async fn test_acquire_concurrent_different_endpoints() {
735 let rl = std::sync::Arc::new(RateLimiter::clob_default());
737
738 let rl1 = rl.clone();
739 let rl2 = rl.clone();
740 let rl3 = rl.clone();
741
742 let start = std::time::Instant::now();
743 let (r1, r2, r3) = tokio::join!(
744 tokio::spawn(async move { rl1.acquire("/markets", None).await }),
745 tokio::spawn(async move { rl2.acquire("/auth", None).await }),
746 tokio::spawn(async move { rl3.acquire("/order", Some(&Method::POST)).await }),
747 );
748 r1.unwrap();
749 r2.unwrap();
750 r3.unwrap();
751
752 assert!(
753 start.elapsed() < Duration::from_millis(50),
754 "different endpoints should not block: {:?}",
755 start.elapsed()
756 );
757 }
758
759 #[test]
762 fn test_clob_post_order_has_dual_window() {
763 let rl = RateLimiter::clob_default();
764 let post_order = rl
765 .inner
766 .limits
767 .iter()
768 .find(|l| l.path_prefix == "/order" && l.method == Some(Method::POST))
769 .expect("POST /order endpoint should exist");
770
771 assert!(
772 post_order.sustained.is_some(),
773 "POST /order should have a sustained (10-min) window"
774 );
775 }
776
777 #[test]
778 fn test_clob_delete_order_has_no_sustained_window() {
779 let rl = RateLimiter::clob_default();
780 let delete_order = rl
781 .inner
782 .limits
783 .iter()
784 .find(|l| l.path_prefix == "/order" && l.method == Some(Method::DELETE))
785 .expect("DELETE /order endpoint should exist");
786
787 assert!(
788 delete_order.sustained.is_none(),
789 "DELETE /order should only have a burst window"
790 );
791 }
792
793 #[tokio::test]
794 async fn test_dual_window_both_burst_and_sustained_are_awaited() {
795 let rl = RateLimiter::clob_default();
798 let start = std::time::Instant::now();
799 rl.acquire("/order", Some(&Method::POST)).await;
800 assert!(
801 start.elapsed() < Duration::from_millis(50),
802 "dual window single acquire should be fast: {:?}",
803 start.elapsed()
804 );
805 }
806
807 #[test]
810 fn test_should_retry_exhaustion() {
811 let client = crate::HttpClientBuilder::new("https://example.com")
813 .with_retry_config(RetryConfig {
814 max_retries: 3,
815 ..RetryConfig::default()
816 })
817 .build()
818 .unwrap();
819
820 for attempt in 0..3 {
822 assert!(
823 client
824 .should_retry(reqwest::StatusCode::TOO_MANY_REQUESTS, attempt, None)
825 .is_some(),
826 "attempt {attempt} should allow retry"
827 );
828 }
829 assert!(
831 client
832 .should_retry(reqwest::StatusCode::TOO_MANY_REQUESTS, 3, None)
833 .is_none(),
834 "attempt 3 should exhaust retries"
835 );
836 }
837
838 #[test]
839 fn test_should_retry_zero_max_retries_never_retries() {
840 let client = crate::HttpClientBuilder::new("https://example.com")
841 .with_retry_config(RetryConfig {
842 max_retries: 0,
843 ..RetryConfig::default()
844 })
845 .build()
846 .unwrap();
847
848 assert!(
849 client
850 .should_retry(reqwest::StatusCode::TOO_MANY_REQUESTS, 0, None)
851 .is_none(),
852 "max_retries=0 should never retry"
853 );
854 }
855}