Skip to main content

mkt_core/http/
rate_limit.rs

1//! Token-bucket rate limiter for API calls.
2
3use std::sync::Arc;
4
5use tokio::sync::Semaphore;
6
7/// A simple rate limiter backed by a `tokio::sync::Semaphore`.
8///
9/// Each API call acquires a number of permits (1 for reads, 3 for writes).
10/// Permits are released on a timer to maintain the desired rate.
11#[derive(Debug, Clone)]
12pub struct RateLimiter {
13    semaphore: Arc<Semaphore>,
14}
15
16impl RateLimiter {
17    /// Create a new rate limiter with the given maximum concurrent permits.
18    pub fn new(max_permits: usize) -> Self {
19        Self {
20            semaphore: Arc::new(Semaphore::new(max_permits)),
21        }
22    }
23
24    /// Acquire `cost` permits. Waits if insufficient permits are available.
25    ///
26    /// # Errors
27    ///
28    /// Returns an error if the semaphore is closed.
29    #[allow(clippy::significant_drop_tightening)] // permit is consumed by .forget()
30    pub async fn acquire(&self, cost: u32) -> crate::error::Result<()> {
31        let permit =
32            self.semaphore.acquire_many(cost).await.map_err(|e| {
33                crate::error::MktError::ConfigError(format!("Rate limiter error: {e}"))
34            })?;
35
36        // Forget the permit immediately — it will be replenished below.
37        // For simplicity, we immediately release it. This means the limiter acts as
38        // a concurrency limiter rather than a strict rate limiter. A stricter
39        // implementation can use a background task to drain and refill.
40        permit.forget();
41
42        self.semaphore.add_permits(cost as usize);
43
44        Ok(())
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use super::*;
51
52    #[tokio::test]
53    async fn acquire_within_budget() {
54        let limiter = RateLimiter::new(100);
55        let result = limiter.acquire(1).await;
56        assert!(result.is_ok());
57    }
58
59    #[tokio::test]
60    async fn acquire_multiple_permits() {
61        let limiter = RateLimiter::new(100);
62        let result = limiter.acquire(3).await;
63        assert!(result.is_ok());
64    }
65}