Skip to main content

research_master/utils/
http.rs

1//! HTTP client utilities with rate limiting support.
2
3use governor::{
4    clock::DefaultClock,
5    state::{InMemoryState, NotKeyed},
6    Quota, RateLimiter,
7};
8use reqwest::{header, Client, StatusCode};
9use std::num::NonZeroU32;
10use std::path::Path;
11use std::sync::Arc;
12use std::time::Duration;
13
14use crate::models::{DownloadRequest, DownloadResult};
15use crate::sources::SourceError;
16
17/// Default rate limit: requests per second
18const DEFAULT_REQUESTS_PER_SECOND: u32 = 5;
19
20/// Environment variable for rate limiting (requests per second)
21const RATE_LIMIT_ENV_VAR: &str = "RESEARCH_MASTER_RATE_LIMITS_DEFAULT_REQUESTS_PER_SECOND";
22
23/// Environment variable for HTTP proxy
24const HTTP_PROXY_ENV_VAR: &str = "HTTP_PROXY";
25
26/// Environment variable for HTTPS proxy
27const HTTPS_PROXY_ENV_VAR: &str = "HTTPS_PROXY";
28
29/// Environment variable for no proxy (comma-separated list of hosts to bypass proxy)
30const NO_PROXY_ENV_VAR: &str = "NO_PROXY";
31
32/// Proxy configuration
33#[derive(Debug, Clone, Default)]
34pub struct ProxyConfig {
35    pub http_proxy: Option<String>,
36    pub https_proxy: Option<String>,
37    pub no_proxy: Option<Vec<String>>,
38}
39
40/// Create proxy configuration from environment variables
41pub fn create_proxy_config() -> ProxyConfig {
42    let http_proxy = std::env::var(HTTP_PROXY_ENV_VAR).ok();
43    let https_proxy = std::env::var(HTTPS_PROXY_ENV_VAR).ok();
44    let no_proxy: Option<Vec<String>> = std::env::var(NO_PROXY_ENV_VAR)
45        .ok()
46        .map(|s| s.split(',').map(|v| v.trim().to_string()).collect());
47
48    if http_proxy.is_some() || https_proxy.is_some() {
49        tracing::info!(
50            "Proxy configured: HTTP={:?}, HTTPS={:?}, NO_PROXY={:?}",
51            http_proxy,
52            https_proxy,
53            no_proxy
54        );
55    }
56
57    ProxyConfig {
58        http_proxy,
59        https_proxy,
60        no_proxy,
61    }
62}
63
64/// Check if a URL should bypass the proxy
65fn should_bypass_proxy(url: &str, no_proxy: &Option<Vec<String>>) -> bool {
66    let Some(hosts) = no_proxy else {
67        return false;
68    };
69
70    if hosts.iter().any(|h| h == "*") {
71        return true;
72    }
73
74    // Parse URL to extract host
75    if let Ok(url) = reqwest::Url::parse(url) {
76        let host = url.host_str().map(|h| h.to_lowercase());
77        if let Some(host) = host {
78            // Check exact match or domain suffix match
79            for no_proxy_host in hosts {
80                if host == no_proxy_host.to_lowercase() {
81                    return true;
82                }
83                // Check if the host ends with the no_proxy domain
84                if host.ends_with(&format!(".{}", no_proxy_host.to_lowercase())) {
85                    return true;
86                }
87            }
88        }
89    }
90
91    false
92}
93
94/// Shared HTTP client with sensible defaults and rate limiting
95#[derive(Debug, Clone)]
96pub struct HttpClient {
97    client: Arc<Client>,
98    rate_limiter: Option<Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>>,
99    no_proxy: Option<Vec<String>>,
100}
101
102/// Rate-limited request builder - compatible API with reqwest::RequestBuilder
103pub struct RateLimitedRequestBuilder {
104    inner: reqwest::RequestBuilder,
105    rate_limiter: Option<Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>>,
106}
107
108impl RateLimitedRequestBuilder {
109    /// Send the request (with rate limiting applied first)
110    pub async fn send(self) -> Result<reqwest::Response, reqwest::Error> {
111        if let Some(ref limiter) = self.rate_limiter {
112            limiter.until_ready().await;
113        }
114        self.inner.send().await
115    }
116
117    /// Add a header (accepts &str for convenience - most common use case)
118    pub fn header<K, V>(mut self, key: K, value: V) -> Self
119    where
120        K: AsRef<str>,
121        V: AsRef<str>,
122    {
123        self.inner = self.inner.header(key.as_ref(), value.as_ref());
124        self
125    }
126
127    /// Set headers
128    pub fn headers(mut self, headers: header::HeaderMap) -> Self {
129        self.inner = self.inner.headers(headers);
130        self
131    }
132
133    /// Basic auth
134    pub fn basic_auth<U, P>(self, username: U, password: Option<P>) -> Self
135    where
136        U: Into<String> + std::fmt::Display,
137        P: Into<String> + std::fmt::Display,
138    {
139        Self {
140            inner: self.inner.basic_auth(username, password),
141            rate_limiter: self.rate_limiter,
142        }
143    }
144
145    /// Bearer auth
146    pub fn bearer_auth<T>(self, token: T) -> Self
147    where
148        T: Into<String> + std::fmt::Display,
149    {
150        Self {
151            inner: self.inner.bearer_auth(token),
152            rate_limiter: self.rate_limiter,
153        }
154    }
155
156    /// Query parameters
157    pub fn query<T: serde::Serialize + ?Sized>(mut self, query: &T) -> Self {
158        self.inner = self.inner.query(query);
159        self
160    }
161
162    /// Form data
163    pub fn form<T: serde::Serialize + ?Sized>(mut self, form: &T) -> Self {
164        self.inner = self.inner.form(form);
165        self
166    }
167
168    /// JSON body
169    pub fn json<T: serde::Serialize + ?Sized>(mut self, json: &T) -> Self {
170        self.inner = self.inner.json(json);
171        self
172    }
173
174    /// Build the request
175    pub fn build(self) -> Result<reqwest::Request, reqwest::Error> {
176        self.inner.build()
177    }
178}
179
180impl HttpClient {
181    /// Create a new HTTP client with default settings and rate limiting
182    pub fn new() -> Result<Self, SourceError> {
183        Self::with_user_agent(concat!(
184            env!("CARGO_PKG_NAME"),
185            "/",
186            env!("CARGO_PKG_VERSION")
187        ))
188    }
189
190    /// Create a new HTTP client with a custom user agent
191    pub fn with_user_agent(user_agent: &str) -> Result<Self, SourceError> {
192        let rate_limiter = Self::create_rate_limiter();
193        let proxy = create_proxy_config();
194
195        let mut builder = Client::builder()
196            .user_agent(user_agent)
197            .timeout(Duration::from_secs(30))
198            .connect_timeout(Duration::from_secs(10))
199            .pool_idle_timeout(Duration::from_secs(90));
200
201        // Apply proxy if configured
202        if let Some(proxy_url) = proxy.http_proxy {
203            builder = builder.proxy(reqwest::Proxy::http(&proxy_url)?);
204        }
205        if let Some(proxy_url) = proxy.https_proxy {
206            builder = builder.proxy(reqwest::Proxy::https(&proxy_url)?);
207        }
208
209        let client = builder
210            .build()
211            .map_err(|e| SourceError::Network(format!("Failed to create HTTP client: {}", e)))?;
212
213        Ok(Self {
214            client: Arc::new(client),
215            rate_limiter,
216            no_proxy: proxy.no_proxy,
217        })
218    }
219
220    /// Create a new HTTP client without rate limiting
221    pub fn without_rate_limit(user_agent: &str) -> Result<Self, SourceError> {
222        let proxy = create_proxy_config();
223        let mut builder = Client::builder()
224            .user_agent(user_agent)
225            .timeout(Duration::from_secs(30))
226            .connect_timeout(Duration::from_secs(10))
227            .pool_idle_timeout(Duration::from_secs(90));
228
229        if let Some(proxy_url) = proxy.http_proxy {
230            builder = builder.proxy(reqwest::Proxy::http(&proxy_url)?);
231        }
232        if let Some(proxy_url) = proxy.https_proxy {
233            builder = builder.proxy(reqwest::Proxy::https(&proxy_url)?);
234        }
235
236        let client = builder
237            .build()
238            .map_err(|e| SourceError::Network(format!("Failed to create HTTP client: {}", e)))?;
239
240        Ok(Self {
241            client: Arc::new(client),
242            rate_limiter: None,
243            no_proxy: proxy.no_proxy,
244        })
245    }
246
247    /// Check if a URL should bypass the proxy
248    pub fn should_bypass_proxy(&self, url: &str) -> bool {
249        should_bypass_proxy(url, &self.no_proxy)
250    }
251
252    /// Create a new HTTP client with a custom rate limit
253    pub fn with_rate_limit(
254        user_agent: &str,
255        requests_per_second: u32,
256    ) -> Result<Self, SourceError> {
257        let rate_limiter = if requests_per_second == 0 {
258            None
259        } else {
260            let nonzero = NonZeroU32::new(requests_per_second)
261                .expect("requests_per_second should be > 0 when not 0 branch");
262            let quota = Quota::per_second(nonzero);
263            Some(Arc::new(RateLimiter::direct(quota)))
264        };
265
266        let proxy = create_proxy_config();
267        let mut builder = Client::builder()
268            .user_agent(user_agent)
269            .timeout(Duration::from_secs(30))
270            .connect_timeout(Duration::from_secs(10))
271            .pool_idle_timeout(Duration::from_secs(90));
272
273        if let Some(proxy_url) = proxy.http_proxy {
274            builder = builder.proxy(reqwest::Proxy::http(&proxy_url)?);
275        }
276        if let Some(proxy_url) = proxy.https_proxy {
277            builder = builder.proxy(reqwest::Proxy::https(&proxy_url)?);
278        }
279
280        let client = builder
281            .build()
282            .map_err(|e| SourceError::Network(format!("Failed to create HTTP client: {}", e)))?;
283
284        Ok(Self {
285            client: Arc::new(client),
286            rate_limiter,
287            no_proxy: proxy.no_proxy,
288        })
289    }
290
291    /// Create HTTP client with per-source proxy
292    pub fn with_proxy(
293        user_agent: &str,
294        http_proxy: Option<String>,
295        https_proxy: Option<String>,
296        requests_per_second: u32,
297    ) -> Result<Self, SourceError> {
298        let rate_limiter = if requests_per_second == 0 {
299            None
300        } else {
301            let nonzero = NonZeroU32::new(requests_per_second)
302                .expect("requests_per_second should be > 0 when not 0 branch");
303            let quota = Quota::per_second(nonzero);
304            Some(Arc::new(RateLimiter::direct(quota)))
305        };
306
307        let mut builder = Client::builder()
308            .user_agent(user_agent)
309            .timeout(Duration::from_secs(30))
310            .connect_timeout(Duration::from_secs(10))
311            .pool_idle_timeout(Duration::from_secs(90));
312
313        if let Some(proxy_url) = http_proxy {
314            builder = builder.proxy(reqwest::Proxy::http(&proxy_url)?);
315        }
316        if let Some(proxy_url) = https_proxy {
317            builder = builder.proxy(reqwest::Proxy::https(&proxy_url)?);
318        }
319
320        let client = builder
321            .build()
322            .map_err(|e| SourceError::Network(format!("Failed to create HTTP client: {}", e)))?;
323
324        Ok(Self {
325            client: Arc::new(client),
326            rate_limiter,
327            no_proxy: None, // Per-source proxy doesn't use env no_proxy
328        })
329    }
330
331    /// Create rate limiter from environment variable or default
332    fn create_rate_limiter() -> Option<Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>> {
333        let requests_per_second = std::env::var(RATE_LIMIT_ENV_VAR)
334            .ok()
335            .and_then(|s| s.parse::<u32>().ok())
336            .unwrap_or(DEFAULT_REQUESTS_PER_SECOND);
337
338        if requests_per_second == 0 {
339            // Rate limiting disabled
340            tracing::info!("Rate limiting disabled");
341            return None;
342        }
343
344        let nonzero =
345            NonZeroU32::new(requests_per_second).expect("requests_per_second should not be zero");
346        let quota = Quota::per_second(nonzero);
347        let limiter = RateLimiter::direct(quota);
348
349        tracing::info!(
350            "Rate limiting enabled: {} requests per second",
351            requests_per_second
352        );
353
354        Some(Arc::new(limiter))
355    }
356
357    /// Create from an existing reqwest Client
358    pub fn from_client(client: Arc<Client>) -> Self {
359        Self {
360            client,
361            rate_limiter: Self::create_rate_limiter(),
362            no_proxy: None,
363        }
364    }
365
366    /// Get the underlying client
367    pub fn client(&self) -> &Client {
368        &self.client
369    }
370
371    /// Create a rate-limited GET request builder
372    pub fn get(&self, url: &str) -> RateLimitedRequestBuilder {
373        RateLimitedRequestBuilder {
374            inner: self.client.get(url),
375            rate_limiter: self.rate_limiter.clone(),
376        }
377    }
378
379    /// Create a rate-limited POST request builder
380    pub fn post(&self, url: &str) -> RateLimitedRequestBuilder {
381        RateLimitedRequestBuilder {
382            inner: self.client.post(url),
383            rate_limiter: self.rate_limiter.clone(),
384        }
385    }
386
387    /// Download a file from a URL to the specified path
388    pub async fn download_to_file(
389        &self,
390        url: &str,
391        request: &DownloadRequest,
392        filename: &str,
393    ) -> Result<DownloadResult, SourceError> {
394        if let Some(ref limiter) = self.rate_limiter {
395            limiter.until_ready().await;
396        }
397
398        let response = self
399            .client
400            .get(url)
401            .send()
402            .await
403            .map_err(|e| SourceError::Network(format!("Failed to download: {}", e)))?;
404
405        if !response.status().is_success() {
406            return Err(SourceError::NotFound(format!(
407                "Failed to download: HTTP {}",
408                response.status()
409            )));
410        }
411
412        let bytes = response
413            .bytes()
414            .await
415            .map_err(|e| SourceError::Network(format!("Failed to read response: {}", e)))?;
416
417        // Create download directory if it doesn't exist
418        std::fs::create_dir_all(&request.save_path).map_err(|e| {
419            SourceError::Io(std::io::Error::other(format!(
420                "Failed to create directory: {}",
421                e
422            )))
423        })?;
424
425        let path = Path::new(&request.save_path).join(filename);
426
427        std::fs::write(&path, bytes.as_ref()).map_err(SourceError::Io)?;
428
429        Ok(DownloadResult::success(
430            path.to_string_lossy().to_string(),
431            bytes.len() as u64,
432        ))
433    }
434
435    /// Download a PDF with a sanitized filename
436    pub async fn download_pdf(
437        &self,
438        url: &str,
439        request: &DownloadRequest,
440        paper_id: &str,
441    ) -> Result<DownloadResult, SourceError> {
442        let filename = format!("{}.pdf", paper_id.replace('/', "_"));
443        self.download_to_file(url, request, &filename).await
444    }
445
446    /// Check if a URL returns success status
447    pub async fn head(&self, url: &str) -> Result<bool, SourceError> {
448        if let Some(ref limiter) = self.rate_limiter {
449            limiter.until_ready().await;
450        }
451
452        let response = self
453            .client
454            .head(url)
455            .send()
456            .await
457            .map_err(|e| SourceError::Network(format!("Head request failed: {}", e)))?;
458        Ok(response.status() == StatusCode::OK)
459    }
460}
461
462impl Default for HttpClient {
463    fn default() -> Self {
464        Self::new().expect("Failed to create default HTTP client")
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471    use std::sync::{Mutex, OnceLock};
472
473    fn env_lock() -> &'static Mutex<()> {
474        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
475        LOCK.get_or_init(|| Mutex::new(()))
476    }
477
478    fn with_rate_limit_env<T>(value: Option<&str>, f: impl FnOnce() -> T) -> T {
479        let _guard = env_lock().lock().expect("env lock poisoned");
480        let previous = std::env::var(RATE_LIMIT_ENV_VAR).ok();
481
482        match value {
483            Some(v) => std::env::set_var(RATE_LIMIT_ENV_VAR, v),
484            None => std::env::remove_var(RATE_LIMIT_ENV_VAR),
485        }
486
487        let result = f();
488
489        match previous {
490            Some(v) => std::env::set_var(RATE_LIMIT_ENV_VAR, v),
491            _ => std::env::remove_var(RATE_LIMIT_ENV_VAR),
492        }
493
494        result
495    }
496
497    #[test]
498    fn test_create_rate_limiter_with_default() {
499        with_rate_limit_env(None, || {
500            let limiter = HttpClient::create_rate_limiter();
501            assert!(limiter.is_some(), "Default rate limiter should be created");
502        });
503    }
504
505    #[test]
506    fn test_create_rate_limiter_disabled() {
507        with_rate_limit_env(Some("0"), || {
508            let limiter = HttpClient::create_rate_limiter();
509            assert!(
510                limiter.is_none(),
511                "Rate limiter should be disabled when set to 0"
512            );
513        });
514    }
515
516    #[test]
517    fn test_create_rate_limiter_custom() {
518        with_rate_limit_env(Some("10"), || {
519            let limiter = HttpClient::create_rate_limiter();
520            assert!(limiter.is_some(), "Custom rate limiter should be created");
521        });
522    }
523
524    #[test]
525    fn test_create_rate_limiter_invalid() {
526        with_rate_limit_env(Some("invalid"), || {
527            let limiter = HttpClient::create_rate_limiter();
528            // Should fall back to default when invalid value is provided
529            assert!(
530                limiter.is_some(),
531                "Should fall back to default rate limiter"
532            );
533        });
534    }
535
536    #[test]
537    fn test_should_bypass_proxy_no_config() {
538        // No no_proxy configured
539        let result = should_bypass_proxy("https://api.semanticscholar.org", &None);
540        assert!(!result, "Should not bypass when no no_proxy configured");
541    }
542
543    #[test]
544    fn test_should_bypass_proxy_wildcard() {
545        let no_proxy = Some(vec!["*".to_string()]);
546        let result = should_bypass_proxy("https://api.semanticscholar.org", &no_proxy);
547        assert!(result, "Should bypass for wildcard");
548    }
549
550    #[test]
551    fn test_should_bypass_proxy_exact_match() {
552        let no_proxy = Some(vec!["api.semanticscholar.org".to_string()]);
553        let result = should_bypass_proxy("https://api.semanticscholar.org", &no_proxy);
554        assert!(result, "Should bypass for exact match");
555    }
556
557    #[test]
558    fn test_should_bypass_proxy_domain_suffix() {
559        let no_proxy = Some(vec!["semanticscholar.org".to_string()]);
560        let result = should_bypass_proxy("https://api.semanticscholar.org", &no_proxy);
561        assert!(result, "Should bypass for domain suffix match");
562    }
563
564    #[test]
565    fn test_should_bypass_proxy_no_match() {
566        let no_proxy = Some(vec!["other-domain.org".to_string()]);
567        let result = should_bypass_proxy("https://api.semanticscholar.org", &no_proxy);
568        assert!(!result, "Should not bypass when domain doesn't match");
569    }
570
571    #[test]
572    fn test_should_bypass_proxy_multiple_hosts() {
573        let no_proxy = Some(vec![
574            "api.semanticscholar.org".to_string(),
575            "arxiv.org".to_string(),
576        ]);
577        assert!(should_bypass_proxy(
578            "https://api.semanticscholar.org",
579            &no_proxy
580        ));
581        assert!(should_bypass_proxy("https://arxiv.org", &no_proxy));
582        assert!(!should_bypass_proxy("https://openalex.org", &no_proxy));
583    }
584}