1use std::collections::HashMap;
2use std::fmt;
3use std::io::IoSlice;
4
5use bytes::buf::{Chain, Take};
6use bytes::{Buf, Bytes};
7use http::{
8 header::{
9 AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
10 CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
11 },
12 HeaderMap, HeaderName, HeaderValue,
13};
14
15use super::io::WriteBuf;
16use super::role::{write_headers, write_headers_title_case};
17
18type StaticBuf = &'static [u8];
19
20#[derive(Debug, Clone, PartialEq)]
22pub(crate) struct Encoder {
23 kind: Kind,
24 is_last: bool,
25}
26
27#[derive(Debug)]
28pub(crate) struct EncodedBuf<B> {
29 kind: BufKind<B>,
30}
31
32#[derive(Debug)]
33pub(crate) struct NotEof(u64);
34
35#[derive(Debug, PartialEq, Clone)]
36enum Kind {
37 Chunked(Option<Vec<HeaderValue>>),
39 Length(u64),
43 #[cfg(feature = "server")]
48 CloseDelimited,
49}
50
51#[derive(Debug)]
52enum BufKind<B> {
53 Exact(B),
54 Limited(Take<B>),
55 Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
56 ChunkedEnd(StaticBuf),
57 Trailers(Chain<Chain<StaticBuf, Bytes>, StaticBuf>),
58}
59
60impl Encoder {
61 fn new(kind: Kind) -> Encoder {
62 Encoder {
63 kind,
64 is_last: false,
65 }
66 }
67 pub(crate) fn chunked() -> Encoder {
68 Encoder::new(Kind::Chunked(None))
69 }
70
71 pub(crate) fn length(len: u64) -> Encoder {
72 Encoder::new(Kind::Length(len))
73 }
74
75 #[cfg(feature = "server")]
76 pub(crate) fn close_delimited() -> Encoder {
77 Encoder::new(Kind::CloseDelimited)
78 }
79
80 pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec<HeaderValue>) -> Encoder {
81 match self.kind {
82 Kind::Chunked(_) => Encoder {
83 kind: Kind::Chunked(Some(trailers)),
84 is_last: self.is_last,
85 },
86 _ => self,
87 }
88 }
89
90 pub(crate) fn is_eof(&self) -> bool {
91 matches!(self.kind, Kind::Length(0))
92 }
93
94 #[cfg(feature = "server")]
95 pub(crate) fn set_last(mut self, is_last: bool) -> Self {
96 self.is_last = is_last;
97 self
98 }
99
100 pub(crate) fn is_last(&self) -> bool {
101 self.is_last
102 }
103
104 pub(crate) fn is_close_delimited(&self) -> bool {
105 match self.kind {
106 #[cfg(feature = "server")]
107 Kind::CloseDelimited => true,
108 _ => false,
109 }
110 }
111
112 pub(crate) fn is_chunked(&self) -> bool {
113 match self.kind {
114 Kind::Chunked(_) => true,
115 _ => false,
116 }
117 }
118
119 pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
120 match self.kind {
121 Kind::Length(0) => Ok(None),
122 Kind::Chunked(_) => Ok(Some(EncodedBuf {
123 kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
124 })),
125 #[cfg(feature = "server")]
126 Kind::CloseDelimited => Ok(None),
127 Kind::Length(n) => Err(NotEof(n)),
128 }
129 }
130
131 pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B>
132 where
133 B: Buf,
134 {
135 let len = msg.remaining();
136 debug_assert!(len > 0, "encode() called with empty buf");
137
138 let kind = match self.kind {
139 Kind::Chunked(_) => {
140 trace!("encoding chunked {}B", len);
141 let buf = ChunkSize::new(len)
142 .chain(msg)
143 .chain(b"\r\n" as &'static [u8]);
144 BufKind::Chunked(buf)
145 }
146 Kind::Length(ref mut remaining) => {
147 trace!("sized write, len = {}", len);
148 if len as u64 > *remaining {
149 let limit = *remaining as usize;
150 *remaining = 0;
151 BufKind::Limited(msg.take(limit))
152 } else {
153 *remaining -= len as u64;
154 BufKind::Exact(msg)
155 }
156 }
157 #[cfg(feature = "server")]
158 Kind::CloseDelimited => {
159 trace!("close delimited write {}B", len);
160 BufKind::Exact(msg)
161 }
162 };
163 EncodedBuf { kind }
164 }
165
166 pub(crate) fn encode_trailers<B>(
167 &self,
168 trailers: HeaderMap,
169 title_case_headers: bool,
170 ) -> Option<EncodedBuf<B>> {
171 match &self.kind {
172 Kind::Chunked(Some(ref allowed_trailer_fields)) => {
173 let allowed_trailer_field_map = allowed_trailer_field_map(&allowed_trailer_fields);
174
175 let mut cur_name = None;
176 let mut allowed_trailers = HeaderMap::new();
177
178 for (opt_name, value) in trailers {
179 if let Some(n) = opt_name {
180 cur_name = Some(n);
181 }
182 let name = cur_name.as_ref().expect("current header name");
183
184 if allowed_trailer_field_map.contains_key(name.as_str())
185 && valid_trailer_field(name)
186 {
187 allowed_trailers.insert(name, value);
188 }
189 }
190
191 let mut buf = Vec::new();
192 if title_case_headers {
193 write_headers_title_case(&allowed_trailers, &mut buf);
194 } else {
195 write_headers(&allowed_trailers, &mut buf);
196 }
197
198 if buf.is_empty() {
199 return None;
200 }
201
202 Some(EncodedBuf {
203 kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
204 })
205 }
206 _ => {
207 debug!("attempted to encode trailers for non-chunked response");
208 None
209 }
210 }
211 }
212
213 pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
214 where
215 B: Buf,
216 {
217 let len = msg.remaining();
218 debug_assert!(len > 0, "encode() called with empty buf");
219
220 match self.kind {
221 Kind::Chunked(_) => {
222 trace!("encoding chunked {}B", len);
223 let buf = ChunkSize::new(len)
224 .chain(msg)
225 .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
226 dst.buffer(buf);
227 !self.is_last
228 }
229 Kind::Length(remaining) => {
230 use std::cmp::Ordering;
231
232 trace!("sized write, len = {}", len);
233 match (len as u64).cmp(&remaining) {
234 Ordering::Equal => {
235 dst.buffer(msg);
236 !self.is_last
237 }
238 Ordering::Greater => {
239 dst.buffer(msg.take(remaining as usize));
240 !self.is_last
241 }
242 Ordering::Less => {
243 dst.buffer(msg);
244 false
245 }
246 }
247 }
248 #[cfg(feature = "server")]
249 Kind::CloseDelimited => {
250 trace!("close delimited write {}B", len);
251 dst.buffer(msg);
252 false
253 }
254 }
255 }
256}
257
258fn valid_trailer_field(name: &HeaderName) -> bool {
259 match name {
260 &AUTHORIZATION => false,
261 &CACHE_CONTROL => false,
262 &CONTENT_ENCODING => false,
263 &CONTENT_LENGTH => false,
264 &CONTENT_RANGE => false,
265 &CONTENT_TYPE => false,
266 &HOST => false,
267 &MAX_FORWARDS => false,
268 &SET_COOKIE => false,
269 &TRAILER => false,
270 &TRANSFER_ENCODING => false,
271 &TE => false,
272 _ => true,
273 }
274}
275
276fn allowed_trailer_field_map(allowed_trailer_fields: &Vec<HeaderValue>) -> HashMap<String, ()> {
277 let mut trailer_map = HashMap::new();
278
279 for header_value in allowed_trailer_fields {
280 if let Ok(header_str) = header_value.to_str() {
281 let items: Vec<&str> = header_str.split(',').map(|item| item.trim()).collect();
282
283 for item in items {
284 trailer_map.entry(item.to_string()).or_insert(());
285 }
286 }
287 }
288
289 trailer_map
290}
291
292impl<B> Buf for EncodedBuf<B>
293where
294 B: Buf,
295{
296 #[inline]
297 fn remaining(&self) -> usize {
298 match self.kind {
299 BufKind::Exact(ref b) => b.remaining(),
300 BufKind::Limited(ref b) => b.remaining(),
301 BufKind::Chunked(ref b) => b.remaining(),
302 BufKind::ChunkedEnd(ref b) => b.remaining(),
303 BufKind::Trailers(ref b) => b.remaining(),
304 }
305 }
306
307 #[inline]
308 fn chunk(&self) -> &[u8] {
309 match self.kind {
310 BufKind::Exact(ref b) => b.chunk(),
311 BufKind::Limited(ref b) => b.chunk(),
312 BufKind::Chunked(ref b) => b.chunk(),
313 BufKind::ChunkedEnd(ref b) => b.chunk(),
314 BufKind::Trailers(ref b) => b.chunk(),
315 }
316 }
317
318 #[inline]
319 fn advance(&mut self, cnt: usize) {
320 match self.kind {
321 BufKind::Exact(ref mut b) => b.advance(cnt),
322 BufKind::Limited(ref mut b) => b.advance(cnt),
323 BufKind::Chunked(ref mut b) => b.advance(cnt),
324 BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
325 BufKind::Trailers(ref mut b) => b.advance(cnt),
326 }
327 }
328
329 #[inline]
330 fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
331 match self.kind {
332 BufKind::Exact(ref b) => b.chunks_vectored(dst),
333 BufKind::Limited(ref b) => b.chunks_vectored(dst),
334 BufKind::Chunked(ref b) => b.chunks_vectored(dst),
335 BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
336 BufKind::Trailers(ref b) => b.chunks_vectored(dst),
337 }
338 }
339}
340
341#[cfg(target_pointer_width = "32")]
342const USIZE_BYTES: usize = 4;
343
344#[cfg(target_pointer_width = "64")]
345const USIZE_BYTES: usize = 8;
346
347const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2;
349
350#[derive(Clone, Copy)]
351struct ChunkSize {
352 bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2],
353 pos: u8,
354 len: u8,
355}
356
357impl ChunkSize {
358 fn new(len: usize) -> ChunkSize {
359 use std::fmt::Write;
360 let mut size = ChunkSize {
361 bytes: [0; CHUNK_SIZE_MAX_BYTES + 2],
362 pos: 0,
363 len: 0,
364 };
365 write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize");
366 size
367 }
368}
369
370impl Buf for ChunkSize {
371 #[inline]
372 fn remaining(&self) -> usize {
373 (self.len - self.pos).into()
374 }
375
376 #[inline]
377 fn chunk(&self) -> &[u8] {
378 &self.bytes[self.pos.into()..self.len.into()]
379 }
380
381 #[inline]
382 fn advance(&mut self, cnt: usize) {
383 assert!(cnt <= self.remaining());
384 self.pos += cnt as u8; }
386}
387
388impl fmt::Debug for ChunkSize {
389 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
390 f.debug_struct("ChunkSize")
391 .field("bytes", &&self.bytes[..self.len.into()])
392 .field("pos", &self.pos)
393 .finish()
394 }
395}
396
397impl fmt::Write for ChunkSize {
398 fn write_str(&mut self, num: &str) -> fmt::Result {
399 use std::io::Write;
400 (&mut self.bytes[self.len.into()..])
401 .write_all(num.as_bytes())
402 .expect("&mut [u8].write() cannot error");
403 self.len += num.len() as u8; Ok(())
405 }
406}
407
408impl<B: Buf> From<B> for EncodedBuf<B> {
409 fn from(buf: B) -> Self {
410 EncodedBuf {
411 kind: BufKind::Exact(buf),
412 }
413 }
414}
415
416impl<B: Buf> From<Take<B>> for EncodedBuf<B> {
417 fn from(buf: Take<B>) -> Self {
418 EncodedBuf {
419 kind: BufKind::Limited(buf),
420 }
421 }
422}
423
424impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> {
425 fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self {
426 EncodedBuf {
427 kind: BufKind::Chunked(buf),
428 }
429 }
430}
431
432impl fmt::Display for NotEof {
433 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
434 write!(f, "early end, expected {} more bytes", self.0)
435 }
436}
437
438impl std::error::Error for NotEof {}
439
440#[cfg(test)]
441mod tests {
442 use std::iter::FromIterator;
443
444 use bytes::BufMut;
445 use http::{
446 header::{
447 AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
448 CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
449 },
450 HeaderMap, HeaderName, HeaderValue,
451 };
452
453 use super::super::io::Cursor;
454 use super::Encoder;
455
456 #[test]
457 fn chunked() {
458 let mut encoder = Encoder::chunked();
459 let mut dst = Vec::new();
460
461 let msg1 = b"foo bar".as_ref();
462 let buf1 = encoder.encode(msg1);
463 dst.put(buf1);
464 assert_eq!(dst, b"7\r\nfoo bar\r\n");
465
466 let msg2 = b"baz quux herp".as_ref();
467 let buf2 = encoder.encode(msg2);
468 dst.put(buf2);
469
470 assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
471
472 let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap();
473 dst.put(end);
474
475 assert_eq!(
476 dst,
477 b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref()
478 );
479 }
480
481 #[test]
482 fn length() {
483 let max_len = 8;
484 let mut encoder = Encoder::length(max_len as u64);
485 let mut dst = Vec::new();
486
487 let msg1 = b"foo bar".as_ref();
488 let buf1 = encoder.encode(msg1);
489 dst.put(buf1);
490
491 assert_eq!(dst, b"foo bar");
492 assert!(!encoder.is_eof());
493 encoder.end::<()>().unwrap_err();
494
495 let msg2 = b"baz".as_ref();
496 let buf2 = encoder.encode(msg2);
497 dst.put(buf2);
498
499 assert_eq!(dst.len(), max_len);
500 assert_eq!(dst, b"foo barb");
501 assert!(encoder.is_eof());
502 assert!(encoder.end::<()>().unwrap().is_none());
503 }
504
505 #[test]
506 fn eof() {
507 let mut encoder = Encoder::close_delimited();
508 let mut dst = Vec::new();
509
510 let msg1 = b"foo bar".as_ref();
511 let buf1 = encoder.encode(msg1);
512 dst.put(buf1);
513
514 assert_eq!(dst, b"foo bar");
515 assert!(!encoder.is_eof());
516 encoder.end::<()>().unwrap();
517
518 let msg2 = b"baz".as_ref();
519 let buf2 = encoder.encode(msg2);
520 dst.put(buf2);
521
522 assert_eq!(dst, b"foo barbaz");
523 assert!(!encoder.is_eof());
524 encoder.end::<()>().unwrap();
525 }
526
527 #[test]
528 fn chunked_with_valid_trailers() {
529 let encoder = Encoder::chunked();
530 let trailers = vec![HeaderValue::from_static("chunky-trailer")];
531 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
532
533 let headers = HeaderMap::from_iter(
534 vec![
535 (
536 HeaderName::from_static("chunky-trailer"),
537 HeaderValue::from_static("header data"),
538 ),
539 (
540 HeaderName::from_static("should-not-be-included"),
541 HeaderValue::from_static("oops"),
542 ),
543 ]
544 .into_iter(),
545 );
546
547 let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
548
549 let mut dst = Vec::new();
550 dst.put(buf1);
551 assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n");
552 }
553
554 #[test]
555 fn chunked_with_multiple_trailer_headers() {
556 let encoder = Encoder::chunked();
557 let trailers = vec![
558 HeaderValue::from_static("chunky-trailer"),
559 HeaderValue::from_static("chunky-trailer-2"),
560 ];
561 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
562
563 let headers = HeaderMap::from_iter(
564 vec![
565 (
566 HeaderName::from_static("chunky-trailer"),
567 HeaderValue::from_static("header data"),
568 ),
569 (
570 HeaderName::from_static("chunky-trailer-2"),
571 HeaderValue::from_static("more header data"),
572 ),
573 ]
574 .into_iter(),
575 );
576
577 let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
578
579 let mut dst = Vec::new();
580 dst.put(buf1);
581 assert_eq!(
582 dst,
583 b"0\r\nchunky-trailer: header data\r\nchunky-trailer-2: more header data\r\n\r\n"
584 );
585 }
586
587 #[test]
588 fn chunked_with_no_trailer_header() {
589 let encoder = Encoder::chunked();
590
591 let headers = HeaderMap::from_iter(
592 vec![(
593 HeaderName::from_static("chunky-trailer"),
594 HeaderValue::from_static("header data"),
595 )]
596 .into_iter(),
597 );
598
599 assert!(encoder
600 .encode_trailers::<&[u8]>(headers.clone(), false)
601 .is_none());
602
603 let trailers = vec![];
604 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
605
606 assert!(encoder.encode_trailers::<&[u8]>(headers, false).is_none());
607 }
608
609 #[test]
610 fn chunked_with_invalid_trailers() {
611 let encoder = Encoder::chunked();
612
613 let trailers = format!(
614 "{},{},{},{},{},{},{},{},{},{},{},{}",
615 AUTHORIZATION,
616 CACHE_CONTROL,
617 CONTENT_ENCODING,
618 CONTENT_LENGTH,
619 CONTENT_RANGE,
620 CONTENT_TYPE,
621 HOST,
622 MAX_FORWARDS,
623 SET_COOKIE,
624 TRAILER,
625 TRANSFER_ENCODING,
626 TE,
627 );
628 let trailers = vec![HeaderValue::from_str(&trailers).unwrap()];
629 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
630
631 let mut headers = HeaderMap::new();
632 headers.insert(AUTHORIZATION, HeaderValue::from_static("header data"));
633 headers.insert(CACHE_CONTROL, HeaderValue::from_static("header data"));
634 headers.insert(CONTENT_ENCODING, HeaderValue::from_static("header data"));
635 headers.insert(CONTENT_LENGTH, HeaderValue::from_static("header data"));
636 headers.insert(CONTENT_RANGE, HeaderValue::from_static("header data"));
637 headers.insert(CONTENT_TYPE, HeaderValue::from_static("header data"));
638 headers.insert(HOST, HeaderValue::from_static("header data"));
639 headers.insert(MAX_FORWARDS, HeaderValue::from_static("header data"));
640 headers.insert(SET_COOKIE, HeaderValue::from_static("header data"));
641 headers.insert(TRAILER, HeaderValue::from_static("header data"));
642 headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("header data"));
643 headers.insert(TE, HeaderValue::from_static("header data"));
644
645 assert!(encoder.encode_trailers::<&[u8]>(headers, true).is_none());
646 }
647
648 #[test]
649 fn chunked_with_title_case_headers() {
650 let encoder = Encoder::chunked();
651 let trailers = vec![HeaderValue::from_static("chunky-trailer")];
652 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
653
654 let headers = HeaderMap::from_iter(
655 vec![(
656 HeaderName::from_static("chunky-trailer"),
657 HeaderValue::from_static("header data"),
658 )]
659 .into_iter(),
660 );
661 let buf1 = encoder.encode_trailers::<&[u8]>(headers, true).unwrap();
662
663 let mut dst = Vec::new();
664 dst.put(buf1);
665 assert_eq!(dst, b"0\r\nChunky-Trailer: header data\r\n\r\n");
666 }
667}