px_core/exchange/
rate_limit.rs1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::{OwnedSemaphorePermit, Semaphore};
5use tokio::time::sleep;
6
7pub struct RateLimiter {
8 last_request: Instant,
9 min_interval: Duration,
10}
11
12impl RateLimiter {
13 pub fn new(requests_per_second: u32) -> Self {
14 let min_interval = if requests_per_second > 0 {
15 Duration::from_secs_f64(1.0 / requests_per_second as f64)
16 } else {
17 Duration::ZERO
18 };
19
20 Self {
21 last_request: Instant::now() - min_interval,
22 min_interval,
23 }
24 }
25
26 pub async fn wait(&mut self) {
27 let elapsed = self.last_request.elapsed();
28 if elapsed < self.min_interval {
29 let wait_time = self.min_interval - elapsed;
30 sleep(wait_time).await;
31 }
32 self.last_request = Instant::now();
33 }
34}
35
36pub struct ConcurrentRateLimiter {
41 semaphore: Arc<Semaphore>,
42 next_allowed_nanos: AtomicU64,
44 epoch: Instant,
46 min_interval_nanos: u64,
47}
48
49impl ConcurrentRateLimiter {
50 pub fn new(requests_per_second: u32, max_concurrent: usize) -> Self {
56 let min_interval = if requests_per_second > 0 {
57 Duration::from_secs_f64(1.0 / requests_per_second as f64)
58 } else {
59 Duration::ZERO
60 };
61
62 let epoch = Instant::now();
63
64 Self {
65 semaphore: Arc::new(Semaphore::new(max_concurrent)),
66 next_allowed_nanos: AtomicU64::new(0),
67 epoch,
68 min_interval_nanos: min_interval.as_nanos() as u64,
69 }
70 }
71
72 pub async fn acquire(&self) -> OwnedSemaphorePermit {
76 let permit = self
80 .semaphore
81 .clone()
82 .acquire_owned()
83 .await
84 .expect("ConcurrentRateLimiter semaphore unexpectedly closed");
85
86 let wait_nanos = loop {
89 let now_nanos = self.epoch.elapsed().as_nanos() as u64;
90 let current = self.next_allowed_nanos.load(Ordering::Acquire);
91 let scheduled = if now_nanos >= current {
92 now_nanos
93 } else {
94 current
95 };
96 let next = scheduled + self.min_interval_nanos;
97 match self.next_allowed_nanos.compare_exchange_weak(
98 current,
99 next,
100 Ordering::AcqRel,
101 Ordering::Acquire,
102 ) {
103 Ok(_) => break scheduled.saturating_sub(now_nanos),
104 Err(_) => continue, }
106 };
107
108 if wait_nanos > 0 {
109 sleep(Duration::from_nanos(wait_nanos)).await;
110 }
111
112 permit
113 }
114}
115
116use crate::exchange::manifest::{RateLimitCategory, RateLimitConfig};
117
118pub struct CategoryRateLimiter {
121 limiters: [tokio::sync::Mutex<RateLimiter>; RateLimitCategory::COUNT],
122}
123
124impl CategoryRateLimiter {
125 pub fn from_config(config: &RateLimitConfig) -> Self {
127 let limiters = RateLimitCategory::ALL.map(|cat| {
128 let rps = config.rps(cat);
129 tokio::sync::Mutex::new(RateLimiter::new(rps))
130 });
131 Self { limiters }
132 }
133
134 pub async fn wait(&self, category: RateLimitCategory) {
136 self.limiters[category as usize].lock().await.wait().await;
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[tokio::test]
145 async fn test_rate_limiter_respects_interval() {
146 let mut limiter = RateLimiter::new(10);
147 let start = Instant::now();
148
149 limiter.wait().await;
150 limiter.wait().await;
151
152 let elapsed = start.elapsed();
153 assert!(elapsed >= Duration::from_millis(90));
154 }
155}