1use std::time::SystemTime;
2
3use crate::auth::basic_authorization_value;
4use crate::client::Client;
5use crate::date::format_http_date;
6use crate::errors::NanoGetError;
7use crate::url::{ToUrl, Url};
8
9#[cfg(test)]
10const DEFAULT_USER_AGENT: &str = "nano-get/0.3.0";
11#[cfg(test)]
12const DEFAULT_ACCEPT: &str = "*/*";
13
14#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct Header {
19 name: String,
20 value: String,
21}
22
23impl Header {
24 pub fn new(name: impl Into<String>, value: impl Into<String>) -> Result<Self, NanoGetError> {
26 let name = name.into();
27 let value = value.into();
28 validate_header_name(&name)?;
29 validate_header_value(&value)?;
30 Ok(Self { name, value })
31 }
32
33 pub(crate) fn unchecked(name: impl Into<String>, value: impl Into<String>) -> Self {
34 Self {
35 name: name.into(),
36 value: value.into(),
37 }
38 }
39
40 pub fn name(&self) -> &str {
42 &self.name
43 }
44
45 pub fn value(&self) -> &str {
47 &self.value
48 }
49
50 pub fn matches_name(&self, needle: &str) -> bool {
52 self.name.eq_ignore_ascii_case(needle)
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum Method {
59 Get,
61 Head,
63}
64
65impl Method {
66 pub fn as_str(self) -> &'static str {
68 match self {
69 Self::Get => "GET",
70 Self::Head => "HEAD",
71 }
72 }
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum RedirectPolicy {
78 None,
80 Follow {
82 max_redirects: usize,
84 },
85}
86
87impl RedirectPolicy {
88 pub const fn none() -> Self {
90 Self::None
91 }
92
93 pub const fn follow(max_redirects: usize) -> Self {
95 Self::Follow { max_redirects }
96 }
97
98 pub fn max_redirects(self) -> Option<usize> {
100 match self {
101 Self::None => None,
102 Self::Follow { max_redirects } => Some(max_redirects),
103 }
104 }
105}
106
107#[derive(Debug, Clone)]
109pub struct Request {
110 url: Url,
111 method: Method,
112 headers: Vec<Header>,
113 redirect_policy: RedirectPolicy,
114 redirect_policy_explicit: bool,
115 preemptive_origin_auth_allowed: bool,
116}
117
118impl Request {
119 pub fn new<U: ToUrl>(method: Method, url: U) -> Result<Self, NanoGetError> {
121 Ok(Self {
122 url: url.to_url()?,
123 method,
124 headers: Vec::new(),
125 redirect_policy: RedirectPolicy::none(),
126 redirect_policy_explicit: false,
127 preemptive_origin_auth_allowed: true,
128 })
129 }
130
131 pub fn get<U: ToUrl>(url: U) -> Result<Self, NanoGetError> {
133 Self::new(Method::Get, url)
134 }
135
136 pub fn head<U: ToUrl>(url: U) -> Result<Self, NanoGetError> {
138 Self::new(Method::Head, url)
139 }
140
141 pub fn method(&self) -> Method {
143 self.method
144 }
145
146 pub fn url(&self) -> &Url {
148 &self.url
149 }
150
151 pub fn headers(&self) -> &[Header] {
153 &self.headers
154 }
155
156 pub fn header(&self, name: &str) -> Option<&str> {
158 self.headers
159 .iter()
160 .find(|header| header.matches_name(name))
161 .map(Header::value)
162 }
163
164 pub fn headers_named<'a>(&'a self, name: &'a str) -> impl Iterator<Item = &'a Header> + 'a {
166 self.headers
167 .iter()
168 .filter(move |header| header.matches_name(name))
169 }
170
171 pub fn redirect_policy(&self) -> RedirectPolicy {
173 self.redirect_policy
174 }
175
176 pub fn with_redirect_policy(mut self, policy: RedirectPolicy) -> Self {
178 self.redirect_policy = policy;
179 self.redirect_policy_explicit = true;
180 self
181 }
182
183 pub fn set_redirect_policy(&mut self, policy: RedirectPolicy) -> &mut Self {
185 self.redirect_policy = policy;
186 self.redirect_policy_explicit = true;
187 self
188 }
189
190 pub fn add_header(
194 &mut self,
195 name: impl Into<String>,
196 value: impl Into<String>,
197 ) -> Result<&mut Self, NanoGetError> {
198 let name = name.into();
199 validate_request_header_name(&name)?;
200 self.headers.push(Header::new(name, value)?);
201 Ok(self)
202 }
203
204 pub fn set_header(
208 &mut self,
209 name: impl Into<String>,
210 value: impl Into<String>,
211 ) -> Result<&mut Self, NanoGetError> {
212 let name = name.into();
213 validate_request_header_name(&name)?;
214 self.remove_headers_named(&name);
215 self.headers.push(Header::new(name, value)?);
216 Ok(self)
217 }
218
219 pub fn remove_headers_named(&mut self, name: &str) -> &mut Self {
221 self.headers.retain(|header| !header.matches_name(name));
222 self
223 }
224
225 pub fn if_none_match(&mut self, etag: impl Into<String>) -> Result<&mut Self, NanoGetError> {
227 self.set_header("If-None-Match", etag)
228 }
229
230 pub fn if_match(&mut self, etag: impl Into<String>) -> Result<&mut Self, NanoGetError> {
232 self.set_header("If-Match", etag)
233 }
234
235 pub fn if_modified_since(&mut self, timestamp: SystemTime) -> Result<&mut Self, NanoGetError> {
237 self.set_header("If-Modified-Since", format_http_date(timestamp)?)
238 }
239
240 pub fn if_unmodified_since(
242 &mut self,
243 timestamp: SystemTime,
244 ) -> Result<&mut Self, NanoGetError> {
245 self.set_header("If-Unmodified-Since", format_http_date(timestamp)?)
246 }
247
248 pub fn if_range(&mut self, value: impl Into<String>) -> Result<&mut Self, NanoGetError> {
250 self.set_header("If-Range", value)
251 }
252
253 pub fn authorization(&mut self, value: impl Into<String>) -> Result<&mut Self, NanoGetError> {
257 self.set_header("Authorization", value)
258 }
259
260 pub fn proxy_authorization(
265 &mut self,
266 value: impl Into<String>,
267 ) -> Result<&mut Self, NanoGetError> {
268 self.set_header("Proxy-Authorization", value)
269 }
270
271 pub fn basic_auth(
274 &mut self,
275 username: impl Into<String>,
276 password: impl Into<String>,
277 ) -> Result<&mut Self, NanoGetError> {
278 self.authorization(basic_authorization_value(username.into(), password.into()))
279 }
280
281 pub fn proxy_basic_auth(
284 &mut self,
285 username: impl Into<String>,
286 password: impl Into<String>,
287 ) -> Result<&mut Self, NanoGetError> {
288 self.proxy_authorization(basic_authorization_value(username.into(), password.into()))
289 }
290
291 pub fn range_bytes(
298 &mut self,
299 start: Option<u64>,
300 end: Option<u64>,
301 ) -> Result<&mut Self, NanoGetError> {
302 let range = match (start, end) {
303 (Some(start), Some(end)) if start <= end => format!("bytes={start}-{end}"),
304 (Some(start), None) => format!("bytes={start}-"),
305 (None, Some(end)) => format!("bytes=-{end}"),
306 _ => {
307 return Err(NanoGetError::InvalidHeaderValue(
308 "invalid byte range".to_string(),
309 ))
310 }
311 };
312
313 self.set_header("Range", range)
314 }
315
316 pub fn execute(&self) -> Result<crate::response::Response, NanoGetError> {
320 Client::default().execute(self.clone())
321 }
322
323 pub(crate) fn has_header(&self, name: &str) -> bool {
324 self.headers.iter().any(|header| header.matches_name(name))
325 }
326
327 #[cfg(test)]
328 pub(crate) fn default_headers(&self) -> [Header; 4] {
329 self.default_headers_for(true)
330 }
331
332 #[cfg(test)]
333 pub(crate) fn default_headers_for(&self, connection_close: bool) -> [Header; 4] {
334 [
335 Header::unchecked("Host", self.url.host_header_value()),
336 Header::unchecked("User-Agent", DEFAULT_USER_AGENT),
337 Header::unchecked("Accept", DEFAULT_ACCEPT),
338 Header::unchecked(
339 "Connection",
340 if connection_close {
341 "close"
342 } else {
343 "keep-alive"
344 },
345 ),
346 ]
347 }
348
349 pub(crate) fn clone_with_url(&self, url: Url) -> Self {
350 let mut cloned = self.clone();
351 cloned.url = url;
352 cloned
353 }
354
355 pub(crate) fn effective_redirect_policy(&self, fallback: RedirectPolicy) -> RedirectPolicy {
356 if self.redirect_policy_explicit {
357 self.redirect_policy
358 } else {
359 fallback
360 }
361 }
362
363 pub(crate) fn preemptive_origin_auth_allowed(&self) -> bool {
364 self.preemptive_origin_auth_allowed
365 }
366
367 pub(crate) fn disable_preemptive_origin_auth(&mut self) {
368 self.preemptive_origin_auth_allowed = false;
369 }
370}
371
372fn validate_header_name(name: &str) -> Result<(), NanoGetError> {
373 if name.is_empty() || !name.as_bytes().iter().all(|byte| is_tchar(*byte)) {
374 return Err(NanoGetError::InvalidHeaderName(name.to_string()));
375 }
376 Ok(())
377}
378
379fn validate_header_value(value: &str) -> Result<(), NanoGetError> {
380 if value
381 .chars()
382 .any(|ch| ch == '\r' || ch == '\n' || (ch.is_ascii_control() && ch != '\t'))
383 {
384 return Err(NanoGetError::InvalidHeaderValue(value.to_string()));
385 }
386 Ok(())
387}
388
389fn is_tchar(byte: u8) -> bool {
390 byte.is_ascii_alphanumeric()
391 || matches!(
392 byte,
393 b'!' | b'#'
394 | b'$'
395 | b'%'
396 | b'&'
397 | b'\''
398 | b'*'
399 | b'+'
400 | b'-'
401 | b'.'
402 | b'^'
403 | b'_'
404 | b'`'
405 | b'|'
406 | b'~'
407 )
408}
409
410fn validate_request_header_name(name: &str) -> Result<(), NanoGetError> {
411 if matches_protocol_managed_header(name) {
412 return Err(NanoGetError::ProtocolManagedHeader(name.to_string()));
413 }
414
415 if matches_hop_by_hop_header(name) {
416 return Err(NanoGetError::HopByHopHeader(name.to_string()));
417 }
418
419 Ok(())
420}
421
422fn matches_protocol_managed_header(name: &str) -> bool {
423 matches!(
424 name.to_ascii_lowercase().as_str(),
425 "host" | "connection" | "content-length" | "transfer-encoding" | "trailer" | "upgrade"
426 )
427}
428
429fn matches_hop_by_hop_header(name: &str) -> bool {
430 matches!(
431 name.to_ascii_lowercase().as_str(),
432 "keep-alive" | "proxy-connection" | "te"
433 )
434}
435
436pub(crate) fn should_follow_redirect(status_code: u16) -> bool {
437 matches!(status_code, 301 | 302 | 303 | 307 | 308)
438}
439
440#[cfg(test)]
441mod tests {
442 use std::time::{Duration, UNIX_EPOCH};
443
444 use super::{Method, RedirectPolicy, Request};
445 use crate::errors::NanoGetError;
446
447 #[test]
448 fn request_defaults_to_no_redirects() {
449 let request = Request::get("http://example.com").unwrap();
450 assert_eq!(request.redirect_policy(), RedirectPolicy::None);
451 }
452
453 #[test]
454 fn add_header_validates_name() {
455 let error = Request::get("http://example.com")
456 .unwrap()
457 .add_header("bad:name", "value")
458 .unwrap_err();
459 assert!(matches!(error, NanoGetError::InvalidHeaderName(_)));
460
461 let error = Request::get("http://example.com")
462 .unwrap()
463 .add_header("bad(name)", "value")
464 .unwrap_err();
465 assert!(matches!(error, NanoGetError::InvalidHeaderName(_)));
466 }
467
468 #[test]
469 fn add_header_validates_value() {
470 let error = Request::get("http://example.com")
471 .unwrap()
472 .add_header("x-test", "bad\r\nvalue")
473 .unwrap_err();
474 assert!(matches!(error, NanoGetError::InvalidHeaderValue(_)));
475
476 let error = Request::get("http://example.com")
477 .unwrap()
478 .add_header("x-test", "bad\u{0000}value")
479 .unwrap_err();
480 assert!(matches!(error, NanoGetError::InvalidHeaderValue(_)));
481 }
482
483 #[test]
484 fn builder_updates_redirect_policy() {
485 let request = Request::head("http://example.com")
486 .unwrap()
487 .with_redirect_policy(RedirectPolicy::follow(5));
488 assert_eq!(request.method(), Method::Head);
489 assert_eq!(request.redirect_policy().max_redirects(), Some(5));
490 assert_eq!(RedirectPolicy::none().max_redirects(), None);
491 }
492
493 #[test]
494 fn set_redirect_policy_updates_in_place() {
495 let mut request = Request::get("http://example.com").unwrap();
496 request.set_redirect_policy(RedirectPolicy::follow(2));
497 assert_eq!(request.redirect_policy().max_redirects(), Some(2));
498 }
499
500 #[test]
501 fn default_headers_include_host() {
502 let request = Request::get("http://example.com:8080/path").unwrap();
503 let headers = request.default_headers();
504 assert!(headers
505 .iter()
506 .any(|header| { header.matches_name("host") && header.value() == "example.com:8080" }));
507 }
508
509 #[test]
510 fn set_header_replaces_existing_values() {
511 let mut request = Request::get("http://example.com").unwrap();
512 request.add_header("X-Test", "one").unwrap();
513 request.set_header("x-test", "two").unwrap();
514 let values: Vec<_> = request
515 .headers_named("X-Test")
516 .map(|header| header.value())
517 .collect();
518 assert_eq!(values, vec!["two"]);
519 }
520
521 #[test]
522 fn range_header_helper_supports_suffixes() {
523 let mut request = Request::get("http://example.com").unwrap();
524 request.range_bytes(None, Some(128)).unwrap();
525 assert_eq!(request.header("range"), Some("bytes=-128"));
526 }
527
528 #[test]
529 fn authorization_helpers_set_headers() {
530 let mut request = Request::get("http://example.com").unwrap();
531 request.basic_auth("user", "pass").unwrap();
532 request.proxy_basic_auth("proxy", "secret").unwrap();
533 assert_eq!(request.header("authorization"), Some("Basic dXNlcjpwYXNz"));
534 assert_eq!(
535 request.header("proxy-authorization"),
536 Some("Basic cHJveHk6c2VjcmV0")
537 );
538 }
539
540 #[test]
541 fn rejects_protocol_managed_headers() {
542 for name in [
543 "Host",
544 "Connection",
545 "Content-Length",
546 "Transfer-Encoding",
547 "Trailer",
548 "Upgrade",
549 ] {
550 let error = Request::get("http://example.com")
551 .unwrap()
552 .add_header(name, "value")
553 .unwrap_err();
554 assert!(matches!(error, NanoGetError::ProtocolManagedHeader(_)));
555 }
556 }
557
558 #[test]
559 fn rejects_hop_by_hop_headers() {
560 for name in ["Keep-Alive", "Proxy-Connection", "TE"] {
561 let error = Request::get("http://example.com")
562 .unwrap()
563 .add_header(name, "value")
564 .unwrap_err();
565 assert!(matches!(error, NanoGetError::HopByHopHeader(_)));
566 }
567 }
568
569 #[test]
570 fn date_header_helpers_format_http_dates() {
571 let mut request = Request::get("http://example.com").unwrap();
572 request
573 .if_modified_since(UNIX_EPOCH + Duration::from_secs(784_111_777))
574 .unwrap();
575 request
576 .if_unmodified_since(UNIX_EPOCH + Duration::from_secs(784_111_777))
577 .unwrap();
578 request.if_match("\"etag\"").unwrap();
579 assert_eq!(
580 request.header("if-modified-since"),
581 Some("Sun, 06 Nov 1994 08:49:37 GMT")
582 );
583 assert_eq!(
584 request.header("if-unmodified-since"),
585 Some("Sun, 06 Nov 1994 08:49:37 GMT")
586 );
587 assert_eq!(request.header("if-match"), Some("\"etag\""));
588 }
589
590 #[test]
591 fn range_helper_supports_open_ended_ranges_and_rejects_invalid_values() {
592 let mut request = Request::get("http://example.com").unwrap();
593 request.range_bytes(Some(128), None).unwrap();
594 assert_eq!(request.header("range"), Some("bytes=128-"));
595
596 let error = request.range_bytes(Some(10), Some(2)).unwrap_err();
597 assert!(matches!(error, NanoGetError::InvalidHeaderValue(_)));
598 }
599}