1use std::{cell::RefCell, cmp, fmt, io::Cursor};
2
3use ntex_bytes::{ByteString, BytesMut};
4use ntex_http::{header, uri, HeaderMap, HeaderName, Method, StatusCode, Uri};
5
6use crate::hpack;
7
8use super::priority::StreamDependency;
9use super::{util, Frame, FrameError, Head, Kind, Protocol, StreamId};
10
11#[derive(Clone)]
15pub struct Headers {
16 stream_id: StreamId,
18
19 header_block: HeaderBlock,
21
22 flags: HeadersFlag,
24}
25
26#[derive(Copy, Clone, Eq, PartialEq)]
27pub struct HeadersFlag(u8);
28
29#[derive(Clone, Debug, Default)]
30pub struct PseudoHeaders {
31 pub method: Option<Method>,
33 pub scheme: Option<ByteString>,
34 pub authority: Option<ByteString>,
35 pub path: Option<ByteString>,
36 pub protocol: Option<Protocol>,
37
38 pub status: Option<StatusCode>,
40}
41
42pub struct Iter<'a> {
43 pseudo: Option<PseudoHeaders>,
45
46 fields: header::Iter<'a>,
48}
49
50#[derive(Debug, Clone)]
51struct HeaderBlock {
52 fields: HeaderMap,
54
55 pseudo: PseudoHeaders,
58}
59
60const END_STREAM: u8 = 0x1;
61const END_HEADERS: u8 = 0x4;
62const PADDED: u8 = 0x8;
63const PRIORITY: u8 = 0x20;
64const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
65
66impl Headers {
69 pub fn new(stream_id: StreamId, pseudo: PseudoHeaders, fields: HeaderMap, eof: bool) -> Self {
71 let mut flags = HeadersFlag::default();
72 if eof {
73 flags.set_end_stream();
74 }
75 Headers {
76 flags,
77 stream_id,
78 header_block: HeaderBlock { fields, pseudo },
79 }
80 }
81
82 pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
83 let mut flags = HeadersFlag::default();
84 flags.set_end_stream();
85
86 Headers {
87 stream_id,
88 flags,
89 header_block: HeaderBlock {
90 fields,
91 pseudo: PseudoHeaders::default(),
92 },
93 }
94 }
95
96 pub fn load(head: Head, src: &mut BytesMut) -> Result<Self, FrameError> {
100 let flags = HeadersFlag(head.flag());
101
102 if head.stream_id().is_zero() {
103 return Err(FrameError::InvalidStreamId);
104 }
105
106 let pad = if flags.is_padded() {
108 if src.is_empty() {
109 return Err(FrameError::MalformedMessage);
110 }
111 let pad = src[0] as usize;
112
113 let _ = src.split_to(1);
115 pad
116 } else {
117 0
118 };
119
120 if flags.is_priority() {
122 if src.len() < 5 {
123 return Err(FrameError::MalformedMessage);
124 }
125 let stream_dep = StreamDependency::load(&src[..5])?;
126
127 if stream_dep.dependency_id() == head.stream_id() {
128 return Err(FrameError::InvalidDependencyId);
129 }
130
131 let _ = src.split_to(5);
133 }
134
135 if pad > 0 {
136 if pad > src.len() {
137 return Err(FrameError::TooMuchPadding);
138 }
139 src.truncate(src.len() - pad);
140 }
141
142 Ok(Headers {
143 flags,
144 stream_id: head.stream_id(),
145 header_block: HeaderBlock {
146 fields: HeaderMap::new(),
147 pseudo: PseudoHeaders::default(),
148 },
149 })
150 }
151
152 pub fn load_hpack(
153 &mut self,
154 src: &mut BytesMut,
155 decoder: &mut hpack::Decoder,
156 ) -> Result<(), FrameError> {
157 self.header_block.load(src, decoder)
158 }
159
160 pub fn stream_id(&self) -> StreamId {
161 self.stream_id
162 }
163
164 pub fn is_end_headers(&self) -> bool {
165 self.flags.is_end_headers()
166 }
167
168 pub fn set_end_headers(&mut self) {
169 self.flags.set_end_headers();
170 }
171
172 pub fn is_end_stream(&self) -> bool {
173 self.flags.is_end_stream()
174 }
175
176 pub fn set_end_stream(&mut self) {
177 self.flags.set_end_stream()
178 }
179
180 pub fn into_parts(self) -> (PseudoHeaders, HeaderMap) {
181 (self.header_block.pseudo, self.header_block.fields)
182 }
183
184 pub fn fields(&self) -> &HeaderMap {
185 &self.header_block.fields
186 }
187
188 pub fn pseudo(&self) -> &PseudoHeaders {
189 &self.header_block.pseudo
190 }
191
192 pub fn into_fields(self) -> HeaderMap {
193 self.header_block.fields
194 }
195
196 pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut, max_size: usize) {
197 debug_assert!(self.flags.is_end_headers());
199
200 let head = self.head();
202
203 self.header_block.encode(encoder, &head, dst, max_size);
204 }
205
206 fn head(&self) -> Head {
207 Head::new(Kind::Headers, self.flags.into(), self.stream_id)
208 }
209}
210
211impl From<Headers> for Frame {
212 fn from(src: Headers) -> Self {
213 Frame::Headers(src)
214 }
215}
216
217impl fmt::Debug for Headers {
218 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219 let mut builder = f.debug_struct("Headers");
220 builder
221 .field("stream_id", &self.stream_id)
222 .field("flags", &self.flags)
223 .field("pseudo", &self.header_block.pseudo);
224
225 if let Some(ref protocol) = self.header_block.pseudo.protocol {
226 builder.field("protocol", protocol);
227 }
228
229 builder.finish()
231 }
232}
233
234impl PseudoHeaders {
237 pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
238 let parts = uri::Parts::from(uri);
239
240 let mut path = parts
241 .path_and_query
242 .map(|v| ByteString::from(v.as_str()))
243 .unwrap_or(ByteString::from_static(""));
244
245 match method {
246 Method::OPTIONS | Method::CONNECT => {}
247 _ if path.is_empty() => {
248 path = ByteString::from_static("/");
249 }
250 _ => {}
251 }
252
253 let mut pseudo = PseudoHeaders {
254 method: Some(method),
255 scheme: None,
256 authority: None,
257 path: Some(path).filter(|p| !p.is_empty()),
258 protocol,
259 status: None,
260 };
261
262 if let Some(scheme) = parts.scheme {
266 pseudo.set_scheme(scheme);
267 }
268
269 if let Some(authority) = parts.authority {
272 pseudo.set_authority(ByteString::from(authority.as_str()));
273 }
274
275 pseudo
276 }
277
278 pub fn response(status: StatusCode) -> Self {
279 PseudoHeaders {
280 method: None,
281 scheme: None,
282 authority: None,
283 path: None,
284 protocol: None,
285 status: Some(status),
286 }
287 }
288
289 pub fn set_status(&mut self, value: StatusCode) {
290 self.status = Some(value);
291 }
292
293 pub fn set_scheme(&mut self, scheme: uri::Scheme) {
294 self.scheme = Some(match scheme.as_str() {
295 "http" => ByteString::from_static("http"),
296 "https" => ByteString::from_static("https"),
297 s => ByteString::from(s),
298 });
299 }
300
301 pub fn set_protocol(&mut self, protocol: Protocol) {
302 self.protocol = Some(protocol);
303 }
304
305 pub fn set_authority(&mut self, authority: ByteString) {
306 self.authority = Some(authority);
307 }
308}
309
310impl Iterator for Iter<'_> {
313 type Item = hpack::Header<Option<HeaderName>>;
314
315 fn next(&mut self) -> Option<Self::Item> {
316 use crate::hpack::Header::*;
317
318 if let Some(ref mut pseudo) = self.pseudo {
319 if let Some(method) = pseudo.method.take() {
320 return Some(Method(method));
321 }
322
323 if let Some(scheme) = pseudo.scheme.take() {
324 return Some(Scheme(scheme));
325 }
326
327 if let Some(authority) = pseudo.authority.take() {
328 return Some(Authority(authority));
329 }
330
331 if let Some(path) = pseudo.path.take() {
332 return Some(Path(path));
333 }
334
335 if let Some(protocol) = pseudo.protocol.take() {
336 return Some(Protocol(protocol.into()));
337 }
338
339 if let Some(status) = pseudo.status.take() {
340 return Some(Status(status));
341 }
342 }
343
344 self.pseudo = None;
345
346 self.fields.next().map(|(name, value)| Field {
347 name: Some(name.clone()),
348 value: value.clone(),
349 })
350 }
351}
352
353impl HeadersFlag {
356 pub fn empty() -> HeadersFlag {
357 HeadersFlag(0)
358 }
359
360 pub fn load(bits: u8) -> HeadersFlag {
361 HeadersFlag(bits & ALL)
362 }
363
364 pub fn is_end_stream(&self) -> bool {
365 self.0 & END_STREAM == END_STREAM
366 }
367
368 pub fn set_end_stream(&mut self) {
369 self.0 |= END_STREAM;
370 }
371
372 pub fn is_end_headers(&self) -> bool {
373 self.0 & END_HEADERS == END_HEADERS
374 }
375
376 pub fn set_end_headers(&mut self) {
377 self.0 |= END_HEADERS;
378 }
379
380 pub fn is_padded(&self) -> bool {
381 self.0 & PADDED == PADDED
382 }
383
384 pub fn is_priority(&self) -> bool {
385 self.0 & PRIORITY == PRIORITY
386 }
387}
388
389impl Default for HeadersFlag {
390 fn default() -> Self {
392 HeadersFlag(END_HEADERS)
393 }
394}
395
396impl From<HeadersFlag> for u8 {
397 fn from(src: HeadersFlag) -> u8 {
398 src.0
399 }
400}
401
402impl fmt::Debug for HeadersFlag {
403 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
404 util::debug_flags(fmt, self.0)
405 .flag_if(self.is_end_headers(), "END_HEADERS")
406 .flag_if(self.is_end_stream(), "END_STREAM")
407 .flag_if(self.is_padded(), "PADDED")
408 .flag_if(self.is_priority(), "PRIORITY")
409 .finish()
410 }
411}
412
413thread_local! {
416 static HDRS_BUF: RefCell<BytesMut> = RefCell::new(BytesMut::with_capacity(1024));
417}
418
419impl HeaderBlock {
420 fn load(&mut self, src: &mut BytesMut, decoder: &mut hpack::Decoder) -> Result<(), FrameError> {
421 let mut reg = !self.fields.is_empty();
422 let mut malformed = false;
423
424 macro_rules! set_pseudo {
425 ($field:ident, $val:expr) => {{
426 if reg {
427 log::trace!("load_hpack; header malformed -- pseudo not at head of block");
428 malformed = true;
429 } else if self.pseudo.$field.is_some() {
430 log::trace!("load_hpack; header malformed -- repeated pseudo");
431 malformed = true;
432 } else {
433 self.pseudo.$field = Some($val.into());
434 }
435 }};
436 }
437
438 let mut cursor = Cursor::new(src);
439
440 let res = decoder.decode(&mut cursor, |header| {
445 use crate::hpack::Header::*;
446
447 match header {
448 Field { name, value } => {
449 if name == header::CONNECTION
453 || name == header::TRANSFER_ENCODING
454 || name == header::UPGRADE
455 || name == "keep-alive"
456 || name == "proxy-connection"
457 {
458 log::trace!("load_hpack; connection level header");
459 malformed = true;
460 } else if name == header::TE && value != "trailers" {
461 log::trace!("load_hpack; TE header not set to trailers; val={value:?}");
462 malformed = true;
463 } else {
464 reg = true;
465 self.fields.append(name, value);
466 }
467 }
468 Authority(v) => {
469 set_pseudo!(authority, v)
470 }
471 Method(v) => {
472 set_pseudo!(method, v)
473 }
474 Scheme(v) => {
475 set_pseudo!(scheme, v)
476 }
477 Path(v) => {
478 set_pseudo!(path, v)
479 }
480 Protocol(v) => {
481 set_pseudo!(protocol, v)
482 }
483 Status(v) => {
484 set_pseudo!(status, v)
485 }
486 }
487 });
488
489 if let Err(e) = res {
490 log::trace!("hpack decoding error; err={e:?}");
491 return Err(e.into());
492 }
493
494 if malformed {
495 log::trace!("malformed message");
496 return Err(FrameError::MalformedMessage);
497 }
498
499 Ok(())
500 }
501
502 fn encode(
503 self,
504 encoder: &mut hpack::Encoder,
505 head: &Head,
506 dst: &mut BytesMut,
507 max_size: usize,
508 ) {
509 HDRS_BUF.with(|buf| {
510 let mut b = buf.borrow_mut();
511 let hpack = &mut b;
512 hpack.clear();
513
514 let headers = Iter {
516 pseudo: Some(self.pseudo),
517 fields: self.fields.into_iter(),
518 };
519 encoder.encode(headers, hpack);
520
521 let mut head = *head;
522 let mut start = 0;
523 loop {
524 let end = cmp::min(start + max_size, hpack.len());
525
526 if hpack.len() > end {
528 Head::new(head.kind(), head.flag() ^ END_HEADERS, head.stream_id())
529 .encode(max_size, dst);
530 dst.extend_from_slice(&hpack[start..end]);
531 head = Head::new(Kind::Continuation, END_HEADERS, head.stream_id());
532 start = end;
533 } else {
534 head.encode(end - start, dst);
535 dst.extend_from_slice(&hpack[start..end]);
536 break;
537 }
538 }
539 });
540 }
541}
542
543#[cfg(test)]
544mod test {
545 use ntex_http::HeaderValue;
546
547 use super::*;
548 use crate::hpack::{huffman, Encoder};
549
550 #[test]
551 fn test_nameless_header_at_resume() {
552 let mut encoder = Encoder::default();
553 let mut dst = BytesMut::new();
554
555 let mut hdrs = HeaderMap::default();
556 hdrs.append(
557 HeaderName::from_static("hello"),
558 HeaderValue::from_static("world"),
559 );
560 hdrs.append(
561 HeaderName::from_static("hello"),
562 HeaderValue::from_static("zomg"),
563 );
564 hdrs.append(
565 HeaderName::from_static("hello"),
566 HeaderValue::from_static("sup"),
567 );
568
569 let mut headers = Headers::new(StreamId::CON, Default::default(), hdrs, false);
570 headers.set_end_headers();
571 headers.encode(&mut encoder, &mut dst, 8);
572 assert_eq!(48, dst.len());
573 assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
574 assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
575 assert_eq!("hello", huff_decode(&dst[11..15]));
576 assert_eq!(0x80 | 4, dst[15]);
577
578 let mut world = BytesMut::from(&dst[16..17]);
579 world.extend_from_slice(&dst[26..29]);
580 assert_eq!([0, 0, 8, 9, 0, 0, 0, 0, 0], &dst[17..26]);
583
584 }
590
591 fn huff_decode(src: &[u8]) -> BytesMut {
592 let mut buf = BytesMut::new();
593 huffman::decode(src, &mut buf).unwrap()
594 }
595}