1mod options;
2mod tests;
3mod upgrade;
4mod writebuf;
5mod zerocopy;
6
7pub use options::*;
8pub use upgrade::*;
9pub use zerocopy::*;
10
11#[cfg(unix)]
12pub(crate) type RawHandle = std::os::fd::RawFd;
13#[cfg(windows)]
14pub(crate) type RawHandle = std::os::windows::io::RawHandle;
15
16use std::{
17 io::IoSlice,
18 mem::MaybeUninit,
19 pin::Pin,
20 str::FromStr,
21 task::{Context, Poll},
22 time::UNIX_EPOCH,
23};
24
25use async_channel::Receiver;
26use bytes::{Buf, Bytes, BytesMut};
27use futures_util::stream::Stream;
28use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, Uri, Version};
29use http_body::Body;
30use http_body_util::{BodyExt, Empty};
31use memchr::{memchr3_iter, memmem};
32use tokio::io::{AsyncReadExt, AsyncWriteExt};
33use tokio_util::sync::CancellationToken;
34
35use crate::{h1::writebuf::WriteBuf, EarlyHints, HttpProtocol, Incoming};
36
37const HEX_DIGITS: &[u8; 16] = b"0123456789ABCDEF";
38const WRITE_BUF_BATCH_THRESHOLD: usize = 16384;
39
40pub struct Http1<Io> {
72 io: Io,
73 options: options::Http1Options,
74 cancel_token: Option<CancellationToken>,
75 parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]>,
76 date_header_value_cached: Option<(String, std::time::SystemTime)>,
77 cached_headers: Option<HeaderMap>,
78 read_buf: BytesMut,
79 response_head_buf: Vec<u8>,
80 write_buf: WriteBuf,
81}
82
83#[cfg(all(target_os = "linux", feature = "h1-zerocopy"))]
84impl<Io> Http1<Io>
85where
86 for<'a> Io: tokio::io::AsyncRead
87 + tokio::io::AsyncWrite
88 + vibeio::io::AsInnerRawHandle<'a>
89 + Unpin
90 + 'static,
91{
92 #[inline]
103 pub fn zerocopy(self) -> Http1Zerocopy<Io> {
104 Http1Zerocopy { inner: self }
105 }
106}
107
108impl<Io> Http1<Io>
109where
110 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
111{
112 #[inline]
123 pub fn new(io: Io, options: options::Http1Options) -> Self {
124 let read_buf = BytesMut::with_capacity(options.max_header_size);
126 let parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]> =
127 Box::new_uninit_slice(options.max_header_count);
128 Self {
129 io,
130 options,
131 cancel_token: None,
132 parsed_headers,
133 date_header_value_cached: None,
134 cached_headers: None,
135 read_buf,
136 response_head_buf: Vec::with_capacity(1024),
137 write_buf: WriteBuf::new(),
138 }
139 }
140
141 #[inline]
142 fn get_date_header_value(&mut self) -> &str {
143 let now = std::time::SystemTime::now();
144 if self.date_header_value_cached.as_ref().is_none_or(|v| {
145 v.1.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
146 != now.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
147 }) {
148 let value = httpdate::fmt_http_date(now).to_string();
149 self.date_header_value_cached = Some((value, now));
150 }
151 self.date_header_value_cached
152 .as_ref()
153 .map(|v| v.0.as_str())
154 .unwrap_or("")
155 }
156
157 #[inline]
167 pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
168 self.cancel_token = Some(token);
169 self
170 }
171
172 #[inline]
173 async fn fill_buf(&mut self) -> Result<usize, std::io::Error> {
174 if self.read_buf.remaining() < 1024 {
175 self.read_buf.reserve(1024);
176 }
177 let spare_capacity = self.read_buf.spare_capacity_mut();
178 let n = self
180 .io
181 .read(unsafe {
182 &mut *std::ptr::slice_from_raw_parts_mut(
183 spare_capacity.as_mut_ptr() as *mut u8,
184 spare_capacity.len(),
185 )
186 })
187 .await?;
188 if n == 0 {
189 return Ok(0);
190 }
191 unsafe { self.read_buf.set_len(self.read_buf.len() + n) };
192 Ok(n)
193 }
194
195 #[inline]
196 async fn read_body_fn(
197 &mut self,
198 body_tx: &async_channel::Sender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
199 content_length: u64,
200 ) -> Result<(), std::io::Error> {
201 let mut remaining = content_length;
202 let mut just_started = true;
203 while remaining > 0 {
204 let have_to_read_buf = !just_started || self.read_buf.is_empty();
205 just_started = false;
206 if have_to_read_buf {
207 let n = self.fill_buf().await?;
208 if n == 0 {
209 break;
210 }
211 }
212 let chunk = self
213 .read_buf
214 .split_to(
215 self.read_buf
216 .len()
217 .min(remaining.min(usize::MAX as u64) as usize),
218 )
219 .freeze();
220 remaining -= chunk.len() as u64;
221
222 let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
223 }
224 body_tx.close(); Ok(())
226 }
227
228 #[inline]
229 async fn read_body_chunk(
230 &mut self,
231 would_have_trailers: bool,
232 ) -> Result<bytes::Bytes, std::io::Error> {
233 let len = {
234 let mut len_buf_pos: usize = 0;
236 let mut just_started = true;
237 loop {
238 if len_buf_pos >= 48 {
239 return Err(std::io::Error::new(
240 std::io::ErrorKind::InvalidData,
241 "chunk length buffer overflow",
242 ));
243 }
244
245 let begin_search = len_buf_pos.saturating_sub(1);
246
247 let have_to_read_buf = !just_started || self.read_buf.is_empty();
248 just_started = false;
249 if have_to_read_buf {
250 let n = self.fill_buf().await?;
251 if n == 0 {
252 return Err(std::io::Error::new(
253 std::io::ErrorKind::UnexpectedEof,
254 "unexpected EOF",
255 ));
256 }
257 len_buf_pos += n;
258 } else {
259 len_buf_pos += self.read_buf.len();
260 }
261
262 if let Some(pos) =
263 memmem::find(&self.read_buf[begin_search..len_buf_pos.min(48)], b"\r\n")
264 {
265 let numbers =
266 std::str::from_utf8(&self.read_buf[begin_search..begin_search + pos])
267 .map_err(|_| {
268 std::io::Error::new(
269 std::io::ErrorKind::InvalidData,
270 "invalid chunk length",
271 )
272 })?;
273 let len = usize::from_str_radix(numbers, 16).map_err(|_| {
274 std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid chunk length")
275 })?;
276 self.read_buf.advance(begin_search + pos + 2);
278 break len;
279 }
280 }
281 };
282 let mut read = 0;
284 if len == 0 && would_have_trailers {
285 return Ok(bytes::Bytes::new()); }
287 let mut just_started = true;
288 while read < len + 2 {
290 let have_to_read_buf = !just_started || self.read_buf.is_empty();
291 just_started = false;
292 if have_to_read_buf {
293 let n = self.fill_buf().await?;
294 if n == 0 {
295 return Err(std::io::Error::new(
296 std::io::ErrorKind::UnexpectedEof,
297 "unexpected EOF",
298 ));
299 }
300 read += n;
301 } else {
302 read += self.read_buf.len();
303 }
304 }
305 let chunk = self.read_buf.split_to(len).freeze();
306 self.read_buf.advance(2); Ok(chunk)
308 }
309
310 #[inline]
311 async fn read_trailers(&mut self) -> Result<Option<HeaderMap>, std::io::Error> {
312 let mut bytes_read: usize = 0;
314 let mut just_started = true;
315 while bytes_read < self.options.max_header_size {
316 let old_bytes_read = bytes_read;
317 let begin_search = old_bytes_read.saturating_sub(3);
318
319 let have_to_read_buf = !just_started || self.read_buf.is_empty();
320 just_started = false;
321 if have_to_read_buf {
322 let n = self.fill_buf().await?;
323 if n == 0 {
324 return Err(std::io::Error::new(
325 std::io::ErrorKind::UnexpectedEof,
326 "unexpected EOF",
327 ));
328 }
329 bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
330 } else {
331 bytes_read =
332 (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
333 }
334
335 if bytes_read > 2 && self.read_buf[0] == b'\r' && self.read_buf[1] == b'\n' {
336 return Ok(None);
338 }
339
340 if let Some(separator_index) =
341 memmem::find(&self.read_buf[begin_search..bytes_read], b"\r\n\r\n")
342 {
343 let to_parse_length = begin_search + separator_index + 4;
344 let buf_ro = self.read_buf.split_to(to_parse_length).freeze();
345
346 let mut httparse_trailers =
348 vec![httparse::EMPTY_HEADER; self.options.max_header_count].into_boxed_slice();
349 let status = httparse::parse_headers(&buf_ro, &mut httparse_trailers)
350 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
351 if let httparse::Status::Complete((_, trailers)) = status {
352 let mut trailers_constructed = HeaderMap::new();
353 for header in trailers {
354 if header == &httparse::EMPTY_HEADER {
355 break;
357 }
358 let name = HeaderName::from_bytes(header.name.as_bytes())
359 .map_err(|e| std::io::Error::other(e.to_string()))?;
360 let value_start = header.value.as_ptr() as usize - buf_ro.as_ptr() as usize;
361 let value_len = header.value.len();
362 let value = unsafe {
364 HeaderValue::from_maybe_shared_unchecked(
365 buf_ro.slice(value_start..(value_start + value_len)),
366 )
367 };
368 trailers_constructed.append(name, value);
369 }
370
371 return Ok(Some(trailers_constructed));
372 } else {
373 return Err(std::io::Error::new(
374 std::io::ErrorKind::InvalidInput,
375 "trailer headers incomplete",
376 ));
377 }
378 }
379 }
380 Err(std::io::Error::new(
381 std::io::ErrorKind::InvalidData,
382 "request too large",
383 ))
384 }
385
386 #[inline]
387 async fn read_chunked_body_fn(
388 &mut self,
389 body_tx: &async_channel::Sender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
390 would_have_trailers: bool,
391 ) -> Result<(), std::io::Error> {
392 loop {
393 let chunk = self.read_body_chunk(would_have_trailers).await?;
394 if chunk.is_empty() {
395 break;
396 }
397
398 let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
399 }
400 if would_have_trailers {
401 let trailers = self.read_trailers().await?;
403 if let Some(trailers) = trailers {
404 let _ = body_tx.send(Ok(http_body::Frame::trailers(trailers))).await;
405 }
406 }
407 body_tx.close(); Ok(())
409 }
410
411 #[inline]
412 async fn read_request(
413 &mut self,
414 ) -> Result<
415 Option<(
416 Request<Incoming>,
417 async_channel::Sender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
418 )>,
419 std::io::Error,
420 > {
421 let (request, body_tx) = {
423 let Some((head, headers)) = self.get_head().await? else {
424 return Ok(None);
425 };
426 let headers = unsafe {
428 std::mem::transmute::<
429 &mut [MaybeUninit<httparse::Header<'static>>],
430 &mut [MaybeUninit<httparse::Header<'_>>],
431 >(headers)
432 };
433 let mut req = httparse::Request::new(&mut []);
434 let status = req
435 .parse_with_uninit_headers(&head, headers)
436 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
437 if status.is_partial() {
438 return Err(std::io::Error::new(
439 std::io::ErrorKind::InvalidData,
440 "partial request head",
441 ));
442 }
443
444 let (body_tx, body_rx) = async_channel::bounded(2);
446 let request_body = Http1Body {
447 inner: Box::pin(body_rx),
448 };
449 let mut request = Request::new(Incoming::H1(request_body));
450 match req.version {
451 Some(0) => *request.version_mut() = http::Version::HTTP_10,
452 Some(1) => *request.version_mut() = http::Version::HTTP_11,
453 _ => *request.version_mut() = http::Version::HTTP_11,
454 };
455 if let Some(method) = req.method {
456 *request.method_mut() = Method::from_bytes(method.as_bytes())
457 .map_err(|e| std::io::Error::other(e.to_string()))?;
458 }
459 if let Some(path) = req.path {
460 *request.uri_mut() =
461 Uri::from_str(path).map_err(|e| std::io::Error::other(e.to_string()))?;
462 }
463 let mut header_map = self.cached_headers.take().unwrap_or_default();
464 header_map.clear();
465 let additional_capacity = req.headers.len().saturating_sub(header_map.capacity());
466 if additional_capacity > 0 {
467 header_map.reserve(additional_capacity);
468 }
469 for header in req.headers {
470 if header == &httparse::EMPTY_HEADER {
471 break;
473 }
474 let name = HeaderName::from_bytes(header.name.as_bytes())
475 .map_err(|e| std::io::Error::other(e.to_string()))?;
476 let value_start = header.value.as_ptr() as usize - head.as_ptr() as usize;
477 let value_len = header.value.len();
478 let value = unsafe {
480 HeaderValue::from_maybe_shared_unchecked(
481 head.slice(value_start..(value_start + value_len)),
482 )
483 };
484 header_map.append(name, value);
485 }
486 *request.headers_mut() = header_map;
487
488 (request, body_tx)
489 };
490 Ok(Some((request, body_tx)))
491 }
492
493 #[inline]
494 async fn get_head(
495 &mut self,
496 ) -> Result<Option<(Bytes, &mut [MaybeUninit<httparse::Header<'static>>])>, std::io::Error>
497 {
498 let mut request_line_read = false;
499 let mut bytes_read: usize = 0;
500 let mut whitespace_trimmed = None;
501 let mut just_started = true;
502 while bytes_read < self.options.max_header_size {
503 let old_bytes_read = bytes_read;
504 let begin_search = old_bytes_read.saturating_sub(3);
505
506 let have_to_read_buf = !just_started || self.read_buf.is_empty();
507 just_started = false;
508 if have_to_read_buf {
509 let n = self.fill_buf().await?;
510 if n == 0 {
511 if whitespace_trimmed.is_none() {
512 return Ok(None);
513 }
514 return Err(std::io::Error::new(
515 std::io::ErrorKind::UnexpectedEof,
516 "unexpected EOF",
517 ));
518 }
519 bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
520 } else {
521 bytes_read =
522 (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
523 }
524
525 if whitespace_trimmed.is_none() {
526 whitespace_trimmed = self.read_buf[old_bytes_read..bytes_read]
527 .iter()
528 .position(|b| !b.is_ascii_whitespace());
529 }
530
531 if let Some(whitespace_trimmed) = whitespace_trimmed {
532 if !request_line_read {
534 let memchr = memchr3_iter(
535 b' ',
536 b'\r',
537 b'\n',
538 &self.read_buf[whitespace_trimmed..bytes_read],
539 );
540 let mut spaces = 0;
541 for separator_index in memchr {
542 if self.read_buf[whitespace_trimmed + separator_index] == b' ' {
543 if spaces >= 2 {
544 return Err(std::io::Error::new(
545 std::io::ErrorKind::InvalidInput,
546 "bad request first line",
547 ));
548 }
549 spaces += 1;
550 } else if spaces == 2 {
551 request_line_read = true;
552 break;
553 } else {
554 return Err(std::io::Error::new(
555 std::io::ErrorKind::InvalidInput,
556 "bad request first line",
557 ));
558 }
559 }
560 }
561
562 if request_line_read {
563 let begin_search = begin_search.max(whitespace_trimmed);
564 if let Some((separator_index, separator_len)) =
565 search_header_body_separator(&self.read_buf[begin_search..bytes_read])
566 {
567 let to_parse_length =
568 begin_search + separator_index + separator_len - whitespace_trimmed;
569 self.read_buf.advance(whitespace_trimmed);
570 let head = self.read_buf.split_to(to_parse_length);
571 return Ok(Some((head.freeze(), &mut self.parsed_headers)));
572 }
573 }
574 }
575 }
576 Err(std::io::Error::new(
577 std::io::ErrorKind::InvalidData,
578 "request too large",
579 ))
580 }
581
582 #[inline]
583 async fn write_response<Z, ZFut>(
584 &mut self,
585 mut response: Response<
586 impl Body<Data = bytes::Bytes, Error = impl std::error::Error> + Unpin,
587 >,
588 version: Version,
589 write_trailers: bool,
590 zerocopy_fn: Option<Z>,
591 ) -> Result<(), std::io::Error>
592 where
593 Z: FnMut(RawHandle, &'static Io, u64) -> ZFut,
594 ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
595 {
596 if self.options.send_date_header {
598 response.headers_mut().insert(
599 header::DATE,
600 HeaderValue::from_str(self.get_date_header_value())
601 .map_err(|e| std::io::Error::other(e.to_string()))?,
602 );
603 }
604
605 if let Some(suggested_content_length) = response.body().size_hint().exact() {
607 let headers = response.headers_mut();
608 if !headers.contains_key(header::CONTENT_LENGTH) {
609 headers.insert(header::CONTENT_LENGTH, suggested_content_length.into());
610 }
611 }
612
613 let chunked = response
614 .headers()
615 .get(header::TRANSFER_ENCODING)
616 .map(|v| {
617 v.to_str().ok().is_some_and(|s| {
618 s.split(',')
619 .any(|s| s.trim().eq_ignore_ascii_case("chunked"))
620 })
621 })
622 .unwrap_or_else(|| {
623 response
624 .headers()
625 .get(header::CONTENT_LENGTH)
626 .and_then(|v| v.to_str().ok())
627 .is_none_or(|s| s.parse::<u64>().is_err())
628 });
629
630 if chunked {
631 response.headers_mut().insert(
632 header::TRANSFER_ENCODING,
633 HeaderValue::from_static("chunked"),
634 );
635 while response
636 .headers_mut()
637 .remove(header::CONTENT_LENGTH)
638 .is_some()
639 {}
640 }
641
642 let (parts, mut body) = response.into_parts();
643
644 self.response_head_buf.clear();
645 let estimated_head_len = 30 + parts.headers.len() * 30; if self.response_head_buf.capacity() < estimated_head_len {
647 self.response_head_buf
648 .reserve(estimated_head_len - self.response_head_buf.capacity());
649 }
650 let head = &mut self.response_head_buf;
651 if version == Version::HTTP_10 {
652 head.extend_from_slice(b"HTTP/1.0 ");
653 } else {
654 head.extend_from_slice(b"HTTP/1.1 ");
655 }
656 let status = parts.status;
657 head.extend_from_slice(status.as_str().as_bytes());
658 if let Some(canonical_reason) = status.canonical_reason() {
659 head.extend_from_slice(b" ");
660 head.extend_from_slice(canonical_reason.as_bytes());
661 }
662 head.extend_from_slice(b"\r\n");
663 for (name, value) in &parts.headers {
664 head.extend_from_slice(name.as_str().as_bytes());
665 head.extend_from_slice(b": ");
666 head.extend_from_slice(value.as_bytes());
667 head.extend_from_slice(b"\r\n");
668 }
669 head.extend_from_slice(b"\r\n");
670 unsafe {
671 self.write_buf.push(IoSlice::new(head));
672 }
673
674 if !chunked {
675 if let Some(content_length) = parts
676 .headers
677 .get(header::CONTENT_LENGTH)
678 .and_then(|v| v.to_str().ok())
679 .and_then(|s| s.parse::<u64>().ok())
680 {
681 if let Some(zero_copy) = parts.extensions.get::<ZerocopyResponse>() {
682 if let Some(mut zerocopy_fn) = zerocopy_fn {
683 unsafe {
685 self.write_buf
686 .flush(&mut self.io, self.options.enable_vectored_write)
687 .await?
688 };
689 zerocopy_fn(
690 zero_copy.handle,
691 unsafe { std::mem::transmute::<&Io, &'static Io>(&self.io) },
693 content_length,
694 )
695 .await?;
696 self.io.flush().await?;
697 let reclaimed_headers = parts.headers;
698 self.cached_headers = Some(reclaimed_headers);
699 return Ok(());
700 }
701 }
702 }
703 }
704
705 let mut trailers_written = false;
706 while let Some(chunk) = body.frame().await {
707 let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
708 match chunk.into_data() {
709 Ok(data) => {
710 if chunked {
711 let mut chunk_size_buf = [0u8; 18];
712 let chunk_size = write_chunk_size(&mut chunk_size_buf, data.len());
713 self.write_buf.push_copy(chunk_size);
714 self.write_buf.push_bytes(data);
715 unsafe {
716 self.write_buf.push(IoSlice::new(b"\r\n"));
717 }
718 } else {
719 self.write_buf.push_bytes(data);
720 }
721 while self.write_buf.len() >= WRITE_BUF_BATCH_THRESHOLD {
722 unsafe {
723 self.write_buf
724 .write(&mut self.io, self.options.enable_vectored_write)
725 .await?;
726 }
727 }
728 }
729 Err(chunk) => {
730 if let Ok(trailers) = chunk.into_trailers() {
731 if write_trailers {
732 unsafe {
733 self.write_buf.push(IoSlice::new(b"0\r\n"));
734 for (name, value) in &trailers {
735 self.write_buf.push_copy(name.as_str().as_bytes());
736 self.write_buf.push(IoSlice::new(b": "));
737 self.write_buf.push_copy(value.as_bytes());
738 self.write_buf.push(IoSlice::new(b"\r\n"));
739 }
740 self.write_buf.push(IoSlice::new(b"\r\n"));
741 }
742 trailers_written = true;
743 }
744 break;
745 }
746 }
747 };
748 }
749 if chunked && !trailers_written {
750 unsafe {
752 self.write_buf.push(IoSlice::new(b"0\r\n\r\n"));
753 }
754 }
755 unsafe {
756 self.write_buf
757 .flush(&mut self.io, self.options.enable_vectored_write)
758 .await?;
759 }
760 self.io.flush().await?;
761 let reclaimed_headers = parts.headers;
762 self.cached_headers = Some(reclaimed_headers);
763
764 Ok(())
765 }
766
767 #[inline]
768 async fn write_100_continue(&mut self, version: Version) -> Result<(), std::io::Error> {
769 if version == Version::HTTP_10 {
770 self.io.write_all(b"HTTP/1.0 100 Continue\r\n\r\n").await?;
771 } else {
772 self.io.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
773 }
774 self.io.flush().await?;
775
776 Ok(())
777 }
778
779 #[inline]
780 async fn write_early_hints(
781 &mut self,
782 version: Version,
783 headers: http::HeaderMap,
784 ) -> Result<(), std::io::Error> {
785 let mut head = Vec::new();
786 if version == Version::HTTP_10 {
787 head.extend_from_slice(b"HTTP/1.0 103 Early Hints\r\n");
788 } else {
789 head.extend_from_slice(b"HTTP/1.1 103 Early Hints\r\n");
790 }
791 let mut current_header_name = None;
792 for (name, value) in headers {
793 if let Some(name) = name {
794 current_header_name = Some(name);
795 };
796 if let Some(current_header_name) = ¤t_header_name {
797 head.extend_from_slice(current_header_name.as_str().as_bytes());
798 if value.is_empty() {
799 head.extend_from_slice(b":\r\n");
800 continue;
801 }
802 head.extend_from_slice(b": ");
803 head.extend_from_slice(value.as_bytes());
804 head.extend_from_slice(b"\r\n");
805 }
806 }
807 head.extend_from_slice(b"\r\n");
808
809 self.io.write_all(&head).await?;
810
811 Ok(())
812 }
813
814 #[inline]
815 pub(crate) async fn handle_with_error_fn_and_zerocopy<
816 F,
817 Fut,
818 ResB,
819 ResBE,
820 ResE,
821 EF,
822 EFut,
823 EResB,
824 EResBE,
825 EResE,
826 ZF,
827 ZFut,
828 >(
829 mut self,
830 request_fn: F,
831 error_fn: EF,
832 mut zerocopy_fn: Option<ZF>,
833 ) -> Result<(), std::io::Error>
834 where
835 F: Fn(Request<Incoming>) -> Fut + 'static,
836 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
837 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
838 ResE: std::error::Error,
839 ResBE: std::error::Error,
840 EF: FnOnce(bool) -> EFut,
841 EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
842 EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin,
843 EResE: std::error::Error,
844 EResBE: std::error::Error,
845 ZF: FnMut(RawHandle, &'static Io, u64) -> ZFut,
846 ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
847 {
848 let mut keep_alive = true;
849
850 while keep_alive {
851 let (mut request, body_tx) = match if let Some(timeout) =
852 self.options.header_read_timeout
853 {
854 vibeio::time::timeout(timeout, self.read_request()).await
855 } else {
856 Ok(self.read_request().await)
857 } {
858 Ok(Ok(Some(d))) => d,
859 Ok(Ok(None)) => {
860 return Ok(());
861 }
862 Ok(Err(e)) => {
863 if let Ok(mut response) = error_fn(false).await {
865 response
866 .headers_mut()
867 .insert(header::CONNECTION, HeaderValue::from_static("close"));
868
869 let _ = self
870 .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
871 .await;
872 }
873 return Err(e);
874 }
875 Err(_) => {
876 if let Ok(mut response) = error_fn(true).await {
878 response
879 .headers_mut()
880 .insert(header::CONNECTION, HeaderValue::from_static("close"));
881
882 let _ = self
883 .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
884 .await;
885 }
886 return Err(std::io::Error::new(
887 std::io::ErrorKind::TimedOut,
888 "header read timeout",
889 ));
890 }
891 };
892
893 let connection_header_split = request
895 .headers()
896 .get(header::CONNECTION)
897 .and_then(|v| v.to_str().ok())
898 .map(|v| v.split(",").map(|v| v.trim()));
899 let is_connection_close = connection_header_split
900 .clone()
901 .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("close")));
902 let is_connection_keep_alive = connection_header_split
903 .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("keep-alive")));
904 keep_alive = !is_connection_close
905 && (is_connection_keep_alive || request.version() == http::Version::HTTP_11);
906
907 let version = request.version();
908
909 if self.options.send_continue_response {
911 let is_100_continue = request
912 .headers()
913 .get(header::EXPECT)
914 .and_then(|v| v.to_str().ok())
915 .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
916 if is_100_continue {
917 self.write_100_continue(version).await?;
918 }
919 }
920
921 let early_hints_fut = if self.options.enable_early_hints {
923 let (early_hints_tx, early_hints_rx) = async_channel::unbounded();
924 let early_hints = EarlyHints::new(early_hints_tx);
925 request.extensions_mut().insert(early_hints);
926 let mut_self = unsafe { std::mem::transmute::<&mut Self, &mut Self>(&mut self) };
930 Some(async {
931 let early_hints_rx = early_hints_rx;
932 while let Ok((headers, sender)) = early_hints_rx.recv().await {
933 sender
934 .into_inner()
935 .send(mut_self.write_early_hints(version, headers).await)
936 .ok();
937 }
938 futures_util::future::pending::<Result<(), std::io::Error>>().await
939 })
940 } else {
941 None
942 };
943
944 let content_length = request
946 .headers()
947 .get(header::CONTENT_LENGTH)
948 .and_then(|v| v.to_str().ok())
949 .and_then(|v| v.parse::<u64>().ok())
950 .unwrap_or(0);
951 let chunked = request
952 .headers()
953 .get(header::TRANSFER_ENCODING)
954 .and_then(|v| v.to_str().ok())
955 .is_some_and(|v| {
956 v.split(',')
957 .any(|v| v.trim().eq_ignore_ascii_case("chunked"))
958 });
959 let has_trailers = request
960 .headers()
961 .get(header::TRAILER)
962 .map(|v| v.to_str().ok().is_some_and(|s| !s.is_empty()))
963 .unwrap_or(false);
964 let write_trailers = request
965 .headers()
966 .get(header::TE)
967 .and_then(|v| v.to_str().ok())
968 .map(|v| {
969 v.split(',')
970 .any(|v| v.trim().eq_ignore_ascii_case("trailers"))
971 })
972 .unwrap_or(false);
973
974 let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
976 let upgrade = Upgrade::new(upgrade_rx);
977 let upgraded = upgrade.upgraded.clone();
978 request.extensions_mut().insert(upgrade);
979
980 let mut response = {
982 let read_body_fut = async {
983 if chunked {
984 self.read_chunked_body_fn(&body_tx, has_trailers).await
985 } else {
986 self.read_body_fn(&body_tx, content_length).await
987 }
988 };
989 let read_body_fut_pin = std::pin::pin!(read_body_fut);
990 let request_fut = request_fn(request);
991 let request_fut_pin = std::pin::pin!(request_fut);
992 let early_hints_fut: Pin<
993 Box<dyn std::future::Future<Output = Result<(), std::io::Error>>>,
994 > = if let Some(early_hints) = early_hints_fut {
995 Box::pin(early_hints)
996 } else {
997 Box::pin(futures_util::future::pending::<Result<(), std::io::Error>>())
998 };
999
1000 let select_read_body_either =
1001 futures_util::future::select(request_fut_pin, early_hints_fut);
1002 let select_either =
1003 futures_util::future::select(read_body_fut_pin, select_read_body_either).await;
1004
1005 let (response, body_fut) = match select_either {
1006 futures_util::future::Either::Left((result, request_fut)) => {
1007 result?;
1008 (
1009 match request_fut.await {
1010 futures_util::future::Either::Left((response, _)) => response,
1011 futures_util::future::Either::Right((_, _)) => unreachable!(),
1012 },
1013 None,
1014 )
1015 }
1016 futures_util::future::Either::Right((response, read_body_fut)) => (
1017 match response {
1018 futures_util::future::Either::Left((response, _)) => response,
1019 futures_util::future::Either::Right((_, _)) => unreachable!(),
1020 },
1021 Some(read_body_fut),
1022 ),
1023 };
1024
1025 if let Some(body_fut) = body_fut {
1027 body_fut.await?;
1028 }
1029
1030 response.map_err(|e| std::io::Error::other(e.to_string()))?
1031 };
1032
1033 let mut was_upgraded = false;
1034 if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
1035 was_upgraded = true;
1036 response
1037 .headers_mut()
1038 .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
1039 } else if keep_alive {
1040 if version == Version::HTTP_10
1041 || response.headers().contains_key(header::CONNECTION)
1042 {
1043 response
1044 .headers_mut()
1045 .insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
1046 }
1047 } else if version == Version::HTTP_11
1048 || response.headers().contains_key(header::CONNECTION)
1049 {
1050 response
1051 .headers_mut()
1052 .insert(header::CONNECTION, HeaderValue::from_static("close"));
1053 }
1054
1055 self.write_response(response, version, write_trailers, zerocopy_fn.as_mut())
1057 .await?;
1058
1059 if was_upgraded {
1060 let frozen_buf = self.read_buf.freeze();
1062 let _ = upgrade_tx.send(Upgraded::new(
1063 self.io,
1064 if frozen_buf.is_empty() {
1065 None
1066 } else {
1067 Some(frozen_buf)
1068 },
1069 ));
1070 return Ok(());
1071 }
1072
1073 if self.cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
1074 break;
1076 }
1077 }
1078 Ok(())
1079 }
1080}
1081
1082impl<Io> HttpProtocol for Http1<Io>
1083where
1084 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
1085{
1086 #[inline]
1087 fn handle_with_error_fn<F, Fut, ResB, ResBE, ResE, EF, EFut, EResB, EResBE, EResE>(
1088 self,
1089 request_fn: F,
1090 error_fn: EF,
1091 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1092 where
1093 F: Fn(Request<Incoming>) -> Fut + 'static,
1094 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
1095 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
1096 ResE: std::error::Error,
1097 ResBE: std::error::Error,
1098 EF: FnOnce(bool) -> EFut,
1099 EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
1100 EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin,
1101 EResE: std::error::Error,
1102 EResBE: std::error::Error,
1103 {
1104 #[allow(clippy::type_complexity)]
1105 let no_zerocopy: Option<
1106 Box<
1107 dyn FnMut(
1108 RawHandle,
1109 &Io,
1110 u64,
1111 ) -> Box<
1112 dyn std::future::Future<Output = Result<(), std::io::Error>>
1113 + Unpin
1114 + Send
1115 + Sync,
1116 >,
1117 >,
1118 > = None;
1119 self.handle_with_error_fn_and_zerocopy(request_fn, error_fn, no_zerocopy)
1120 }
1121
1122 #[inline]
1123 fn handle<F, Fut, ResB, ResBE, ResE>(
1124 self,
1125 request_fn: F,
1126 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1127 where
1128 F: Fn(Request<Incoming>) -> Fut + 'static,
1129 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
1130 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
1131 ResE: std::error::Error,
1132 ResBE: std::error::Error,
1133 {
1134 self.handle_with_error_fn(request_fn, |is_timeout| async move {
1135 let mut response = Response::builder();
1136 if is_timeout {
1137 response = response.status(http::StatusCode::REQUEST_TIMEOUT);
1138 } else {
1139 response = response.status(http::StatusCode::BAD_REQUEST);
1140 }
1141 response.body(Empty::new())
1142 })
1143 }
1144}
1145
1146pub(crate) struct Http1Body {
1147 #[allow(clippy::type_complexity)]
1148 inner: Pin<Box<Receiver<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>>>,
1149}
1150
1151impl Body for Http1Body {
1152 type Data = bytes::Bytes;
1153 type Error = std::io::Error;
1154
1155 #[inline]
1156 fn poll_frame(
1157 mut self: Pin<&mut Self>,
1158 cx: &mut Context<'_>,
1159 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
1160 match self.inner.as_mut().poll_next(cx) {
1161 Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
1162 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
1163 Poll::Ready(None) => Poll::Ready(None),
1164 Poll::Pending => Poll::Pending,
1165 }
1166 }
1167}
1168
1169#[inline]
1172fn search_header_body_separator(slice: &[u8]) -> Option<(usize, usize)> {
1173 if slice.len() < 2 {
1174 return None;
1176 }
1177 for (i, b) in slice.iter().copied().enumerate() {
1178 if b == b'\r' {
1179 if slice[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) {
1180 return Some((i, 4));
1181 }
1182 } else if b == b'\n' && slice.get(i + 1) == Some(&b'\n') {
1183 return Some((i, 2));
1184 }
1185 }
1186 None
1187}
1188
1189#[inline]
1191fn write_chunk_size(dst: &mut [u8; 18], len: usize) -> &[u8] {
1192 let mut n = len;
1193 let mut pos = dst.len() - 2;
1194 loop {
1195 pos -= 1;
1196 dst[pos] = HEX_DIGITS[n & 0xF];
1197 n >>= 4;
1198 if n == 0 {
1199 break;
1200 }
1201 }
1202 dst[dst.len() - 2] = b'\r';
1203 dst[dst.len() - 1] = b'\n';
1204 &dst[pos..]
1205}