1use std::num::NonZeroU32;
2use std::sync::Arc;
3use std::time::Duration;
4
5use governor::Quota;
6use reqwest::Method;
7
8type DirectLimiter = governor::RateLimiter<
9 governor::state::NotKeyed,
10 governor::state::InMemoryState,
11 governor::clock::DefaultClock,
12>;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[allow(dead_code)]
17enum MatchMode {
18 Prefix,
22 Exact,
24}
25
26struct EndpointLimit {
28 path_prefix: &'static str,
29 method: Option<Method>,
30 match_mode: MatchMode,
31 burst: DirectLimiter,
32 sustained: Option<DirectLimiter>,
33}
34
35#[derive(Clone)]
40pub struct RateLimiter {
41 inner: Arc<RateLimiterInner>,
42}
43
44impl std::fmt::Debug for RateLimiter {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("RateLimiter")
47 .field("endpoints", &self.inner.limits.len())
48 .finish()
49 }
50}
51
52struct RateLimiterInner {
53 limits: Vec<EndpointLimit>,
54 default: DirectLimiter,
55}
56
57fn quota(count: u32, period: Duration) -> Quota {
62 let count = count.max(1);
63 let interval = period / count;
64 Quota::with_period(interval)
65 .expect("quota interval must be non-zero")
66 .allow_burst(NonZeroU32::new(count).unwrap())
67}
68
69fn endpoint_limit(
71 path_prefix: &'static str,
72 method: Option<Method>,
73 match_mode: MatchMode,
74 burst_count: u32,
75 burst_period: Duration,
76 sustained: Option<(u32, Duration)>,
77) -> EndpointLimit {
78 EndpointLimit {
79 path_prefix,
80 method,
81 match_mode,
82 burst: DirectLimiter::direct(quota(burst_count, burst_period)),
83 sustained: sustained.map(|(count, period)| DirectLimiter::direct(quota(count, period))),
84 }
85}
86
87impl RateLimiter {
88 pub async fn acquire(&self, path: &str, method: Option<&Method>) {
93 self.inner.default.until_ready().await;
94
95 for limit in &self.inner.limits {
96 let matched = match limit.match_mode {
97 MatchMode::Exact => path == limit.path_prefix,
98 MatchMode::Prefix => {
99 match path.strip_prefix(limit.path_prefix) {
102 Some(rest) => {
103 rest.is_empty() || rest.starts_with('/') || rest.starts_with('?')
104 }
105 None => false,
106 }
107 }
108 };
109 if !matched {
110 continue;
111 }
112 if let Some(ref m) = limit.method {
113 if method != Some(m) {
114 continue;
115 }
116 }
117 limit.burst.until_ready().await;
118 if let Some(ref sustained) = limit.sustained {
119 sustained.until_ready().await;
120 }
121 break;
122 }
123 }
124
125 pub fn clob_default() -> Self {
134 let ten_sec = Duration::from_secs(10);
135 let ten_min = Duration::from_secs(600);
136 let p = MatchMode::Prefix;
137
138 Self {
139 inner: Arc::new(RateLimiterInner {
140 default: DirectLimiter::direct(quota(9_000, ten_sec)),
141 limits: vec![
142 endpoint_limit(
144 "/order",
145 Some(Method::POST),
146 p,
147 3_500,
148 ten_sec,
149 Some((36_000, ten_min)),
150 ),
151 endpoint_limit("/order", Some(Method::DELETE), p, 3_000, ten_sec, None),
153 endpoint_limit("/auth", None, p, 100, ten_sec, None),
155 endpoint_limit("/trades", None, p, 900, ten_sec, None),
157 endpoint_limit("/data/", None, p, 900, ten_sec, None),
158 endpoint_limit("/prices-history", None, p, 1_500, ten_sec, None),
160 endpoint_limit("/markets", None, p, 1_500, ten_sec, None),
161 endpoint_limit("/book", None, p, 1_500, ten_sec, None),
162 endpoint_limit("/price", None, p, 1_500, ten_sec, None),
163 endpoint_limit("/midpoint", None, p, 1_500, ten_sec, None),
164 endpoint_limit("/neg-risk", None, p, 1_500, ten_sec, None),
165 endpoint_limit("/tick-size", None, p, 1_500, ten_sec, None),
166 ],
167 }),
168 }
169 }
170
171 pub fn gamma_default() -> Self {
180 let ten_sec = Duration::from_secs(10);
181 let p = MatchMode::Prefix;
182
183 Self {
184 inner: Arc::new(RateLimiterInner {
185 default: DirectLimiter::direct(quota(4_000, ten_sec)),
186 limits: vec![
187 endpoint_limit("/comments", None, p, 200, ten_sec, None),
188 endpoint_limit("/tags", None, p, 200, ten_sec, None),
189 endpoint_limit("/markets", None, p, 300, ten_sec, None),
190 endpoint_limit("/public-search", None, p, 350, ten_sec, None),
191 endpoint_limit("/events", None, p, 500, ten_sec, None),
192 ],
193 }),
194 }
195 }
196
197 pub fn data_default() -> Self {
203 let ten_sec = Duration::from_secs(10);
204 let p = MatchMode::Prefix;
205
206 Self {
207 inner: Arc::new(RateLimiterInner {
208 default: DirectLimiter::direct(quota(1_000, ten_sec)),
209 limits: vec![
210 endpoint_limit("/closed-positions", None, p, 150, ten_sec, None),
211 endpoint_limit("/positions", None, p, 150, ten_sec, None),
212 endpoint_limit("/trades", None, p, 200, ten_sec, None),
213 ],
214 }),
215 }
216 }
217
218 pub fn relay_default() -> Self {
222 Self {
223 inner: Arc::new(RateLimiterInner {
224 default: DirectLimiter::direct(quota(25, Duration::from_secs(60))),
225 limits: vec![],
226 }),
227 }
228 }
229}
230
231#[derive(Debug, Clone)]
233pub struct RetryConfig {
234 pub max_retries: u32,
235 pub initial_backoff_ms: u64,
236 pub max_backoff_ms: u64,
237}
238
239impl Default for RetryConfig {
240 fn default() -> Self {
241 Self {
242 max_retries: 3,
243 initial_backoff_ms: 500,
244 max_backoff_ms: 10_000,
245 }
246 }
247}
248
249impl RetryConfig {
250 pub fn backoff(&self, attempt: u32) -> Duration {
255 let base = self
256 .initial_backoff_ms
257 .saturating_mul(1u64 << attempt.min(10));
258 let capped = base.min(self.max_backoff_ms);
259 let jitter_factor = 0.75 + (fastrand::f64() * 0.5);
261 let ms = (capped as f64 * jitter_factor) as u64;
262 Duration::from_millis(ms.max(1))
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
273 fn test_retry_config_default() {
274 let cfg = RetryConfig::default();
275 assert_eq!(cfg.max_retries, 3);
276 assert_eq!(cfg.initial_backoff_ms, 500);
277 assert_eq!(cfg.max_backoff_ms, 10_000);
278 }
279
280 #[test]
281 fn test_backoff_attempt_zero() {
282 let cfg = RetryConfig::default();
283 let d = cfg.backoff(0);
284 let ms = d.as_millis() as u64;
287 assert!(
288 (375..=625).contains(&ms),
289 "attempt 0: {ms}ms not in [375, 625]"
290 );
291 }
292
293 #[test]
294 fn test_backoff_exponential_growth() {
295 let cfg = RetryConfig::default();
296 let d0 = cfg.backoff(0);
297 let d1 = cfg.backoff(1);
298 let d2 = cfg.backoff(2);
299 assert!(d0 < d1, "d0={d0:?} should be < d1={d1:?}");
300 assert!(d1 < d2, "d1={d1:?} should be < d2={d2:?}");
301 }
302
303 #[test]
304 fn test_backoff_jitter_bounds() {
305 let cfg = RetryConfig::default();
306 for attempt in 0..20 {
307 let d = cfg.backoff(attempt);
308 let base = cfg
309 .initial_backoff_ms
310 .saturating_mul(1u64 << attempt.min(10));
311 let capped = base.min(cfg.max_backoff_ms);
312 let lower = (capped as f64 * 0.75) as u64;
313 let upper = (capped as f64 * 1.25) as u64;
314 let ms = d.as_millis() as u64;
315 assert!(
316 ms >= lower.max(1) && ms <= upper,
317 "attempt {attempt}: {ms}ms not in [{lower}, {upper}]"
318 );
319 }
320 }
321
322 #[test]
323 fn test_backoff_max_capping() {
324 let cfg = RetryConfig::default();
325 for attempt in 5..=10 {
326 let d = cfg.backoff(attempt);
327 let ceiling = (cfg.max_backoff_ms as f64 * 1.25) as u64;
328 assert!(
329 d.as_millis() as u64 <= ceiling,
330 "attempt {attempt}: {:?} exceeded ceiling {ceiling}ms",
331 d
332 );
333 }
334 }
335
336 #[test]
337 fn test_backoff_very_high_attempt() {
338 let cfg = RetryConfig::default();
339 let d = cfg.backoff(100);
340 let ceiling = (cfg.max_backoff_ms as f64 * 1.25) as u64;
341 assert!(d.as_millis() as u64 <= ceiling);
342 assert!(d.as_millis() >= 1);
343 }
344
345 #[test]
346 fn test_backoff_jitter_distribution() {
347 let cfg = RetryConfig::default();
350 let midpoint = cfg.initial_backoff_ms; let (mut below, mut above) = (0u32, 0u32);
352 for _ in 0..200 {
353 let ms = cfg.backoff(0).as_millis() as u64;
354 if ms < midpoint {
355 below += 1;
356 } else {
357 above += 1;
358 }
359 }
360 assert!(
361 below >= 20 && above >= 20,
362 "jitter looks degenerate: {below} below midpoint, {above} above"
363 );
364 }
365
366 #[test]
369 fn test_quota_creation() {
370 let _ = quota(100, Duration::from_secs(10));
372 let _ = quota(1, Duration::from_secs(60));
373 let _ = quota(9_000, Duration::from_secs(10));
374 }
375
376 #[test]
377 fn test_quota_edge_zero_count() {
378 let _ = quota(0, Duration::from_secs(10));
380 }
381
382 #[test]
385 fn test_clob_default_construction() {
386 let rl = RateLimiter::clob_default();
387 assert_eq!(rl.inner.limits.len(), 12);
388 assert!(format!("{:?}", rl).contains("endpoints"));
389 }
390
391 #[test]
392 fn test_gamma_default_construction() {
393 let rl = RateLimiter::gamma_default();
394 assert_eq!(rl.inner.limits.len(), 5);
395 }
396
397 #[test]
398 fn test_data_default_construction() {
399 let rl = RateLimiter::data_default();
400 assert_eq!(rl.inner.limits.len(), 3);
401 }
402
403 #[test]
404 fn test_relay_default_construction() {
405 let rl = RateLimiter::relay_default();
406 assert_eq!(rl.inner.limits.len(), 0);
407 }
408
409 #[test]
410 fn test_rate_limiter_debug_format() {
411 let rl = RateLimiter::clob_default();
412 let dbg = format!("{:?}", rl);
413 assert!(dbg.contains("RateLimiter"), "missing struct name: {dbg}");
414 assert!(dbg.contains("endpoints: 12"), "missing count: {dbg}");
415 }
416
417 #[test]
420 fn test_clob_endpoint_order_and_methods() {
421 let rl = RateLimiter::clob_default();
422 let limits = &rl.inner.limits;
423
424 assert_eq!(limits[0].path_prefix, "/order");
426 assert_eq!(limits[0].method, Some(Method::POST));
427 assert!(limits[0].sustained.is_some());
428
429 assert_eq!(limits[1].path_prefix, "/order");
431 assert_eq!(limits[1].method, Some(Method::DELETE));
432 assert!(limits[1].sustained.is_none());
433
434 assert_eq!(limits[2].path_prefix, "/auth");
436 assert!(limits[2].method.is_none());
437 }
438
439 #[tokio::test]
442 async fn test_acquire_single_completes_immediately() {
443 let rl = RateLimiter::clob_default();
444 let start = std::time::Instant::now();
445 rl.acquire("/order", Some(&Method::POST)).await;
446 assert!(start.elapsed() < Duration::from_millis(50));
447 }
448
449 #[tokio::test]
450 async fn test_acquire_matches_endpoint_by_prefix() {
451 let rl = RateLimiter::clob_default();
452 let start = std::time::Instant::now();
453 rl.acquire("/order/123", Some(&Method::POST)).await;
455 assert!(start.elapsed() < Duration::from_millis(50));
456 }
457
458 #[tokio::test]
459 async fn test_acquire_prefix_respects_segment_boundary() {
460 let rl = RateLimiter::clob_default();
461 let limits = &rl.inner.limits;
462
463 let price_idx = limits
465 .iter()
466 .position(|l| l.path_prefix == "/price")
467 .expect("/price endpoint exists");
468
469 let prices_history_idx = limits
471 .iter()
472 .position(|l| l.path_prefix == "/prices-history")
473 .expect("/prices-history endpoint exists");
474
475 assert!(
477 prices_history_idx < price_idx,
478 "/prices-history (idx {prices_history_idx}) should come before /price (idx {price_idx})"
479 );
480 }
481
482 #[test]
483 fn test_match_mode_prefix_segment_boundary() {
484 let pattern = "/price";
486
487 let check = |path: &str| -> bool {
488 match path.strip_prefix(pattern) {
489 Some(rest) => rest.is_empty() || rest.starts_with('/') || rest.starts_with('?'),
490 None => false,
491 }
492 };
493
494 assert!(check("/price"), "exact match");
496 assert!(check("/price/foo"), "sub-path");
497 assert!(check("/price?token=abc"), "query params");
498
499 assert!(!check("/prices-history"), "partial word /prices-history");
501 assert!(!check("/pricelist"), "partial word /pricelist");
502 assert!(!check("/pricing"), "partial word /pricing");
503
504 assert!(!check("/midpoint"), "different prefix");
506 }
507
508 #[test]
509 fn test_match_mode_exact() {
510 let pattern = "/trades";
512
513 let check = |path: &str| -> bool { path == pattern };
514
515 assert!(check("/trades"), "exact match");
516 assert!(!check("/trades/123"), "sub-path should not match");
517 assert!(!check("/trades?limit=10"), "query params should not match");
518 assert!(!check("/traded"), "different word should not match");
519 }
520
521 #[tokio::test]
522 async fn test_acquire_method_filtering() {
523 let rl = RateLimiter::clob_default();
524 let start = std::time::Instant::now();
525 rl.acquire("/order", Some(&Method::GET)).await;
527 assert!(start.elapsed() < Duration::from_millis(50));
528 }
529
530 #[tokio::test]
531 async fn test_acquire_no_endpoint_match_uses_default_only() {
532 let rl = RateLimiter::clob_default();
533 let start = std::time::Instant::now();
534 rl.acquire("/unknown/path", None).await;
535 assert!(start.elapsed() < Duration::from_millis(50));
536 }
537
538 #[tokio::test]
539 async fn test_acquire_method_none_matches_any_method() {
540 let rl = RateLimiter::gamma_default();
541 let start = std::time::Instant::now();
542 rl.acquire("/events", Some(&Method::GET)).await;
544 rl.acquire("/events", Some(&Method::POST)).await;
545 rl.acquire("/events", None).await;
546 assert!(start.elapsed() < Duration::from_millis(50));
547 }
548
549 #[test]
552 fn test_clob_price_and_prices_history_are_distinct() {
553 let rl = RateLimiter::clob_default();
554 let limits = &rl.inner.limits;
555
556 let price = limits.iter().find(|l| l.path_prefix == "/price").unwrap();
557 let prices_history = limits
558 .iter()
559 .find(|l| l.path_prefix == "/prices-history")
560 .unwrap();
561
562 assert_eq!(price.match_mode, MatchMode::Prefix);
564 assert_eq!(prices_history.match_mode, MatchMode::Prefix);
565
566 if let Some(rest) = "/prices-history".strip_prefix(price.path_prefix) {
568 assert!(
569 !rest.is_empty() && !rest.starts_with('/') && !rest.starts_with('?'),
570 "/prices-history must not match /price pattern, rest = '{rest}'"
571 );
572 }
573 }
574
575 #[test]
576 fn test_data_positions_and_closed_positions_are_distinct() {
577 let rl = RateLimiter::data_default();
578 let limits = &rl.inner.limits;
579
580 let positions = limits
581 .iter()
582 .find(|l| l.path_prefix == "/positions")
583 .unwrap();
584 let closed = limits
585 .iter()
586 .find(|l| l.path_prefix == "/closed-positions")
587 .unwrap();
588
589 assert_eq!(positions.match_mode, MatchMode::Prefix);
590 assert_eq!(closed.match_mode, MatchMode::Prefix);
591
592 assert!(
594 !"/closed-positions".starts_with(positions.path_prefix),
595 "/closed-positions should not match /positions prefix"
596 );
597 }
598
599 #[test]
600 fn test_all_clob_endpoints_have_match_mode() {
601 let rl = RateLimiter::clob_default();
602 for limit in &rl.inner.limits {
603 assert!(
605 limit.match_mode == MatchMode::Prefix || limit.match_mode == MatchMode::Exact,
606 "endpoint {} has no valid match mode",
607 limit.path_prefix
608 );
609 }
610 }
611
612 #[tokio::test]
615 async fn test_acquire_concurrent_tasks_all_complete() {
616 let rl = RateLimiter::clob_default(); let rl = std::sync::Arc::new(rl);
619
620 let mut handles = Vec::new();
621 for _ in 0..10 {
622 let rl = rl.clone();
623 handles.push(tokio::spawn(async move {
624 rl.acquire("/markets", None).await;
625 }));
626 }
627
628 let start = std::time::Instant::now();
629 for handle in handles {
630 handle.await.unwrap();
631 }
632 assert!(
634 start.elapsed() < Duration::from_millis(100),
635 "concurrent acquires took too long: {:?}",
636 start.elapsed()
637 );
638 }
639
640 #[tokio::test]
641 async fn test_acquire_concurrent_different_endpoints() {
642 let rl = std::sync::Arc::new(RateLimiter::clob_default());
644
645 let rl1 = rl.clone();
646 let rl2 = rl.clone();
647 let rl3 = rl.clone();
648
649 let start = std::time::Instant::now();
650 let (r1, r2, r3) = tokio::join!(
651 tokio::spawn(async move { rl1.acquire("/markets", None).await }),
652 tokio::spawn(async move { rl2.acquire("/auth", None).await }),
653 tokio::spawn(async move { rl3.acquire("/order", Some(&Method::POST)).await }),
654 );
655 r1.unwrap();
656 r2.unwrap();
657 r3.unwrap();
658
659 assert!(
660 start.elapsed() < Duration::from_millis(50),
661 "different endpoints should not block: {:?}",
662 start.elapsed()
663 );
664 }
665
666 #[test]
669 fn test_clob_post_order_has_dual_window() {
670 let rl = RateLimiter::clob_default();
671 let post_order = rl
672 .inner
673 .limits
674 .iter()
675 .find(|l| l.path_prefix == "/order" && l.method == Some(Method::POST))
676 .expect("POST /order endpoint should exist");
677
678 assert!(
679 post_order.sustained.is_some(),
680 "POST /order should have a sustained (10-min) window"
681 );
682 }
683
684 #[test]
685 fn test_clob_delete_order_has_no_sustained_window() {
686 let rl = RateLimiter::clob_default();
687 let delete_order = rl
688 .inner
689 .limits
690 .iter()
691 .find(|l| l.path_prefix == "/order" && l.method == Some(Method::DELETE))
692 .expect("DELETE /order endpoint should exist");
693
694 assert!(
695 delete_order.sustained.is_none(),
696 "DELETE /order should only have a burst window"
697 );
698 }
699
700 #[tokio::test]
701 async fn test_dual_window_both_burst_and_sustained_are_awaited() {
702 let rl = RateLimiter::clob_default();
705 let start = std::time::Instant::now();
706 rl.acquire("/order", Some(&Method::POST)).await;
707 assert!(
708 start.elapsed() < Duration::from_millis(50),
709 "dual window single acquire should be fast: {:?}",
710 start.elapsed()
711 );
712 }
713
714 #[test]
717 fn test_should_retry_exhaustion() {
718 let client = crate::HttpClientBuilder::new("https://example.com")
720 .with_retry_config(RetryConfig {
721 max_retries: 3,
722 ..RetryConfig::default()
723 })
724 .build()
725 .unwrap();
726
727 for attempt in 0..3 {
729 assert!(
730 client
731 .should_retry(reqwest::StatusCode::TOO_MANY_REQUESTS, attempt, None)
732 .is_some(),
733 "attempt {attempt} should allow retry"
734 );
735 }
736 assert!(
738 client
739 .should_retry(reqwest::StatusCode::TOO_MANY_REQUESTS, 3, None)
740 .is_none(),
741 "attempt 3 should exhaust retries"
742 );
743 }
744
745 #[test]
746 fn test_should_retry_zero_max_retries_never_retries() {
747 let client = crate::HttpClientBuilder::new("https://example.com")
748 .with_retry_config(RetryConfig {
749 max_retries: 0,
750 ..RetryConfig::default()
751 })
752 .build()
753 .unwrap();
754
755 assert!(
756 client
757 .should_retry(reqwest::StatusCode::TOO_MANY_REQUESTS, 0, None)
758 .is_none(),
759 "max_retries=0 should never retry"
760 );
761 }
762}