Skip to main content

px_core/exchange/
rate_limit.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::{OwnedSemaphorePermit, Semaphore};
5use tokio::time::sleep;
6
7pub struct RateLimiter {
8    last_request: Instant,
9    min_interval: Duration,
10}
11
12impl RateLimiter {
13    pub fn new(requests_per_second: u32) -> Self {
14        let min_interval = if requests_per_second > 0 {
15            Duration::from_secs_f64(1.0 / requests_per_second as f64)
16        } else {
17            Duration::ZERO
18        };
19
20        Self {
21            last_request: Instant::now() - min_interval,
22            min_interval,
23        }
24    }
25
26    pub async fn wait(&mut self) {
27        let elapsed = self.last_request.elapsed();
28        if elapsed < self.min_interval {
29            let wait_time = self.min_interval - elapsed;
30            sleep(wait_time).await;
31        }
32        self.last_request = Instant::now();
33    }
34}
35
36/// A concurrent rate limiter that enforces a global rate limit across multiple
37/// concurrent streams. Uses a semaphore for concurrency control and an atomic
38/// timestamp to ensure min_interval between ANY two requests globally.
39/// Lock-free: uses AtomicU64 CAS loop instead of a mutex for the timestamp.
40pub struct ConcurrentRateLimiter {
41    semaphore: Arc<Semaphore>,
42    /// Nanoseconds since `epoch` when the next request is allowed.
43    next_allowed_nanos: AtomicU64,
44    /// Reference instant for converting between Instant and u64 nanos.
45    epoch: Instant,
46    min_interval_nanos: u64,
47}
48
49impl ConcurrentRateLimiter {
50    /// Create a new concurrent rate limiter.
51    ///
52    /// # Arguments
53    /// * `requests_per_second` - Target requests per second rate limit
54    /// * `max_concurrent` - Maximum concurrent requests allowed
55    pub fn new(requests_per_second: u32, max_concurrent: usize) -> Self {
56        let min_interval = if requests_per_second > 0 {
57            Duration::from_secs_f64(1.0 / requests_per_second as f64)
58        } else {
59            Duration::ZERO
60        };
61
62        let epoch = Instant::now();
63
64        Self {
65            semaphore: Arc::new(Semaphore::new(max_concurrent)),
66            next_allowed_nanos: AtomicU64::new(0),
67            epoch,
68            min_interval_nanos: min_interval.as_nanos() as u64,
69        }
70    }
71
72    /// Acquire a rate limit permit. Waits for both:
73    /// 1. A semaphore permit (concurrency limit)
74    /// 2. The global rate limit interval since last request
75    pub async fn acquire(&self) -> OwnedSemaphorePermit {
76        // First acquire semaphore permit for concurrency control
77        // Safety: semaphore is never closed (we hold an Arc to it).
78        // If it were closed (e.g., memory corruption), panic is appropriate.
79        let permit = self
80            .semaphore
81            .clone()
82            .acquire_owned()
83            .await
84            .expect("ConcurrentRateLimiter semaphore unexpectedly closed");
85
86        // Reserve the next globally-allowed send slot via atomic CAS loop,
87        // then sleep outside the atomic to avoid serializing concurrent waiters.
88        let wait_nanos = loop {
89            let now_nanos = self.epoch.elapsed().as_nanos() as u64;
90            let current = self.next_allowed_nanos.load(Ordering::Acquire);
91            let scheduled = if now_nanos >= current {
92                now_nanos
93            } else {
94                current
95            };
96            let next = scheduled + self.min_interval_nanos;
97            match self.next_allowed_nanos.compare_exchange_weak(
98                current,
99                next,
100                Ordering::AcqRel,
101                Ordering::Acquire,
102            ) {
103                Ok(_) => break scheduled.saturating_sub(now_nanos),
104                Err(_) => continue, // Another thread won the CAS, retry
105            }
106        };
107
108        if wait_nanos > 0 {
109            sleep(Duration::from_nanos(wait_nanos)).await;
110        }
111
112        permit
113    }
114}
115
116use crate::exchange::manifest::{RateLimitCategory, RateLimitConfig};
117
118/// Holds one `RateLimiter` per endpoint category for per-category rate limiting.
119/// Indexed by `RateLimitCategory` discriminant for O(1) lookup.
120pub struct CategoryRateLimiter {
121    limiters: [tokio::sync::Mutex<RateLimiter>; RateLimitCategory::COUNT],
122}
123
124impl CategoryRateLimiter {
125    /// Build from a manifest's `RateLimitConfig`.
126    pub fn from_config(config: &RateLimitConfig) -> Self {
127        let limiters = RateLimitCategory::ALL.map(|cat| {
128            let rps = config.rps(cat);
129            tokio::sync::Mutex::new(RateLimiter::new(rps))
130        });
131        Self { limiters }
132    }
133
134    /// Wait for the rate limiter of the given category.
135    pub async fn wait(&self, category: RateLimitCategory) {
136        self.limiters[category as usize].lock().await.wait().await;
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[tokio::test]
145    async fn test_rate_limiter_respects_interval() {
146        let mut limiter = RateLimiter::new(10);
147        let start = Instant::now();
148
149        limiter.wait().await;
150        limiter.wait().await;
151
152        let elapsed = start.elapsed();
153        assert!(elapsed >= Duration::from_millis(90));
154    }
155}