argus_worker/
rate_limit.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use async_trait::async_trait;
6use tokio::sync::Mutex;
7
8#[async_trait]
9pub trait RateLimiter: Send + Sync {
10 async fn wait_for_host(&self, host: &str, delay_ms: u64);
12}
13
14#[derive(Clone, Default)]
16pub struct InMemoryRateLimiter {
17 last: Arc<Mutex<HashMap<String, Instant>>>,
18}
19
20#[async_trait]
21impl RateLimiter for InMemoryRateLimiter {
22 async fn wait_for_host(&self, host: &str, delay_ms: u64) {
23 let delay = Duration::from_millis(delay_ms);
24 let map = self.last.lock().await;
25 let last = map.get(host).copied();
26 drop(map);
27 if let Some(last) = last {
28 let elapsed = last.elapsed();
29 if elapsed < delay {
30 tokio::time::sleep(delay - elapsed).await;
31 }
32 }
33 let mut map = self.last.lock().await;
34 map.insert(host.to_string(), Instant::now());
35 }
36}
37
38#[cfg(test)]
39mod tests {
40 use super::*;
41
42 #[tokio::test]
43 async fn in_memory_rate_limiter_zero_delay_does_not_block() {
44 let limiter = InMemoryRateLimiter::default();
45 limiter.wait_for_host("example.com", 0).await;
46 limiter.wait_for_host("example.com", 0).await;
47 }
48
49 #[tokio::test]
50 async fn in_memory_rate_limiter_second_call_waits() {
51 let limiter = InMemoryRateLimiter::default();
52 let start = std::time::Instant::now();
53 limiter.wait_for_host("host", 50).await;
54 limiter.wait_for_host("host", 50).await;
55 let elapsed = start.elapsed();
56 assert!(
57 elapsed >= Duration::from_millis(40),
58 "second call should wait ~50ms, got {:?}",
59 elapsed
60 );
61 }
62}
63
64#[cfg(feature = "redis")]
65mod redis_limiter {
66 use std::sync::Arc;
67
68 use async_trait::async_trait;
69 use redis::aio::MultiplexedConnection;
70 use redis::AsyncCommands;
71 use tokio::sync::Mutex;
72
73 use super::RateLimiter;
74
75 const KEY_PREFIX: &str = "argus:rate:";
76
77 #[derive(Clone)]
78 pub struct RedisRateLimiter {
79 conn: Arc<Mutex<MultiplexedConnection>>,
80 }
81
82 impl RedisRateLimiter {
83 pub async fn connect(redis_url: &str) -> anyhow::Result<Self> {
84 let client = redis::Client::open(redis_url)?;
85 let conn = client.get_multiplexed_tokio_connection().await?;
86 Ok(Self {
87 conn: Arc::new(Mutex::new(conn)),
88 })
89 }
90 }
91
92 #[async_trait]
93 impl RateLimiter for RedisRateLimiter {
94 async fn wait_for_host(&self, host: &str, delay_ms: u64) {
95 let key = format!("{}{}", KEY_PREFIX, host);
96 let now_ms = std::time::SystemTime::now()
97 .duration_since(std::time::UNIX_EPOCH)
98 .unwrap_or_default()
99 .as_millis() as u64;
100
101 let mut conn = self.conn.lock().await;
102 let last_ms: Option<u64> = conn.get(&key).await.ok().flatten();
103 drop(conn);
104
105 if let Some(last) = last_ms {
106 let elapsed = now_ms.saturating_sub(last);
107 if elapsed < delay_ms {
108 let wait_ms = delay_ms - elapsed;
109 tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
110 }
111 }
112
113 let mut conn = self.conn.lock().await;
114 let set_now = std::time::SystemTime::now()
115 .duration_since(std::time::UNIX_EPOCH)
116 .unwrap_or_default()
117 .as_millis() as u64;
118 let _: Result<(), _> = conn.set(&key, set_now).await;
119 }
120 }
121}
122
123#[cfg(feature = "redis")]
124pub use redis_limiter::RedisRateLimiter;