Skip to main content

essence/crawler/
rate_limiter.rs

1use governor::{
2    clock::DefaultClock,
3    state::{InMemoryState, NotKeyed},
4    Quota, RateLimiter,
5};
6use std::collections::HashMap;
7use std::num::NonZeroU32;
8use std::sync::{Arc, Mutex};
9use url::Url;
10
11type DomainLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
12
13/// Per-domain rate limiter to ensure respectful crawling
14pub struct DomainRateLimiter {
15    limiters: Arc<Mutex<HashMap<String, Arc<DomainLimiter>>>>,
16    default_quota: Quota,
17}
18
19impl DomainRateLimiter {
20    /// Create a new rate limiter with a default requests-per-second limit
21    pub fn new(requests_per_second: u32) -> Self {
22        let quota = Quota::per_second(
23            NonZeroU32::new(requests_per_second).unwrap_or(NonZeroU32::new(2).unwrap()),
24        );
25
26        Self {
27            limiters: Arc::new(Mutex::new(HashMap::new())),
28            default_quota: quota,
29        }
30    }
31
32    /// Wait until we're allowed to make a request to this domain
33    pub async fn wait_for_permission(&self, url: &str) -> Result<(), String> {
34        let domain = Self::extract_domain(url)?;
35
36        let limiter = {
37            let mut limiters = self
38                .limiters
39                .lock()
40                .map_err(|e| format!("Failed to acquire lock: {}", e))?;
41            
42            limiters
43                .entry(domain.clone())
44                .or_insert_with(|| Arc::new(RateLimiter::direct(self.default_quota)))
45                .clone()
46        };
47
48        // Wait until we have permission (non-blocking in async context)
49        limiter.until_ready().await;
50
51        tracing::debug!(
52            "Rate limiter: Granted permission for domain: {}",
53            domain
54        );
55
56        Ok(())
57    }
58
59    /// Extract domain from URL
60    fn extract_domain(url: &str) -> Result<String, String> {
61        let parsed = Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
62
63        parsed
64            .host_str()
65            .map(|h| h.to_string())
66            .ok_or_else(|| "No host in URL".to_string())
67    }
68
69    /// Set custom rate for specific domain
70    pub fn set_domain_rate(&self, domain: &str, requests_per_second: u32) {
71        let quota = Quota::per_second(
72            NonZeroU32::new(requests_per_second).unwrap_or(NonZeroU32::new(1).unwrap()),
73        );
74        let limiter = Arc::new(RateLimiter::direct(quota));
75
76        if let Ok(mut limiters) = self.limiters.lock() {
77            limiters.insert(domain.to_string(), limiter);
78            tracing::info!(
79                "Set custom rate limit for {}: {} req/sec",
80                domain,
81                requests_per_second
82            );
83        }
84    }
85}
86
87impl Default for DomainRateLimiter {
88    fn default() -> Self {
89        Self::new(2) // 2 requests/second default
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use std::time::Instant;
97
98    #[tokio::test]
99    async fn test_rate_limiter_enforces_delay() {
100        let limiter = DomainRateLimiter::new(2); // 2 req/sec
101
102        let start = Instant::now();
103
104        // Make 3 requests to same domain
105        limiter
106            .wait_for_permission("https://example.com/1")
107            .await
108            .unwrap();
109        limiter
110            .wait_for_permission("https://example.com/2")
111            .await
112            .unwrap();
113        limiter
114            .wait_for_permission("https://example.com/3")
115            .await
116            .unwrap();
117
118        let elapsed = start.elapsed();
119
120        // Should take at least 500ms (3 requests at 2/sec = 2 allowed immediately, 1 delayed)
121        assert!(
122            elapsed.as_millis() >= 400,
123            "Expected at least 400ms delay, got {}ms",
124            elapsed.as_millis()
125        );
126    }
127
128    #[tokio::test]
129    async fn test_different_domains_not_limited() {
130        let limiter = DomainRateLimiter::new(1); // 1 req/sec
131
132        let start = Instant::now();
133
134        // Different domains shouldn't affect each other
135        limiter
136            .wait_for_permission("https://example.com")
137            .await
138            .unwrap();
139        limiter
140            .wait_for_permission("https://other.com")
141            .await
142            .unwrap();
143
144        let elapsed = start.elapsed();
145
146        // Should be instant (different domains) - allow some overhead for timing variability
147        assert!(
148            elapsed.as_millis() < 300,
149            "Different domains should not block each other, got {}ms",
150            elapsed.as_millis()
151        );
152    }
153
154    #[tokio::test]
155    async fn test_custom_domain_rate() {
156        let limiter = DomainRateLimiter::new(10); // 10 req/sec default
157
158        // Set custom rate for specific domain
159        limiter.set_domain_rate("slow.example.com", 1); // 1 req/sec
160
161        let start = Instant::now();
162
163        // Make 2 requests to slow domain
164        limiter
165            .wait_for_permission("https://slow.example.com/1")
166            .await
167            .unwrap();
168        limiter
169            .wait_for_permission("https://slow.example.com/2")
170            .await
171            .unwrap();
172
173        let elapsed = start.elapsed();
174
175        // Should be delayed by custom rate (1 req/sec)
176        assert!(
177            elapsed.as_millis() >= 800,
178            "Custom rate should be enforced, got {}ms",
179            elapsed.as_millis()
180        );
181    }
182
183    #[test]
184    fn test_extract_domain() {
185        assert_eq!(
186            DomainRateLimiter::extract_domain("https://example.com/path").unwrap(),
187            "example.com"
188        );
189        assert_eq!(
190            DomainRateLimiter::extract_domain("https://sub.example.com").unwrap(),
191            "sub.example.com"
192        );
193        assert!(DomainRateLimiter::extract_domain("invalid").is_err());
194    }
195}