1use std::{
2 convert::TryFrom,
3 fmt,
4 iter::{IntoIterator, Iterator},
5 str::FromStr,
6};
7
8use http::{
9 header::{self, HeaderName, HeaderValue},
10 uri::{self, Authority, Parts, PathAndQuery, Scheme, Uri},
11 Extensions, HeaderMap, Method, StatusCode,
12};
13
14use crate::{ext::Protocol, qpack::HeaderField};
15
16#[derive(Debug)]
17#[cfg_attr(test, derive(PartialEq, Clone))]
18pub struct Header {
19 pseudo: Pseudo,
20 fields: HeaderMap,
21}
22
23#[allow(clippy::len_without_is_empty)]
24impl Header {
25 pub fn request(
27 method: Method,
28 uri: Uri,
29 fields: HeaderMap,
30 ext: Extensions,
31 ) -> Result<Self, HeaderError> {
32 match (uri.authority(), fields.get("host")) {
33 (None, None) => Err(HeaderError::MissingAuthority),
34 (Some(a), Some(h)) if a.as_str() != h => Err(HeaderError::ContradictedAuthority),
35 _ => Ok(Self {
36 pseudo: Pseudo::request(method, uri, ext),
37 fields,
38 }),
39 }
40 }
41
42 pub fn response(status: StatusCode, fields: HeaderMap) -> Self {
43 Self {
44 pseudo: Pseudo::response(status),
45 fields,
46 }
47 }
48
49 pub fn trailer(fields: HeaderMap) -> Self {
50 Self {
51 pseudo: Pseudo::default(),
55 fields,
56 }
57 }
58
59 pub fn into_request_parts(
60 self,
61 ) -> Result<(Method, Uri, Option<Protocol>, HeaderMap), HeaderError> {
62 let mut uri = Uri::builder();
63
64 if let Some(path) = self.pseudo.path {
65 uri = uri.path_and_query(path.as_str().as_bytes());
66 }
67
68 if let Some(scheme) = self.pseudo.scheme {
69 uri = uri.scheme(scheme.as_str().as_bytes());
70 }
71
72 match (self.pseudo.authority, self.fields.get("host")) {
84 (None, None) => return Err(HeaderError::MissingAuthority),
85 (Some(a), None) => uri = uri.authority(a.as_str().as_bytes()),
86 (None, Some(h)) => uri = uri.authority(h.as_bytes()),
87 (Some(a), Some(h)) if a.as_str() != h => {
90 return Err(HeaderError::ContradictedAuthority)
91 }
92 (Some(_), Some(h)) => uri = uri.authority(h.as_bytes()),
93 }
94
95 Ok((
96 self.pseudo.method.ok_or(HeaderError::MissingMethod)?,
97 uri.build().map_err(HeaderError::InvalidRequest)?,
102 self.pseudo.protocol,
103 self.fields,
104 ))
105 }
106
107 pub fn into_response_parts(self) -> Result<(StatusCode, HeaderMap), HeaderError> {
108 Ok((
115 self.pseudo.status.ok_or(HeaderError::MissingStatus)?,
116 self.fields,
117 ))
118 }
119
120 pub fn into_fields(self) -> HeaderMap {
121 self.fields
122 }
123
124 pub fn len(&self) -> usize {
125 self.pseudo.len() + self.fields.len()
126 }
127
128 pub fn size(&self) -> usize {
129 self.pseudo.len() + self.fields.len()
130 }
131
132 #[cfg(test)]
133 pub(crate) fn authory_mut(&mut self) -> &mut Option<Authority> {
134 &mut self.pseudo.authority
135 }
136}
137
138impl IntoIterator for Header {
139 type Item = HeaderField;
140 type IntoIter = HeaderIter;
141 fn into_iter(self) -> Self::IntoIter {
142 HeaderIter {
143 pseudo: Some(self.pseudo),
144 last_header_name: None,
145 fields: self.fields.into_iter(),
146 }
147 }
148}
149
150pub struct HeaderIter {
151 pseudo: Option<Pseudo>,
152 last_header_name: Option<HeaderName>,
153 fields: header::IntoIter<HeaderValue>,
154}
155
156impl Iterator for HeaderIter {
157 type Item = HeaderField;
158
159 fn next(&mut self) -> Option<Self::Item> {
160 if let Some(ref mut pseudo) = self.pseudo {
164 if let Some(method) = pseudo.method.take() {
165 return Some((":method", method.as_str()).into());
166 }
167
168 if let Some(scheme) = pseudo.scheme.take() {
169 return Some((":scheme", scheme.as_str().as_bytes()).into());
170 }
171
172 if let Some(authority) = pseudo.authority.take() {
173 return Some((":authority", authority.as_str().as_bytes()).into());
174 }
175
176 if let Some(path) = pseudo.path.take() {
177 return Some((":path", path.as_str().as_bytes()).into());
178 }
179
180 if let Some(status) = pseudo.status.take() {
181 return Some((":status", status.as_str()).into());
182 }
183
184 if let Some(protocol) = pseudo.protocol.take() {
185 return Some((":protocol", protocol.as_str().as_bytes()).into());
186 }
187 }
188
189 self.pseudo = None;
190
191 for (new_header_name, header_value) in self.fields.by_ref() {
192 if let Some(new) = new_header_name {
193 self.last_header_name = Some(new);
194 }
195 if let (Some(ref n), v) = (&self.last_header_name, header_value) {
196 return Some((n.as_str(), v.as_bytes()).into());
197 }
198 }
199
200 None
201 }
202}
203
204impl TryFrom<Vec<HeaderField>> for Header {
205 type Error = HeaderError;
206 fn try_from(headers: Vec<HeaderField>) -> Result<Self, Self::Error> {
207 let mut fields = HeaderMap::with_capacity(headers.len());
208 let mut pseudo = Pseudo::default();
209
210 for field in headers.into_iter() {
211 let (name, value) = field.into_inner();
212 match Field::parse(name, value)? {
213 Field::Method(m) => {
214 pseudo.method = Some(m);
215 pseudo.len += 1;
216 }
217 Field::Scheme(s) => {
218 pseudo.scheme = Some(s);
219 pseudo.len += 1;
220 }
221 Field::Authority(a) => {
222 pseudo.authority = Some(a);
223 pseudo.len += 1;
224 }
225 Field::Path(p) => {
226 pseudo.path = Some(p);
227 pseudo.len += 1;
228 }
229 Field::Status(s) => {
230 pseudo.status = Some(s);
231 pseudo.len += 1;
232 }
233 Field::Header((n, v)) => {
234 fields.append(n, v);
235 }
236 Field::Protocol(p) => {
237 pseudo.protocol = Some(p);
238 pseudo.len += 1;
239 }
240 }
241 }
242
243 Ok(Header { pseudo, fields })
244 }
245}
246
247enum Field {
248 Method(Method),
249 Scheme(Scheme),
250 Authority(Authority),
251 Path(PathAndQuery),
252 Status(StatusCode),
253 Protocol(Protocol),
254 Header((HeaderName, HeaderValue)),
255}
256
257impl Field {
258 fn parse<N, V>(name: N, value: V) -> Result<Self, HeaderError>
259 where
260 N: AsRef<[u8]>,
261 V: AsRef<[u8]>,
262 {
263 let name = name.as_ref();
264 if name.is_empty() {
265 return Err(HeaderError::InvalidHeaderName("name is empty".into()));
266 }
267
268 if name[0] != b':' {
284 return Ok(Field::Header((
285 HeaderName::from_lowercase(name).map_err(|_| HeaderError::invalid_name(name))?,
286 HeaderValue::from_bytes(value.as_ref())
287 .map_err(|_| HeaderError::invalid_value(name, value))?,
288 )));
289 }
290
291 Ok(match name {
292 b":scheme" => Field::Scheme(try_value(name, value)?),
293 b":authority" => Field::Authority(try_value(name, value)?),
297 b":path" => Field::Path(try_value(name, value)?),
298 b":method" => Field::Method(
299 Method::from_bytes(value.as_ref())
300 .map_err(|_| HeaderError::invalid_value(name, value))?,
301 ),
302 b":status" => Field::Status(
303 StatusCode::from_bytes(value.as_ref())
304 .map_err(|_| HeaderError::invalid_value(name, value))?,
305 ),
306 b":protocol" => Field::Protocol(try_value(name, value)?),
307 _ => return Err(HeaderError::invalid_name(name)),
308 })
309 }
310}
311
312fn try_value<N, V, R>(name: N, value: V) -> Result<R, HeaderError>
313where
314 N: AsRef<[u8]>,
315 V: AsRef<[u8]>,
316 R: FromStr,
317{
318 let (name, value) = (name.as_ref(), value.as_ref());
319 let s = std::str::from_utf8(value).map_err(|_| HeaderError::invalid_value(name, value))?;
320 R::from_str(s).map_err(|_| HeaderError::invalid_value(name, value))
321}
322
323#[derive(Debug, Default)]
329#[cfg_attr(test, derive(PartialEq, Clone))]
330struct Pseudo {
331 method: Option<Method>,
339 scheme: Option<Scheme>,
340 authority: Option<Authority>,
341 path: Option<PathAndQuery>,
342
343 status: Option<StatusCode>,
345
346 protocol: Option<Protocol>,
347
348 len: usize,
349}
350
351#[allow(clippy::len_without_is_empty)]
352impl Pseudo {
353 fn request(method: Method, uri: Uri, ext: Extensions) -> Self {
354 let Parts {
355 scheme,
356 authority,
357 path_and_query,
358 ..
359 } = uri::Parts::from(uri);
360
361 let path = path_and_query.map_or_else(
367 || PathAndQuery::from_static("/"),
368 |path| {
369 if path.path().is_empty() && method != Method::OPTIONS {
370 PathAndQuery::from_static("/")
371 } else {
372 path
373 }
374 },
375 );
376
377 let protocol = if method == Method::CONNECT {
381 ext.get::<Protocol>().copied()
382 } else {
383 None
384 };
385
386 let len = 3 + authority.is_some() as usize + protocol.is_some() as usize;
387
388 Self {
400 method: Some(method),
401 scheme: scheme.or(Some(Scheme::HTTPS)),
402 authority,
403 path: Some(path),
404 status: None,
405 protocol,
406 len,
407 }
408 }
409
410 fn response(status: StatusCode) -> Self {
411 Pseudo {
417 method: None,
418 scheme: None,
419 authority: None,
420 path: None,
421 status: Some(status),
422 len: 1,
423 protocol: None,
424 }
425 }
426
427 fn len(&self) -> usize {
428 self.len
429 }
430}
431
432#[derive(Debug)]
433pub enum HeaderError {
434 InvalidHeaderName(String),
435 InvalidHeaderValue(String),
436 InvalidRequest(http::Error),
437 MissingMethod,
438 MissingStatus,
439 MissingAuthority,
440 ContradictedAuthority,
441}
442
443impl HeaderError {
444 fn invalid_name<N>(name: N) -> Self
445 where
446 N: AsRef<[u8]>,
447 {
448 HeaderError::InvalidHeaderName(format!("{:?}", name.as_ref()))
449 }
450
451 fn invalid_value<N, V>(name: N, value: V) -> Self
452 where
453 N: AsRef<[u8]>,
454 V: AsRef<[u8]>,
455 {
456 HeaderError::InvalidHeaderValue(format!(
457 "{:?} {:?}",
458 String::from_utf8_lossy(name.as_ref()),
459 value.as_ref()
460 ))
461 }
462}
463
464impl std::error::Error for HeaderError {}
465
466impl fmt::Display for HeaderError {
467 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
468 match self {
469 HeaderError::InvalidHeaderName(h) => write!(f, "invalid header name: {}", h),
470 HeaderError::InvalidHeaderValue(v) => write!(f, "invalid header value: {}", v),
471 HeaderError::InvalidRequest(r) => write!(f, "invalid request: {}", r),
472 HeaderError::MissingMethod => write!(f, "missing method in request headers"),
473 HeaderError::MissingStatus => write!(f, "missing status in response headers"),
474 HeaderError::MissingAuthority => write!(f, "missing authority"),
475 HeaderError::ContradictedAuthority => {
476 write!(f, "uri and authority field are in contradiction")
477 }
478 }
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use assert_matches::assert_matches;
486
487 #[test]
488 fn request_has_no_authority_nor_host() {
489 let headers = Header::try_from(vec![(b":method", Method::GET.as_str()).into()]).unwrap();
496 assert!(headers.pseudo.authority.is_none());
497 assert_matches!(
498 headers.into_request_parts(),
499 Err(HeaderError::MissingAuthority)
500 );
501 }
502
503 #[test]
504 fn request_has_empty_authority() {
505 assert_matches!(
510 Header::try_from(vec![
511 (b":method", Method::GET.as_str()).into(),
512 (b":authority", b"").into(),
513 ]),
514 Err(HeaderError::InvalidHeaderValue(_))
515 );
516 }
517
518 #[test]
519 fn request_has_empty_host() {
520 let headers = Header::try_from(vec![
525 (b":method", Method::GET.as_str()).into(),
526 (b"host", b"").into(),
527 ])
528 .unwrap();
529 assert_matches!(
530 headers.into_request_parts(),
531 Err(HeaderError::InvalidRequest(_))
532 );
533 }
534
535 #[test]
536 fn request_has_authority() {
537 let headers = Header::try_from(vec![
544 (b":method", Method::GET.as_str()).into(),
545 (b":authority", b"test.com").into(),
546 ])
547 .unwrap();
548 assert_matches!(headers.into_request_parts(), Ok(_));
549 }
550
551 #[test]
552 fn request_has_host() {
553 let headers = Header::try_from(vec![
560 (b":method", Method::GET.as_str()).into(),
561 (b"host", b"test.com").into(),
562 ])
563 .unwrap();
564 assert!(headers.pseudo.authority.is_none());
565 assert_matches!(headers.into_request_parts(), Ok(_));
566 }
567
568 #[test]
569 fn request_has_same_host_and_authority() {
570 let headers = Header::try_from(vec![
574 (b":method", Method::GET.as_str()).into(),
575 (b":authority", b"test.com").into(),
576 (b"host", b"test.com").into(),
577 ])
578 .unwrap();
579 assert_matches!(headers.into_request_parts(), Ok(_));
580 }
581 #[test]
582 fn request_has_different_host_and_authority() {
583 let headers = Header::try_from(vec![
587 (b":method", Method::GET.as_str()).into(),
588 (b":authority", b"authority.com").into(),
589 (b"host", b"host.com").into(),
590 ])
591 .unwrap();
592 assert_matches!(
593 headers.into_request_parts(),
594 Err(HeaderError::ContradictedAuthority)
595 );
596 }
597
598 #[test]
599 fn preserves_duplicate_headers() {
600 let headers = Header::try_from(vec![
601 (b":method", Method::GET.as_str()).into(),
602 (b":authority", b"test.com").into(),
603 (b"set-cookie", b"foo=foo").into(),
604 (b"set-cookie", b"bar=bar").into(),
605 (b"other-header", b"other-header-value").into(),
606 ])
607 .unwrap();
608
609 assert_eq!(
610 headers
611 .clone()
612 .into_iter()
613 .filter(|h| h.name.as_ref() == b"set-cookie")
614 .collect::<Vec<_>>(),
615 vec![
616 HeaderField {
617 name: std::borrow::Cow::Borrowed(b"set-cookie"),
618 value: std::borrow::Cow::Borrowed(b"foo=foo")
619 },
620 HeaderField {
621 name: std::borrow::Cow::Borrowed(b"set-cookie"),
622 value: std::borrow::Cow::Borrowed(b"bar=bar")
623 }
624 ]
625 );
626 assert_eq!(
627 headers
628 .into_iter()
629 .filter(|h| h.name.as_ref() == b"other-header")
630 .collect::<Vec<_>>(),
631 vec![HeaderField {
632 name: std::borrow::Cow::Borrowed(b"other-header"),
633 value: std::borrow::Cow::Borrowed(b"other-header-value")
634 },]
635 );
636 }
637}