1mod options;
2mod tests;
3mod writebuf;
4mod zerocopy;
5
6pub use options::*;
7pub use zerocopy::*;
8
9#[cfg(unix)]
10pub(crate) type RawHandle = std::os::fd::RawFd;
11#[cfg(windows)]
12pub(crate) type RawHandle = std::os::windows::io::RawHandle;
13
14use std::{
15 future::Future,
16 io::IoSlice,
17 mem::MaybeUninit,
18 pin::Pin,
19 str::FromStr,
20 sync::{
21 atomic::{AtomicBool, Ordering},
22 Arc,
23 },
24 task::{Context, Poll},
25 time::UNIX_EPOCH,
26};
27
28use bytes::{Buf, Bytes, BytesMut};
29use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, Uri, Version};
30use http_body::Body;
31use http_body_util::{BodyExt, Empty};
32use kanal::AsyncReceiver;
33use memchr::{memchr3_iter, memmem};
34use tokio::io::{AsyncReadExt, AsyncWriteExt};
35use tokio_util::sync::CancellationToken;
36
37use crate::{h1::writebuf::WriteBuf, EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded};
38
39const HEX_DIGITS: &[u8; 16] = b"0123456789ABCDEF";
40const WRITE_BUF_BATCH_THRESHOLD: usize = 16384;
41
42pub struct Http1<Io> {
74 io: Io,
75 options: options::Http1Options,
76 cancel_token: Option<CancellationToken>,
77 parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]>,
78 date_header_value_cached: Option<(String, std::time::SystemTime)>,
79 cached_headers: Option<HeaderMap>,
80 read_buf: BytesMut,
81 response_head_buf: Vec<u8>,
82 write_buf: WriteBuf,
83}
84
85#[cfg(all(target_os = "linux", feature = "h1-zerocopy"))]
86impl<Io> Http1<Io>
87where
88 for<'a> Io: tokio::io::AsyncRead
89 + tokio::io::AsyncWrite
90 + vibeio::io::AsInnerRawHandle<'a>
91 + Unpin
92 + 'static,
93{
94 #[inline]
105 pub fn zerocopy(self) -> Http1Zerocopy<Io> {
106 Http1Zerocopy { inner: self }
107 }
108}
109
110impl<Io> Http1<Io>
111where
112 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
113{
114 #[inline]
125 pub fn new(io: Io, options: options::Http1Options) -> Self {
126 let read_buf = BytesMut::with_capacity(options.max_header_size);
128 let parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]> =
129 Box::new_uninit_slice(options.max_header_count);
130 Self {
131 io,
132 options,
133 cancel_token: None,
134 parsed_headers,
135 date_header_value_cached: None,
136 cached_headers: None,
137 read_buf,
138 response_head_buf: Vec::with_capacity(1024),
139 write_buf: WriteBuf::new(),
140 }
141 }
142
143 #[inline]
144 fn get_date_header_value(&mut self) -> &str {
145 let now = std::time::SystemTime::now();
146 if self.date_header_value_cached.as_ref().is_none_or(|v| {
147 v.1.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
148 != now.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
149 }) {
150 let value = httpdate::fmt_http_date(now).to_string();
151 self.date_header_value_cached = Some((value, now));
152 }
153 self.date_header_value_cached
154 .as_ref()
155 .map(|v| v.0.as_str())
156 .unwrap_or("")
157 }
158
159 #[inline]
169 pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
170 self.cancel_token = Some(token);
171 self
172 }
173
174 #[inline]
175 async fn fill_buf(&mut self) -> Result<usize, std::io::Error> {
176 if self.read_buf.remaining() < 1024 {
177 self.read_buf.reserve(1024);
178 }
179 let spare_capacity = self.read_buf.spare_capacity_mut();
180 let n = self
182 .io
183 .read(unsafe {
184 &mut *std::ptr::slice_from_raw_parts_mut(
185 spare_capacity.as_mut_ptr() as *mut u8,
186 spare_capacity.len(),
187 )
188 })
189 .await?;
190 if n == 0 {
191 return Ok(0);
192 }
193 unsafe { self.read_buf.set_len(self.read_buf.len() + n) };
194 Ok(n)
195 }
196
197 #[inline]
198 async fn read_body_fn(
199 &mut self,
200 body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
201 content_length: u64,
202 send_continue_body: &Option<Arc<AtomicBool>>,
203 continue_sent: &mut bool,
204 version: Version,
205 ) -> Result<(), std::io::Error> {
206 let mut remaining = content_length;
207 let mut just_started = true;
208 while remaining > 0 {
209 if !*continue_sent
210 && send_continue_body
211 .as_ref()
212 .is_some_and(|b| b.load(Ordering::Relaxed))
213 {
214 *continue_sent = true;
215 self.write_100_continue(version).await?;
216 }
217
218 let have_to_read_buf = !just_started || self.read_buf.is_empty();
219 just_started = false;
220 if have_to_read_buf {
221 let n = self.fill_buf().await?;
222 if n == 0 {
223 break;
224 }
225 }
226 let chunk = self
227 .read_buf
228 .split_to(
229 self.read_buf
230 .len()
231 .min(remaining.min(usize::MAX as u64) as usize),
232 )
233 .freeze();
234 remaining -= chunk.len() as u64;
235
236 let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
237 }
238 Ok(())
239 }
240
241 #[inline]
242 async fn read_body_chunk(
243 &mut self,
244 would_have_trailers: bool,
245 send_continue_body: &Option<Arc<AtomicBool>>,
246 continue_sent: &mut bool,
247 version: Version,
248 ) -> Result<bytes::Bytes, std::io::Error> {
249 let len = {
250 let mut len_buf_pos: usize = 0;
252 let mut just_started = true;
253 loop {
254 if len_buf_pos >= 48 {
255 return Err(std::io::Error::new(
256 std::io::ErrorKind::InvalidData,
257 "chunk length buffer overflow",
258 ));
259 }
260
261 let begin_search = len_buf_pos.saturating_sub(1);
262
263 let have_to_read_buf = !just_started || self.read_buf.is_empty();
264 just_started = false;
265 if have_to_read_buf {
266 if !*continue_sent
267 && send_continue_body
268 .as_ref()
269 .is_some_and(|b| b.load(Ordering::Relaxed))
270 {
271 *continue_sent = true;
272 self.write_100_continue(version).await?;
273 }
274 let n = self.fill_buf().await?;
275 if n == 0 {
276 return Err(std::io::Error::new(
277 std::io::ErrorKind::UnexpectedEof,
278 "unexpected EOF",
279 ));
280 }
281 len_buf_pos += n;
282 } else {
283 len_buf_pos += self.read_buf.len();
284 }
285
286 if let Some(pos) =
287 memmem::find(&self.read_buf[begin_search..len_buf_pos.min(48)], b"\r\n")
288 {
289 let numbers = std::str::from_utf8(&self.read_buf[..begin_search + pos])
290 .map_err(|_| {
291 std::io::Error::new(
292 std::io::ErrorKind::InvalidData,
293 "invalid chunk length",
294 )
295 })?;
296 let len = usize::from_str_radix(numbers, 16).map_err(|_| {
297 std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid chunk length")
298 })?;
299 self.read_buf.advance(begin_search + pos + 2);
301 break len;
302 }
303 }
304 };
305 let mut read = 0;
307 if len == 0 && would_have_trailers {
308 return Ok(bytes::Bytes::new()); }
310 let mut just_started = true;
311 let Some(len_plus_two) = len.checked_add(2) else {
313 return Err(std::io::Error::new(
314 std::io::ErrorKind::InvalidData,
315 "chunk length too large",
316 ));
317 };
318 while read < len_plus_two {
319 let have_to_read_buf = !just_started || self.read_buf.is_empty();
320 just_started = false;
321 if have_to_read_buf {
322 if !*continue_sent
323 && send_continue_body
324 .as_ref()
325 .is_some_and(|b| b.load(Ordering::Relaxed))
326 {
327 *continue_sent = true;
328 self.write_100_continue(version).await?;
329 }
330 let n = self.fill_buf().await?;
331 if n == 0 {
332 return Err(std::io::Error::new(
333 std::io::ErrorKind::UnexpectedEof,
334 "unexpected EOF",
335 ));
336 }
337 read += n;
338 } else {
339 read += self.read_buf.len();
340 }
341 }
342 let chunk = self.read_buf.split_to(len).freeze();
343 self.read_buf.advance(2); Ok(chunk)
345 }
346
347 #[inline]
348 async fn read_trailers(&mut self) -> Result<Option<HeaderMap>, std::io::Error> {
349 let mut bytes_read: usize = 0;
351 let mut just_started = true;
352 while bytes_read < self.options.max_header_size {
353 let old_bytes_read = bytes_read;
354 let begin_search = old_bytes_read.saturating_sub(3);
355
356 let have_to_read_buf = !just_started || self.read_buf.is_empty();
357 just_started = false;
358 if have_to_read_buf {
359 let n = self.fill_buf().await?;
360 if n == 0 {
361 return Err(std::io::Error::new(
362 std::io::ErrorKind::UnexpectedEof,
363 "unexpected EOF",
364 ));
365 }
366 bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
367 } else {
368 bytes_read =
369 (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
370 }
371
372 if bytes_read >= 2 && self.read_buf[0] == b'\r' && self.read_buf[1] == b'\n' {
373 return Ok(None);
375 }
376
377 if let Some(separator_index) =
378 memmem::find(&self.read_buf[begin_search..bytes_read], b"\r\n\r\n")
379 {
380 let to_parse_length = begin_search + separator_index + 4;
381 let buf_ro = self.read_buf.split_to(to_parse_length).freeze();
382
383 let mut httparse_trailers =
385 vec![httparse::EMPTY_HEADER; self.options.max_header_count].into_boxed_slice();
386 let status = httparse::parse_headers(&buf_ro, &mut httparse_trailers)
387 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
388 if let httparse::Status::Complete((_, trailers)) = status {
389 let mut trailers_constructed = HeaderMap::new();
390 for header in trailers {
391 if header == &httparse::EMPTY_HEADER {
392 break;
394 }
395 let name = HeaderName::from_bytes(header.name.as_bytes())
396 .map_err(|e| std::io::Error::other(e.to_string()))?;
397 let value_start = header.value.as_ptr() as usize - buf_ro.as_ptr() as usize;
398 let value_len = header.value.len();
399 let value = unsafe {
401 HeaderValue::from_maybe_shared_unchecked(
402 buf_ro.slice(value_start..(value_start + value_len)),
403 )
404 };
405 trailers_constructed.append(name, value);
406 }
407
408 return Ok(Some(trailers_constructed));
409 } else {
410 return Err(std::io::Error::new(
411 std::io::ErrorKind::InvalidInput,
412 "trailer headers incomplete",
413 ));
414 }
415 }
416 }
417 Err(std::io::Error::new(
418 std::io::ErrorKind::InvalidData,
419 "request too large",
420 ))
421 }
422
423 #[inline]
424 async fn read_chunked_body_fn(
425 &mut self,
426 body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
427 would_have_trailers: bool,
428 send_continue_body: &Option<Arc<AtomicBool>>,
429 continue_sent: &mut bool,
430 version: Version,
431 ) -> Result<(), std::io::Error> {
432 loop {
433 let chunk = self
434 .read_body_chunk(
435 would_have_trailers,
436 send_continue_body,
437 continue_sent,
438 version,
439 )
440 .await?;
441 if chunk.is_empty() {
442 break;
443 }
444
445 let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
446 }
447 if would_have_trailers {
448 let trailers = self.read_trailers().await?;
450 if let Some(trailers) = trailers {
451 let _ = body_tx.send(Ok(http_body::Frame::trailers(trailers))).await;
452 }
453 }
454 Ok(())
455 }
456
457 #[inline]
458 async fn read_request(
459 &mut self,
460 ) -> Result<
461 Option<(
462 Request<Incoming>,
463 kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
464 Option<Arc<AtomicBool>>,
465 )>,
466 std::io::Error,
467 > {
468 let (request, body_tx, send_continue_body) = {
470 let Some((head, headers)) = self.get_head().await? else {
471 return Ok(None);
472 };
473 let headers = unsafe {
475 std::mem::transmute::<
476 &mut [MaybeUninit<httparse::Header<'static>>],
477 &mut [MaybeUninit<httparse::Header<'_>>],
478 >(headers)
479 };
480 let mut req = httparse::Request::new(&mut []);
481 let status = req
482 .parse_with_uninit_headers(&head, headers)
483 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
484 if status.is_partial() {
485 return Err(std::io::Error::new(
486 std::io::ErrorKind::InvalidData,
487 "partial request head",
488 ));
489 }
490
491 let (body_tx, body_rx) = kanal::bounded_async(2);
493
494 let is_100_continue = self.options.send_continue_response
496 && req.headers.iter().any(|h| {
497 h.name.eq_ignore_ascii_case("expect")
498 && h.value.eq_ignore_ascii_case(b"100-continue")
499 });
500 let send_continue_body = is_100_continue.then(|| Arc::new(AtomicBool::new(false)));
501
502 let request_body = Http1Body {
503 inner: Box::pin(body_rx),
504 send_continue_body: send_continue_body.clone(),
505 };
506 let mut request = Request::new(Incoming::H1(request_body));
507 match req.version {
508 Some(0) => *request.version_mut() = http::Version::HTTP_10,
509 Some(1) => *request.version_mut() = http::Version::HTTP_11,
510 _ => *request.version_mut() = http::Version::HTTP_11,
511 };
512 if let Some(method) = req.method {
513 *request.method_mut() = Method::from_bytes(method.as_bytes())
514 .map_err(|e| std::io::Error::other(e.to_string()))?;
515 }
516 if let Some(path) = req.path {
517 *request.uri_mut() =
518 Uri::from_str(path).map_err(|e| std::io::Error::other(e.to_string()))?;
519 }
520 let mut header_map = self.cached_headers.take().unwrap_or_default();
521 header_map.clear();
522 let additional_capacity = req.headers.len().saturating_sub(header_map.capacity());
523 if additional_capacity > 0 {
524 header_map.reserve(additional_capacity);
525 }
526 for header in req.headers {
527 if header == &httparse::EMPTY_HEADER {
528 break;
530 }
531 let name = HeaderName::from_bytes(header.name.as_bytes())
532 .map_err(|e| std::io::Error::other(e.to_string()))?;
533 let value_start = header.value.as_ptr() as usize - head.as_ptr() as usize;
534 let value_len = header.value.len();
535 let value = unsafe {
537 HeaderValue::from_maybe_shared_unchecked(
538 head.slice(value_start..(value_start + value_len)),
539 )
540 };
541 header_map.append(name, value);
542 }
543 *request.headers_mut() = header_map;
544
545 (request, body_tx, send_continue_body)
546 };
547 Ok(Some((request, body_tx, send_continue_body)))
548 }
549
550 #[inline]
551 async fn get_head(
552 &mut self,
553 ) -> Result<Option<(Bytes, &mut [MaybeUninit<httparse::Header<'static>>])>, std::io::Error>
554 {
555 let mut request_line_read = false;
556 let mut bytes_read: usize = 0;
557 let mut whitespace_trimmed = None;
558 let mut just_started = true;
559 while bytes_read < self.options.max_header_size {
560 let old_bytes_read = bytes_read;
561 let begin_search = old_bytes_read.saturating_sub(3);
562
563 let have_to_read_buf = !just_started || self.read_buf.is_empty();
564 just_started = false;
565 if have_to_read_buf {
566 let n = self.fill_buf().await?;
567 if n == 0 {
568 if whitespace_trimmed.is_none() {
569 return Ok(None);
570 }
571 return Err(std::io::Error::new(
572 std::io::ErrorKind::UnexpectedEof,
573 "unexpected EOF",
574 ));
575 }
576 bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
577 } else {
578 bytes_read =
579 (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
580 }
581
582 if whitespace_trimmed.is_none() {
583 whitespace_trimmed = self.read_buf[old_bytes_read..bytes_read]
584 .iter()
585 .position(|b| !b.is_ascii_whitespace());
586 }
587
588 if let Some(whitespace_trimmed) = whitespace_trimmed {
589 if !request_line_read {
591 let memchr = memchr3_iter(
592 b' ',
593 b'\r',
594 b'\n',
595 &self.read_buf[whitespace_trimmed..bytes_read],
596 );
597 let mut spaces = 0;
598 for separator_index in memchr {
599 if self.read_buf[whitespace_trimmed + separator_index] == b' ' {
600 if spaces >= 2 {
601 return Err(std::io::Error::new(
602 std::io::ErrorKind::InvalidInput,
603 "bad request first line",
604 ));
605 }
606 spaces += 1;
607 } else if spaces == 2 {
608 request_line_read = true;
609 break;
610 } else {
611 return Err(std::io::Error::new(
612 std::io::ErrorKind::InvalidInput,
613 "bad request first line",
614 ));
615 }
616 }
617 }
618
619 if request_line_read {
620 let begin_search = begin_search.max(whitespace_trimmed);
621 if let Some((separator_index, separator_len)) =
622 search_header_body_separator(&self.read_buf[begin_search..bytes_read])
623 {
624 let to_parse_length =
625 begin_search + separator_index + separator_len - whitespace_trimmed;
626 self.read_buf.advance(whitespace_trimmed);
627 let head = self.read_buf.split_to(to_parse_length);
628 return Ok(Some((head.freeze(), &mut self.parsed_headers)));
629 }
630 }
631 }
632 }
633 Err(std::io::Error::new(
634 std::io::ErrorKind::InvalidData,
635 "request too large",
636 ))
637 }
638
639 #[inline]
640 async fn write_response<Z, ZFut>(
641 &mut self,
642 mut response: Response<
643 impl Body<Data = bytes::Bytes, Error = impl std::error::Error> + Unpin,
644 >,
645 version: Version,
646 write_trailers: bool,
647 zerocopy_fn: Option<Z>,
648 ) -> Result<(), std::io::Error>
649 where
650 Z: FnMut(RawHandle, &'static Io, u64) -> ZFut,
651 ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
652 {
653 if self.options.send_date_header {
655 response.headers_mut().insert(
656 header::DATE,
657 HeaderValue::from_str(self.get_date_header_value())
658 .map_err(|e| std::io::Error::other(e.to_string()))?,
659 );
660 }
661
662 if let Some(suggested_content_length) = response.body().size_hint().exact() {
664 let headers = response.headers_mut();
665 if !headers.contains_key(header::CONTENT_LENGTH) {
666 headers.insert(header::CONTENT_LENGTH, suggested_content_length.into());
667 }
668 }
669
670 let chunked = response
671 .headers()
672 .get(header::TRANSFER_ENCODING)
673 .map(|v| {
674 v.to_str().ok().is_some_and(|s| {
675 s.split(',')
676 .any(|s| s.trim().eq_ignore_ascii_case("chunked"))
677 })
678 })
679 .unwrap_or_else(|| {
680 response
681 .headers()
682 .get(header::CONTENT_LENGTH)
683 .and_then(|v| v.to_str().ok())
684 .is_none_or(|s| s.parse::<u64>().is_err())
685 });
686
687 if chunked {
688 response.headers_mut().insert(
689 header::TRANSFER_ENCODING,
690 HeaderValue::from_static("chunked"),
691 );
692 while response
693 .headers_mut()
694 .remove(header::CONTENT_LENGTH)
695 .is_some()
696 {}
697 }
698
699 let (parts, mut body) = response.into_parts();
700
701 self.response_head_buf.clear();
702 let estimated_head_len = 30 + parts.headers.len() * 30; if self.response_head_buf.capacity() < estimated_head_len {
704 self.response_head_buf
705 .reserve(estimated_head_len - self.response_head_buf.capacity());
706 }
707 let head = &mut self.response_head_buf;
708 if version == Version::HTTP_10 {
709 head.extend_from_slice(b"HTTP/1.0 ");
710 } else {
711 head.extend_from_slice(b"HTTP/1.1 ");
712 }
713 let status = parts.status;
714 head.extend_from_slice(status.as_str().as_bytes());
715 if let Some(canonical_reason) = status.canonical_reason() {
716 head.extend_from_slice(b" ");
717 head.extend_from_slice(canonical_reason.as_bytes());
718 }
719 head.extend_from_slice(b"\r\n");
720 for (name, value) in &parts.headers {
721 head.extend_from_slice(name.as_str().as_bytes());
722 head.extend_from_slice(b": ");
723 head.extend_from_slice(value.as_bytes());
724 head.extend_from_slice(b"\r\n");
725 }
726 head.extend_from_slice(b"\r\n");
727 unsafe {
728 self.write_buf.push(IoSlice::new(head));
729 }
730
731 if !chunked {
732 if let Some(content_length) = parts
733 .headers
734 .get(header::CONTENT_LENGTH)
735 .and_then(|v| v.to_str().ok())
736 .and_then(|s| s.parse::<u64>().ok())
737 {
738 if let Some(zero_copy) = parts.extensions.get::<ZerocopyResponse>() {
739 if let Some(mut zerocopy_fn) = zerocopy_fn {
740 unsafe {
742 self.write_buf
743 .flush(&mut self.io, self.options.enable_vectored_write)
744 .await?
745 };
746 zerocopy_fn(
747 zero_copy.handle,
748 unsafe { std::mem::transmute::<&Io, &'static Io>(&self.io) },
750 content_length,
751 )
752 .await?;
753 self.io.flush().await?;
754 let reclaimed_headers = parts.headers;
755 self.cached_headers = Some(reclaimed_headers);
756 return Ok(());
757 }
758 }
759 }
760 }
761
762 let mut trailers_written = false;
763 while let Some(chunk) = body.frame().await {
764 let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
765 match chunk.into_data() {
766 Ok(data) => {
767 if chunked {
768 let mut chunk_size_buf = [0u8; 18];
769 let chunk_size = write_chunk_size(&mut chunk_size_buf, data.len());
770 self.write_buf.push_copy(chunk_size);
771 self.write_buf.push_bytes(data);
772 unsafe {
773 self.write_buf.push(IoSlice::new(b"\r\n"));
774 }
775 } else {
776 self.write_buf.push_bytes(data);
777 }
778 while self.write_buf.len() >= WRITE_BUF_BATCH_THRESHOLD {
779 let bytes_written = unsafe {
780 self.write_buf
781 .write(&mut self.io, self.options.enable_vectored_write)
782 .await?
783 };
784 if bytes_written == 0 {
785 return Err(std::io::ErrorKind::WriteZero.into());
786 }
787 }
788 }
789 Err(chunk) => {
790 if let Ok(trailers) = chunk.into_trailers() {
791 if write_trailers {
792 unsafe {
793 self.write_buf.push(IoSlice::new(b"0\r\n"));
794 for (name, value) in &trailers {
795 self.write_buf.push_copy(name.as_str().as_bytes());
796 self.write_buf.push(IoSlice::new(b": "));
797 self.write_buf.push_copy(value.as_bytes());
798 self.write_buf.push(IoSlice::new(b"\r\n"));
799 }
800 self.write_buf.push(IoSlice::new(b"\r\n"));
801 }
802 trailers_written = true;
803 }
804 break;
805 }
806 }
807 };
808 }
809 if chunked && !trailers_written {
810 unsafe {
812 self.write_buf.push(IoSlice::new(b"0\r\n\r\n"));
813 }
814 }
815 unsafe {
816 self.write_buf
817 .flush(&mut self.io, self.options.enable_vectored_write)
818 .await?;
819 }
820 self.io.flush().await?;
821 let reclaimed_headers = parts.headers;
822 self.cached_headers = Some(reclaimed_headers);
823
824 Ok(())
825 }
826
827 #[inline]
828 async fn write_100_continue(&mut self, version: Version) -> Result<(), std::io::Error> {
829 if version == Version::HTTP_10 {
830 self.io.write_all(b"HTTP/1.0 100 Continue\r\n\r\n").await?;
831 } else {
832 self.io.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
833 }
834 self.io.flush().await?;
835
836 Ok(())
837 }
838
839 #[inline]
840 async fn write_early_hints(
841 &mut self,
842 version: Version,
843 headers: http::HeaderMap,
844 ) -> Result<(), std::io::Error> {
845 let mut head = Vec::new();
846 if version == Version::HTTP_10 {
847 head.extend_from_slice(b"HTTP/1.0 103 Early Hints\r\n");
848 } else {
849 head.extend_from_slice(b"HTTP/1.1 103 Early Hints\r\n");
850 }
851 let mut current_header_name = None;
852 for (name, value) in headers {
853 if let Some(name) = name {
854 current_header_name = Some(name);
855 };
856 if let Some(current_header_name) = ¤t_header_name {
857 head.extend_from_slice(current_header_name.as_str().as_bytes());
858 if value.is_empty() {
859 head.extend_from_slice(b":\r\n");
860 continue;
861 }
862 head.extend_from_slice(b": ");
863 head.extend_from_slice(value.as_bytes());
864 head.extend_from_slice(b"\r\n");
865 }
866 }
867 head.extend_from_slice(b"\r\n");
868
869 self.io.write_all(&head).await?;
870
871 Ok(())
872 }
873
874 #[inline]
875 pub(crate) async fn handle_with_error_fn_and_zerocopy<
876 F,
877 Fut,
878 ResB,
879 ResBE,
880 ResE,
881 EF,
882 EFut,
883 EResB,
884 EResBE,
885 EResE,
886 ZF,
887 ZFut,
888 >(
889 mut self,
890 request_fn: F,
891 error_fn: EF,
892 mut zerocopy_fn: Option<ZF>,
893 ) -> Result<(), std::io::Error>
894 where
895 F: Fn(Request<Incoming>) -> Fut + 'static,
896 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
897 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
898 ResE: std::error::Error,
899 ResBE: std::error::Error,
900 EF: FnOnce(bool) -> EFut,
901 EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
902 EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
903 EResE: std::error::Error,
904 EResBE: std::error::Error,
905 ZF: FnMut(RawHandle, &'static Io, u64) -> ZFut,
906 ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
907 {
908 let mut keep_alive = true;
909
910 while keep_alive {
911 let (mut request, body_tx, send_continue_body) = match if let Some(timeout) =
912 self.options.header_read_timeout
913 {
914 vibeio::time::timeout(timeout, async {
915 if let Some(token) = self.cancel_token.clone() {
916 token.run_until_cancelled(self.read_request()).await
917 } else {
918 Some(self.read_request().await)
919 }
920 })
921 .await
922 } else {
923 Ok(Some(self.read_request().await))
924 } {
925 Ok(Some(Ok(Some(d)))) => d,
926 Ok(Some(Ok(None))) => {
927 return Ok(());
928 }
929 Ok(Some(Err(e))) => {
930 if let Ok(mut response) = error_fn(false).await {
932 response
933 .headers_mut()
934 .insert(header::CONNECTION, HeaderValue::from_static("close"));
935
936 let _ = self
937 .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
938 .await;
939 }
940 return Err(e);
941 }
942 Ok(None) => {
943 return Ok(());
945 }
946 Err(_) => {
947 if let Ok(mut response) = error_fn(true).await {
949 response
950 .headers_mut()
951 .insert(header::CONNECTION, HeaderValue::from_static("close"));
952
953 let _ = self
954 .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
955 .await;
956 }
957 return Err(std::io::Error::new(
958 std::io::ErrorKind::TimedOut,
959 "header read timeout",
960 ));
961 }
962 };
963
964 let connection_header_split = request
966 .headers()
967 .get(header::CONNECTION)
968 .and_then(|v| v.to_str().ok())
969 .map(|v| v.split(",").map(|v| v.trim()));
970 let is_connection_close = connection_header_split
971 .clone()
972 .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("close")));
973 let is_connection_keep_alive = connection_header_split
974 .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("keep-alive")));
975 keep_alive = !is_connection_close
976 && (is_connection_keep_alive || request.version() == http::Version::HTTP_11);
977
978 let version = request.version();
979 let is_100_continue = send_continue_body.is_some();
980
981 let early_hints_fut = if self.options.enable_early_hints {
983 let (early_hints, mut early_hints_rx) = EarlyHints::new_lazy();
984 request.extensions_mut().insert(early_hints);
985 let mut_self = unsafe { std::mem::transmute::<&mut Self, &mut Self>(&mut self) };
989 futures_util::future::Either::Left(async move {
990 while let Some((headers, sender)) =
991 std::future::poll_fn(|cx| early_hints_rx.poll_recv(cx)).await
992 {
993 sender
994 .into_inner()
995 .send(mut_self.write_early_hints(version, headers).await)
996 .ok();
997 }
998 futures_util::future::pending::<Result<(), std::io::Error>>().await
999 })
1000 } else {
1001 futures_util::future::Either::Right(futures_util::future::pending::<
1002 Result<(), std::io::Error>,
1003 >())
1004 };
1005
1006 let content_length = request
1008 .headers()
1009 .get(header::CONTENT_LENGTH)
1010 .and_then(|v| v.to_str().ok())
1011 .and_then(|v| v.parse::<u64>().ok())
1012 .unwrap_or(0);
1013 let chunked = request
1014 .headers()
1015 .get(header::TRANSFER_ENCODING)
1016 .and_then(|v| v.to_str().ok())
1017 .is_some_and(|v| {
1018 v.split(',')
1019 .any(|v| v.trim().eq_ignore_ascii_case("chunked"))
1020 });
1021 let has_trailers = request
1022 .headers()
1023 .get(header::TRAILER)
1024 .map(|v| v.to_str().ok().is_some_and(|s| !s.is_empty()))
1025 .unwrap_or(false);
1026 let write_trailers = request
1027 .headers()
1028 .get(header::TE)
1029 .and_then(|v| v.to_str().ok())
1030 .map(|v| {
1031 v.split(',')
1032 .any(|v| v.trim().eq_ignore_ascii_case("trailers"))
1033 })
1034 .unwrap_or(false);
1035
1036 let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
1038 let upgrade = Upgrade::new(upgrade_rx);
1039 let upgraded = upgrade.upgraded.clone();
1040 request.extensions_mut().insert(upgrade);
1041
1042 let mut continue_sent = false;
1044 let mut response = {
1045 let read_body_fut = async {
1046 if chunked {
1047 self.read_chunked_body_fn(
1048 body_tx,
1049 has_trailers,
1050 &send_continue_body,
1051 &mut continue_sent,
1052 version,
1053 )
1054 .await
1055 } else {
1056 self.read_body_fn(
1057 body_tx,
1058 content_length,
1059 &send_continue_body,
1060 &mut continue_sent,
1061 version,
1062 )
1063 .await
1064 }
1065 };
1066 let read_body_fut_pin = std::pin::pin!(read_body_fut);
1067 let request_fut = request_fn(request);
1068 let request_fut_pin = std::pin::pin!(request_fut);
1069 let early_hints_fut_pin = std::pin::pin!(early_hints_fut);
1070
1071 let select_read_body_either =
1072 futures_util::future::select(request_fut_pin, early_hints_fut_pin);
1073 let select_either =
1074 futures_util::future::select(read_body_fut_pin, select_read_body_either).await;
1075
1076 let (response, body_fut) = match select_either {
1077 futures_util::future::Either::Left((result, request_fut)) => {
1078 result?;
1079 (
1080 match request_fut.await {
1081 futures_util::future::Either::Left((response, _)) => response,
1082 futures_util::future::Either::Right((_, _)) => unreachable!(),
1083 },
1084 None,
1085 )
1086 }
1087 futures_util::future::Either::Right((response, read_body_fut)) => (
1088 match response {
1089 futures_util::future::Either::Left((response, _)) => response,
1090 futures_util::future::Either::Right((_, _)) => unreachable!(),
1091 },
1092 Some(read_body_fut),
1093 ),
1094 };
1095
1096 if let Some(body_fut) = body_fut {
1098 body_fut.await?;
1099 }
1100
1101 response.map_err(|e| std::io::Error::other(e.to_string()))?
1102 };
1103
1104 if !continue_sent
1106 && is_100_continue
1107 && !response.status().is_client_error()
1108 && !response.status().is_server_error()
1109 {
1110 self.write_100_continue(version).await?;
1111 }
1112
1113 let mut was_upgraded = false;
1114 if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
1115 was_upgraded = true;
1116 response
1117 .headers_mut()
1118 .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
1119 } else if keep_alive {
1120 if version == Version::HTTP_10
1121 || response.headers().contains_key(header::CONNECTION)
1122 {
1123 response
1124 .headers_mut()
1125 .insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
1126 }
1127 } else if version == Version::HTTP_11
1128 || response.headers().contains_key(header::CONNECTION)
1129 {
1130 response
1131 .headers_mut()
1132 .insert(header::CONNECTION, HeaderValue::from_static("close"));
1133 }
1134
1135 self.write_response(response, version, write_trailers, zerocopy_fn.as_mut())
1137 .await?;
1138
1139 if was_upgraded {
1140 let frozen_buf = self.read_buf.freeze();
1142 let _ = upgrade_tx.send(Upgraded::new(
1143 self.io,
1144 if frozen_buf.is_empty() {
1145 None
1146 } else {
1147 Some(frozen_buf)
1148 },
1149 ));
1150 return Ok(());
1151 }
1152
1153 if self.cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
1154 break;
1156 }
1157 }
1158 Ok(())
1159 }
1160}
1161
1162impl<Io> HttpProtocol for Http1<Io>
1163where
1164 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
1165{
1166 #[inline]
1167 fn handle_with_error_fn<F, Fut, ResB, ResBE, ResE, EF, EFut, EResB, EResBE, EResE>(
1168 self,
1169 request_fn: F,
1170 error_fn: EF,
1171 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1172 where
1173 F: Fn(Request<Incoming>) -> Fut + 'static,
1174 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
1175 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
1176 ResE: std::error::Error,
1177 ResBE: std::error::Error,
1178 EF: FnOnce(bool) -> EFut,
1179 EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
1180 EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
1181 EResE: std::error::Error,
1182 EResBE: std::error::Error,
1183 {
1184 #[allow(clippy::type_complexity)]
1185 let no_zerocopy: Option<
1186 Box<
1187 dyn FnMut(
1188 RawHandle,
1189 &Io,
1190 u64,
1191 ) -> Box<
1192 dyn std::future::Future<Output = Result<(), std::io::Error>>
1193 + Unpin
1194 + Send
1195 + Sync,
1196 >,
1197 >,
1198 > = None;
1199 self.handle_with_error_fn_and_zerocopy(request_fn, error_fn, no_zerocopy)
1200 }
1201
1202 #[inline]
1203 fn handle<F, Fut, ResB, ResBE, ResE>(
1204 self,
1205 request_fn: F,
1206 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1207 where
1208 F: Fn(Request<Incoming>) -> Fut + 'static,
1209 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
1210 ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
1211 ResE: std::error::Error,
1212 ResBE: std::error::Error,
1213 {
1214 self.handle_with_error_fn(request_fn, |is_timeout| async move {
1215 let mut response = Response::builder();
1216 if is_timeout {
1217 response = response.status(http::StatusCode::REQUEST_TIMEOUT);
1218 } else {
1219 response = response.status(http::StatusCode::BAD_REQUEST);
1220 }
1221 response.body(Empty::new())
1222 })
1223 }
1224}
1225
1226pub(crate) struct Http1Body {
1227 #[allow(clippy::type_complexity)]
1228 inner: Pin<Box<AsyncReceiver<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>>>,
1229 send_continue_body: Option<Arc<AtomicBool>>,
1230}
1231
1232impl Body for Http1Body {
1233 type Data = bytes::Bytes;
1234 type Error = std::io::Error;
1235
1236 #[inline]
1237 fn poll_frame(
1238 self: Pin<&mut Self>,
1239 cx: &mut Context<'_>,
1240 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
1241 match std::pin::pin!(self.inner.recv()).poll(cx) {
1242 Poll::Ready(Ok(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
1243 Poll::Ready(Ok(Err(e))) => Poll::Ready(Some(Err(e))),
1244 Poll::Ready(Err(_)) => Poll::Ready(None),
1245 Poll::Pending => {
1246 if let Some(scb) = self.send_continue_body.as_ref() {
1247 scb.store(true, Ordering::Relaxed);
1248 }
1249 Poll::Pending
1250 }
1251 }
1252 }
1253}
1254
1255#[inline]
1258fn search_header_body_separator(slice: &[u8]) -> Option<(usize, usize)> {
1259 if slice.len() < 2 {
1260 return None;
1262 }
1263 for (i, b) in slice.iter().copied().enumerate() {
1264 if b == b'\r' {
1265 if slice[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) {
1266 return Some((i, 4));
1267 }
1268 } else if b == b'\n' && slice.get(i + 1) == Some(&b'\n') {
1269 return Some((i, 2));
1270 }
1271 }
1272 None
1273}
1274
1275#[inline]
1277fn write_chunk_size(dst: &mut [u8; 18], len: usize) -> &[u8] {
1278 let mut n = len;
1279 let mut pos = dst.len() - 2;
1280 loop {
1281 pos -= 1;
1282 dst[pos] = HEX_DIGITS[n & 0xF];
1283 n >>= 4;
1284 if n == 0 {
1285 break;
1286 }
1287 }
1288 dst[dst.len() - 2] = b'\r';
1289 dst[dst.len() - 1] = b'\n';
1290 &dst[pos..]
1291}