agent_fetch/
rate_limit.rs1use std::sync::Mutex;
2use std::time::Instant;
3
4use tokio::sync::Semaphore;
5
6use crate::error::FetchError;
7
8pub struct RateLimiter {
10 global_max_per_minute: u32,
11 state: Mutex<Vec<Instant>>,
12 concurrency: Semaphore,
13}
14
15impl RateLimiter {
16 pub fn new(max_per_minute: u32, max_concurrent: usize) -> Self {
17 Self {
18 global_max_per_minute: max_per_minute,
19 state: Mutex::new(Vec::new()),
20 concurrency: Semaphore::new(max_concurrent),
21 }
22 }
23
24 pub async fn acquire(
27 &self,
28 _domain: &str,
29 ) -> Result<tokio::sync::SemaphorePermit<'_>, FetchError> {
30 let permit = self
31 .concurrency
32 .try_acquire()
33 .map_err(|_| FetchError::RateLimitExceeded)?;
34
35 {
36 let mut timestamps = self.state.lock().unwrap();
37 let now = Instant::now();
38 let one_minute_ago = now - std::time::Duration::from_secs(60);
39
40 timestamps.retain(|t| *t > one_minute_ago);
41
42 if timestamps.len() as u32 >= self.global_max_per_minute {
43 drop(permit);
44 return Err(FetchError::RateLimitExceeded);
45 }
46
47 timestamps.push(now);
48 }
49
50 Ok(permit)
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57
58 #[tokio::test]
59 async fn allows_within_limit() {
60 let rl = RateLimiter::new(10, 5);
61 for _ in 0..10 {
62 assert!(rl.acquire("example.com").await.is_ok());
63 }
64 }
65
66 #[tokio::test]
67 async fn rejects_over_limit() {
68 let rl = RateLimiter::new(3, 100);
69 for _ in 0..3 {
70 let _permit = rl.acquire("example.com").await.unwrap();
71 }
73 assert!(rl.acquire("example.com").await.is_err());
74 }
75
76 #[tokio::test]
77 async fn rejects_over_concurrency() {
78 let rl = RateLimiter::new(100, 2);
79 let _p1 = rl.acquire("a.com").await.unwrap();
80 let _p2 = rl.acquire("b.com").await.unwrap();
81 assert!(rl.acquire("c.com").await.is_err());
83 }
84}