areq_h1/
handler.rs

1use {
2    crate::{bytes::InitBytesMut, error::Error},
3    bytes::Bytes,
4    futures_lite::prelude::*,
5    http::{Request, Response, Uri, Version},
6    httparse::{Header, ParserConfig},
7    std::{io::Write, str},
8};
9
10const KB: usize = 1 << 10;
11const INIT_BUFFER_LEN: usize = 2 * KB;
12const MAX_BUFFER_LEN: usize = 128 * KB;
13
14pub(crate) struct Handler<I> {
15    io: I,
16    read_buf: InitBytesMut,
17    read_strategy: Strategy,
18    write_buf: Vec<u8>,
19}
20
21impl<I> Handler<I> {
22    pub fn new(io: I, read_strategy: ReadStrategy) -> Self {
23        Self {
24            io,
25            read_buf: InitBytesMut::new(),
26            read_strategy: read_strategy.state(),
27            write_buf: Vec::with_capacity(INIT_BUFFER_LEN),
28        }
29    }
30
31    async fn read_to_buf(&mut self) -> Result<(), Error>
32    where
33        I: AsyncRead + Unpin,
34    {
35        let next = self.read_strategy.next();
36        if self.read_buf.spare_capacity_len() < next {
37            self.read_buf.reserve(next);
38        }
39
40        let buf = self.read_buf.spare_capacity_mut();
41        if buf.is_empty() {
42            return Err(Error::TooLargeInput);
43        }
44
45        let n = self.io.read(buf).await?;
46        self.read_buf.advance(n);
47        self.read_strategy.record(n);
48
49        if n == 0 {
50            Err(Error::unexpected_eof())
51        } else {
52            Ok(())
53        }
54    }
55
56    async fn read_until(&mut self, sep: &[u8]) -> Result<Bytes, Error>
57    where
58        I: AsyncRead + Unpin,
59    {
60        debug_assert!(!sep.is_empty(), "sep must not be empty");
61
62        let mut cursor = 0;
63        loop {
64            let start = usize::saturating_sub(cursor, sep.len());
65            let buf = &self.read_buf.as_mut()[start..];
66            for i in memchr::memchr_iter(sep[0], buf) {
67                let (_, rest) = buf.split_at(i);
68                if rest.starts_with(sep) {
69                    let at = start + i + sep.len();
70                    return Ok(self.read_buf.split_to(at).freeze());
71                }
72            }
73
74            cursor += self.read_buf.len() - cursor;
75            self.read_to_buf().await?;
76        }
77    }
78
79    pub async fn read_header(&mut self) -> Result<Bytes, Error>
80    where
81        I: AsyncRead + Unpin,
82    {
83        let bytes = self.read_until(b"\r\n\r\n").await?;
84        Ok(bytes)
85    }
86
87    pub async fn read_body(&mut self, remaining: &mut usize) -> Result<Bytes, Error>
88    where
89        I: AsyncRead + Unpin,
90    {
91        debug_assert_ne!(*remaining, 0, "do not call this when remaining is zero");
92
93        if self.read_buf.is_empty() {
94            self.read_to_buf().await?;
95        }
96
97        let chunk_len = usize::min(*remaining, self.read_buf.len());
98        let chunk = self.read_buf.split_to(chunk_len).freeze();
99        *remaining -= chunk_len;
100        Ok(chunk)
101    }
102
103    pub async fn read_chunk(&mut self) -> Result<Bytes, Error>
104    where
105        I: AsyncRead + Unpin,
106    {
107        const SEP: &[u8; 2] = b"\r\n";
108
109        let len = {
110            let len_bytes = self.read_until(SEP).await?;
111            let len_bytes = len_bytes
112                .strip_suffix(SEP)
113                .expect("bytes read include suffix");
114
115            let len_str = str::from_utf8(len_bytes).map_err(|_| Error::invalid_input())?;
116            let len = usize::from_str_radix(len_str, 16).map_err(|_| Error::invalid_input())?;
117            len + SEP.len()
118        };
119
120        while self.read_buf.len() < len {
121            self.read_to_buf().await?;
122        }
123
124        let mut chunk = self.read_buf.split_to(len);
125        if chunk.ends_with(SEP) {
126            chunk.truncate(chunk.len() - SEP.len());
127            Ok(chunk.freeze())
128        } else {
129            Err(Error::invalid_input())
130        }
131    }
132
133    pub async fn write_header(&mut self, req: &Request<()>) -> Result<(), Error>
134    where
135        I: AsyncWrite + Unpin,
136    {
137        fn write_uri_to_buf(uri: &Uri, buf: &mut Vec<u8>) {
138            let n = buf.len();
139            _ = write!(buf, "{uri}");
140
141            // uri was not written
142            // may happen because of https://github.com/hyperium/http/issues/507
143            if n == buf.len() {
144                buf.push(b'/');
145            }
146        }
147
148        fn write_to_buf(req: &Request<()>, buf: &mut Vec<u8>) {
149            let method = req.method();
150            let uri = req.uri();
151
152            assert_eq!(
153                req.version(),
154                Version::HTTP_11,
155                "only HTTP/1.1 version is supported",
156            );
157
158            _ = write!(buf, "{method} ");
159            write_uri_to_buf(uri, buf);
160            buf.extend_from_slice(b" HTTP/1.1\r\n");
161            for (name, value) in req.headers() {
162                _ = write!(buf, "{name}: ");
163                buf.extend_from_slice(value.as_bytes());
164                buf.extend_from_slice(b"\r\n");
165            }
166
167            buf.extend_from_slice(b"\r\n");
168        }
169
170        self.write_buf.clear();
171        write_to_buf(req, &mut self.write_buf);
172        self.io.write(&self.write_buf).await?;
173        Ok(())
174    }
175
176    pub async fn write_body(&mut self, body: &[u8]) -> Result<(), Error>
177    where
178        I: AsyncWrite + Unpin,
179    {
180        self.io.write(body).await?;
181        Ok(())
182    }
183
184    pub async fn write_chunk(&mut self, chunk: &[u8]) -> Result<(), Error>
185    where
186        I: AsyncWrite + Unpin,
187    {
188        self.write_buf.clear();
189        let chunk_len = chunk.len();
190        _ = write!(&mut self.write_buf, "{chunk_len:X}\r\n");
191
192        self.io.write(&self.write_buf).await?;
193        self.io.write(chunk).await?;
194        self.io.write(b"\r\n").await?;
195        Ok(())
196    }
197
198    pub async fn flush(&mut self) -> Result<(), Error>
199    where
200        I: AsyncWrite + Unpin,
201    {
202        self.io.flush().await?;
203        Ok(())
204    }
205}
206
207#[derive(Clone, Copy)]
208pub enum ReadStrategy {
209    Exact(usize),
210    Adaptive { max: usize },
211}
212
213impl ReadStrategy {
214    fn state(self) -> Strategy {
215        match self {
216            Self::Exact(n) => Strategy::Exact(n),
217            Self::Adaptive { max } => Strategy::Adaptive {
218                next: INIT_BUFFER_LEN,
219                max,
220            },
221        }
222    }
223}
224
225impl Default for ReadStrategy {
226    fn default() -> Self {
227        Self::Adaptive {
228            max: MAX_BUFFER_LEN,
229        }
230    }
231}
232
233#[derive(Clone, Copy)]
234enum Strategy {
235    Exact(usize),
236    Adaptive { next: usize, max: usize },
237}
238
239impl Strategy {
240    fn next(self) -> usize {
241        match self {
242            Self::Exact(n) => n,
243            Self::Adaptive { next, .. } => next,
244        }
245    }
246
247    fn record(&mut self, n: usize) {
248        match self {
249            Self::Exact(_) => {}
250            Self::Adaptive { next, max } => {
251                if n >= *next {
252                    let incpow = usize::saturating_mul(*next, 2);
253                    *next = usize::min(incpow, *max);
254                }
255            }
256        }
257    }
258}
259
260#[derive(Clone)]
261pub(crate) struct Parser {
262    conf: ParserConfig,
263    max_headers: usize,
264}
265
266impl Parser {
267    const HEADERS_STACK_BUFFER_LEN: usize = 150;
268
269    pub fn new() -> Self {
270        Self {
271            conf: ParserConfig::default(),
272            max_headers: Self::HEADERS_STACK_BUFFER_LEN,
273        }
274    }
275
276    pub fn set_max_headers(&mut self, n: usize) {
277        self.max_headers = n;
278    }
279
280    pub fn parse_header(&self, buf: Bytes) -> Result<Response<()>, Error> {
281        use {
282            http::{HeaderName, HeaderValue, StatusCode},
283            httparse::Status,
284            std::mem::MaybeUninit,
285        };
286
287        let mut out = httparse::Response::new(&mut []);
288        let uninit_headers = if self.max_headers <= Self::HEADERS_STACK_BUFFER_LEN {
289            &mut [MaybeUninit::uninit(); Self::HEADERS_STACK_BUFFER_LEN][..self.max_headers]
290        } else {
291            &mut vec![MaybeUninit::uninit(); self.max_headers][..]
292        };
293
294        match self
295            .conf
296            .parse_response_with_uninit_headers(&mut out, &buf, uninit_headers)?
297        {
298            Status::Complete(n) if n == buf.len() => {}
299            _ => panic!("failed to complete parsing"),
300        }
301
302        let mut res = Response::new(());
303        *res.version_mut() = match out.version {
304            Some(9) => return Err(Error::UnsupportedVersion(Version::HTTP_09)),
305            Some(0) => return Err(Error::UnsupportedVersion(Version::HTTP_10)),
306            Some(1) => Version::HTTP_11,
307            _ => return Err(Error::Parse(httparse::Error::Version)),
308        };
309
310        *res.status_mut() =
311            StatusCode::from_u16(out.code.unwrap_or_default()).expect("valid status code");
312
313        *res.headers_mut() = {
314            let entry = |header: Header<'_>| {
315                let name =
316                    HeaderName::from_bytes(header.name.as_bytes()).expect("valid header name");
317                let value = HeaderValue::from_maybe_shared(buf.slice_ref(header.value))
318                    .expect("valid header value");
319
320                (name, value)
321            };
322
323            out.headers.iter().copied().map(entry).collect()
324        };
325
326        Ok(res)
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use {super::*, futures_lite::future};
333
334    impl<I> Handler<I> {
335        fn test(io: I) -> Self {
336            Self::new(io, ReadStrategy::default())
337        }
338    }
339
340    const RESPONSE: &[u8] = b"\
341        HTTP/1.1 200 OK\r\n\
342        date: mon, 27 jul 2009 12:28:53 gmt\r\n\
343        last-modified: wed, 22 jul 2009 19:15:56 gmt\r\n\
344        accept-ranges: bytes\r\n\
345        content-length: 4\r\n\
346        vary: accept-encoding\r\n\
347        content-type: text/plain\r\n\
348        \r\n\
349        body\
350    ";
351
352    fn header() -> &'static [u8] {
353        RESPONSE.strip_suffix(b"body").expect("strip body")
354    }
355
356    #[test]
357    fn read_head() -> Result<(), Error> {
358        let mut h = Handler::test(RESPONSE);
359        let head = future::block_on(h.read_header())?;
360        assert_eq!(head, header());
361        Ok(())
362    }
363
364    #[test]
365    fn parse_head() -> Result<(), Error> {
366        use http::{HeaderMap, HeaderName, HeaderValue, StatusCode, Version};
367
368        let res = header();
369        let head = Parser::new().parse_header(Bytes::copy_from_slice(res))?;
370        assert_eq!(head.status(), StatusCode::OK);
371        assert_eq!(head.version(), Version::HTTP_11);
372
373        let headers = [
374            ("date", "mon, 27 jul 2009 12:28:53 gmt"),
375            ("last-modified", "wed, 22 jul 2009 19:15:56 gmt"),
376            ("accept-ranges", "bytes"),
377            ("content-length", "4"),
378            ("vary", "accept-encoding"),
379            ("content-type", "text/plain"),
380        ];
381
382        let headers: HeaderMap = headers
383            .into_iter()
384            .map(|(name, value)| {
385                (
386                    HeaderName::from_bytes(name.as_bytes()).expect("lowercased header name"),
387                    HeaderValue::from_static(value),
388                )
389            })
390            .collect();
391
392        assert_eq!(head.headers(), &headers);
393        Ok(())
394    }
395
396    #[test]
397    fn parse_head_max_headers() -> Result<(), Error> {
398        use http::{StatusCode, Version};
399
400        let parser = Parser {
401            max_headers: 5,
402            ..Parser::new()
403        };
404
405        let res = header();
406        let e = parser
407            .parse_header(Bytes::copy_from_slice(res))
408            .expect_err("too many headers");
409
410        assert!(matches!(e, Error::Parse(httparse::Error::TooManyHeaders)));
411
412        let parser = Parser {
413            max_headers: 6,
414            ..Parser::new()
415        };
416
417        let head = parser.parse_header(Bytes::copy_from_slice(res))?;
418        assert_eq!(head.status(), StatusCode::OK);
419        assert_eq!(head.version(), Version::HTTP_11);
420        Ok(())
421    }
422
423    #[test]
424    fn read_body() -> Result<(), Error> {
425        const BODY: &[u8] = b"Hello, World!";
426
427        let mut h = Handler::test(BODY);
428        let mut remaining = BODY.len();
429        let body = future::block_on(h.read_body(&mut remaining))?;
430        assert_eq!(body, BODY);
431        assert_eq!(remaining, 0);
432        Ok(())
433    }
434
435    #[test]
436    fn read_response() -> Result<(), Error> {
437        let mut h = Handler::test(RESPONSE);
438        let head = future::block_on(h.read_header())?;
439        assert_eq!(head, header());
440
441        let mut remaining = 4;
442        let body = future::block_on(h.read_body(&mut remaining))?;
443        assert_eq!(body, "body".as_bytes());
444        assert_eq!(remaining, 0);
445        assert!(h.read_buf.is_empty());
446        Ok(())
447    }
448
449    #[test]
450    fn read_partial() -> Result<(), Error> {
451        use crate::test;
452
453        let cases = [
454            (["_", "_", "A"].as_slice(), "A", "__A"),
455            (&["_", "_", "A", "_"], "A", "__A"),
456            (&["A", "B"], "AB", "AB"),
457            (&["A", "B", "C"], "ABC", "ABC"),
458            (&["___A", "B", "___"], "AB", "___AB"),
459            (&["___A", "B", "C___"], "ABC", "___ABC"),
460            (&["_", "__", "_A", "B", "C___"], "ABC", "____ABC"),
461            (&["_", "__", "_A", "B", "C___"], "A", "____A"),
462            (&["AA", "_BA_", "_A", "B", "C___"], "AB", "AA_BA__AB"),
463        ];
464
465        for (reads, until, actual) in cases {
466            let parts = test::parts(reads.iter().copied().map(str::as_bytes));
467            let mut h = Handler::test(parts);
468            let bytes = future::block_on(h.read_until(until.as_bytes()))?;
469            assert_eq!(bytes, actual);
470        }
471
472        Ok(())
473    }
474
475    #[test]
476    fn write_head() -> Result<(), Error> {
477        use http::{HeaderValue, Method, Uri, Version};
478
479        const REQUEST: &[u8] = b"\
480            GET /get HTTP/1.1\r\n\
481            name: value\r\n\
482            \r\n\
483        ";
484
485        let mut req = Request::new(());
486        *req.method_mut() = Method::GET;
487        *req.uri_mut() = Uri::from_static("/get");
488        *req.version_mut() = Version::HTTP_11;
489        req.headers_mut()
490            .append("name", HeaderValue::from_static("value"));
491
492        let mut write = vec![];
493        let mut h = Handler::test(&mut write);
494        future::block_on(h.write_header(&req))?;
495        assert_eq!(write, REQUEST);
496        Ok(())
497    }
498
499    #[test]
500    fn write_head_empty_path() -> Result<(), Error> {
501        use http::{Method, Uri, Version};
502
503        const REQUEST: &[u8] = b"\
504            GET / HTTP/1.1\r\n\
505            \r\n\
506        ";
507
508        let mut req = Request::new(());
509        *req.method_mut() = Method::GET;
510        *req.uri_mut() = Uri::from_static("s://a")
511            .into_parts()
512            .path_and_query
513            .expect("get empty path")
514            .into();
515
516        *req.version_mut() = Version::HTTP_11;
517
518        let mut write = vec![];
519        let mut h = Handler::test(&mut write);
520        future::block_on(h.write_header(&req))?;
521        assert_eq!(write, REQUEST);
522        Ok(())
523    }
524
525    #[test]
526    fn exact_read() -> Result<(), Error> {
527        let mut h = Handler::test(RESPONSE);
528        h.read_strategy = Strategy::Exact(2);
529
530        future::block_on(h.read_to_buf())?;
531        assert_eq!(h.read_strategy.next(), 2);
532        Ok(())
533    }
534
535    #[test]
536    fn adaptive_read() -> Result<(), Error> {
537        let mut h = Handler::test(RESPONSE);
538        h.read_strategy = Strategy::Adaptive { next: 1, max: 10 };
539
540        for n in [2, 4, 8, 10] {
541            future::block_on(h.read_to_buf())?;
542            assert_eq!(h.read_strategy.next(), n);
543        }
544
545        Ok(())
546    }
547}