Skip to main content

agent_fetch/
rate_limit.rs

1use std::sync::Mutex;
2use std::time::Instant;
3
4use tokio::sync::Semaphore;
5
6use crate::error::FetchError;
7
8/// Simple sliding-window rate limiter with a concurrency semaphore.
9pub struct RateLimiter {
10    global_max_per_minute: u32,
11    state: Mutex<Vec<Instant>>,
12    concurrency: Semaphore,
13}
14
15impl RateLimiter {
16    pub fn new(max_per_minute: u32, max_concurrent: usize) -> Self {
17        Self {
18            global_max_per_minute: max_per_minute,
19            state: Mutex::new(Vec::new()),
20            concurrency: Semaphore::new(max_concurrent),
21        }
22    }
23
24    /// Check whether a request to `domain` is allowed.
25    /// Returns a permit that must be held for the duration of the request.
26    pub async fn acquire(
27        &self,
28        _domain: &str,
29    ) -> Result<tokio::sync::SemaphorePermit<'_>, FetchError> {
30        let permit = self
31            .concurrency
32            .try_acquire()
33            .map_err(|_| FetchError::RateLimitExceeded)?;
34
35        {
36            let mut timestamps = self.state.lock().unwrap();
37            let now = Instant::now();
38            let one_minute_ago = now - std::time::Duration::from_secs(60);
39
40            timestamps.retain(|t| *t > one_minute_ago);
41
42            if timestamps.len() as u32 >= self.global_max_per_minute {
43                drop(permit);
44                return Err(FetchError::RateLimitExceeded);
45            }
46
47            timestamps.push(now);
48        }
49
50        Ok(permit)
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57
58    #[tokio::test]
59    async fn allows_within_limit() {
60        let rl = RateLimiter::new(10, 5);
61        for _ in 0..10 {
62            assert!(rl.acquire("example.com").await.is_ok());
63        }
64    }
65
66    #[tokio::test]
67    async fn rejects_over_limit() {
68        let rl = RateLimiter::new(3, 100);
69        for _ in 0..3 {
70            let _permit = rl.acquire("example.com").await.unwrap();
71            // permit is dropped immediately, freeing concurrency slot
72        }
73        assert!(rl.acquire("example.com").await.is_err());
74    }
75
76    #[tokio::test]
77    async fn rejects_over_concurrency() {
78        let rl = RateLimiter::new(100, 2);
79        let _p1 = rl.acquire("a.com").await.unwrap();
80        let _p2 = rl.acquire("b.com").await.unwrap();
81        // Third should fail — concurrency limit reached
82        assert!(rl.acquire("c.com").await.is_err());
83    }
84}