essence/crawler/
rate_limiter.rs1use governor::{
2 clock::DefaultClock,
3 state::{InMemoryState, NotKeyed},
4 Quota, RateLimiter,
5};
6use std::collections::HashMap;
7use std::num::NonZeroU32;
8use std::sync::{Arc, Mutex};
9use url::Url;
10
11type DomainLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
12
13pub struct DomainRateLimiter {
15 limiters: Arc<Mutex<HashMap<String, Arc<DomainLimiter>>>>,
16 default_quota: Quota,
17}
18
19impl DomainRateLimiter {
20 pub fn new(requests_per_second: u32) -> Self {
22 let quota = Quota::per_second(
23 NonZeroU32::new(requests_per_second).unwrap_or(NonZeroU32::new(2).unwrap()),
24 );
25
26 Self {
27 limiters: Arc::new(Mutex::new(HashMap::new())),
28 default_quota: quota,
29 }
30 }
31
32 pub async fn wait_for_permission(&self, url: &str) -> Result<(), String> {
34 let domain = Self::extract_domain(url)?;
35
36 let limiter = {
37 let mut limiters = self
38 .limiters
39 .lock()
40 .map_err(|e| format!("Failed to acquire lock: {}", e))?;
41
42 limiters
43 .entry(domain.clone())
44 .or_insert_with(|| Arc::new(RateLimiter::direct(self.default_quota)))
45 .clone()
46 };
47
48 limiter.until_ready().await;
50
51 tracing::debug!(
52 "Rate limiter: Granted permission for domain: {}",
53 domain
54 );
55
56 Ok(())
57 }
58
59 fn extract_domain(url: &str) -> Result<String, String> {
61 let parsed = Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
62
63 parsed
64 .host_str()
65 .map(|h| h.to_string())
66 .ok_or_else(|| "No host in URL".to_string())
67 }
68
69 pub fn set_domain_rate(&self, domain: &str, requests_per_second: u32) {
71 let quota = Quota::per_second(
72 NonZeroU32::new(requests_per_second).unwrap_or(NonZeroU32::new(1).unwrap()),
73 );
74 let limiter = Arc::new(RateLimiter::direct(quota));
75
76 if let Ok(mut limiters) = self.limiters.lock() {
77 limiters.insert(domain.to_string(), limiter);
78 tracing::info!(
79 "Set custom rate limit for {}: {} req/sec",
80 domain,
81 requests_per_second
82 );
83 }
84 }
85}
86
87impl Default for DomainRateLimiter {
88 fn default() -> Self {
89 Self::new(2) }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use std::time::Instant;
97
98 #[tokio::test]
99 async fn test_rate_limiter_enforces_delay() {
100 let limiter = DomainRateLimiter::new(2); let start = Instant::now();
103
104 limiter
106 .wait_for_permission("https://example.com/1")
107 .await
108 .unwrap();
109 limiter
110 .wait_for_permission("https://example.com/2")
111 .await
112 .unwrap();
113 limiter
114 .wait_for_permission("https://example.com/3")
115 .await
116 .unwrap();
117
118 let elapsed = start.elapsed();
119
120 assert!(
122 elapsed.as_millis() >= 400,
123 "Expected at least 400ms delay, got {}ms",
124 elapsed.as_millis()
125 );
126 }
127
128 #[tokio::test]
129 async fn test_different_domains_not_limited() {
130 let limiter = DomainRateLimiter::new(1); let start = Instant::now();
133
134 limiter
136 .wait_for_permission("https://example.com")
137 .await
138 .unwrap();
139 limiter
140 .wait_for_permission("https://other.com")
141 .await
142 .unwrap();
143
144 let elapsed = start.elapsed();
145
146 assert!(
148 elapsed.as_millis() < 300,
149 "Different domains should not block each other, got {}ms",
150 elapsed.as_millis()
151 );
152 }
153
154 #[tokio::test]
155 async fn test_custom_domain_rate() {
156 let limiter = DomainRateLimiter::new(10); limiter.set_domain_rate("slow.example.com", 1); let start = Instant::now();
162
163 limiter
165 .wait_for_permission("https://slow.example.com/1")
166 .await
167 .unwrap();
168 limiter
169 .wait_for_permission("https://slow.example.com/2")
170 .await
171 .unwrap();
172
173 let elapsed = start.elapsed();
174
175 assert!(
177 elapsed.as_millis() >= 800,
178 "Custom rate should be enforced, got {}ms",
179 elapsed.as_millis()
180 );
181 }
182
183 #[test]
184 fn test_extract_domain() {
185 assert_eq!(
186 DomainRateLimiter::extract_domain("https://example.com/path").unwrap(),
187 "example.com"
188 );
189 assert_eq!(
190 DomainRateLimiter::extract_domain("https://sub.example.com").unwrap(),
191 "sub.example.com"
192 );
193 assert!(DomainRateLimiter::extract_domain("invalid").is_err());
194 }
195}