feedparser_rs/http/
client.rs

1use super::response::FeedHttpResponse;
2use super::validation::validate_url;
3use crate::error::{FeedError, Result};
4use reqwest::blocking::{Client, Response};
5use reqwest::header::{
6    ACCEPT, ACCEPT_ENCODING, HeaderMap, HeaderName, HeaderValue, IF_MODIFIED_SINCE, IF_NONE_MATCH,
7    USER_AGENT,
8};
9use std::collections::HashMap;
10use std::time::Duration;
11
12/// HTTP client for fetching feeds
13pub struct FeedHttpClient {
14    client: Client,
15    user_agent: String,
16    timeout: Duration,
17}
18
19impl FeedHttpClient {
20    /// Creates a new HTTP client with default settings
21    ///
22    /// Default settings:
23    /// - 30 second timeout
24    /// - Gzip, deflate, and brotli compression enabled
25    /// - Maximum 10 redirects
26    /// - Custom User-Agent
27    ///
28    /// # Errors
29    ///
30    /// Returns `FeedError::Http` if the underlying HTTP client cannot be created.
31    pub fn new() -> Result<Self> {
32        let client = Client::builder()
33            .timeout(Duration::from_secs(30))
34            .gzip(true)
35            .deflate(true)
36            .brotli(true)
37            .redirect(reqwest::redirect::Policy::limited(10))
38            .build()
39            .map_err(|e| FeedError::Http {
40                message: format!("Failed to create HTTP client: {e}"),
41            })?;
42
43        Ok(Self {
44            client,
45            user_agent: format!(
46                "feedparser-rs/{} (+https://github.com/bug-ops/feedparser-rs)",
47                env!("CARGO_PKG_VERSION")
48            ),
49            timeout: Duration::from_secs(30),
50        })
51    }
52
53    /// Sets a custom User-Agent header
54    ///
55    /// # Security
56    ///
57    /// User-Agent is truncated to 512 bytes to prevent header injection attacks.
58    #[must_use]
59    pub fn with_user_agent(mut self, agent: String) -> Self {
60        // Truncate to 512 bytes to prevent header injection
61        const MAX_USER_AGENT_LEN: usize = 512;
62        self.user_agent = if agent.len() > MAX_USER_AGENT_LEN {
63            agent.chars().take(MAX_USER_AGENT_LEN).collect()
64        } else {
65            agent
66        };
67        self
68    }
69
70    /// Sets request timeout
71    #[must_use]
72    pub const fn with_timeout(mut self, timeout: Duration) -> Self {
73        self.timeout = timeout;
74        self
75    }
76
77    /// Insert header with consistent error handling
78    ///
79    /// Helper method to reduce boilerplate in header insertion.
80    #[inline]
81    fn insert_header(
82        headers: &mut HeaderMap,
83        name: HeaderName,
84        value: &str,
85        field_name: &str,
86    ) -> Result<()> {
87        headers.insert(
88            name,
89            HeaderValue::from_str(value).map_err(|e| FeedError::Http {
90                message: format!("Invalid {field_name}: {e}"),
91            })?,
92        );
93        Ok(())
94    }
95
96    /// Fetches a feed from the given URL
97    ///
98    /// Supports conditional GET with `ETag` and `Last-Modified` headers.
99    ///
100    /// # Arguments
101    ///
102    /// * `url` - HTTP/HTTPS URL to fetch
103    /// * `etag` - Optional `ETag` from previous fetch
104    /// * `modified` - Optional `Last-Modified` from previous fetch
105    /// * `extra_headers` - Additional custom headers
106    ///
107    /// # Errors
108    ///
109    /// Returns `FeedError::Http` if the request fails or headers are invalid.
110    pub fn get(
111        &self,
112        url: &str,
113        etag: Option<&str>,
114        modified: Option<&str>,
115        extra_headers: Option<&HeaderMap>,
116    ) -> Result<FeedHttpResponse> {
117        // Validate URL to prevent SSRF attacks
118        let validated_url = validate_url(url)?;
119        let url_str = validated_url.as_str();
120
121        let mut headers = HeaderMap::new();
122
123        // Standard headers
124        Self::insert_header(&mut headers, USER_AGENT, &self.user_agent, "User-Agent")?;
125
126        headers.insert(
127            ACCEPT,
128            HeaderValue::from_static(
129                "application/rss+xml, application/atom+xml, application/xml, text/xml, */*",
130            ),
131        );
132
133        headers.insert(
134            ACCEPT_ENCODING,
135            HeaderValue::from_static("gzip, deflate, br"),
136        );
137
138        // Conditional GET headers with length validation
139        if let Some(etag_val) = etag {
140            // Truncate ETag to 1KB to prevent oversized headers
141            const MAX_ETAG_LEN: usize = 1024;
142            let sanitized_etag = if etag_val.len() > MAX_ETAG_LEN {
143                &etag_val[..MAX_ETAG_LEN]
144            } else {
145                etag_val
146            };
147            Self::insert_header(&mut headers, IF_NONE_MATCH, sanitized_etag, "ETag")?;
148        }
149
150        if let Some(modified_val) = modified {
151            // Truncate Last-Modified to 64 bytes (RFC 822 dates are ~30 bytes)
152            const MAX_MODIFIED_LEN: usize = 64;
153            let sanitized_modified = if modified_val.len() > MAX_MODIFIED_LEN {
154                &modified_val[..MAX_MODIFIED_LEN]
155            } else {
156                modified_val
157            };
158            Self::insert_header(
159                &mut headers,
160                IF_MODIFIED_SINCE,
161                sanitized_modified,
162                "Last-Modified",
163            )?;
164        }
165
166        // Merge extra headers
167        if let Some(extra) = extra_headers {
168            headers.extend(extra.clone());
169        }
170
171        let response = self
172            .client
173            .get(url_str)
174            .headers(headers)
175            .send()
176            .map_err(|e| FeedError::Http {
177                message: format!("HTTP request failed: {e}"),
178            })?;
179
180        Self::build_response(response, url_str)
181    }
182
183    /// Converts `reqwest` Response to `FeedHttpResponse`
184    fn build_response(response: Response, _original_url: &str) -> Result<FeedHttpResponse> {
185        let status = response.status().as_u16();
186        let url = response.url().to_string();
187
188        // Convert headers to HashMap with pre-allocated capacity
189        let mut headers_map = HashMap::with_capacity(response.headers().len());
190        for (name, value) in response.headers() {
191            if let Ok(val_str) = value.to_str() {
192                headers_map.insert(name.to_string(), val_str.to_string());
193            }
194        }
195
196        // Extract caching headers
197        let etag = headers_map.get("etag").cloned();
198        let last_modified = headers_map.get("last-modified").cloned();
199        let content_type = headers_map.get("content-type").cloned();
200
201        // Extract encoding from Content-Type
202        let encoding = content_type
203            .as_ref()
204            .and_then(|ct| FeedHttpResponse::extract_charset_from_content_type(ct));
205
206        // Read body (handles gzip/deflate automatically)
207        let body = if status == 304 {
208            // Not Modified - no body
209            Vec::new()
210        } else {
211            response
212                .bytes()
213                .map_err(|e| FeedError::Http {
214                    message: format!("Failed to read response body: {e}"),
215                })?
216                .to_vec()
217        };
218
219        Ok(FeedHttpResponse {
220            status,
221            url,
222            headers: headers_map,
223            body,
224            etag,
225            last_modified,
226            content_type,
227            encoding,
228        })
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_client_creation() {
238        let client = FeedHttpClient::new();
239        assert!(client.is_ok());
240    }
241
242    #[test]
243    fn test_custom_user_agent() {
244        let client = FeedHttpClient::new()
245            .unwrap()
246            .with_user_agent("CustomBot/1.0".to_string());
247        assert_eq!(client.user_agent, "CustomBot/1.0");
248    }
249
250    #[test]
251    fn test_custom_timeout() {
252        let timeout = Duration::from_secs(60);
253        let client = FeedHttpClient::new().unwrap().with_timeout(timeout);
254        assert_eq!(client.timeout, timeout);
255    }
256
257    // SSRF protection tests
258    #[test]
259    fn test_reject_localhost_url() {
260        let client = FeedHttpClient::new().unwrap();
261        let result = client.get("http://localhost/feed.xml", None, None, None);
262        assert!(result.is_err());
263        let err_msg = result.err().unwrap().to_string();
264        assert!(err_msg.contains("Localhost domain not allowed"));
265    }
266
267    #[test]
268    fn test_reject_private_ip() {
269        let client = FeedHttpClient::new().unwrap();
270        let result = client.get("http://192.168.1.1/feed.xml", None, None, None);
271        assert!(result.is_err());
272        let err_msg = result.err().unwrap().to_string();
273        assert!(err_msg.contains("Private IP address not allowed"));
274    }
275
276    #[test]
277    fn test_reject_metadata_endpoint() {
278        let client = FeedHttpClient::new().unwrap();
279        let result = client.get("http://169.254.169.254/latest/meta-data/", None, None, None);
280        assert!(result.is_err());
281        let err_msg = result.err().unwrap().to_string();
282        // Should be rejected as AWS metadata endpoint or link-local
283        assert!(err_msg.contains("metadata") || err_msg.contains("Link-local"));
284    }
285
286    #[test]
287    fn test_reject_file_scheme() {
288        let client = FeedHttpClient::new().unwrap();
289        let result = client.get("file:///etc/passwd", None, None, None);
290        assert!(result.is_err());
291        let err_msg = result.err().unwrap().to_string();
292        assert!(err_msg.contains("Unsupported URL scheme"));
293    }
294
295    #[test]
296    fn test_reject_internal_domain() {
297        let client = FeedHttpClient::new().unwrap();
298        let result = client.get("http://server.local/feed.xml", None, None, None);
299        assert!(result.is_err());
300        let err_msg = result.err().unwrap().to_string();
301        assert!(err_msg.contains("Internal domain TLD not allowed"));
302    }
303
304    #[test]
305    fn test_insert_header_valid() {
306        let mut headers = HeaderMap::new();
307        let result =
308            FeedHttpClient::insert_header(&mut headers, USER_AGENT, "TestBot/1.0", "User-Agent");
309        assert!(result.is_ok());
310        assert_eq!(headers.get(USER_AGENT).unwrap(), "TestBot/1.0");
311    }
312
313    #[test]
314    fn test_insert_header_invalid_value() {
315        let mut headers = HeaderMap::new();
316        // Invalid header value with control characters
317        let result = FeedHttpClient::insert_header(
318            &mut headers,
319            USER_AGENT,
320            "Invalid\nHeader",
321            "User-Agent",
322        );
323        assert!(result.is_err());
324        match result {
325            Err(FeedError::Http { message }) => {
326                assert!(message.contains("Invalid User-Agent"));
327            }
328            _ => panic!("Expected Http error"),
329        }
330    }
331
332    #[test]
333    fn test_insert_header_multiple_headers() {
334        let mut headers = HeaderMap::new();
335
336        FeedHttpClient::insert_header(&mut headers, USER_AGENT, "TestBot/1.0", "User-Agent")
337            .unwrap();
338
339        FeedHttpClient::insert_header(&mut headers, ACCEPT, "application/xml", "Accept").unwrap();
340
341        assert_eq!(headers.len(), 2);
342        assert_eq!(headers.get(USER_AGENT).unwrap(), "TestBot/1.0");
343        assert_eq!(headers.get(ACCEPT).unwrap(), "application/xml");
344    }
345}