1use governor::{
4 clock::DefaultClock,
5 state::{InMemoryState, NotKeyed},
6 Quota, RateLimiter,
7};
8use reqwest::{header, Client, StatusCode};
9use std::num::NonZeroU32;
10use std::path::Path;
11use std::sync::Arc;
12use std::time::Duration;
13
14use crate::models::{DownloadRequest, DownloadResult};
15use crate::sources::SourceError;
16
17const DEFAULT_REQUESTS_PER_SECOND: u32 = 5;
19
20const RATE_LIMIT_ENV_VAR: &str = "RESEARCH_MASTER_RATE_LIMITS_DEFAULT_REQUESTS_PER_SECOND";
22
23const HTTP_PROXY_ENV_VAR: &str = "HTTP_PROXY";
25
26const HTTPS_PROXY_ENV_VAR: &str = "HTTPS_PROXY";
28
29const NO_PROXY_ENV_VAR: &str = "NO_PROXY";
31
32#[derive(Debug, Clone, Default)]
34pub struct ProxyConfig {
35 pub http_proxy: Option<String>,
36 pub https_proxy: Option<String>,
37 pub no_proxy: Option<Vec<String>>,
38}
39
40pub fn create_proxy_config() -> ProxyConfig {
42 let http_proxy = std::env::var(HTTP_PROXY_ENV_VAR).ok();
43 let https_proxy = std::env::var(HTTPS_PROXY_ENV_VAR).ok();
44 let no_proxy: Option<Vec<String>> = std::env::var(NO_PROXY_ENV_VAR)
45 .ok()
46 .map(|s| s.split(',').map(|v| v.trim().to_string()).collect());
47
48 if http_proxy.is_some() || https_proxy.is_some() {
49 tracing::info!(
50 "Proxy configured: HTTP={:?}, HTTPS={:?}, NO_PROXY={:?}",
51 http_proxy,
52 https_proxy,
53 no_proxy
54 );
55 }
56
57 ProxyConfig {
58 http_proxy,
59 https_proxy,
60 no_proxy,
61 }
62}
63
64fn should_bypass_proxy(url: &str, no_proxy: &Option<Vec<String>>) -> bool {
66 let Some(hosts) = no_proxy else {
67 return false;
68 };
69
70 if hosts.iter().any(|h| h == "*") {
71 return true;
72 }
73
74 if let Ok(url) = reqwest::Url::parse(url) {
76 let host = url.host_str().map(|h| h.to_lowercase());
77 if let Some(host) = host {
78 for no_proxy_host in hosts {
80 if host == no_proxy_host.to_lowercase() {
81 return true;
82 }
83 if host.ends_with(&format!(".{}", no_proxy_host.to_lowercase())) {
85 return true;
86 }
87 }
88 }
89 }
90
91 false
92}
93
94#[derive(Debug, Clone)]
96pub struct HttpClient {
97 client: Arc<Client>,
98 rate_limiter: Option<Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>>,
99 no_proxy: Option<Vec<String>>,
100}
101
102pub struct RateLimitedRequestBuilder {
104 inner: reqwest::RequestBuilder,
105 rate_limiter: Option<Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>>,
106}
107
108impl RateLimitedRequestBuilder {
109 pub async fn send(self) -> Result<reqwest::Response, reqwest::Error> {
111 if let Some(ref limiter) = self.rate_limiter {
112 limiter.until_ready().await;
113 }
114 self.inner.send().await
115 }
116
117 pub fn header<K, V>(mut self, key: K, value: V) -> Self
119 where
120 K: AsRef<str>,
121 V: AsRef<str>,
122 {
123 self.inner = self.inner.header(key.as_ref(), value.as_ref());
124 self
125 }
126
127 pub fn headers(mut self, headers: header::HeaderMap) -> Self {
129 self.inner = self.inner.headers(headers);
130 self
131 }
132
133 pub fn basic_auth<U, P>(self, username: U, password: Option<P>) -> Self
135 where
136 U: Into<String> + std::fmt::Display,
137 P: Into<String> + std::fmt::Display,
138 {
139 Self {
140 inner: self.inner.basic_auth(username, password),
141 rate_limiter: self.rate_limiter,
142 }
143 }
144
145 pub fn bearer_auth<T>(self, token: T) -> Self
147 where
148 T: Into<String> + std::fmt::Display,
149 {
150 Self {
151 inner: self.inner.bearer_auth(token),
152 rate_limiter: self.rate_limiter,
153 }
154 }
155
156 pub fn query<T: serde::Serialize + ?Sized>(mut self, query: &T) -> Self {
158 self.inner = self.inner.query(query);
159 self
160 }
161
162 pub fn form<T: serde::Serialize + ?Sized>(mut self, form: &T) -> Self {
164 self.inner = self.inner.form(form);
165 self
166 }
167
168 pub fn json<T: serde::Serialize + ?Sized>(mut self, json: &T) -> Self {
170 self.inner = self.inner.json(json);
171 self
172 }
173
174 pub fn build(self) -> Result<reqwest::Request, reqwest::Error> {
176 self.inner.build()
177 }
178}
179
180impl HttpClient {
181 pub fn new() -> Result<Self, SourceError> {
183 Self::with_user_agent(concat!(
184 env!("CARGO_PKG_NAME"),
185 "/",
186 env!("CARGO_PKG_VERSION")
187 ))
188 }
189
190 pub fn with_user_agent(user_agent: &str) -> Result<Self, SourceError> {
192 let rate_limiter = Self::create_rate_limiter();
193 let proxy = create_proxy_config();
194
195 let mut builder = Client::builder()
196 .user_agent(user_agent)
197 .timeout(Duration::from_secs(30))
198 .connect_timeout(Duration::from_secs(10))
199 .pool_idle_timeout(Duration::from_secs(90));
200
201 if let Some(proxy_url) = proxy.http_proxy {
203 builder = builder.proxy(reqwest::Proxy::http(&proxy_url)?);
204 }
205 if let Some(proxy_url) = proxy.https_proxy {
206 builder = builder.proxy(reqwest::Proxy::https(&proxy_url)?);
207 }
208
209 let client = builder
210 .build()
211 .map_err(|e| SourceError::Network(format!("Failed to create HTTP client: {}", e)))?;
212
213 Ok(Self {
214 client: Arc::new(client),
215 rate_limiter,
216 no_proxy: proxy.no_proxy,
217 })
218 }
219
220 pub fn without_rate_limit(user_agent: &str) -> Result<Self, SourceError> {
222 let proxy = create_proxy_config();
223 let mut builder = Client::builder()
224 .user_agent(user_agent)
225 .timeout(Duration::from_secs(30))
226 .connect_timeout(Duration::from_secs(10))
227 .pool_idle_timeout(Duration::from_secs(90));
228
229 if let Some(proxy_url) = proxy.http_proxy {
230 builder = builder.proxy(reqwest::Proxy::http(&proxy_url)?);
231 }
232 if let Some(proxy_url) = proxy.https_proxy {
233 builder = builder.proxy(reqwest::Proxy::https(&proxy_url)?);
234 }
235
236 let client = builder
237 .build()
238 .map_err(|e| SourceError::Network(format!("Failed to create HTTP client: {}", e)))?;
239
240 Ok(Self {
241 client: Arc::new(client),
242 rate_limiter: None,
243 no_proxy: proxy.no_proxy,
244 })
245 }
246
247 pub fn should_bypass_proxy(&self, url: &str) -> bool {
249 should_bypass_proxy(url, &self.no_proxy)
250 }
251
252 pub fn with_rate_limit(
254 user_agent: &str,
255 requests_per_second: u32,
256 ) -> Result<Self, SourceError> {
257 let rate_limiter = if requests_per_second == 0 {
258 None
259 } else {
260 let nonzero = NonZeroU32::new(requests_per_second)
261 .expect("requests_per_second should be > 0 when not 0 branch");
262 let quota = Quota::per_second(nonzero);
263 Some(Arc::new(RateLimiter::direct(quota)))
264 };
265
266 let proxy = create_proxy_config();
267 let mut builder = Client::builder()
268 .user_agent(user_agent)
269 .timeout(Duration::from_secs(30))
270 .connect_timeout(Duration::from_secs(10))
271 .pool_idle_timeout(Duration::from_secs(90));
272
273 if let Some(proxy_url) = proxy.http_proxy {
274 builder = builder.proxy(reqwest::Proxy::http(&proxy_url)?);
275 }
276 if let Some(proxy_url) = proxy.https_proxy {
277 builder = builder.proxy(reqwest::Proxy::https(&proxy_url)?);
278 }
279
280 let client = builder
281 .build()
282 .map_err(|e| SourceError::Network(format!("Failed to create HTTP client: {}", e)))?;
283
284 Ok(Self {
285 client: Arc::new(client),
286 rate_limiter,
287 no_proxy: proxy.no_proxy,
288 })
289 }
290
291 pub fn with_proxy(
293 user_agent: &str,
294 http_proxy: Option<String>,
295 https_proxy: Option<String>,
296 requests_per_second: u32,
297 ) -> Result<Self, SourceError> {
298 let rate_limiter = if requests_per_second == 0 {
299 None
300 } else {
301 let nonzero = NonZeroU32::new(requests_per_second)
302 .expect("requests_per_second should be > 0 when not 0 branch");
303 let quota = Quota::per_second(nonzero);
304 Some(Arc::new(RateLimiter::direct(quota)))
305 };
306
307 let mut builder = Client::builder()
308 .user_agent(user_agent)
309 .timeout(Duration::from_secs(30))
310 .connect_timeout(Duration::from_secs(10))
311 .pool_idle_timeout(Duration::from_secs(90));
312
313 if let Some(proxy_url) = http_proxy {
314 builder = builder.proxy(reqwest::Proxy::http(&proxy_url)?);
315 }
316 if let Some(proxy_url) = https_proxy {
317 builder = builder.proxy(reqwest::Proxy::https(&proxy_url)?);
318 }
319
320 let client = builder
321 .build()
322 .map_err(|e| SourceError::Network(format!("Failed to create HTTP client: {}", e)))?;
323
324 Ok(Self {
325 client: Arc::new(client),
326 rate_limiter,
327 no_proxy: None, })
329 }
330
331 fn create_rate_limiter() -> Option<Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>> {
333 let requests_per_second = std::env::var(RATE_LIMIT_ENV_VAR)
334 .ok()
335 .and_then(|s| s.parse::<u32>().ok())
336 .unwrap_or(DEFAULT_REQUESTS_PER_SECOND);
337
338 if requests_per_second == 0 {
339 tracing::info!("Rate limiting disabled");
341 return None;
342 }
343
344 let nonzero =
345 NonZeroU32::new(requests_per_second).expect("requests_per_second should not be zero");
346 let quota = Quota::per_second(nonzero);
347 let limiter = RateLimiter::direct(quota);
348
349 tracing::info!(
350 "Rate limiting enabled: {} requests per second",
351 requests_per_second
352 );
353
354 Some(Arc::new(limiter))
355 }
356
357 pub fn from_client(client: Arc<Client>) -> Self {
359 Self {
360 client,
361 rate_limiter: Self::create_rate_limiter(),
362 no_proxy: None,
363 }
364 }
365
366 pub fn client(&self) -> &Client {
368 &self.client
369 }
370
371 pub fn get(&self, url: &str) -> RateLimitedRequestBuilder {
373 RateLimitedRequestBuilder {
374 inner: self.client.get(url),
375 rate_limiter: self.rate_limiter.clone(),
376 }
377 }
378
379 pub fn post(&self, url: &str) -> RateLimitedRequestBuilder {
381 RateLimitedRequestBuilder {
382 inner: self.client.post(url),
383 rate_limiter: self.rate_limiter.clone(),
384 }
385 }
386
387 pub async fn download_to_file(
389 &self,
390 url: &str,
391 request: &DownloadRequest,
392 filename: &str,
393 ) -> Result<DownloadResult, SourceError> {
394 if let Some(ref limiter) = self.rate_limiter {
395 limiter.until_ready().await;
396 }
397
398 let response = self
399 .client
400 .get(url)
401 .send()
402 .await
403 .map_err(|e| SourceError::Network(format!("Failed to download: {}", e)))?;
404
405 if !response.status().is_success() {
406 return Err(SourceError::NotFound(format!(
407 "Failed to download: HTTP {}",
408 response.status()
409 )));
410 }
411
412 let bytes = response
413 .bytes()
414 .await
415 .map_err(|e| SourceError::Network(format!("Failed to read response: {}", e)))?;
416
417 std::fs::create_dir_all(&request.save_path).map_err(|e| {
419 SourceError::Io(std::io::Error::other(format!(
420 "Failed to create directory: {}",
421 e
422 )))
423 })?;
424
425 let path = Path::new(&request.save_path).join(filename);
426
427 std::fs::write(&path, bytes.as_ref()).map_err(SourceError::Io)?;
428
429 Ok(DownloadResult::success(
430 path.to_string_lossy().to_string(),
431 bytes.len() as u64,
432 ))
433 }
434
435 pub async fn download_pdf(
437 &self,
438 url: &str,
439 request: &DownloadRequest,
440 paper_id: &str,
441 ) -> Result<DownloadResult, SourceError> {
442 let filename = format!("{}.pdf", paper_id.replace('/', "_"));
443 self.download_to_file(url, request, &filename).await
444 }
445
446 pub async fn head(&self, url: &str) -> Result<bool, SourceError> {
448 if let Some(ref limiter) = self.rate_limiter {
449 limiter.until_ready().await;
450 }
451
452 let response = self
453 .client
454 .head(url)
455 .send()
456 .await
457 .map_err(|e| SourceError::Network(format!("Head request failed: {}", e)))?;
458 Ok(response.status() == StatusCode::OK)
459 }
460}
461
462impl Default for HttpClient {
463 fn default() -> Self {
464 Self::new().expect("Failed to create default HTTP client")
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471 use std::sync::{Mutex, OnceLock};
472
473 fn env_lock() -> &'static Mutex<()> {
474 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
475 LOCK.get_or_init(|| Mutex::new(()))
476 }
477
478 fn with_rate_limit_env<T>(value: Option<&str>, f: impl FnOnce() -> T) -> T {
479 let _guard = env_lock().lock().expect("env lock poisoned");
480 let previous = std::env::var(RATE_LIMIT_ENV_VAR).ok();
481
482 match value {
483 Some(v) => std::env::set_var(RATE_LIMIT_ENV_VAR, v),
484 None => std::env::remove_var(RATE_LIMIT_ENV_VAR),
485 }
486
487 let result = f();
488
489 match previous {
490 Some(v) => std::env::set_var(RATE_LIMIT_ENV_VAR, v),
491 _ => std::env::remove_var(RATE_LIMIT_ENV_VAR),
492 }
493
494 result
495 }
496
497 #[test]
498 fn test_create_rate_limiter_with_default() {
499 with_rate_limit_env(None, || {
500 let limiter = HttpClient::create_rate_limiter();
501 assert!(limiter.is_some(), "Default rate limiter should be created");
502 });
503 }
504
505 #[test]
506 fn test_create_rate_limiter_disabled() {
507 with_rate_limit_env(Some("0"), || {
508 let limiter = HttpClient::create_rate_limiter();
509 assert!(
510 limiter.is_none(),
511 "Rate limiter should be disabled when set to 0"
512 );
513 });
514 }
515
516 #[test]
517 fn test_create_rate_limiter_custom() {
518 with_rate_limit_env(Some("10"), || {
519 let limiter = HttpClient::create_rate_limiter();
520 assert!(limiter.is_some(), "Custom rate limiter should be created");
521 });
522 }
523
524 #[test]
525 fn test_create_rate_limiter_invalid() {
526 with_rate_limit_env(Some("invalid"), || {
527 let limiter = HttpClient::create_rate_limiter();
528 assert!(
530 limiter.is_some(),
531 "Should fall back to default rate limiter"
532 );
533 });
534 }
535
536 #[test]
537 fn test_should_bypass_proxy_no_config() {
538 let result = should_bypass_proxy("https://api.semanticscholar.org", &None);
540 assert!(!result, "Should not bypass when no no_proxy configured");
541 }
542
543 #[test]
544 fn test_should_bypass_proxy_wildcard() {
545 let no_proxy = Some(vec!["*".to_string()]);
546 let result = should_bypass_proxy("https://api.semanticscholar.org", &no_proxy);
547 assert!(result, "Should bypass for wildcard");
548 }
549
550 #[test]
551 fn test_should_bypass_proxy_exact_match() {
552 let no_proxy = Some(vec!["api.semanticscholar.org".to_string()]);
553 let result = should_bypass_proxy("https://api.semanticscholar.org", &no_proxy);
554 assert!(result, "Should bypass for exact match");
555 }
556
557 #[test]
558 fn test_should_bypass_proxy_domain_suffix() {
559 let no_proxy = Some(vec!["semanticscholar.org".to_string()]);
560 let result = should_bypass_proxy("https://api.semanticscholar.org", &no_proxy);
561 assert!(result, "Should bypass for domain suffix match");
562 }
563
564 #[test]
565 fn test_should_bypass_proxy_no_match() {
566 let no_proxy = Some(vec!["other-domain.org".to_string()]);
567 let result = should_bypass_proxy("https://api.semanticscholar.org", &no_proxy);
568 assert!(!result, "Should not bypass when domain doesn't match");
569 }
570
571 #[test]
572 fn test_should_bypass_proxy_multiple_hosts() {
573 let no_proxy = Some(vec![
574 "api.semanticscholar.org".to_string(),
575 "arxiv.org".to_string(),
576 ]);
577 assert!(should_bypass_proxy(
578 "https://api.semanticscholar.org",
579 &no_proxy
580 ));
581 assert!(should_bypass_proxy("https://arxiv.org", &no_proxy));
582 assert!(!should_bypass_proxy("https://openalex.org", &no_proxy));
583 }
584}