1use std::{cell::RefCell, cmp, fmt, io::Cursor};
2
3use ntex_bytes::{ByteString, Bytes, BytesMut};
4use ntex_http::{HeaderMap, HeaderName, Method, StatusCode, Uri, header, uri};
5
6use crate::hpack;
7
8use super::priority::StreamDependency;
9use super::{Frame, FrameError, Head, Kind, Protocol, StreamId, util};
10
11#[derive(Clone)]
15pub struct Headers {
16 stream_id: StreamId,
18
19 header_block: HeaderBlock,
21
22 flags: HeadersFlag,
24}
25
26#[derive(Copy, Clone, Eq, PartialEq)]
27pub struct HeadersFlag(u8);
28
29#[derive(Clone, Debug, Default)]
30pub struct PseudoHeaders {
31 pub method: Option<Method>,
33 pub scheme: Option<ByteString>,
34 pub authority: Option<ByteString>,
35 pub path: Option<ByteString>,
36 pub protocol: Option<Protocol>,
37
38 pub status: Option<StatusCode>,
40}
41
42pub(super) struct Iter<'a> {
43 pseudo: Option<PseudoHeaders>,
45
46 fields: header::Iter<'a>,
48}
49
50#[derive(Debug, Clone)]
51struct HeaderBlock {
52 fields: HeaderMap,
54
55 pseudo: PseudoHeaders,
58}
59
60const END_STREAM: u8 = 0x1;
61const END_HEADERS: u8 = 0x4;
62const PADDED: u8 = 0x8;
63const PRIORITY: u8 = 0x20;
64const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
65
66impl Headers {
69 pub fn new(stream_id: StreamId, pseudo: PseudoHeaders, fields: HeaderMap, eof: bool) -> Self {
71 let mut flags = HeadersFlag::default();
72 if eof {
73 flags.set_end_stream();
74 }
75 Headers {
76 flags,
77 stream_id,
78 header_block: HeaderBlock { fields, pseudo },
79 }
80 }
81
82 pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
83 let mut flags = HeadersFlag::default();
84 flags.set_end_stream();
85
86 Headers {
87 stream_id,
88 flags,
89 header_block: HeaderBlock {
90 fields,
91 pseudo: PseudoHeaders::default(),
92 },
93 }
94 }
95
96 pub fn load(head: Head, src: &mut Bytes) -> Result<Self, FrameError> {
100 let flags = HeadersFlag(head.flag());
101
102 if head.stream_id().is_zero() {
103 return Err(FrameError::InvalidStreamId);
104 }
105
106 let pad = if flags.is_padded() {
108 if src.is_empty() {
109 return Err(FrameError::MalformedMessage);
110 }
111 let pad = src[0] as usize;
112
113 src.advance_to(1);
115 pad
116 } else {
117 0
118 };
119
120 if flags.is_priority() {
122 if src.len() < 5 {
123 return Err(FrameError::MalformedMessage);
124 }
125 let stream_dep = StreamDependency::load(&src[..5])?;
126
127 if stream_dep.dependency_id() == head.stream_id() {
128 return Err(FrameError::InvalidDependencyId);
129 }
130
131 src.advance_to(5);
133 }
134
135 if pad > 0 {
136 if pad > src.len() {
137 return Err(FrameError::TooMuchPadding);
138 }
139 src.truncate(src.len() - pad);
140 }
141
142 Ok(Headers {
143 flags,
144 stream_id: head.stream_id(),
145 header_block: HeaderBlock {
146 fields: HeaderMap::new(),
147 pseudo: PseudoHeaders::default(),
148 },
149 })
150 }
151
152 pub fn load_hpack(
153 &mut self,
154 src: &mut Bytes,
155 decoder: &mut hpack::Decoder,
156 ) -> Result<(), FrameError> {
157 self.header_block.load(src, decoder)
158 }
159
160 pub fn stream_id(&self) -> StreamId {
161 self.stream_id
162 }
163
164 pub fn is_end_headers(&self) -> bool {
165 self.flags.is_end_headers()
166 }
167
168 pub fn set_end_headers(&mut self) {
169 self.flags.set_end_headers();
170 }
171
172 pub fn is_end_stream(&self) -> bool {
173 self.flags.is_end_stream()
174 }
175
176 pub fn set_end_stream(&mut self) {
177 self.flags.set_end_stream();
178 }
179
180 pub fn into_parts(self) -> (PseudoHeaders, HeaderMap) {
181 (self.header_block.pseudo, self.header_block.fields)
182 }
183
184 pub fn fields(&self) -> &HeaderMap {
185 &self.header_block.fields
186 }
187
188 pub fn pseudo(&self) -> &PseudoHeaders {
189 &self.header_block.pseudo
190 }
191
192 pub fn into_fields(self) -> HeaderMap {
193 self.header_block.fields
194 }
195
196 pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut, max_size: usize) {
197 debug_assert!(self.flags.is_end_headers());
199
200 let head = self.head();
202
203 self.header_block.encode(encoder, head, dst, max_size);
204 }
205
206 fn head(&self) -> Head {
207 Head::new(Kind::Headers, self.flags.into(), self.stream_id)
208 }
209}
210
211impl From<Headers> for Frame {
212 fn from(src: Headers) -> Self {
213 Frame::Headers(src)
214 }
215}
216
217impl fmt::Debug for Headers {
218 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219 let mut builder = f.debug_struct("Headers");
220 builder
221 .field("stream_id", &self.stream_id)
222 .field("flags", &self.flags)
223 .field("pseudo", &self.header_block.pseudo);
224
225 if let Some(ref protocol) = self.header_block.pseudo.protocol {
226 builder.field("protocol", protocol);
227 }
228
229 builder.finish()
231 }
232}
233
234impl PseudoHeaders {
237 pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
238 let parts = uri::Parts::from(uri);
239
240 let mut path = parts
241 .path_and_query
242 .map_or(ByteString::from_static(""), |v| {
243 ByteString::from(v.as_str())
244 });
245
246 match method {
247 Method::OPTIONS | Method::CONNECT => {}
248 _ if path.is_empty() => {
249 path = ByteString::from_static("/");
250 }
251 _ => {}
252 }
253
254 let mut pseudo = PseudoHeaders {
255 method: Some(method),
256 scheme: None,
257 authority: None,
258 path: Some(path).filter(|p| !p.is_empty()),
259 protocol,
260 status: None,
261 };
262
263 if let Some(ref scheme) = parts.scheme {
267 pseudo.set_scheme(scheme);
268 }
269
270 if let Some(authority) = parts.authority {
273 pseudo.set_authority(ByteString::from(authority.as_str()));
274 }
275
276 pseudo
277 }
278
279 pub fn response(status: StatusCode) -> Self {
280 PseudoHeaders {
281 method: None,
282 scheme: None,
283 authority: None,
284 path: None,
285 protocol: None,
286 status: Some(status),
287 }
288 }
289
290 pub fn set_status(&mut self, value: StatusCode) {
291 self.status = Some(value);
292 }
293
294 pub fn set_scheme(&mut self, scheme: &uri::Scheme) {
295 self.scheme = Some(match scheme.as_str() {
296 "http" => ByteString::from_static("http"),
297 "https" => ByteString::from_static("https"),
298 s => ByteString::from(s),
299 });
300 }
301
302 pub fn set_protocol(&mut self, protocol: Protocol) {
303 self.protocol = Some(protocol);
304 }
305
306 pub fn set_authority(&mut self, authority: ByteString) {
307 self.authority = Some(authority);
308 }
309}
310
311impl Iterator for Iter<'_> {
314 type Item = hpack::Header<Option<HeaderName>>;
315
316 fn next(&mut self) -> Option<Self::Item> {
317 use crate::hpack::Header;
318
319 if let Some(ref mut pseudo) = self.pseudo {
320 if let Some(method) = pseudo.method.take() {
321 return Some(Header::Method(method));
322 }
323
324 if let Some(scheme) = pseudo.scheme.take() {
325 return Some(Header::Scheme(scheme));
326 }
327
328 if let Some(authority) = pseudo.authority.take() {
329 return Some(Header::Authority(authority));
330 }
331
332 if let Some(path) = pseudo.path.take() {
333 return Some(Header::Path(path));
334 }
335
336 if let Some(protocol) = pseudo.protocol.take() {
337 return Some(Header::Protocol(protocol.into()));
338 }
339
340 if let Some(status) = pseudo.status.take() {
341 return Some(Header::Status(status));
342 }
343 }
344
345 self.pseudo = None;
346
347 self.fields.next().map(|(name, value)| Header::Field {
348 name: Some(name.clone()),
349 value: value.clone(),
350 })
351 }
352}
353
354impl HeadersFlag {
357 pub fn empty() -> HeadersFlag {
358 HeadersFlag(0)
359 }
360
361 pub fn load(bits: u8) -> HeadersFlag {
362 HeadersFlag(bits & ALL)
363 }
364
365 pub fn is_end_stream(self) -> bool {
366 self.0 & END_STREAM == END_STREAM
367 }
368
369 pub fn set_end_stream(&mut self) {
370 self.0 |= END_STREAM;
371 }
372
373 pub fn is_end_headers(self) -> bool {
374 self.0 & END_HEADERS == END_HEADERS
375 }
376
377 pub fn set_end_headers(&mut self) {
378 self.0 |= END_HEADERS;
379 }
380
381 pub fn is_padded(self) -> bool {
382 self.0 & PADDED == PADDED
383 }
384
385 pub fn is_priority(self) -> bool {
386 self.0 & PRIORITY == PRIORITY
387 }
388}
389
390impl Default for HeadersFlag {
391 fn default() -> Self {
393 HeadersFlag(END_HEADERS)
394 }
395}
396
397impl From<HeadersFlag> for u8 {
398 fn from(src: HeadersFlag) -> u8 {
399 src.0
400 }
401}
402
403impl fmt::Debug for HeadersFlag {
404 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
405 util::debug_flags(fmt, self.0)
406 .flag_if(self.is_end_headers(), "END_HEADERS")
407 .flag_if(self.is_end_stream(), "END_STREAM")
408 .flag_if(self.is_padded(), "PADDED")
409 .flag_if(self.is_priority(), "PRIORITY")
410 .finish()
411 }
412}
413
414thread_local! {
417 static HDRS_BUF: RefCell<BytesMut> = RefCell::new(BytesMut::with_capacity(1024));
418}
419
420impl HeaderBlock {
421 fn load(&mut self, src: &mut Bytes, decoder: &mut hpack::Decoder) -> Result<(), FrameError> {
422 let mut reg = !self.fields.is_empty();
423 let mut malformed = false;
424
425 macro_rules! set_pseudo {
426 ($field:ident, $val:expr) => {{
427 if reg {
428 log::trace!("load_hpack; header malformed -- pseudo not at head of block");
429 malformed = true;
430 } else if self.pseudo.$field.is_some() {
431 log::trace!("load_hpack; header malformed -- repeated pseudo");
432 malformed = true;
433 } else {
434 self.pseudo.$field = Some($val.into());
435 }
436 }};
437 }
438
439 let mut cursor = Cursor::new(src);
440
441 let res = decoder.decode(&mut cursor, |header| {
446 use crate::hpack::Header;
447
448 match header {
449 Header::Field { name, value } => {
450 if name == header::CONNECTION
454 || name == header::TRANSFER_ENCODING
455 || name == header::UPGRADE
456 || name == "keep-alive"
457 || name == "proxy-connection"
458 {
459 log::trace!("load_hpack; connection level header");
460 malformed = true;
461 } else if name == header::TE && value != "trailers" {
462 log::trace!("load_hpack; TE header not set to trailers; val={value:?}");
463 malformed = true;
464 } else {
465 reg = true;
466 self.fields.append(name, value);
467 }
468 }
469 Header::Authority(v) => {
470 set_pseudo!(authority, v);
471 }
472 Header::Method(v) => {
473 set_pseudo!(method, v);
474 }
475 Header::Scheme(v) => {
476 set_pseudo!(scheme, v);
477 }
478 Header::Path(v) => {
479 set_pseudo!(path, v);
480 }
481 Header::Protocol(v) => {
482 set_pseudo!(protocol, v);
483 }
484 Header::Status(v) => {
485 set_pseudo!(status, v);
486 }
487 }
488 });
489
490 if let Err(e) = res {
491 log::trace!("hpack decoding error; err={e:?}");
492 return Err(e.into());
493 }
494
495 if malformed {
496 log::trace!("malformed message");
497 return Err(FrameError::MalformedMessage);
498 }
499
500 Ok(())
501 }
502
503 fn encode(self, encoder: &mut hpack::Encoder, head: Head, dst: &mut BytesMut, max_size: usize) {
504 HDRS_BUF.with(|buf| {
505 let mut b = buf.borrow_mut();
506 let hpack = &mut b;
507 hpack.clear();
508
509 let headers = Iter {
511 pseudo: Some(self.pseudo),
512 fields: self.fields.into_iter(),
513 };
514 encoder.encode(headers, hpack);
515
516 let mut head = head;
517 let mut start = 0;
518 loop {
519 let end = cmp::min(start + max_size, hpack.len());
520
521 if hpack.len() > end {
523 Head::new(head.kind(), head.flag() ^ END_HEADERS, head.stream_id())
524 .encode(max_size, dst);
525 dst.extend_from_slice(&hpack[start..end]);
526 head = Head::new(Kind::Continuation, END_HEADERS, head.stream_id());
527 start = end;
528 } else {
529 head.encode(end - start, dst);
530 dst.extend_from_slice(&hpack[start..end]);
531 break;
532 }
533 }
534 });
535 }
536}
537
538#[cfg(test)]
539mod test {
540 use ntex_http::HeaderValue;
541
542 use super::*;
543 use crate::hpack::{Encoder, huffman};
544
545 #[test]
546 fn test_nameless_header_at_resume() {
547 let mut encoder = Encoder::default();
548 let mut dst = BytesMut::new();
549
550 let mut hdrs = HeaderMap::default();
551 hdrs.append(
552 HeaderName::from_static("hello"),
553 HeaderValue::from_static("world"),
554 );
555 hdrs.append(
556 HeaderName::from_static("hello"),
557 HeaderValue::from_static("zomg"),
558 );
559 hdrs.append(
560 HeaderName::from_static("hello"),
561 HeaderValue::from_static("sup"),
562 );
563
564 let mut headers = Headers::new(StreamId::CON, Default::default(), hdrs, false);
565 headers.set_end_headers();
566 headers.encode(&mut encoder, &mut dst, 8);
567 assert_eq!(48, dst.len());
568 assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
569 assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
570 assert_eq!("hello", huff_decode(&dst[11..15]));
571 assert_eq!(0x80 | 4, dst[15]);
572
573 let mut world = BytesMut::from(&dst[16..17]);
574 world.extend_from_slice(&dst[26..29]);
575 assert_eq!([0, 0, 8, 9, 0, 0, 0, 0, 0], &dst[17..26]);
578
579 }
585
586 fn huff_decode(src: &[u8]) -> Bytes {
587 let mut buf = BytesMut::new();
588 huffman::decode(src, &mut buf).unwrap()
589 }
590}