rama_hyper/proto/h1/
encode.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::io::IoSlice;
4
5use bytes::buf::{Chain, Take};
6use bytes::{Buf, Bytes};
7use http::{
8    header::{
9        AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
10        CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
11    },
12    HeaderMap, HeaderName, HeaderValue,
13};
14
15use super::io::WriteBuf;
16use super::role::{write_headers, write_headers_title_case};
17
18type StaticBuf = &'static [u8];
19
20/// Encoders to handle different Transfer-Encodings.
21#[derive(Debug, Clone, PartialEq)]
22pub(crate) struct Encoder {
23    kind: Kind,
24    is_last: bool,
25}
26
27#[derive(Debug)]
28pub(crate) struct EncodedBuf<B> {
29    kind: BufKind<B>,
30}
31
32#[derive(Debug)]
33pub(crate) struct NotEof(u64);
34
35#[derive(Debug, PartialEq, Clone)]
36enum Kind {
37    /// An Encoder for when Transfer-Encoding includes `chunked`.
38    Chunked(Option<Vec<HeaderValue>>),
39    /// An Encoder for when Content-Length is set.
40    ///
41    /// Enforces that the body is not longer than the Content-Length header.
42    Length(u64),
43    /// An Encoder for when neither Content-Length nor Chunked encoding is set.
44    ///
45    /// This is mostly only used with HTTP/1.0 with a length. This kind requires
46    /// the connection to be closed when the body is finished.
47    #[cfg(feature = "server")]
48    CloseDelimited,
49}
50
51#[derive(Debug)]
52enum BufKind<B> {
53    Exact(B),
54    Limited(Take<B>),
55    Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
56    ChunkedEnd(StaticBuf),
57    Trailers(Chain<Chain<StaticBuf, Bytes>, StaticBuf>),
58}
59
60impl Encoder {
61    fn new(kind: Kind) -> Encoder {
62        Encoder {
63            kind,
64            is_last: false,
65        }
66    }
67    pub(crate) fn chunked() -> Encoder {
68        Encoder::new(Kind::Chunked(None))
69    }
70
71    pub(crate) fn length(len: u64) -> Encoder {
72        Encoder::new(Kind::Length(len))
73    }
74
75    #[cfg(feature = "server")]
76    pub(crate) fn close_delimited() -> Encoder {
77        Encoder::new(Kind::CloseDelimited)
78    }
79
80    pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec<HeaderValue>) -> Encoder {
81        match self.kind {
82            Kind::Chunked(_) => Encoder {
83                kind: Kind::Chunked(Some(trailers)),
84                is_last: self.is_last,
85            },
86            _ => self,
87        }
88    }
89
90    pub(crate) fn is_eof(&self) -> bool {
91        matches!(self.kind, Kind::Length(0))
92    }
93
94    #[cfg(feature = "server")]
95    pub(crate) fn set_last(mut self, is_last: bool) -> Self {
96        self.is_last = is_last;
97        self
98    }
99
100    pub(crate) fn is_last(&self) -> bool {
101        self.is_last
102    }
103
104    pub(crate) fn is_close_delimited(&self) -> bool {
105        match self.kind {
106            #[cfg(feature = "server")]
107            Kind::CloseDelimited => true,
108            _ => false,
109        }
110    }
111
112    pub(crate) fn is_chunked(&self) -> bool {
113        match self.kind {
114            Kind::Chunked(_) => true,
115            _ => false,
116        }
117    }
118
119    pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
120        match self.kind {
121            Kind::Length(0) => Ok(None),
122            Kind::Chunked(_) => Ok(Some(EncodedBuf {
123                kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
124            })),
125            #[cfg(feature = "server")]
126            Kind::CloseDelimited => Ok(None),
127            Kind::Length(n) => Err(NotEof(n)),
128        }
129    }
130
131    pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B>
132    where
133        B: Buf,
134    {
135        let len = msg.remaining();
136        debug_assert!(len > 0, "encode() called with empty buf");
137
138        let kind = match self.kind {
139            Kind::Chunked(_) => {
140                trace!("encoding chunked {}B", len);
141                let buf = ChunkSize::new(len)
142                    .chain(msg)
143                    .chain(b"\r\n" as &'static [u8]);
144                BufKind::Chunked(buf)
145            }
146            Kind::Length(ref mut remaining) => {
147                trace!("sized write, len = {}", len);
148                if len as u64 > *remaining {
149                    let limit = *remaining as usize;
150                    *remaining = 0;
151                    BufKind::Limited(msg.take(limit))
152                } else {
153                    *remaining -= len as u64;
154                    BufKind::Exact(msg)
155                }
156            }
157            #[cfg(feature = "server")]
158            Kind::CloseDelimited => {
159                trace!("close delimited write {}B", len);
160                BufKind::Exact(msg)
161            }
162        };
163        EncodedBuf { kind }
164    }
165
166    pub(crate) fn encode_trailers<B>(
167        &self,
168        trailers: HeaderMap,
169        title_case_headers: bool,
170    ) -> Option<EncodedBuf<B>> {
171        match &self.kind {
172            Kind::Chunked(Some(ref allowed_trailer_fields)) => {
173                let allowed_trailer_field_map = allowed_trailer_field_map(&allowed_trailer_fields);
174
175                let mut cur_name = None;
176                let mut allowed_trailers = HeaderMap::new();
177
178                for (opt_name, value) in trailers {
179                    if let Some(n) = opt_name {
180                        cur_name = Some(n);
181                    }
182                    let name = cur_name.as_ref().expect("current header name");
183
184                    if allowed_trailer_field_map.contains_key(name.as_str())
185                        && valid_trailer_field(name)
186                    {
187                        allowed_trailers.insert(name, value);
188                    }
189                }
190
191                let mut buf = Vec::new();
192                if title_case_headers {
193                    write_headers_title_case(&allowed_trailers, &mut buf);
194                } else {
195                    write_headers(&allowed_trailers, &mut buf);
196                }
197
198                if buf.is_empty() {
199                    return None;
200                }
201
202                Some(EncodedBuf {
203                    kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
204                })
205            }
206            _ => {
207                debug!("attempted to encode trailers for non-chunked response");
208                None
209            }
210        }
211    }
212
213    pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
214    where
215        B: Buf,
216    {
217        let len = msg.remaining();
218        debug_assert!(len > 0, "encode() called with empty buf");
219
220        match self.kind {
221            Kind::Chunked(_) => {
222                trace!("encoding chunked {}B", len);
223                let buf = ChunkSize::new(len)
224                    .chain(msg)
225                    .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
226                dst.buffer(buf);
227                !self.is_last
228            }
229            Kind::Length(remaining) => {
230                use std::cmp::Ordering;
231
232                trace!("sized write, len = {}", len);
233                match (len as u64).cmp(&remaining) {
234                    Ordering::Equal => {
235                        dst.buffer(msg);
236                        !self.is_last
237                    }
238                    Ordering::Greater => {
239                        dst.buffer(msg.take(remaining as usize));
240                        !self.is_last
241                    }
242                    Ordering::Less => {
243                        dst.buffer(msg);
244                        false
245                    }
246                }
247            }
248            #[cfg(feature = "server")]
249            Kind::CloseDelimited => {
250                trace!("close delimited write {}B", len);
251                dst.buffer(msg);
252                false
253            }
254        }
255    }
256}
257
258fn valid_trailer_field(name: &HeaderName) -> bool {
259    match name {
260        &AUTHORIZATION => false,
261        &CACHE_CONTROL => false,
262        &CONTENT_ENCODING => false,
263        &CONTENT_LENGTH => false,
264        &CONTENT_RANGE => false,
265        &CONTENT_TYPE => false,
266        &HOST => false,
267        &MAX_FORWARDS => false,
268        &SET_COOKIE => false,
269        &TRAILER => false,
270        &TRANSFER_ENCODING => false,
271        &TE => false,
272        _ => true,
273    }
274}
275
276fn allowed_trailer_field_map(allowed_trailer_fields: &Vec<HeaderValue>) -> HashMap<String, ()> {
277    let mut trailer_map = HashMap::new();
278
279    for header_value in allowed_trailer_fields {
280        if let Ok(header_str) = header_value.to_str() {
281            let items: Vec<&str> = header_str.split(',').map(|item| item.trim()).collect();
282
283            for item in items {
284                trailer_map.entry(item.to_string()).or_insert(());
285            }
286        }
287    }
288
289    trailer_map
290}
291
292impl<B> Buf for EncodedBuf<B>
293where
294    B: Buf,
295{
296    #[inline]
297    fn remaining(&self) -> usize {
298        match self.kind {
299            BufKind::Exact(ref b) => b.remaining(),
300            BufKind::Limited(ref b) => b.remaining(),
301            BufKind::Chunked(ref b) => b.remaining(),
302            BufKind::ChunkedEnd(ref b) => b.remaining(),
303            BufKind::Trailers(ref b) => b.remaining(),
304        }
305    }
306
307    #[inline]
308    fn chunk(&self) -> &[u8] {
309        match self.kind {
310            BufKind::Exact(ref b) => b.chunk(),
311            BufKind::Limited(ref b) => b.chunk(),
312            BufKind::Chunked(ref b) => b.chunk(),
313            BufKind::ChunkedEnd(ref b) => b.chunk(),
314            BufKind::Trailers(ref b) => b.chunk(),
315        }
316    }
317
318    #[inline]
319    fn advance(&mut self, cnt: usize) {
320        match self.kind {
321            BufKind::Exact(ref mut b) => b.advance(cnt),
322            BufKind::Limited(ref mut b) => b.advance(cnt),
323            BufKind::Chunked(ref mut b) => b.advance(cnt),
324            BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
325            BufKind::Trailers(ref mut b) => b.advance(cnt),
326        }
327    }
328
329    #[inline]
330    fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
331        match self.kind {
332            BufKind::Exact(ref b) => b.chunks_vectored(dst),
333            BufKind::Limited(ref b) => b.chunks_vectored(dst),
334            BufKind::Chunked(ref b) => b.chunks_vectored(dst),
335            BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
336            BufKind::Trailers(ref b) => b.chunks_vectored(dst),
337        }
338    }
339}
340
341#[cfg(target_pointer_width = "32")]
342const USIZE_BYTES: usize = 4;
343
344#[cfg(target_pointer_width = "64")]
345const USIZE_BYTES: usize = 8;
346
347// each byte will become 2 hex
348const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2;
349
350#[derive(Clone, Copy)]
351struct ChunkSize {
352    bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2],
353    pos: u8,
354    len: u8,
355}
356
357impl ChunkSize {
358    fn new(len: usize) -> ChunkSize {
359        use std::fmt::Write;
360        let mut size = ChunkSize {
361            bytes: [0; CHUNK_SIZE_MAX_BYTES + 2],
362            pos: 0,
363            len: 0,
364        };
365        write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize");
366        size
367    }
368}
369
370impl Buf for ChunkSize {
371    #[inline]
372    fn remaining(&self) -> usize {
373        (self.len - self.pos).into()
374    }
375
376    #[inline]
377    fn chunk(&self) -> &[u8] {
378        &self.bytes[self.pos.into()..self.len.into()]
379    }
380
381    #[inline]
382    fn advance(&mut self, cnt: usize) {
383        assert!(cnt <= self.remaining());
384        self.pos += cnt as u8; // just asserted cnt fits in u8
385    }
386}
387
388impl fmt::Debug for ChunkSize {
389    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
390        f.debug_struct("ChunkSize")
391            .field("bytes", &&self.bytes[..self.len.into()])
392            .field("pos", &self.pos)
393            .finish()
394    }
395}
396
397impl fmt::Write for ChunkSize {
398    fn write_str(&mut self, num: &str) -> fmt::Result {
399        use std::io::Write;
400        (&mut self.bytes[self.len.into()..])
401            .write_all(num.as_bytes())
402            .expect("&mut [u8].write() cannot error");
403        self.len += num.len() as u8; // safe because bytes is never bigger than 256
404        Ok(())
405    }
406}
407
408impl<B: Buf> From<B> for EncodedBuf<B> {
409    fn from(buf: B) -> Self {
410        EncodedBuf {
411            kind: BufKind::Exact(buf),
412        }
413    }
414}
415
416impl<B: Buf> From<Take<B>> for EncodedBuf<B> {
417    fn from(buf: Take<B>) -> Self {
418        EncodedBuf {
419            kind: BufKind::Limited(buf),
420        }
421    }
422}
423
424impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> {
425    fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self {
426        EncodedBuf {
427            kind: BufKind::Chunked(buf),
428        }
429    }
430}
431
432impl fmt::Display for NotEof {
433    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
434        write!(f, "early end, expected {} more bytes", self.0)
435    }
436}
437
438impl std::error::Error for NotEof {}
439
440#[cfg(test)]
441mod tests {
442    use std::iter::FromIterator;
443
444    use bytes::BufMut;
445    use http::{
446        header::{
447            AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
448            CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
449        },
450        HeaderMap, HeaderName, HeaderValue,
451    };
452
453    use super::super::io::Cursor;
454    use super::Encoder;
455
456    #[test]
457    fn chunked() {
458        let mut encoder = Encoder::chunked();
459        let mut dst = Vec::new();
460
461        let msg1 = b"foo bar".as_ref();
462        let buf1 = encoder.encode(msg1);
463        dst.put(buf1);
464        assert_eq!(dst, b"7\r\nfoo bar\r\n");
465
466        let msg2 = b"baz quux herp".as_ref();
467        let buf2 = encoder.encode(msg2);
468        dst.put(buf2);
469
470        assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
471
472        let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap();
473        dst.put(end);
474
475        assert_eq!(
476            dst,
477            b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref()
478        );
479    }
480
481    #[test]
482    fn length() {
483        let max_len = 8;
484        let mut encoder = Encoder::length(max_len as u64);
485        let mut dst = Vec::new();
486
487        let msg1 = b"foo bar".as_ref();
488        let buf1 = encoder.encode(msg1);
489        dst.put(buf1);
490
491        assert_eq!(dst, b"foo bar");
492        assert!(!encoder.is_eof());
493        encoder.end::<()>().unwrap_err();
494
495        let msg2 = b"baz".as_ref();
496        let buf2 = encoder.encode(msg2);
497        dst.put(buf2);
498
499        assert_eq!(dst.len(), max_len);
500        assert_eq!(dst, b"foo barb");
501        assert!(encoder.is_eof());
502        assert!(encoder.end::<()>().unwrap().is_none());
503    }
504
505    #[test]
506    fn eof() {
507        let mut encoder = Encoder::close_delimited();
508        let mut dst = Vec::new();
509
510        let msg1 = b"foo bar".as_ref();
511        let buf1 = encoder.encode(msg1);
512        dst.put(buf1);
513
514        assert_eq!(dst, b"foo bar");
515        assert!(!encoder.is_eof());
516        encoder.end::<()>().unwrap();
517
518        let msg2 = b"baz".as_ref();
519        let buf2 = encoder.encode(msg2);
520        dst.put(buf2);
521
522        assert_eq!(dst, b"foo barbaz");
523        assert!(!encoder.is_eof());
524        encoder.end::<()>().unwrap();
525    }
526
527    #[test]
528    fn chunked_with_valid_trailers() {
529        let encoder = Encoder::chunked();
530        let trailers = vec![HeaderValue::from_static("chunky-trailer")];
531        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
532
533        let headers = HeaderMap::from_iter(
534            vec![
535                (
536                    HeaderName::from_static("chunky-trailer"),
537                    HeaderValue::from_static("header data"),
538                ),
539                (
540                    HeaderName::from_static("should-not-be-included"),
541                    HeaderValue::from_static("oops"),
542                ),
543            ]
544            .into_iter(),
545        );
546
547        let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
548
549        let mut dst = Vec::new();
550        dst.put(buf1);
551        assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n");
552    }
553
554    #[test]
555    fn chunked_with_multiple_trailer_headers() {
556        let encoder = Encoder::chunked();
557        let trailers = vec![
558            HeaderValue::from_static("chunky-trailer"),
559            HeaderValue::from_static("chunky-trailer-2"),
560        ];
561        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
562
563        let headers = HeaderMap::from_iter(
564            vec![
565                (
566                    HeaderName::from_static("chunky-trailer"),
567                    HeaderValue::from_static("header data"),
568                ),
569                (
570                    HeaderName::from_static("chunky-trailer-2"),
571                    HeaderValue::from_static("more header data"),
572                ),
573            ]
574            .into_iter(),
575        );
576
577        let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
578
579        let mut dst = Vec::new();
580        dst.put(buf1);
581        assert_eq!(
582            dst,
583            b"0\r\nchunky-trailer: header data\r\nchunky-trailer-2: more header data\r\n\r\n"
584        );
585    }
586
587    #[test]
588    fn chunked_with_no_trailer_header() {
589        let encoder = Encoder::chunked();
590
591        let headers = HeaderMap::from_iter(
592            vec![(
593                HeaderName::from_static("chunky-trailer"),
594                HeaderValue::from_static("header data"),
595            )]
596            .into_iter(),
597        );
598
599        assert!(encoder
600            .encode_trailers::<&[u8]>(headers.clone(), false)
601            .is_none());
602
603        let trailers = vec![];
604        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
605
606        assert!(encoder.encode_trailers::<&[u8]>(headers, false).is_none());
607    }
608
609    #[test]
610    fn chunked_with_invalid_trailers() {
611        let encoder = Encoder::chunked();
612
613        let trailers = format!(
614            "{},{},{},{},{},{},{},{},{},{},{},{}",
615            AUTHORIZATION,
616            CACHE_CONTROL,
617            CONTENT_ENCODING,
618            CONTENT_LENGTH,
619            CONTENT_RANGE,
620            CONTENT_TYPE,
621            HOST,
622            MAX_FORWARDS,
623            SET_COOKIE,
624            TRAILER,
625            TRANSFER_ENCODING,
626            TE,
627        );
628        let trailers = vec![HeaderValue::from_str(&trailers).unwrap()];
629        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
630
631        let mut headers = HeaderMap::new();
632        headers.insert(AUTHORIZATION, HeaderValue::from_static("header data"));
633        headers.insert(CACHE_CONTROL, HeaderValue::from_static("header data"));
634        headers.insert(CONTENT_ENCODING, HeaderValue::from_static("header data"));
635        headers.insert(CONTENT_LENGTH, HeaderValue::from_static("header data"));
636        headers.insert(CONTENT_RANGE, HeaderValue::from_static("header data"));
637        headers.insert(CONTENT_TYPE, HeaderValue::from_static("header data"));
638        headers.insert(HOST, HeaderValue::from_static("header data"));
639        headers.insert(MAX_FORWARDS, HeaderValue::from_static("header data"));
640        headers.insert(SET_COOKIE, HeaderValue::from_static("header data"));
641        headers.insert(TRAILER, HeaderValue::from_static("header data"));
642        headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("header data"));
643        headers.insert(TE, HeaderValue::from_static("header data"));
644
645        assert!(encoder.encode_trailers::<&[u8]>(headers, true).is_none());
646    }
647
648    #[test]
649    fn chunked_with_title_case_headers() {
650        let encoder = Encoder::chunked();
651        let trailers = vec![HeaderValue::from_static("chunky-trailer")];
652        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
653
654        let headers = HeaderMap::from_iter(
655            vec![(
656                HeaderName::from_static("chunky-trailer"),
657                HeaderValue::from_static("header data"),
658            )]
659            .into_iter(),
660        );
661        let buf1 = encoder.encode_trailers::<&[u8]>(headers, true).unwrap();
662
663        let mut dst = Vec::new();
664        dst.put(buf1);
665        assert_eq!(dst, b"0\r\nChunky-Trailer: header data\r\n\r\n");
666    }
667}