Skip to main content

argus_worker/
rate_limit.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use async_trait::async_trait;
6use tokio::sync::Mutex;
7
8#[async_trait]
9pub trait RateLimiter: Send + Sync {
10    /// Wait if needed so that at least `delay_ms` has passed since the last fetch for this host.
11    async fn wait_for_host(&self, host: &str, delay_ms: u64);
12}
13
14/// Per-process rate limit; each host is delayed based on local fetch times.
15#[derive(Clone, Default)]
16pub struct InMemoryRateLimiter {
17    last: Arc<Mutex<HashMap<String, Instant>>>,
18}
19
20#[async_trait]
21impl RateLimiter for InMemoryRateLimiter {
22    async fn wait_for_host(&self, host: &str, delay_ms: u64) {
23        let delay = Duration::from_millis(delay_ms);
24        let map = self.last.lock().await;
25        let last = map.get(host).copied();
26        drop(map);
27        if let Some(last) = last {
28            let elapsed = last.elapsed();
29            if elapsed < delay {
30                tokio::time::sleep(delay - elapsed).await;
31            }
32        }
33        let mut map = self.last.lock().await;
34        map.insert(host.to_string(), Instant::now());
35    }
36}
37
38#[cfg(test)]
39mod tests {
40    use super::*;
41
42    #[tokio::test]
43    async fn in_memory_rate_limiter_zero_delay_does_not_block() {
44        let limiter = InMemoryRateLimiter::default();
45        limiter.wait_for_host("example.com", 0).await;
46        limiter.wait_for_host("example.com", 0).await;
47    }
48
49    #[tokio::test]
50    async fn in_memory_rate_limiter_second_call_waits() {
51        let limiter = InMemoryRateLimiter::default();
52        let start = std::time::Instant::now();
53        limiter.wait_for_host("host", 50).await;
54        limiter.wait_for_host("host", 50).await;
55        let elapsed = start.elapsed();
56        assert!(
57            elapsed >= Duration::from_millis(40),
58            "second call should wait ~50ms, got {:?}",
59            elapsed
60        );
61    }
62}
63
64#[cfg(feature = "redis")]
65mod redis_limiter {
66    use std::sync::Arc;
67
68    use async_trait::async_trait;
69    use redis::aio::MultiplexedConnection;
70    use redis::AsyncCommands;
71    use tokio::sync::Mutex;
72
73    use super::RateLimiter;
74
75    const KEY_PREFIX: &str = "argus:rate:";
76
77    #[derive(Clone)]
78    pub struct RedisRateLimiter {
79        conn: Arc<Mutex<MultiplexedConnection>>,
80    }
81
82    impl RedisRateLimiter {
83        pub async fn connect(redis_url: &str) -> anyhow::Result<Self> {
84            let client = redis::Client::open(redis_url)?;
85            let conn = client.get_multiplexed_tokio_connection().await?;
86            Ok(Self {
87                conn: Arc::new(Mutex::new(conn)),
88            })
89        }
90    }
91
92    #[async_trait]
93    impl RateLimiter for RedisRateLimiter {
94        async fn wait_for_host(&self, host: &str, delay_ms: u64) {
95            let key = format!("{}{}", KEY_PREFIX, host);
96            let now_ms = std::time::SystemTime::now()
97                .duration_since(std::time::UNIX_EPOCH)
98                .unwrap_or_default()
99                .as_millis() as u64;
100
101            let mut conn = self.conn.lock().await;
102            let last_ms: Option<u64> = conn.get(&key).await.ok().flatten();
103            drop(conn);
104
105            if let Some(last) = last_ms {
106                let elapsed = now_ms.saturating_sub(last);
107                if elapsed < delay_ms {
108                    let wait_ms = delay_ms - elapsed;
109                    tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
110                }
111            }
112
113            let mut conn = self.conn.lock().await;
114            let set_now = std::time::SystemTime::now()
115                .duration_since(std::time::UNIX_EPOCH)
116                .unwrap_or_default()
117                .as_millis() as u64;
118            let _: Result<(), _> = conn.set(&key, set_now).await;
119        }
120    }
121}
122
123#[cfg(feature = "redis")]
124pub use redis_limiter::RedisRateLimiter;