ttpkit/
request.rs

1//! Request types.
2
3use std::{borrow::Borrow, marker::PhantomData, ops::Deref, str::Utf8Error};
4
5use bytes::{Bytes, BytesMut};
6
7#[cfg(feature = "tokio-codec")]
8use tokio_util::codec::{Decoder, Encoder};
9
10use crate::{
11    error::Error,
12    header::{
13        FieldIter, HeaderField, HeaderFieldDecoder, HeaderFieldEncoder, HeaderFieldValue,
14        HeaderFields, Iter,
15    },
16    line::{LineDecoder, LineDecoderOptions},
17    utils::ascii::AsciiExt,
18};
19
20#[cfg(feature = "tokio-codec")]
21use crate::error::CodecError;
22
23/// Request path.
24#[derive(Debug, Clone)]
25pub struct RequestPath {
26    inner: Bytes,
27}
28
29impl RequestPath {
30    /// Create a new request path.
31    #[inline]
32    pub const fn from_static_str(s: &'static str) -> Self {
33        Self::from_static_bytes(s.as_bytes())
34    }
35
36    /// Create a new request path.
37    #[inline]
38    pub const fn from_static_bytes(s: &'static [u8]) -> Self {
39        Self {
40            inner: Bytes::from_static(s),
41        }
42    }
43
44    /// Get the request path as an UTF-8 string.
45    #[inline]
46    pub fn to_str(&self) -> Result<&str, Utf8Error> {
47        std::str::from_utf8(&self.inner)
48    }
49}
50
51impl PartialEq for RequestPath {
52    #[inline]
53    fn eq(&self, other: &Self) -> bool {
54        self.inner.eq(&other.inner)
55    }
56}
57
58impl Eq for RequestPath {}
59
60impl AsRef<[u8]> for RequestPath {
61    #[inline]
62    fn as_ref(&self) -> &[u8] {
63        &self.inner
64    }
65}
66
67impl Borrow<[u8]> for RequestPath {
68    #[inline]
69    fn borrow(&self) -> &[u8] {
70        &self.inner
71    }
72}
73
74impl Deref for RequestPath {
75    type Target = [u8];
76
77    #[inline]
78    fn deref(&self) -> &Self::Target {
79        &self.inner
80    }
81}
82
83impl From<&'static [u8]> for RequestPath {
84    #[inline]
85    fn from(s: &'static [u8]) -> Self {
86        Self::from(Bytes::from(s))
87    }
88}
89
90impl From<&'static str> for RequestPath {
91    #[inline]
92    fn from(s: &'static str) -> Self {
93        Self::from(Bytes::from(s))
94    }
95}
96
97impl From<Bytes> for RequestPath {
98    #[inline]
99    fn from(bytes: Bytes) -> Self {
100        Self { inner: bytes }
101    }
102}
103
104impl From<BytesMut> for RequestPath {
105    #[inline]
106    fn from(bytes: BytesMut) -> Self {
107        Self::from(Bytes::from(bytes))
108    }
109}
110
111impl From<Box<[u8]>> for RequestPath {
112    #[inline]
113    fn from(bytes: Box<[u8]>) -> Self {
114        Self::from(Bytes::from(bytes))
115    }
116}
117
118impl From<Vec<u8>> for RequestPath {
119    #[inline]
120    fn from(bytes: Vec<u8>) -> Self {
121        Self::from(Bytes::from(bytes))
122    }
123}
124
125impl From<String> for RequestPath {
126    #[inline]
127    fn from(s: String) -> Self {
128        Self::from(Bytes::from(s))
129    }
130}
131
132/// Internal error type for invalid request lines.
133struct InvalidRequestLine;
134
135impl From<InvalidRequestLine> for Error {
136    fn from(_: InvalidRequestLine) -> Self {
137        Error::from_static_msg("invalid request line")
138    }
139}
140
141/// Request header builder.
142#[derive(Clone)]
143pub struct RequestHeaderBuilder<P = Bytes, V = Bytes, M = Bytes> {
144    header: RequestHeader<P, V, M>,
145}
146
147impl<P, V, M> RequestHeaderBuilder<P, V, M> {
148    /// Set the protocol version.
149    #[inline]
150    pub fn set_version(mut self, version: V) -> Self {
151        self.header.version = version;
152        self
153    }
154
155    /// Set the request method.
156    #[inline]
157    pub fn set_method(mut self, method: M) -> Self {
158        self.header.method = method;
159        self
160    }
161
162    /// Set the request path.
163    #[inline]
164    pub fn set_path(mut self, path: RequestPath) -> Self {
165        self.header.path = path;
166        self
167    }
168
169    /// Replace the current header fields having the same name (if any).
170    pub fn set_header_field<T>(mut self, field: T) -> Self
171    where
172        T: Into<HeaderField>,
173    {
174        self.header.header_fields.set(field);
175        self
176    }
177
178    /// Add a given header field.
179    pub fn add_header_field<T>(mut self, field: T) -> Self
180    where
181        T: Into<HeaderField>,
182    {
183        self.header.header_fields.add(field);
184        self
185    }
186
187    /// Remove all header fields with a given name.
188    pub fn remove_header_fields<N>(mut self, name: &N) -> Self
189    where
190        N: AsRef<[u8]> + ?Sized,
191    {
192        self.header.header_fields.remove(name);
193        self
194    }
195
196    /// Build the request header.
197    #[inline]
198    pub fn build(self) -> RequestHeader<P, V, M> {
199        self.header
200    }
201}
202
203impl<P, V, M> From<RequestHeader<P, V, M>> for RequestHeaderBuilder<P, V, M> {
204    #[inline]
205    fn from(header: RequestHeader<P, V, M>) -> Self {
206        Self { header }
207    }
208}
209
210/// Request header.
211///
212/// It can be used for constructing custom headers that have the same structure
213/// as HTTP headers.
214#[derive(Debug, Clone)]
215pub struct RequestHeader<P = Bytes, V = Bytes, M = Bytes> {
216    method: M,
217    path: RequestPath,
218    protocol: P,
219    version: V,
220    header_fields: HeaderFields,
221}
222
223impl RequestHeader {
224    /// Parse a given request line.
225    fn parse_request_line(line: Bytes) -> Result<Self, InvalidRequestLine> {
226        let (method, rest) = line
227            .trim_ascii_start()
228            .split_once(|b| b.is_ascii_whitespace())
229            .ok_or(InvalidRequestLine)?;
230
231        let (path, rest) = rest
232            .trim_ascii_start()
233            .split_once(|b| b.is_ascii_whitespace())
234            .ok_or(InvalidRequestLine)?;
235
236        let (protocol, version) = rest.split_once(|b| b == b'/').ok_or(InvalidRequestLine)?;
237
238        let res = Self {
239            method,
240            path: path.into(),
241            protocol: protocol.trim_ascii(),
242            version: version.trim_ascii(),
243            header_fields: HeaderFields::new(),
244        };
245
246        Ok(res)
247    }
248
249    /// Parse the request parts from the current header.
250    fn parse_request_parts<P, V, M>(self) -> Result<RequestHeader<P, V, M>, Error>
251    where
252        P: TryFrom<Bytes>,
253        V: TryFrom<Bytes>,
254        M: TryFrom<Bytes>,
255        Error: From<P::Error>,
256        Error: From<V::Error>,
257        Error: From<M::Error>,
258    {
259        let protocol = P::try_from(self.protocol)?;
260        let version = V::try_from(self.version)?;
261        let method = M::try_from(self.method)?;
262
263        let res = RequestHeader {
264            method,
265            path: self.path,
266            protocol,
267            version,
268            header_fields: self.header_fields,
269        };
270
271        Ok(res)
272    }
273}
274
275impl<P, V, M> RequestHeader<P, V, M> {
276    /// Create a new request header.
277    ///
278    /// # Arguments
279    ///
280    /// * `protocol` - type of the protocol (e.g. "HTTP")
281    /// * `version` - version of the protocol (e.g. "1.1")
282    /// * `method` - request method (e.g. "GET")
283    /// * `path` - request path (e.g. /some/path?with=query)
284    #[inline]
285    pub const fn new(protocol: P, version: V, method: M, path: RequestPath) -> Self {
286        Self {
287            method,
288            path,
289            protocol,
290            version,
291            header_fields: HeaderFields::new(),
292        }
293    }
294
295    /// Get a request header builder.
296    ///
297    /// # Arguments
298    ///
299    /// * `protocol` - type of the protocol (e.g. "HTTP")
300    /// * `version` - version of the protocol (e.g. "1.1")
301    /// * `method` - request method (e.g. "GET")
302    /// * `path` - request path (e.g. /some/path?with=query)
303    #[inline]
304    pub const fn builder(
305        protocol: P,
306        version: V,
307        method: M,
308        path: RequestPath,
309    ) -> RequestHeaderBuilder<P, V, M> {
310        RequestHeaderBuilder {
311            header: Self::new(protocol, version, method, path),
312        }
313    }
314
315    /// Get the request method.
316    #[inline]
317    pub fn method(&self) -> &M {
318        &self.method
319    }
320
321    /// Get the request protocol.
322    #[inline]
323    pub fn protocol(&self) -> &P {
324        &self.protocol
325    }
326
327    /// Get the request protocol version.
328    #[inline]
329    pub fn version(&self) -> &V {
330        &self.version
331    }
332
333    /// Get the request path.
334    #[inline]
335    pub fn path(&self) -> &RequestPath {
336        &self.path
337    }
338
339    /// Get all header fields.
340    #[inline]
341    pub fn get_all_header_fields(&self) -> Iter<'_> {
342        self.header_fields.all()
343    }
344
345    /// Get header fields corresponding to a given name.
346    pub fn get_header_fields<'a, N>(&'a self, name: &'a N) -> FieldIter<'a>
347    where
348        N: AsRef<[u8]> + ?Sized,
349    {
350        self.header_fields.get(name)
351    }
352
353    /// Get last header field of a given name.
354    pub fn get_header_field<'a, N>(&'a self, name: &'a N) -> Option<&'a HeaderField>
355    where
356        N: AsRef<[u8]> + ?Sized,
357    {
358        self.header_fields.last(name)
359    }
360
361    /// Get value of the last header field with a given name.
362    pub fn get_header_field_value<'a, N>(&'a self, name: &'a N) -> Option<&'a HeaderFieldValue>
363    where
364        N: AsRef<[u8]> + ?Sized,
365    {
366        self.header_fields.last_value(name)
367    }
368}
369
370/// Encoder for request headers.
371pub struct RequestHeaderEncoder(());
372
373impl RequestHeaderEncoder {
374    /// Create a new request header encoder.
375    #[inline]
376    pub const fn new() -> Self {
377        Self(())
378    }
379
380    /// Encode a given request header.
381    pub fn encode<P, V, M>(&mut self, header: &RequestHeader<P, V, M>, dst: &mut BytesMut)
382    where
383        P: AsRef<[u8]>,
384        V: AsRef<[u8]>,
385        M: AsRef<[u8]>,
386    {
387        // helper function to avoid expensive monomorphizations
388        fn inner(
389            method: &[u8],
390            path: &[u8],
391            protocol: &[u8],
392            version: &[u8],
393            fields: &HeaderFields,
394            dst: &mut BytesMut,
395        ) {
396            let mut hfe = HeaderFieldEncoder::new();
397
398            let len = 7
399                + method.len()
400                + path.len()
401                + protocol.len()
402                + version.len()
403                + fields
404                    .all()
405                    .map(|f| 2 + hfe.get_encoded_length(f))
406                    .sum::<usize>();
407
408            dst.reserve(len);
409
410            dst.extend_from_slice(method);
411            dst.extend_from_slice(b" ");
412
413            dst.extend_from_slice(path);
414            dst.extend_from_slice(b" ");
415
416            dst.extend_from_slice(protocol);
417            dst.extend_from_slice(b"/");
418            dst.extend_from_slice(version);
419            dst.extend_from_slice(b"\r\n");
420
421            for field in fields.all() {
422                hfe.encode(field, dst);
423                dst.extend_from_slice(b"\r\n");
424            }
425
426            dst.extend_from_slice(b"\r\n");
427        }
428
429        let method = header.method.as_ref();
430        let path = header.path.as_ref();
431        let protocol = header.protocol.as_ref();
432        let version = header.version.as_ref();
433
434        inner(method, path, protocol, version, &header.header_fields, dst)
435    }
436}
437
438impl Default for RequestHeaderEncoder {
439    #[inline]
440    fn default() -> Self {
441        Self::new()
442    }
443}
444
445#[cfg(feature = "tokio-codec")]
446#[cfg_attr(docsrs, doc(cfg(feature = "tokio-codec")))]
447impl<P, V, M> Encoder<&RequestHeader<P, V, M>> for RequestHeaderEncoder
448where
449    P: AsRef<[u8]>,
450    V: AsRef<[u8]>,
451    M: AsRef<[u8]>,
452{
453    type Error = CodecError;
454
455    #[inline]
456    fn encode(
457        &mut self,
458        header: &RequestHeader<P, V, M>,
459        dst: &mut BytesMut,
460    ) -> Result<(), Self::Error> {
461        RequestHeaderEncoder::encode(self, header, dst);
462
463        Ok(())
464    }
465}
466
467/// Request header decoder options.
468#[derive(Copy, Clone)]
469pub struct RequestHeaderDecoderOptions {
470    line_decoder_options: LineDecoderOptions,
471    max_header_field_length: Option<usize>,
472    max_header_fields: Option<usize>,
473}
474
475impl RequestHeaderDecoderOptions {
476    /// Create new request header decoder options.
477    ///
478    /// By default only CRLF line endings are accepted, the maximum line length
479    /// is 4096 bytes, the maximum header field length is 4096 bytes and the
480    /// maximum number of header fields is 64.
481    #[inline]
482    pub const fn new() -> Self {
483        let line_decoder_options = LineDecoderOptions::new()
484            .cr(false)
485            .lf(false)
486            .crlf(true)
487            .max_line_length(Some(4096))
488            .require_terminator(false);
489
490        Self {
491            line_decoder_options,
492            max_header_field_length: Some(4096),
493            max_header_fields: Some(64),
494        }
495    }
496
497    /// Enable or disable acceptance of all line endings (CR, LF, CRLF).
498    #[inline]
499    pub const fn accept_all_line_endings(mut self, enabled: bool) -> Self {
500        self.line_decoder_options = self.line_decoder_options.cr(enabled).lf(enabled).crlf(true);
501
502        self
503    }
504
505    /// Set the maximum allowed line length.
506    #[inline]
507    pub const fn max_line_length(mut self, max_length: Option<usize>) -> Self {
508        self.line_decoder_options = self.line_decoder_options.max_line_length(max_length);
509        self
510    }
511
512    /// Set the maximum allowed header field length.
513    #[inline]
514    pub const fn max_header_field_length(mut self, max_length: Option<usize>) -> Self {
515        self.max_header_field_length = max_length;
516        self
517    }
518
519    /// Set the maximum allowed number of header fields.
520    #[inline]
521    pub const fn max_header_fields(mut self, max_fields: Option<usize>) -> Self {
522        self.max_header_fields = max_fields;
523        self
524    }
525}
526
527impl Default for RequestHeaderDecoderOptions {
528    #[inline]
529    fn default() -> Self {
530        Self::new()
531    }
532}
533
534/// Decoder for request headers.
535pub struct RequestHeaderDecoder<P, V, M> {
536    inner: InternalRequestHeaderDecoder,
537    _pd: PhantomData<(P, V, M)>,
538}
539
540impl<P, V, M> RequestHeaderDecoder<P, V, M> {
541    /// Create a new request header decoder.
542    pub fn new(options: RequestHeaderDecoderOptions) -> Self {
543        Self {
544            inner: InternalRequestHeaderDecoder::new(options),
545            _pd: PhantomData,
546        }
547    }
548
549    /// Reset the decoder and make it ready for parsing a new request header.
550    pub fn reset(&mut self) {
551        self.inner.reset();
552    }
553}
554
555impl<P, V, M> RequestHeaderDecoder<P, V, M>
556where
557    P: TryFrom<Bytes>,
558    V: TryFrom<Bytes>,
559    M: TryFrom<Bytes>,
560    Error: From<P::Error>,
561    Error: From<V::Error>,
562    Error: From<M::Error>,
563{
564    /// Decode a given request header chunk.
565    pub fn decode(&mut self, data: &mut BytesMut) -> Result<Option<RequestHeader<P, V, M>>, Error> {
566        let res = self
567            .inner
568            .decode(data)?
569            .map(RequestHeader::parse_request_parts)
570            .transpose()?;
571
572        Ok(res)
573    }
574
575    /// Decode a given request header chunk at the end of the stream.
576    pub fn decode_eof(
577        &mut self,
578        data: &mut BytesMut,
579    ) -> Result<Option<RequestHeader<P, V, M>>, Error> {
580        let res = self
581            .inner
582            .decode_eof(data)?
583            .map(RequestHeader::parse_request_parts)
584            .transpose()?;
585
586        Ok(res)
587    }
588}
589
590#[cfg(feature = "tokio-codec")]
591#[cfg_attr(docsrs, doc(cfg(feature = "tokio-codec")))]
592impl<P, V, M> Decoder for RequestHeaderDecoder<P, V, M>
593where
594    P: TryFrom<Bytes>,
595    V: TryFrom<Bytes>,
596    M: TryFrom<Bytes>,
597    Error: From<P::Error>,
598    Error: From<V::Error>,
599    Error: From<M::Error>,
600{
601    type Item = RequestHeader<P, V, M>;
602    type Error = CodecError;
603
604    #[inline]
605    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
606        RequestHeaderDecoder::<P, V, M>::decode(self, buf).map_err(CodecError::Other)
607    }
608
609    #[inline]
610    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
611        RequestHeaderDecoder::<P, V, M>::decode_eof(self, buf).map_err(CodecError::Other)
612    }
613}
614
615/// Request header decoder.
616struct InternalRequestHeaderDecoder {
617    line_decoder: LineDecoder,
618    header: Option<RequestHeader>,
619    field_decoder: HeaderFieldDecoder,
620    max_header_fields: Option<usize>,
621}
622
623impl InternalRequestHeaderDecoder {
624    /// Create a new request header decoder.
625    fn new(options: RequestHeaderDecoderOptions) -> Self {
626        Self {
627            line_decoder: LineDecoder::new(options.line_decoder_options),
628            header: None,
629            field_decoder: HeaderFieldDecoder::new(options.max_header_field_length),
630            max_header_fields: options.max_header_fields,
631        }
632    }
633
634    /// Decode a given request header chunk.
635    fn decode(&mut self, data: &mut BytesMut) -> Result<Option<RequestHeader>, Error> {
636        while let Some(line) = self.line_decoder.decode(data)? {
637            if let Some(header) = self.decode_line(line)? {
638                return Ok(Some(header));
639            }
640        }
641
642        Ok(None)
643    }
644
645    /// Decode a given request header chunk at the end of the stream.
646    fn decode_eof(&mut self, data: &mut BytesMut) -> Result<Option<RequestHeader>, Error> {
647        while let Some(line) = self.line_decoder.decode_eof(data)? {
648            if let Some(header) = self.decode_line(line)? {
649                return Ok(Some(header));
650            }
651        }
652
653        if data.is_empty() && self.line_decoder.is_empty() && self.header.is_none() {
654            Ok(None)
655        } else {
656            Err(Error::from_static_msg("incomplete request header"))
657        }
658    }
659
660    /// Decode a given request line.
661    fn decode_line(&mut self, line: Bytes) -> Result<Option<RequestHeader>, Error> {
662        if let Some(header) = self.header.as_mut() {
663            let is_empty_line = line.is_empty();
664
665            if let Some(field) = self.field_decoder.decode(line)? {
666                if let Some(max_fields) = self.max_header_fields {
667                    if header.header_fields.len() >= max_fields {
668                        return Err(Error::from_static_msg(
669                            "maximum number of header fields exceeded",
670                        ));
671                    }
672                }
673
674                header.header_fields.add(field);
675            }
676
677            // an empty line means the end of the header
678            if is_empty_line {
679                return Ok(self.take());
680            }
681        } else {
682            self.header = Some(RequestHeader::parse_request_line(line)?);
683        }
684
685        Ok(None)
686    }
687
688    /// Reset the decoder and make it ready for parsing a new request header.
689    fn reset(&mut self) {
690        self.take();
691    }
692
693    /// Take the current header and reset the decoder.
694    fn take(&mut self) -> Option<RequestHeader> {
695        self.line_decoder.reset();
696        self.field_decoder.reset();
697
698        self.header.take()
699    }
700}