1use {
2 crate::{bytes::InitBytesMut, error::Error},
3 bytes::Bytes,
4 futures_lite::prelude::*,
5 http::{Request, Response, Uri, Version},
6 httparse::{Header, ParserConfig},
7 std::{io::Write, str},
8};
9
10const KB: usize = 1 << 10;
11const INIT_BUFFER_LEN: usize = 2 * KB;
12const MAX_BUFFER_LEN: usize = 128 * KB;
13
14pub(crate) struct Handler<I> {
15 io: I,
16 read_buf: InitBytesMut,
17 read_strategy: Strategy,
18 write_buf: Vec<u8>,
19}
20
21impl<I> Handler<I> {
22 pub fn new(io: I, read_strategy: ReadStrategy) -> Self {
23 Self {
24 io,
25 read_buf: InitBytesMut::new(),
26 read_strategy: read_strategy.state(),
27 write_buf: Vec::with_capacity(INIT_BUFFER_LEN),
28 }
29 }
30
31 async fn read_to_buf(&mut self) -> Result<(), Error>
32 where
33 I: AsyncRead + Unpin,
34 {
35 let next = self.read_strategy.next();
36 if self.read_buf.spare_capacity_len() < next {
37 self.read_buf.reserve(next);
38 }
39
40 let buf = self.read_buf.spare_capacity_mut();
41 if buf.is_empty() {
42 return Err(Error::TooLargeInput);
43 }
44
45 let n = self.io.read(buf).await?;
46 self.read_buf.advance(n);
47 self.read_strategy.record(n);
48
49 if n == 0 {
50 Err(Error::unexpected_eof())
51 } else {
52 Ok(())
53 }
54 }
55
56 async fn read_until(&mut self, sep: &[u8]) -> Result<Bytes, Error>
57 where
58 I: AsyncRead + Unpin,
59 {
60 debug_assert!(!sep.is_empty(), "sep must not be empty");
61
62 let mut cursor = 0;
63 loop {
64 let start = usize::saturating_sub(cursor, sep.len());
65 let buf = &self.read_buf.as_mut()[start..];
66 for i in memchr::memchr_iter(sep[0], buf) {
67 let (_, rest) = buf.split_at(i);
68 if rest.starts_with(sep) {
69 let at = start + i + sep.len();
70 return Ok(self.read_buf.split_to(at).freeze());
71 }
72 }
73
74 cursor += self.read_buf.len() - cursor;
75 self.read_to_buf().await?;
76 }
77 }
78
79 pub async fn read_header(&mut self) -> Result<Bytes, Error>
80 where
81 I: AsyncRead + Unpin,
82 {
83 let bytes = self.read_until(b"\r\n\r\n").await?;
84 Ok(bytes)
85 }
86
87 pub async fn read_body(&mut self, remaining: &mut usize) -> Result<Bytes, Error>
88 where
89 I: AsyncRead + Unpin,
90 {
91 debug_assert_ne!(*remaining, 0, "do not call this when remaining is zero");
92
93 if self.read_buf.is_empty() {
94 self.read_to_buf().await?;
95 }
96
97 let chunk_len = usize::min(*remaining, self.read_buf.len());
98 let chunk = self.read_buf.split_to(chunk_len).freeze();
99 *remaining -= chunk_len;
100 Ok(chunk)
101 }
102
103 pub async fn read_chunk(&mut self) -> Result<Bytes, Error>
104 where
105 I: AsyncRead + Unpin,
106 {
107 const SEP: &[u8; 2] = b"\r\n";
108
109 let len = {
110 let len_bytes = self.read_until(SEP).await?;
111 let len_bytes = len_bytes
112 .strip_suffix(SEP)
113 .expect("bytes read include suffix");
114
115 let len_str = str::from_utf8(len_bytes).map_err(|_| Error::invalid_input())?;
116 let len = usize::from_str_radix(len_str, 16).map_err(|_| Error::invalid_input())?;
117 len + SEP.len()
118 };
119
120 while self.read_buf.len() < len {
121 self.read_to_buf().await?;
122 }
123
124 let mut chunk = self.read_buf.split_to(len);
125 if chunk.ends_with(SEP) {
126 chunk.truncate(chunk.len() - SEP.len());
127 Ok(chunk.freeze())
128 } else {
129 Err(Error::invalid_input())
130 }
131 }
132
133 pub async fn write_header(&mut self, req: &Request<()>) -> Result<(), Error>
134 where
135 I: AsyncWrite + Unpin,
136 {
137 fn write_uri_to_buf(uri: &Uri, buf: &mut Vec<u8>) {
138 let n = buf.len();
139 _ = write!(buf, "{uri}");
140
141 if n == buf.len() {
144 buf.push(b'/');
145 }
146 }
147
148 fn write_to_buf(req: &Request<()>, buf: &mut Vec<u8>) {
149 let method = req.method();
150 let uri = req.uri();
151
152 assert_eq!(
153 req.version(),
154 Version::HTTP_11,
155 "only HTTP/1.1 version is supported",
156 );
157
158 _ = write!(buf, "{method} ");
159 write_uri_to_buf(uri, buf);
160 buf.extend_from_slice(b" HTTP/1.1\r\n");
161 for (name, value) in req.headers() {
162 _ = write!(buf, "{name}: ");
163 buf.extend_from_slice(value.as_bytes());
164 buf.extend_from_slice(b"\r\n");
165 }
166
167 buf.extend_from_slice(b"\r\n");
168 }
169
170 self.write_buf.clear();
171 write_to_buf(req, &mut self.write_buf);
172 self.io.write(&self.write_buf).await?;
173 Ok(())
174 }
175
176 pub async fn write_body(&mut self, body: &[u8]) -> Result<(), Error>
177 where
178 I: AsyncWrite + Unpin,
179 {
180 self.io.write(body).await?;
181 Ok(())
182 }
183
184 pub async fn write_chunk(&mut self, chunk: &[u8]) -> Result<(), Error>
185 where
186 I: AsyncWrite + Unpin,
187 {
188 self.write_buf.clear();
189 let chunk_len = chunk.len();
190 _ = write!(&mut self.write_buf, "{chunk_len:X}\r\n");
191
192 self.io.write(&self.write_buf).await?;
193 self.io.write(chunk).await?;
194 self.io.write(b"\r\n").await?;
195 Ok(())
196 }
197
198 pub async fn flush(&mut self) -> Result<(), Error>
199 where
200 I: AsyncWrite + Unpin,
201 {
202 self.io.flush().await?;
203 Ok(())
204 }
205}
206
207#[derive(Clone, Copy)]
208pub enum ReadStrategy {
209 Exact(usize),
210 Adaptive { max: usize },
211}
212
213impl ReadStrategy {
214 fn state(self) -> Strategy {
215 match self {
216 Self::Exact(n) => Strategy::Exact(n),
217 Self::Adaptive { max } => Strategy::Adaptive {
218 next: INIT_BUFFER_LEN,
219 max,
220 },
221 }
222 }
223}
224
225impl Default for ReadStrategy {
226 fn default() -> Self {
227 Self::Adaptive {
228 max: MAX_BUFFER_LEN,
229 }
230 }
231}
232
233#[derive(Clone, Copy)]
234enum Strategy {
235 Exact(usize),
236 Adaptive { next: usize, max: usize },
237}
238
239impl Strategy {
240 fn next(self) -> usize {
241 match self {
242 Self::Exact(n) => n,
243 Self::Adaptive { next, .. } => next,
244 }
245 }
246
247 fn record(&mut self, n: usize) {
248 match self {
249 Self::Exact(_) => {}
250 Self::Adaptive { next, max } => {
251 if n >= *next {
252 let incpow = usize::saturating_mul(*next, 2);
253 *next = usize::min(incpow, *max);
254 }
255 }
256 }
257 }
258}
259
260#[derive(Clone)]
261pub(crate) struct Parser {
262 conf: ParserConfig,
263 max_headers: usize,
264}
265
266impl Parser {
267 const HEADERS_STACK_BUFFER_LEN: usize = 150;
268
269 pub fn new() -> Self {
270 Self {
271 conf: ParserConfig::default(),
272 max_headers: Self::HEADERS_STACK_BUFFER_LEN,
273 }
274 }
275
276 pub fn set_max_headers(&mut self, n: usize) {
277 self.max_headers = n;
278 }
279
280 pub fn parse_header(&self, buf: Bytes) -> Result<Response<()>, Error> {
281 use {
282 http::{HeaderName, HeaderValue, StatusCode},
283 httparse::Status,
284 std::mem::MaybeUninit,
285 };
286
287 let mut out = httparse::Response::new(&mut []);
288 let uninit_headers = if self.max_headers <= Self::HEADERS_STACK_BUFFER_LEN {
289 &mut [MaybeUninit::uninit(); Self::HEADERS_STACK_BUFFER_LEN][..self.max_headers]
290 } else {
291 &mut vec![MaybeUninit::uninit(); self.max_headers][..]
292 };
293
294 match self
295 .conf
296 .parse_response_with_uninit_headers(&mut out, &buf, uninit_headers)?
297 {
298 Status::Complete(n) if n == buf.len() => {}
299 _ => panic!("failed to complete parsing"),
300 }
301
302 let mut res = Response::new(());
303 *res.version_mut() = match out.version {
304 Some(9) => return Err(Error::UnsupportedVersion(Version::HTTP_09)),
305 Some(0) => return Err(Error::UnsupportedVersion(Version::HTTP_10)),
306 Some(1) => Version::HTTP_11,
307 _ => return Err(Error::Parse(httparse::Error::Version)),
308 };
309
310 *res.status_mut() =
311 StatusCode::from_u16(out.code.unwrap_or_default()).expect("valid status code");
312
313 *res.headers_mut() = {
314 let entry = |header: Header<'_>| {
315 let name =
316 HeaderName::from_bytes(header.name.as_bytes()).expect("valid header name");
317 let value = HeaderValue::from_maybe_shared(buf.slice_ref(header.value))
318 .expect("valid header value");
319
320 (name, value)
321 };
322
323 out.headers.iter().copied().map(entry).collect()
324 };
325
326 Ok(res)
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use {super::*, futures_lite::future};
333
334 impl<I> Handler<I> {
335 fn test(io: I) -> Self {
336 Self::new(io, ReadStrategy::default())
337 }
338 }
339
340 const RESPONSE: &[u8] = b"\
341 HTTP/1.1 200 OK\r\n\
342 date: mon, 27 jul 2009 12:28:53 gmt\r\n\
343 last-modified: wed, 22 jul 2009 19:15:56 gmt\r\n\
344 accept-ranges: bytes\r\n\
345 content-length: 4\r\n\
346 vary: accept-encoding\r\n\
347 content-type: text/plain\r\n\
348 \r\n\
349 body\
350 ";
351
352 fn header() -> &'static [u8] {
353 RESPONSE.strip_suffix(b"body").expect("strip body")
354 }
355
356 #[test]
357 fn read_head() -> Result<(), Error> {
358 let mut h = Handler::test(RESPONSE);
359 let head = future::block_on(h.read_header())?;
360 assert_eq!(head, header());
361 Ok(())
362 }
363
364 #[test]
365 fn parse_head() -> Result<(), Error> {
366 use http::{HeaderMap, HeaderName, HeaderValue, StatusCode, Version};
367
368 let res = header();
369 let head = Parser::new().parse_header(Bytes::copy_from_slice(res))?;
370 assert_eq!(head.status(), StatusCode::OK);
371 assert_eq!(head.version(), Version::HTTP_11);
372
373 let headers = [
374 ("date", "mon, 27 jul 2009 12:28:53 gmt"),
375 ("last-modified", "wed, 22 jul 2009 19:15:56 gmt"),
376 ("accept-ranges", "bytes"),
377 ("content-length", "4"),
378 ("vary", "accept-encoding"),
379 ("content-type", "text/plain"),
380 ];
381
382 let headers: HeaderMap = headers
383 .into_iter()
384 .map(|(name, value)| {
385 (
386 HeaderName::from_bytes(name.as_bytes()).expect("lowercased header name"),
387 HeaderValue::from_static(value),
388 )
389 })
390 .collect();
391
392 assert_eq!(head.headers(), &headers);
393 Ok(())
394 }
395
396 #[test]
397 fn parse_head_max_headers() -> Result<(), Error> {
398 use http::{StatusCode, Version};
399
400 let parser = Parser {
401 max_headers: 5,
402 ..Parser::new()
403 };
404
405 let res = header();
406 let e = parser
407 .parse_header(Bytes::copy_from_slice(res))
408 .expect_err("too many headers");
409
410 assert!(matches!(e, Error::Parse(httparse::Error::TooManyHeaders)));
411
412 let parser = Parser {
413 max_headers: 6,
414 ..Parser::new()
415 };
416
417 let head = parser.parse_header(Bytes::copy_from_slice(res))?;
418 assert_eq!(head.status(), StatusCode::OK);
419 assert_eq!(head.version(), Version::HTTP_11);
420 Ok(())
421 }
422
423 #[test]
424 fn read_body() -> Result<(), Error> {
425 const BODY: &[u8] = b"Hello, World!";
426
427 let mut h = Handler::test(BODY);
428 let mut remaining = BODY.len();
429 let body = future::block_on(h.read_body(&mut remaining))?;
430 assert_eq!(body, BODY);
431 assert_eq!(remaining, 0);
432 Ok(())
433 }
434
435 #[test]
436 fn read_response() -> Result<(), Error> {
437 let mut h = Handler::test(RESPONSE);
438 let head = future::block_on(h.read_header())?;
439 assert_eq!(head, header());
440
441 let mut remaining = 4;
442 let body = future::block_on(h.read_body(&mut remaining))?;
443 assert_eq!(body, "body".as_bytes());
444 assert_eq!(remaining, 0);
445 assert!(h.read_buf.is_empty());
446 Ok(())
447 }
448
449 #[test]
450 fn read_partial() -> Result<(), Error> {
451 use crate::test;
452
453 let cases = [
454 (["_", "_", "A"].as_slice(), "A", "__A"),
455 (&["_", "_", "A", "_"], "A", "__A"),
456 (&["A", "B"], "AB", "AB"),
457 (&["A", "B", "C"], "ABC", "ABC"),
458 (&["___A", "B", "___"], "AB", "___AB"),
459 (&["___A", "B", "C___"], "ABC", "___ABC"),
460 (&["_", "__", "_A", "B", "C___"], "ABC", "____ABC"),
461 (&["_", "__", "_A", "B", "C___"], "A", "____A"),
462 (&["AA", "_BA_", "_A", "B", "C___"], "AB", "AA_BA__AB"),
463 ];
464
465 for (reads, until, actual) in cases {
466 let parts = test::parts(reads.iter().copied().map(str::as_bytes));
467 let mut h = Handler::test(parts);
468 let bytes = future::block_on(h.read_until(until.as_bytes()))?;
469 assert_eq!(bytes, actual);
470 }
471
472 Ok(())
473 }
474
475 #[test]
476 fn write_head() -> Result<(), Error> {
477 use http::{HeaderValue, Method, Uri, Version};
478
479 const REQUEST: &[u8] = b"\
480 GET /get HTTP/1.1\r\n\
481 name: value\r\n\
482 \r\n\
483 ";
484
485 let mut req = Request::new(());
486 *req.method_mut() = Method::GET;
487 *req.uri_mut() = Uri::from_static("/get");
488 *req.version_mut() = Version::HTTP_11;
489 req.headers_mut()
490 .append("name", HeaderValue::from_static("value"));
491
492 let mut write = vec![];
493 let mut h = Handler::test(&mut write);
494 future::block_on(h.write_header(&req))?;
495 assert_eq!(write, REQUEST);
496 Ok(())
497 }
498
499 #[test]
500 fn write_head_empty_path() -> Result<(), Error> {
501 use http::{Method, Uri, Version};
502
503 const REQUEST: &[u8] = b"\
504 GET / HTTP/1.1\r\n\
505 \r\n\
506 ";
507
508 let mut req = Request::new(());
509 *req.method_mut() = Method::GET;
510 *req.uri_mut() = Uri::from_static("s://a")
511 .into_parts()
512 .path_and_query
513 .expect("get empty path")
514 .into();
515
516 *req.version_mut() = Version::HTTP_11;
517
518 let mut write = vec![];
519 let mut h = Handler::test(&mut write);
520 future::block_on(h.write_header(&req))?;
521 assert_eq!(write, REQUEST);
522 Ok(())
523 }
524
525 #[test]
526 fn exact_read() -> Result<(), Error> {
527 let mut h = Handler::test(RESPONSE);
528 h.read_strategy = Strategy::Exact(2);
529
530 future::block_on(h.read_to_buf())?;
531 assert_eq!(h.read_strategy.next(), 2);
532 Ok(())
533 }
534
535 #[test]
536 fn adaptive_read() -> Result<(), Error> {
537 let mut h = Handler::test(RESPONSE);
538 h.read_strategy = Strategy::Adaptive { next: 1, max: 10 };
539
540 for n in [2, 4, 8, 10] {
541 future::block_on(h.read_to_buf())?;
542 assert_eq!(h.read_strategy.next(), n);
543 }
544
545 Ok(())
546 }
547}