Skip to main content

nano_get/
request.rs

1use std::time::SystemTime;
2
3use crate::auth::basic_authorization_value;
4use crate::client::Client;
5use crate::date::format_http_date;
6use crate::errors::NanoGetError;
7use crate::url::{ToUrl, Url};
8
9#[cfg(test)]
10const DEFAULT_USER_AGENT: &str = "nano-get/0.3.0";
11#[cfg(test)]
12const DEFAULT_ACCEPT: &str = "*/*";
13
14/// A single HTTP header field.
15///
16/// Header name matching in this crate is ASCII case-insensitive.
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct Header {
19    name: String,
20    value: String,
21}
22
23impl Header {
24    /// Creates a new header, validating the header name and value.
25    pub fn new(name: impl Into<String>, value: impl Into<String>) -> Result<Self, NanoGetError> {
26        let name = name.into();
27        let value = value.into();
28        validate_header_name(&name)?;
29        validate_header_value(&value)?;
30        Ok(Self { name, value })
31    }
32
33    pub(crate) fn unchecked(name: impl Into<String>, value: impl Into<String>) -> Self {
34        Self {
35            name: name.into(),
36            value: value.into(),
37        }
38    }
39
40    /// Returns the header field-name as provided.
41    pub fn name(&self) -> &str {
42        &self.name
43    }
44
45    /// Returns the header field-value as provided.
46    pub fn value(&self) -> &str {
47        &self.value
48    }
49
50    /// Returns `true` when `needle` matches this header name (case-insensitive).
51    pub fn matches_name(&self, needle: &str) -> bool {
52        self.name.eq_ignore_ascii_case(needle)
53    }
54}
55
56/// Supported request methods.
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum Method {
59    /// HTTP `GET`.
60    Get,
61    /// HTTP `HEAD`.
62    Head,
63}
64
65impl Method {
66    /// Returns the method token as sent on the wire.
67    pub fn as_str(self) -> &'static str {
68        match self {
69            Self::Get => "GET",
70            Self::Head => "HEAD",
71        }
72    }
73}
74
75/// Redirect behavior for a request.
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum RedirectPolicy {
78    /// Do not follow redirects automatically.
79    None,
80    /// Follow redirects up to `max_redirects`.
81    Follow {
82        /// Maximum number of redirect hops to follow.
83        max_redirects: usize,
84    },
85}
86
87impl RedirectPolicy {
88    /// Convenience constructor for [`RedirectPolicy::None`].
89    pub const fn none() -> Self {
90        Self::None
91    }
92
93    /// Convenience constructor for [`RedirectPolicy::Follow`].
94    pub const fn follow(max_redirects: usize) -> Self {
95        Self::Follow { max_redirects }
96    }
97
98    /// Returns the configured redirect limit for [`RedirectPolicy::Follow`], otherwise `None`.
99    pub fn max_redirects(self) -> Option<usize> {
100        match self {
101            Self::None => None,
102            Self::Follow { max_redirects } => Some(max_redirects),
103        }
104    }
105}
106
107/// A typed HTTP request for `GET` or `HEAD`.
108#[derive(Debug, Clone)]
109pub struct Request {
110    url: Url,
111    method: Method,
112    headers: Vec<Header>,
113    redirect_policy: RedirectPolicy,
114    redirect_policy_explicit: bool,
115    preemptive_origin_auth_allowed: bool,
116}
117
118impl Request {
119    /// Creates a new request with the given method and URL.
120    pub fn new<U: ToUrl>(method: Method, url: U) -> Result<Self, NanoGetError> {
121        Ok(Self {
122            url: url.to_url()?,
123            method,
124            headers: Vec::new(),
125            redirect_policy: RedirectPolicy::none(),
126            redirect_policy_explicit: false,
127            preemptive_origin_auth_allowed: true,
128        })
129    }
130
131    /// Creates a new `GET` request.
132    pub fn get<U: ToUrl>(url: U) -> Result<Self, NanoGetError> {
133        Self::new(Method::Get, url)
134    }
135
136    /// Creates a new `HEAD` request.
137    pub fn head<U: ToUrl>(url: U) -> Result<Self, NanoGetError> {
138        Self::new(Method::Head, url)
139    }
140
141    /// Returns the request method.
142    pub fn method(&self) -> Method {
143        self.method
144    }
145
146    /// Returns the parsed request URL.
147    pub fn url(&self) -> &Url {
148        &self.url
149    }
150
151    /// Returns all request headers in insertion order.
152    pub fn headers(&self) -> &[Header] {
153        &self.headers
154    }
155
156    /// Returns the first header value matching `name` (case-insensitive).
157    pub fn header(&self, name: &str) -> Option<&str> {
158        self.headers
159            .iter()
160            .find(|header| header.matches_name(name))
161            .map(Header::value)
162    }
163
164    /// Iterates over all headers matching `name` (case-insensitive), preserving insertion order.
165    pub fn headers_named<'a>(&'a self, name: &'a str) -> impl Iterator<Item = &'a Header> + 'a {
166        self.headers
167            .iter()
168            .filter(move |header| header.matches_name(name))
169    }
170
171    /// Returns this request's redirect policy.
172    pub fn redirect_policy(&self) -> RedirectPolicy {
173        self.redirect_policy
174    }
175
176    /// Sets redirect policy on this request using builder style.
177    pub fn with_redirect_policy(mut self, policy: RedirectPolicy) -> Self {
178        self.redirect_policy = policy;
179        self.redirect_policy_explicit = true;
180        self
181    }
182
183    /// Sets redirect policy on this request in-place.
184    pub fn set_redirect_policy(&mut self, policy: RedirectPolicy) -> &mut Self {
185        self.redirect_policy = policy;
186        self.redirect_policy_explicit = true;
187        self
188    }
189
190    /// Appends a header without replacing existing headers of the same name.
191    ///
192    /// Protocol-managed and hop-by-hop header names are rejected.
193    pub fn add_header(
194        &mut self,
195        name: impl Into<String>,
196        value: impl Into<String>,
197    ) -> Result<&mut Self, NanoGetError> {
198        let name = name.into();
199        validate_request_header_name(&name)?;
200        self.headers.push(Header::new(name, value)?);
201        Ok(self)
202    }
203
204    /// Sets a header value by removing existing headers with the same name first.
205    ///
206    /// Protocol-managed and hop-by-hop header names are rejected.
207    pub fn set_header(
208        &mut self,
209        name: impl Into<String>,
210        value: impl Into<String>,
211    ) -> Result<&mut Self, NanoGetError> {
212        let name = name.into();
213        validate_request_header_name(&name)?;
214        self.remove_headers_named(&name);
215        self.headers.push(Header::new(name, value)?);
216        Ok(self)
217    }
218
219    /// Removes all headers with the provided name.
220    pub fn remove_headers_named(&mut self, name: &str) -> &mut Self {
221        self.headers.retain(|header| !header.matches_name(name));
222        self
223    }
224
225    /// Sets `If-None-Match`.
226    pub fn if_none_match(&mut self, etag: impl Into<String>) -> Result<&mut Self, NanoGetError> {
227        self.set_header("If-None-Match", etag)
228    }
229
230    /// Sets `If-Match`.
231    pub fn if_match(&mut self, etag: impl Into<String>) -> Result<&mut Self, NanoGetError> {
232        self.set_header("If-Match", etag)
233    }
234
235    /// Sets `If-Modified-Since` using IMF-fixdate formatting.
236    pub fn if_modified_since(&mut self, timestamp: SystemTime) -> Result<&mut Self, NanoGetError> {
237        self.set_header("If-Modified-Since", format_http_date(timestamp)?)
238    }
239
240    /// Sets `If-Unmodified-Since` using IMF-fixdate formatting.
241    pub fn if_unmodified_since(
242        &mut self,
243        timestamp: SystemTime,
244    ) -> Result<&mut Self, NanoGetError> {
245        self.set_header("If-Unmodified-Since", format_http_date(timestamp)?)
246    }
247
248    /// Sets `If-Range`.
249    pub fn if_range(&mut self, value: impl Into<String>) -> Result<&mut Self, NanoGetError> {
250        self.set_header("If-Range", value)
251    }
252
253    /// Sets an explicit `Authorization` header for this request.
254    ///
255    /// Manual request-level credentials take precedence over automatic client-level auth helpers.
256    pub fn authorization(&mut self, value: impl Into<String>) -> Result<&mut Self, NanoGetError> {
257        self.set_header("Authorization", value)
258    }
259
260    /// Sets an explicit `Proxy-Authorization` header for this request.
261    ///
262    /// Manual request-level credentials take precedence over automatic client-level proxy auth
263    /// helpers.
264    pub fn proxy_authorization(
265        &mut self,
266        value: impl Into<String>,
267    ) -> Result<&mut Self, NanoGetError> {
268        self.set_header("Proxy-Authorization", value)
269    }
270
271    /// Encodes `username:password` as HTTP Basic credentials and stores them in
272    /// `Authorization`.
273    pub fn basic_auth(
274        &mut self,
275        username: impl Into<String>,
276        password: impl Into<String>,
277    ) -> Result<&mut Self, NanoGetError> {
278        self.authorization(basic_authorization_value(username.into(), password.into()))
279    }
280
281    /// Encodes `username:password` as HTTP Basic credentials and stores them in
282    /// `Proxy-Authorization`.
283    pub fn proxy_basic_auth(
284        &mut self,
285        username: impl Into<String>,
286        password: impl Into<String>,
287    ) -> Result<&mut Self, NanoGetError> {
288        self.proxy_authorization(basic_authorization_value(username.into(), password.into()))
289    }
290
291    /// Sets a `Range` header for byte-range requests.
292    ///
293    /// Valid forms:
294    /// - `Some(start), Some(end)` => `bytes=start-end`
295    /// - `Some(start), None` => `bytes=start-`
296    /// - `None, Some(end)` => `bytes=-end`
297    pub fn range_bytes(
298        &mut self,
299        start: Option<u64>,
300        end: Option<u64>,
301    ) -> Result<&mut Self, NanoGetError> {
302        let range = match (start, end) {
303            (Some(start), Some(end)) if start <= end => format!("bytes={start}-{end}"),
304            (Some(start), None) => format!("bytes={start}-"),
305            (None, Some(end)) => format!("bytes=-{end}"),
306            _ => {
307                return Err(NanoGetError::InvalidHeaderValue(
308                    "invalid byte range".to_string(),
309                ))
310            }
311        };
312
313        self.set_header("Range", range)
314    }
315
316    /// Executes this request using [`Client::default`].
317    ///
318    /// Use [`crate::Client`] directly when you need explicit client configuration.
319    pub fn execute(&self) -> Result<crate::response::Response, NanoGetError> {
320        Client::default().execute(self.clone())
321    }
322
323    pub(crate) fn has_header(&self, name: &str) -> bool {
324        self.headers.iter().any(|header| header.matches_name(name))
325    }
326
327    #[cfg(test)]
328    pub(crate) fn default_headers(&self) -> [Header; 4] {
329        self.default_headers_for(true)
330    }
331
332    #[cfg(test)]
333    pub(crate) fn default_headers_for(&self, connection_close: bool) -> [Header; 4] {
334        [
335            Header::unchecked("Host", self.url.host_header_value()),
336            Header::unchecked("User-Agent", DEFAULT_USER_AGENT),
337            Header::unchecked("Accept", DEFAULT_ACCEPT),
338            Header::unchecked(
339                "Connection",
340                if connection_close {
341                    "close"
342                } else {
343                    "keep-alive"
344                },
345            ),
346        ]
347    }
348
349    pub(crate) fn clone_with_url(&self, url: Url) -> Self {
350        let mut cloned = self.clone();
351        cloned.url = url;
352        cloned
353    }
354
355    pub(crate) fn effective_redirect_policy(&self, fallback: RedirectPolicy) -> RedirectPolicy {
356        if self.redirect_policy_explicit {
357            self.redirect_policy
358        } else {
359            fallback
360        }
361    }
362
363    pub(crate) fn preemptive_origin_auth_allowed(&self) -> bool {
364        self.preemptive_origin_auth_allowed
365    }
366
367    pub(crate) fn disable_preemptive_origin_auth(&mut self) {
368        self.preemptive_origin_auth_allowed = false;
369    }
370}
371
372fn validate_header_name(name: &str) -> Result<(), NanoGetError> {
373    if name.is_empty() || !name.as_bytes().iter().all(|byte| is_tchar(*byte)) {
374        return Err(NanoGetError::InvalidHeaderName(name.to_string()));
375    }
376    Ok(())
377}
378
379fn validate_header_value(value: &str) -> Result<(), NanoGetError> {
380    if value
381        .chars()
382        .any(|ch| ch == '\r' || ch == '\n' || (ch.is_ascii_control() && ch != '\t'))
383    {
384        return Err(NanoGetError::InvalidHeaderValue(value.to_string()));
385    }
386    Ok(())
387}
388
389fn is_tchar(byte: u8) -> bool {
390    byte.is_ascii_alphanumeric()
391        || matches!(
392            byte,
393            b'!' | b'#'
394                | b'$'
395                | b'%'
396                | b'&'
397                | b'\''
398                | b'*'
399                | b'+'
400                | b'-'
401                | b'.'
402                | b'^'
403                | b'_'
404                | b'`'
405                | b'|'
406                | b'~'
407        )
408}
409
410fn validate_request_header_name(name: &str) -> Result<(), NanoGetError> {
411    if matches_protocol_managed_header(name) {
412        return Err(NanoGetError::ProtocolManagedHeader(name.to_string()));
413    }
414
415    if matches_hop_by_hop_header(name) {
416        return Err(NanoGetError::HopByHopHeader(name.to_string()));
417    }
418
419    Ok(())
420}
421
422fn matches_protocol_managed_header(name: &str) -> bool {
423    matches!(
424        name.to_ascii_lowercase().as_str(),
425        "host" | "connection" | "content-length" | "transfer-encoding" | "trailer" | "upgrade"
426    )
427}
428
429fn matches_hop_by_hop_header(name: &str) -> bool {
430    matches!(
431        name.to_ascii_lowercase().as_str(),
432        "keep-alive" | "proxy-connection" | "te"
433    )
434}
435
436pub(crate) fn should_follow_redirect(status_code: u16) -> bool {
437    matches!(status_code, 301 | 302 | 303 | 307 | 308)
438}
439
440#[cfg(test)]
441mod tests {
442    use std::time::{Duration, UNIX_EPOCH};
443
444    use super::{Method, RedirectPolicy, Request};
445    use crate::errors::NanoGetError;
446
447    #[test]
448    fn request_defaults_to_no_redirects() {
449        let request = Request::get("http://example.com").unwrap();
450        assert_eq!(request.redirect_policy(), RedirectPolicy::None);
451    }
452
453    #[test]
454    fn add_header_validates_name() {
455        let error = Request::get("http://example.com")
456            .unwrap()
457            .add_header("bad:name", "value")
458            .unwrap_err();
459        assert!(matches!(error, NanoGetError::InvalidHeaderName(_)));
460
461        let error = Request::get("http://example.com")
462            .unwrap()
463            .add_header("bad(name)", "value")
464            .unwrap_err();
465        assert!(matches!(error, NanoGetError::InvalidHeaderName(_)));
466    }
467
468    #[test]
469    fn add_header_validates_value() {
470        let error = Request::get("http://example.com")
471            .unwrap()
472            .add_header("x-test", "bad\r\nvalue")
473            .unwrap_err();
474        assert!(matches!(error, NanoGetError::InvalidHeaderValue(_)));
475
476        let error = Request::get("http://example.com")
477            .unwrap()
478            .add_header("x-test", "bad\u{0000}value")
479            .unwrap_err();
480        assert!(matches!(error, NanoGetError::InvalidHeaderValue(_)));
481    }
482
483    #[test]
484    fn builder_updates_redirect_policy() {
485        let request = Request::head("http://example.com")
486            .unwrap()
487            .with_redirect_policy(RedirectPolicy::follow(5));
488        assert_eq!(request.method(), Method::Head);
489        assert_eq!(request.redirect_policy().max_redirects(), Some(5));
490        assert_eq!(RedirectPolicy::none().max_redirects(), None);
491    }
492
493    #[test]
494    fn set_redirect_policy_updates_in_place() {
495        let mut request = Request::get("http://example.com").unwrap();
496        request.set_redirect_policy(RedirectPolicy::follow(2));
497        assert_eq!(request.redirect_policy().max_redirects(), Some(2));
498    }
499
500    #[test]
501    fn default_headers_include_host() {
502        let request = Request::get("http://example.com:8080/path").unwrap();
503        let headers = request.default_headers();
504        assert!(headers
505            .iter()
506            .any(|header| { header.matches_name("host") && header.value() == "example.com:8080" }));
507    }
508
509    #[test]
510    fn set_header_replaces_existing_values() {
511        let mut request = Request::get("http://example.com").unwrap();
512        request.add_header("X-Test", "one").unwrap();
513        request.set_header("x-test", "two").unwrap();
514        let values: Vec<_> = request
515            .headers_named("X-Test")
516            .map(|header| header.value())
517            .collect();
518        assert_eq!(values, vec!["two"]);
519    }
520
521    #[test]
522    fn range_header_helper_supports_suffixes() {
523        let mut request = Request::get("http://example.com").unwrap();
524        request.range_bytes(None, Some(128)).unwrap();
525        assert_eq!(request.header("range"), Some("bytes=-128"));
526    }
527
528    #[test]
529    fn authorization_helpers_set_headers() {
530        let mut request = Request::get("http://example.com").unwrap();
531        request.basic_auth("user", "pass").unwrap();
532        request.proxy_basic_auth("proxy", "secret").unwrap();
533        assert_eq!(request.header("authorization"), Some("Basic dXNlcjpwYXNz"));
534        assert_eq!(
535            request.header("proxy-authorization"),
536            Some("Basic cHJveHk6c2VjcmV0")
537        );
538    }
539
540    #[test]
541    fn rejects_protocol_managed_headers() {
542        for name in [
543            "Host",
544            "Connection",
545            "Content-Length",
546            "Transfer-Encoding",
547            "Trailer",
548            "Upgrade",
549        ] {
550            let error = Request::get("http://example.com")
551                .unwrap()
552                .add_header(name, "value")
553                .unwrap_err();
554            assert!(matches!(error, NanoGetError::ProtocolManagedHeader(_)));
555        }
556    }
557
558    #[test]
559    fn rejects_hop_by_hop_headers() {
560        for name in ["Keep-Alive", "Proxy-Connection", "TE"] {
561            let error = Request::get("http://example.com")
562                .unwrap()
563                .add_header(name, "value")
564                .unwrap_err();
565            assert!(matches!(error, NanoGetError::HopByHopHeader(_)));
566        }
567    }
568
569    #[test]
570    fn date_header_helpers_format_http_dates() {
571        let mut request = Request::get("http://example.com").unwrap();
572        request
573            .if_modified_since(UNIX_EPOCH + Duration::from_secs(784_111_777))
574            .unwrap();
575        request
576            .if_unmodified_since(UNIX_EPOCH + Duration::from_secs(784_111_777))
577            .unwrap();
578        request.if_match("\"etag\"").unwrap();
579        assert_eq!(
580            request.header("if-modified-since"),
581            Some("Sun, 06 Nov 1994 08:49:37 GMT")
582        );
583        assert_eq!(
584            request.header("if-unmodified-since"),
585            Some("Sun, 06 Nov 1994 08:49:37 GMT")
586        );
587        assert_eq!(request.header("if-match"), Some("\"etag\""));
588    }
589
590    #[test]
591    fn range_helper_supports_open_ended_ranges_and_rejects_invalid_values() {
592        let mut request = Request::get("http://example.com").unwrap();
593        request.range_bytes(Some(128), None).unwrap();
594        assert_eq!(request.header("range"), Some("bytes=128-"));
595
596        let error = request.range_bytes(Some(10), Some(2)).unwrap_err();
597        assert!(matches!(error, NanoGetError::InvalidHeaderValue(_)));
598    }
599}