Skip to main content

vibeio_http/h1/
mod.rs

1mod options;
2mod tests;
3mod writebuf;
4mod zerocopy;
5
6pub use options::*;
7pub use zerocopy::*;
8
9#[cfg(unix)]
10pub(crate) type RawHandle = std::os::fd::RawFd;
11#[cfg(windows)]
12pub(crate) type RawHandle = std::os::windows::io::RawHandle;
13
14use std::{
15    future::Future,
16    io::IoSlice,
17    mem::MaybeUninit,
18    pin::Pin,
19    str::FromStr,
20    task::{Context, Poll},
21    time::UNIX_EPOCH,
22};
23
24use bytes::{Buf, Bytes, BytesMut};
25use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, Uri, Version};
26use http_body::Body;
27use http_body_util::{BodyExt, Empty};
28use kanal::AsyncReceiver;
29use memchr::{memchr3_iter, memmem};
30use tokio::io::{AsyncReadExt, AsyncWriteExt};
31use tokio_util::sync::CancellationToken;
32
33use crate::{h1::writebuf::WriteBuf, EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded};
34
35const HEX_DIGITS: &[u8; 16] = b"0123456789ABCDEF";
36const WRITE_BUF_BATCH_THRESHOLD: usize = 16384;
37
38/// An HTTP/1.x connection handler.
39///
40/// `Http1` wraps an async I/O stream (`Io`) and provides a complete
41/// HTTP/1.0 and HTTP/1.1 server implementation, including:
42///
43/// - Request head parsing (via [`httparse`])
44/// - Streaming request bodies (content-length and chunked transfer-encoding)
45/// - Chunked response encoding and trailer support
46/// - `100 Continue` and `103 Early Hints` interim responses
47/// - HTTP connection upgrades (e.g. WebSocket)
48/// - Optional zero-copy response sending on Linux (see `Http1::zerocopy`)
49/// - Keep-alive connection reuse
50/// - Graceful shutdown via a [`CancellationToken`]
51///
52/// # Construction
53///
54/// ```rust,ignore
55/// let http1 = Http1::new(tcp_stream, Http1Options::default());
56/// ```
57///
58/// # Serving requests
59///
60/// Use the [`HttpProtocol`] trait methods ([`handle`](HttpProtocol::handle) /
61/// [`handle_with_error_fn`](HttpProtocol::handle_with_error_fn)) to drive the
62/// connection to completion:
63///
64/// ```rust,ignore
65/// http1.handle(|req| async move {
66///     Ok::<_, Infallible>(Response::new(Full::new(Bytes::from("Hello!"))))
67/// }).await?;
68/// ```
69pub struct Http1<Io> {
70    io: Io,
71    options: options::Http1Options,
72    cancel_token: Option<CancellationToken>,
73    parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]>,
74    date_header_value_cached: Option<(String, std::time::SystemTime)>,
75    cached_headers: Option<HeaderMap>,
76    read_buf: BytesMut,
77    response_head_buf: Vec<u8>,
78    write_buf: WriteBuf,
79}
80
81#[cfg(all(target_os = "linux", feature = "h1-zerocopy"))]
82impl<Io> Http1<Io>
83where
84    for<'a> Io: tokio::io::AsyncRead
85        + tokio::io::AsyncWrite
86        + vibeio::io::AsInnerRawHandle<'a>
87        + Unpin
88        + 'static,
89{
90    /// Converts this `Http1` into an [`Http1Zerocopy`] that uses emulated
91    /// sendfile (Linux only) to send response bodies without copying data
92    /// through user space.
93    ///
94    /// The response body must have a `ZerocopyResponse` extension installed
95    /// (via [`install_zerocopy`]) containing the file descriptor to send from.
96    /// Responses without that extension are sent normally.
97    ///
98    /// Only available on Linux (`target_os = "linux"`), and only when `Io`
99    /// implements [`vibeio::io::AsInnerRawHandle`].
100    #[inline]
101    pub fn zerocopy(self) -> Http1Zerocopy<Io> {
102        Http1Zerocopy { inner: self }
103    }
104}
105
106impl<Io> Http1<Io>
107where
108    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
109{
110    /// Creates a new `Http1` connection handler wrapping the given I/O stream.
111    ///
112    /// The `options` value controls limits, timeouts, and optional features;
113    /// see [`Http1Options`] for details.
114    ///
115    /// # Example
116    ///
117    /// ```rust,ignore
118    /// let http1 = Http1::new(tcp_stream, Http1Options::default());
119    /// ```
120    #[inline]
121    pub fn new(io: Io, options: options::Http1Options) -> Self {
122        // Safety: u8 is a primitive type, so we can safely assume initialization
123        let read_buf = BytesMut::with_capacity(options.max_header_size);
124        let parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]> =
125            Box::new_uninit_slice(options.max_header_count);
126        Self {
127            io,
128            options,
129            cancel_token: None,
130            parsed_headers,
131            date_header_value_cached: None,
132            cached_headers: None,
133            read_buf,
134            response_head_buf: Vec::with_capacity(1024),
135            write_buf: WriteBuf::new(),
136        }
137    }
138
139    #[inline]
140    fn get_date_header_value(&mut self) -> &str {
141        let now = std::time::SystemTime::now();
142        if self.date_header_value_cached.as_ref().is_none_or(|v| {
143            v.1.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
144                != now.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
145        }) {
146            let value = httpdate::fmt_http_date(now).to_string();
147            self.date_header_value_cached = Some((value, now));
148        }
149        self.date_header_value_cached
150            .as_ref()
151            .map(|v| v.0.as_str())
152            .unwrap_or("")
153    }
154
155    /// Attaches a [`CancellationToken`] for graceful shutdown.
156    ///
157    /// After the current in-flight request has been fully handled and its
158    /// response written, the connection loop checks whether the token has been
159    /// cancelled. If it has, the loop exits cleanly instead of waiting for the
160    /// next request.
161    ///
162    /// This allows the server to drain active connections without abruptly
163    /// closing them mid-response.
164    #[inline]
165    pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
166        self.cancel_token = Some(token);
167        self
168    }
169
170    #[inline]
171    async fn fill_buf(&mut self) -> Result<usize, std::io::Error> {
172        if self.read_buf.remaining() < 1024 {
173            self.read_buf.reserve(1024);
174        }
175        let spare_capacity = self.read_buf.spare_capacity_mut();
176        // Safety: The buffer is are read only after the request head has been parsed
177        let n = self
178            .io
179            .read(unsafe {
180                &mut *std::ptr::slice_from_raw_parts_mut(
181                    spare_capacity.as_mut_ptr() as *mut u8,
182                    spare_capacity.len(),
183                )
184            })
185            .await?;
186        if n == 0 {
187            return Ok(0);
188        }
189        unsafe { self.read_buf.set_len(self.read_buf.len() + n) };
190        Ok(n)
191    }
192
193    #[inline]
194    async fn read_body_fn(
195        &mut self,
196        body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
197        content_length: u64,
198    ) -> Result<(), std::io::Error> {
199        let mut remaining = content_length;
200        let mut just_started = true;
201        while remaining > 0 {
202            let have_to_read_buf = !just_started || self.read_buf.is_empty();
203            just_started = false;
204            if have_to_read_buf {
205                let n = self.fill_buf().await?;
206                if n == 0 {
207                    break;
208                }
209            }
210            let chunk = self
211                .read_buf
212                .split_to(
213                    self.read_buf
214                        .len()
215                        .min(remaining.min(usize::MAX as u64) as usize),
216                )
217                .freeze();
218            remaining -= chunk.len() as u64;
219
220            let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
221        }
222        Ok(())
223    }
224
225    #[inline]
226    async fn read_body_chunk(
227        &mut self,
228        would_have_trailers: bool,
229    ) -> Result<bytes::Bytes, std::io::Error> {
230        let len = {
231            // Safety: u8 is a primitive type, so we can safely assume initialization
232            let mut len_buf_pos: usize = 0;
233            let mut just_started = true;
234            loop {
235                if len_buf_pos >= 48 {
236                    return Err(std::io::Error::new(
237                        std::io::ErrorKind::InvalidData,
238                        "chunk length buffer overflow",
239                    ));
240                }
241
242                let begin_search = len_buf_pos.saturating_sub(1);
243
244                let have_to_read_buf = !just_started || self.read_buf.is_empty();
245                just_started = false;
246                if have_to_read_buf {
247                    let n = self.fill_buf().await?;
248                    if n == 0 {
249                        return Err(std::io::Error::new(
250                            std::io::ErrorKind::UnexpectedEof,
251                            "unexpected EOF",
252                        ));
253                    }
254                    len_buf_pos += n;
255                } else {
256                    len_buf_pos += self.read_buf.len();
257                }
258
259                if let Some(pos) =
260                    memmem::find(&self.read_buf[begin_search..len_buf_pos.min(48)], b"\r\n")
261                {
262                    let numbers = std::str::from_utf8(&self.read_buf[..begin_search + pos])
263                        .map_err(|_| {
264                            std::io::Error::new(
265                                std::io::ErrorKind::InvalidData,
266                                "invalid chunk length",
267                            )
268                        })?;
269                    let len = usize::from_str_radix(numbers, 16).map_err(|_| {
270                        std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid chunk length")
271                    })?;
272                    // Ignore the trailing CRLF
273                    self.read_buf.advance(begin_search + pos + 2);
274                    break len;
275                }
276            }
277        };
278        // Safety: u8 is a primitive type, so we can safely assume initialization
279        let mut read = 0;
280        if len == 0 && would_have_trailers {
281            return Ok(bytes::Bytes::new()); // Empty terminating chunk
282        }
283        let mut just_started = true;
284        // + 2, because we need to read the trailing CRLF
285        let Some(len_plus_two) = len.checked_add(2) else {
286            return Err(std::io::Error::new(
287                std::io::ErrorKind::InvalidData,
288                "chunk length too large",
289            ));
290        };
291        while read < len_plus_two {
292            let have_to_read_buf = !just_started || self.read_buf.is_empty();
293            just_started = false;
294            if have_to_read_buf {
295                let n = self.fill_buf().await?;
296                if n == 0 {
297                    return Err(std::io::Error::new(
298                        std::io::ErrorKind::UnexpectedEof,
299                        "unexpected EOF",
300                    ));
301                }
302                read += n;
303            } else {
304                read += self.read_buf.len();
305            }
306        }
307        let chunk = self.read_buf.split_to(len).freeze();
308        self.read_buf.advance(2); // Ignore the trailing CRLF
309        Ok(chunk)
310    }
311
312    #[inline]
313    async fn read_trailers(&mut self) -> Result<Option<HeaderMap>, std::io::Error> {
314        // Safety: u8 is a primitive type, so we can safely assume initialization
315        let mut bytes_read: usize = 0;
316        let mut just_started = true;
317        while bytes_read < self.options.max_header_size {
318            let old_bytes_read = bytes_read;
319            let begin_search = old_bytes_read.saturating_sub(3);
320
321            let have_to_read_buf = !just_started || self.read_buf.is_empty();
322            just_started = false;
323            if have_to_read_buf {
324                let n = self.fill_buf().await?;
325                if n == 0 {
326                    return Err(std::io::Error::new(
327                        std::io::ErrorKind::UnexpectedEof,
328                        "unexpected EOF",
329                    ));
330                }
331                bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
332            } else {
333                bytes_read =
334                    (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
335            }
336
337            if bytes_read >= 2 && self.read_buf[0] == b'\r' && self.read_buf[1] == b'\n' {
338                // No trailers, return None
339                return Ok(None);
340            }
341
342            if let Some(separator_index) =
343                memmem::find(&self.read_buf[begin_search..bytes_read], b"\r\n\r\n")
344            {
345                let to_parse_length = begin_search + separator_index + 4;
346                let buf_ro = self.read_buf.split_to(to_parse_length).freeze();
347
348                // Parse trailers using `httparse` crate's header parsing
349                let mut httparse_trailers =
350                    vec![httparse::EMPTY_HEADER; self.options.max_header_count].into_boxed_slice();
351                let status = httparse::parse_headers(&buf_ro, &mut httparse_trailers)
352                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
353                if let httparse::Status::Complete((_, trailers)) = status {
354                    let mut trailers_constructed = HeaderMap::new();
355                    for header in trailers {
356                        if header == &httparse::EMPTY_HEADER {
357                            // No more headers...
358                            break;
359                        }
360                        let name = HeaderName::from_bytes(header.name.as_bytes())
361                            .map_err(|e| std::io::Error::other(e.to_string()))?;
362                        let value_start = header.value.as_ptr() as usize - buf_ro.as_ptr() as usize;
363                        let value_len = header.value.len();
364                        // Safety: the header value is already validated by httparse
365                        let value = unsafe {
366                            HeaderValue::from_maybe_shared_unchecked(
367                                buf_ro.slice(value_start..(value_start + value_len)),
368                            )
369                        };
370                        trailers_constructed.append(name, value);
371                    }
372
373                    return Ok(Some(trailers_constructed));
374                } else {
375                    return Err(std::io::Error::new(
376                        std::io::ErrorKind::InvalidInput,
377                        "trailer headers incomplete",
378                    ));
379                }
380            }
381        }
382        Err(std::io::Error::new(
383            std::io::ErrorKind::InvalidData,
384            "request too large",
385        ))
386    }
387
388    #[inline]
389    async fn read_chunked_body_fn(
390        &mut self,
391        body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
392        would_have_trailers: bool,
393    ) -> Result<(), std::io::Error> {
394        loop {
395            let chunk = self.read_body_chunk(would_have_trailers).await?;
396            if chunk.is_empty() {
397                break;
398            }
399
400            let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
401        }
402        if would_have_trailers {
403            // Trailers
404            let trailers = self.read_trailers().await?;
405            if let Some(trailers) = trailers {
406                let _ = body_tx.send(Ok(http_body::Frame::trailers(trailers))).await;
407            }
408        }
409        Ok(())
410    }
411
412    #[inline]
413    async fn read_request(
414        &mut self,
415    ) -> Result<
416        Option<(
417            Request<Incoming>,
418            kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
419        )>,
420        std::io::Error,
421    > {
422        // Parse HTTP request using httparse
423        let (request, body_tx) = {
424            let Some((head, headers)) = self.get_head().await? else {
425                return Ok(None);
426            };
427            // Safety: The headers are read only after the request head has been parsed
428            let headers = unsafe {
429                std::mem::transmute::<
430                    &mut [MaybeUninit<httparse::Header<'static>>],
431                    &mut [MaybeUninit<httparse::Header<'_>>],
432                >(headers)
433            };
434            let mut req = httparse::Request::new(&mut []);
435            let status = req
436                .parse_with_uninit_headers(&head, headers)
437                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
438            if status.is_partial() {
439                return Err(std::io::Error::new(
440                    std::io::ErrorKind::InvalidData,
441                    "partial request head",
442                ));
443            }
444
445            // Convert httparse HTTP request to `http` one
446            let (body_tx, body_rx) = kanal::bounded_async(2);
447            let request_body = Http1Body {
448                inner: Box::pin(body_rx),
449            };
450            let mut request = Request::new(Incoming::H1(request_body));
451            match req.version {
452                Some(0) => *request.version_mut() = http::Version::HTTP_10,
453                Some(1) => *request.version_mut() = http::Version::HTTP_11,
454                _ => *request.version_mut() = http::Version::HTTP_11,
455            };
456            if let Some(method) = req.method {
457                *request.method_mut() = Method::from_bytes(method.as_bytes())
458                    .map_err(|e| std::io::Error::other(e.to_string()))?;
459            }
460            if let Some(path) = req.path {
461                *request.uri_mut() =
462                    Uri::from_str(path).map_err(|e| std::io::Error::other(e.to_string()))?;
463            }
464            let mut header_map = self.cached_headers.take().unwrap_or_default();
465            header_map.clear();
466            let additional_capacity = req.headers.len().saturating_sub(header_map.capacity());
467            if additional_capacity > 0 {
468                header_map.reserve(additional_capacity);
469            }
470            for header in req.headers {
471                if header == &httparse::EMPTY_HEADER {
472                    // No more headers...
473                    break;
474                }
475                let name = HeaderName::from_bytes(header.name.as_bytes())
476                    .map_err(|e| std::io::Error::other(e.to_string()))?;
477                let value_start = header.value.as_ptr() as usize - head.as_ptr() as usize;
478                let value_len = header.value.len();
479                // Safety: the header value is already validated by httparse
480                let value = unsafe {
481                    HeaderValue::from_maybe_shared_unchecked(
482                        head.slice(value_start..(value_start + value_len)),
483                    )
484                };
485                header_map.append(name, value);
486            }
487            *request.headers_mut() = header_map;
488
489            (request, body_tx)
490        };
491        Ok(Some((request, body_tx)))
492    }
493
494    #[inline]
495    async fn get_head(
496        &mut self,
497    ) -> Result<Option<(Bytes, &mut [MaybeUninit<httparse::Header<'static>>])>, std::io::Error>
498    {
499        let mut request_line_read = false;
500        let mut bytes_read: usize = 0;
501        let mut whitespace_trimmed = None;
502        let mut just_started = true;
503        while bytes_read < self.options.max_header_size {
504            let old_bytes_read = bytes_read;
505            let begin_search = old_bytes_read.saturating_sub(3);
506
507            let have_to_read_buf = !just_started || self.read_buf.is_empty();
508            just_started = false;
509            if have_to_read_buf {
510                let n = self.fill_buf().await?;
511                if n == 0 {
512                    if whitespace_trimmed.is_none() {
513                        return Ok(None);
514                    }
515                    return Err(std::io::Error::new(
516                        std::io::ErrorKind::UnexpectedEof,
517                        "unexpected EOF",
518                    ));
519                }
520                bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
521            } else {
522                bytes_read =
523                    (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
524            }
525
526            if whitespace_trimmed.is_none() {
527                whitespace_trimmed = self.read_buf[old_bytes_read..bytes_read]
528                    .iter()
529                    .position(|b| !b.is_ascii_whitespace());
530            }
531
532            if let Some(whitespace_trimmed) = whitespace_trimmed {
533                // Validate first line (request line) before checking for header/body separator
534                if !request_line_read {
535                    let memchr = memchr3_iter(
536                        b' ',
537                        b'\r',
538                        b'\n',
539                        &self.read_buf[whitespace_trimmed..bytes_read],
540                    );
541                    let mut spaces = 0;
542                    for separator_index in memchr {
543                        if self.read_buf[whitespace_trimmed + separator_index] == b' ' {
544                            if spaces >= 2 {
545                                return Err(std::io::Error::new(
546                                    std::io::ErrorKind::InvalidInput,
547                                    "bad request first line",
548                                ));
549                            }
550                            spaces += 1;
551                        } else if spaces == 2 {
552                            request_line_read = true;
553                            break;
554                        } else {
555                            return Err(std::io::Error::new(
556                                std::io::ErrorKind::InvalidInput,
557                                "bad request first line",
558                            ));
559                        }
560                    }
561                }
562
563                if request_line_read {
564                    let begin_search = begin_search.max(whitespace_trimmed);
565                    if let Some((separator_index, separator_len)) =
566                        search_header_body_separator(&self.read_buf[begin_search..bytes_read])
567                    {
568                        let to_parse_length =
569                            begin_search + separator_index + separator_len - whitespace_trimmed;
570                        self.read_buf.advance(whitespace_trimmed);
571                        let head = self.read_buf.split_to(to_parse_length);
572                        return Ok(Some((head.freeze(), &mut self.parsed_headers)));
573                    }
574                }
575            }
576        }
577        Err(std::io::Error::new(
578            std::io::ErrorKind::InvalidData,
579            "request too large",
580        ))
581    }
582
583    #[inline]
584    async fn write_response<Z, ZFut>(
585        &mut self,
586        mut response: Response<
587            impl Body<Data = bytes::Bytes, Error = impl std::error::Error> + Unpin,
588        >,
589        version: Version,
590        write_trailers: bool,
591        zerocopy_fn: Option<Z>,
592    ) -> Result<(), std::io::Error>
593    where
594        Z: FnMut(RawHandle, &'static Io, u64) -> ZFut,
595        ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
596    {
597        // Date header
598        if self.options.send_date_header {
599            response.headers_mut().insert(
600                header::DATE,
601                HeaderValue::from_str(self.get_date_header_value())
602                    .map_err(|e| std::io::Error::other(e.to_string()))?,
603            );
604        }
605
606        // If the body has a size hint, set the Content-Length header if it's not already set
607        if let Some(suggested_content_length) = response.body().size_hint().exact() {
608            let headers = response.headers_mut();
609            if !headers.contains_key(header::CONTENT_LENGTH) {
610                headers.insert(header::CONTENT_LENGTH, suggested_content_length.into());
611            }
612        }
613
614        let chunked = response
615            .headers()
616            .get(header::TRANSFER_ENCODING)
617            .map(|v| {
618                v.to_str().ok().is_some_and(|s| {
619                    s.split(',')
620                        .any(|s| s.trim().eq_ignore_ascii_case("chunked"))
621                })
622            })
623            .unwrap_or_else(|| {
624                response
625                    .headers()
626                    .get(header::CONTENT_LENGTH)
627                    .and_then(|v| v.to_str().ok())
628                    .is_none_or(|s| s.parse::<u64>().is_err())
629            });
630
631        if chunked {
632            response.headers_mut().insert(
633                header::TRANSFER_ENCODING,
634                HeaderValue::from_static("chunked"),
635            );
636            while response
637                .headers_mut()
638                .remove(header::CONTENT_LENGTH)
639                .is_some()
640            {}
641        }
642
643        let (parts, mut body) = response.into_parts();
644
645        self.response_head_buf.clear();
646        let estimated_head_len = 30 + parts.headers.len() * 30; // Similar to Hyper's heuristic
647        if self.response_head_buf.capacity() < estimated_head_len {
648            self.response_head_buf
649                .reserve(estimated_head_len - self.response_head_buf.capacity());
650        }
651        let head = &mut self.response_head_buf;
652        if version == Version::HTTP_10 {
653            head.extend_from_slice(b"HTTP/1.0 ");
654        } else {
655            head.extend_from_slice(b"HTTP/1.1 ");
656        }
657        let status = parts.status;
658        head.extend_from_slice(status.as_str().as_bytes());
659        if let Some(canonical_reason) = status.canonical_reason() {
660            head.extend_from_slice(b" ");
661            head.extend_from_slice(canonical_reason.as_bytes());
662        }
663        head.extend_from_slice(b"\r\n");
664        for (name, value) in &parts.headers {
665            head.extend_from_slice(name.as_str().as_bytes());
666            head.extend_from_slice(b": ");
667            head.extend_from_slice(value.as_bytes());
668            head.extend_from_slice(b"\r\n");
669        }
670        head.extend_from_slice(b"\r\n");
671        unsafe {
672            self.write_buf.push(IoSlice::new(head));
673        }
674
675        if !chunked {
676            if let Some(content_length) = parts
677                .headers
678                .get(header::CONTENT_LENGTH)
679                .and_then(|v| v.to_str().ok())
680                .and_then(|s| s.parse::<u64>().ok())
681            {
682                if let Some(zero_copy) = parts.extensions.get::<ZerocopyResponse>() {
683                    if let Some(mut zerocopy_fn) = zerocopy_fn {
684                        // Zerocopy
685                        unsafe {
686                            self.write_buf
687                                .flush(&mut self.io, self.options.enable_vectored_write)
688                                .await?
689                        };
690                        zerocopy_fn(
691                            zero_copy.handle,
692                            // Safety: the lifetime of the static reference is bound by the lifetime of the Io struct
693                            unsafe { std::mem::transmute::<&Io, &'static Io>(&self.io) },
694                            content_length,
695                        )
696                        .await?;
697                        self.io.flush().await?;
698                        let reclaimed_headers = parts.headers;
699                        self.cached_headers = Some(reclaimed_headers);
700                        return Ok(());
701                    }
702                }
703            }
704        }
705
706        let mut trailers_written = false;
707        while let Some(chunk) = body.frame().await {
708            let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
709            match chunk.into_data() {
710                Ok(data) => {
711                    if chunked {
712                        let mut chunk_size_buf = [0u8; 18];
713                        let chunk_size = write_chunk_size(&mut chunk_size_buf, data.len());
714                        self.write_buf.push_copy(chunk_size);
715                        self.write_buf.push_bytes(data);
716                        unsafe {
717                            self.write_buf.push(IoSlice::new(b"\r\n"));
718                        }
719                    } else {
720                        self.write_buf.push_bytes(data);
721                    }
722                    while self.write_buf.len() >= WRITE_BUF_BATCH_THRESHOLD {
723                        let bytes_written = unsafe {
724                            self.write_buf
725                                .write(&mut self.io, self.options.enable_vectored_write)
726                                .await?
727                        };
728                        if bytes_written == 0 {
729                            return Err(std::io::ErrorKind::WriteZero.into());
730                        }
731                    }
732                }
733                Err(chunk) => {
734                    if let Ok(trailers) = chunk.into_trailers() {
735                        if write_trailers {
736                            unsafe {
737                                self.write_buf.push(IoSlice::new(b"0\r\n"));
738                                for (name, value) in &trailers {
739                                    self.write_buf.push_copy(name.as_str().as_bytes());
740                                    self.write_buf.push(IoSlice::new(b": "));
741                                    self.write_buf.push_copy(value.as_bytes());
742                                    self.write_buf.push(IoSlice::new(b"\r\n"));
743                                }
744                                self.write_buf.push(IoSlice::new(b"\r\n"));
745                            }
746                            trailers_written = true;
747                        }
748                        break;
749                    }
750                }
751            };
752        }
753        if chunked && !trailers_written {
754            // Terminating chunk
755            unsafe {
756                self.write_buf.push(IoSlice::new(b"0\r\n\r\n"));
757            }
758        }
759        unsafe {
760            self.write_buf
761                .flush(&mut self.io, self.options.enable_vectored_write)
762                .await?;
763        }
764        self.io.flush().await?;
765        let reclaimed_headers = parts.headers;
766        self.cached_headers = Some(reclaimed_headers);
767
768        Ok(())
769    }
770
771    #[inline]
772    async fn write_100_continue(&mut self, version: Version) -> Result<(), std::io::Error> {
773        if version == Version::HTTP_10 {
774            self.io.write_all(b"HTTP/1.0 100 Continue\r\n\r\n").await?;
775        } else {
776            self.io.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
777        }
778        self.io.flush().await?;
779
780        Ok(())
781    }
782
783    #[inline]
784    async fn write_early_hints(
785        &mut self,
786        version: Version,
787        headers: http::HeaderMap,
788    ) -> Result<(), std::io::Error> {
789        let mut head = Vec::new();
790        if version == Version::HTTP_10 {
791            head.extend_from_slice(b"HTTP/1.0 103 Early Hints\r\n");
792        } else {
793            head.extend_from_slice(b"HTTP/1.1 103 Early Hints\r\n");
794        }
795        let mut current_header_name = None;
796        for (name, value) in headers {
797            if let Some(name) = name {
798                current_header_name = Some(name);
799            };
800            if let Some(current_header_name) = &current_header_name {
801                head.extend_from_slice(current_header_name.as_str().as_bytes());
802                if value.is_empty() {
803                    head.extend_from_slice(b":\r\n");
804                    continue;
805                }
806                head.extend_from_slice(b": ");
807                head.extend_from_slice(value.as_bytes());
808                head.extend_from_slice(b"\r\n");
809            }
810        }
811        head.extend_from_slice(b"\r\n");
812
813        self.io.write_all(&head).await?;
814
815        Ok(())
816    }
817
818    #[inline]
819    pub(crate) async fn handle_with_error_fn_and_zerocopy<
820        F,
821        Fut,
822        ResB,
823        ResBE,
824        ResE,
825        EF,
826        EFut,
827        EResB,
828        EResBE,
829        EResE,
830        ZF,
831        ZFut,
832    >(
833        mut self,
834        request_fn: F,
835        error_fn: EF,
836        mut zerocopy_fn: Option<ZF>,
837    ) -> Result<(), std::io::Error>
838    where
839        F: Fn(Request<Incoming>) -> Fut + 'static,
840        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
841        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
842        ResE: std::error::Error,
843        ResBE: std::error::Error,
844        EF: FnOnce(bool) -> EFut,
845        EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
846        EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
847        EResE: std::error::Error,
848        EResBE: std::error::Error,
849        ZF: FnMut(RawHandle, &'static Io, u64) -> ZFut,
850        ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
851    {
852        let mut keep_alive = true;
853
854        while keep_alive {
855            let (mut request, body_tx) = match if let Some(timeout) =
856                self.options.header_read_timeout
857            {
858                vibeio::time::timeout(timeout, async {
859                    if let Some(token) = self.cancel_token.clone() {
860                        token.run_until_cancelled(self.read_request()).await
861                    } else {
862                        Some(self.read_request().await)
863                    }
864                })
865                .await
866            } else {
867                Ok(Some(self.read_request().await))
868            } {
869                Ok(Some(Ok(Some(d)))) => d,
870                Ok(Some(Ok(None))) => {
871                    return Ok(());
872                }
873                Ok(Some(Err(e))) => {
874                    // Parse error
875                    if let Ok(mut response) = error_fn(false).await {
876                        response
877                            .headers_mut()
878                            .insert(header::CONNECTION, HeaderValue::from_static("close"));
879
880                        let _ = self
881                            .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
882                            .await;
883                    }
884                    return Err(e);
885                }
886                Ok(None) => {
887                    // Graceful shutdown
888                    return Ok(());
889                }
890                Err(_) => {
891                    // Timeout error
892                    if let Ok(mut response) = error_fn(true).await {
893                        response
894                            .headers_mut()
895                            .insert(header::CONNECTION, HeaderValue::from_static("close"));
896
897                        let _ = self
898                            .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
899                            .await;
900                    }
901                    return Err(std::io::Error::new(
902                        std::io::ErrorKind::TimedOut,
903                        "header read timeout",
904                    ));
905                }
906            };
907
908            // Connection header detection
909            let connection_header_split = request
910                .headers()
911                .get(header::CONNECTION)
912                .and_then(|v| v.to_str().ok())
913                .map(|v| v.split(",").map(|v| v.trim()));
914            let is_connection_close = connection_header_split
915                .clone()
916                .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("close")));
917            let is_connection_keep_alive = connection_header_split
918                .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("keep-alive")));
919            keep_alive = !is_connection_close
920                && (is_connection_keep_alive || request.version() == http::Version::HTTP_11);
921
922            let version = request.version();
923
924            // 100 Continue
925            if self.options.send_continue_response {
926                let is_100_continue = request
927                    .headers()
928                    .get(header::EXPECT)
929                    .and_then(|v| v.to_str().ok())
930                    .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
931                if is_100_continue {
932                    self.write_100_continue(version).await?;
933                }
934            }
935
936            // 103 Early Hints
937            let early_hints_fut = if self.options.enable_early_hints {
938                let (early_hints, mut early_hints_rx) = EarlyHints::new_lazy();
939                request.extensions_mut().insert(early_hints);
940                // Safety: the function below is used only in futures_util::future::select
941                // Also, another function that would borrow self would read data,
942                // while this function would write data
943                let mut_self = unsafe { std::mem::transmute::<&mut Self, &mut Self>(&mut self) };
944                futures_util::future::Either::Left(async move {
945                    while let Some((headers, sender)) =
946                        std::future::poll_fn(|cx| early_hints_rx.poll_recv(cx)).await
947                    {
948                        sender
949                            .into_inner()
950                            .send(mut_self.write_early_hints(version, headers).await)
951                            .ok();
952                    }
953                    futures_util::future::pending::<Result<(), std::io::Error>>().await
954                })
955            } else {
956                futures_util::future::Either::Right(futures_util::future::pending::<
957                    Result<(), std::io::Error>,
958                >())
959            };
960
961            // Content-Length header
962            let content_length = request
963                .headers()
964                .get(header::CONTENT_LENGTH)
965                .and_then(|v| v.to_str().ok())
966                .and_then(|v| v.parse::<u64>().ok())
967                .unwrap_or(0);
968            let chunked = request
969                .headers()
970                .get(header::TRANSFER_ENCODING)
971                .and_then(|v| v.to_str().ok())
972                .is_some_and(|v| {
973                    v.split(',')
974                        .any(|v| v.trim().eq_ignore_ascii_case("chunked"))
975                });
976            let has_trailers = request
977                .headers()
978                .get(header::TRAILER)
979                .map(|v| v.to_str().ok().is_some_and(|s| !s.is_empty()))
980                .unwrap_or(false);
981            let write_trailers = request
982                .headers()
983                .get(header::TE)
984                .and_then(|v| v.to_str().ok())
985                .map(|v| {
986                    v.split(',')
987                        .any(|v| v.trim().eq_ignore_ascii_case("trailers"))
988                })
989                .unwrap_or(false);
990
991            // Install HTTP upgrade
992            let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
993            let upgrade = Upgrade::new(upgrade_rx);
994            let upgraded = upgrade.upgraded.clone();
995            request.extensions_mut().insert(upgrade);
996
997            // Get HTTP response
998            let mut response = {
999                let read_body_fut = async {
1000                    if chunked {
1001                        self.read_chunked_body_fn(body_tx, has_trailers).await
1002                    } else {
1003                        self.read_body_fn(body_tx, content_length).await
1004                    }
1005                };
1006                let read_body_fut_pin = std::pin::pin!(read_body_fut);
1007                let request_fut = request_fn(request);
1008                let request_fut_pin = std::pin::pin!(request_fut);
1009                let early_hints_fut_pin = std::pin::pin!(early_hints_fut);
1010
1011                let select_read_body_either =
1012                    futures_util::future::select(request_fut_pin, early_hints_fut_pin);
1013                let select_either =
1014                    futures_util::future::select(read_body_fut_pin, select_read_body_either).await;
1015
1016                let (response, body_fut) = match select_either {
1017                    futures_util::future::Either::Left((result, request_fut)) => {
1018                        result?;
1019                        (
1020                            match request_fut.await {
1021                                futures_util::future::Either::Left((response, _)) => response,
1022                                futures_util::future::Either::Right((_, _)) => unreachable!(),
1023                            },
1024                            None,
1025                        )
1026                    }
1027                    futures_util::future::Either::Right((response, read_body_fut)) => (
1028                        match response {
1029                            futures_util::future::Either::Left((response, _)) => response,
1030                            futures_util::future::Either::Right((_, _)) => unreachable!(),
1031                        },
1032                        Some(read_body_fut),
1033                    ),
1034                };
1035
1036                // Drain away remaining body
1037                if let Some(body_fut) = body_fut {
1038                    body_fut.await?;
1039                }
1040
1041                response.map_err(|e| std::io::Error::other(e.to_string()))?
1042            };
1043
1044            let mut was_upgraded = false;
1045            if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
1046                was_upgraded = true;
1047                response
1048                    .headers_mut()
1049                    .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
1050            } else if keep_alive {
1051                if version == Version::HTTP_10
1052                    || response.headers().contains_key(header::CONNECTION)
1053                {
1054                    response
1055                        .headers_mut()
1056                        .insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
1057                }
1058            } else if version == Version::HTTP_11
1059                || response.headers().contains_key(header::CONNECTION)
1060            {
1061                response
1062                    .headers_mut()
1063                    .insert(header::CONNECTION, HeaderValue::from_static("close"));
1064            }
1065
1066            // Write response to IO
1067            self.write_response(response, version, write_trailers, zerocopy_fn.as_mut())
1068                .await?;
1069
1070            if was_upgraded {
1071                // HTTP upgrade
1072                let frozen_buf = self.read_buf.freeze();
1073                let _ = upgrade_tx.send(Upgraded::new(
1074                    self.io,
1075                    if frozen_buf.is_empty() {
1076                        None
1077                    } else {
1078                        Some(frozen_buf)
1079                    },
1080                ));
1081                return Ok(());
1082            }
1083
1084            if self.cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
1085                // Graceful shutdown requested, break out of loop
1086                break;
1087            }
1088        }
1089        Ok(())
1090    }
1091}
1092
1093impl<Io> HttpProtocol for Http1<Io>
1094where
1095    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
1096{
1097    #[inline]
1098    fn handle_with_error_fn<F, Fut, ResB, ResBE, ResE, EF, EFut, EResB, EResBE, EResE>(
1099        self,
1100        request_fn: F,
1101        error_fn: EF,
1102    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1103    where
1104        F: Fn(Request<Incoming>) -> Fut + 'static,
1105        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
1106        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
1107        ResE: std::error::Error,
1108        ResBE: std::error::Error,
1109        EF: FnOnce(bool) -> EFut,
1110        EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
1111        EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
1112        EResE: std::error::Error,
1113        EResBE: std::error::Error,
1114    {
1115        #[allow(clippy::type_complexity)]
1116        let no_zerocopy: Option<
1117            Box<
1118                dyn FnMut(
1119                    RawHandle,
1120                    &Io,
1121                    u64,
1122                ) -> Box<
1123                    dyn std::future::Future<Output = Result<(), std::io::Error>>
1124                        + Unpin
1125                        + Send
1126                        + Sync,
1127                >,
1128            >,
1129        > = None;
1130        self.handle_with_error_fn_and_zerocopy(request_fn, error_fn, no_zerocopy)
1131    }
1132
1133    #[inline]
1134    fn handle<F, Fut, ResB, ResBE, ResE>(
1135        self,
1136        request_fn: F,
1137    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1138    where
1139        F: Fn(Request<Incoming>) -> Fut + 'static,
1140        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
1141        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
1142        ResE: std::error::Error,
1143        ResBE: std::error::Error,
1144    {
1145        self.handle_with_error_fn(request_fn, |is_timeout| async move {
1146            let mut response = Response::builder();
1147            if is_timeout {
1148                response = response.status(http::StatusCode::REQUEST_TIMEOUT);
1149            } else {
1150                response = response.status(http::StatusCode::BAD_REQUEST);
1151            }
1152            response.body(Empty::new())
1153        })
1154    }
1155}
1156
1157pub(crate) struct Http1Body {
1158    #[allow(clippy::type_complexity)]
1159    inner: Pin<Box<AsyncReceiver<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>>>,
1160}
1161
1162impl Body for Http1Body {
1163    type Data = bytes::Bytes;
1164    type Error = std::io::Error;
1165
1166    #[inline]
1167    fn poll_frame(
1168        self: Pin<&mut Self>,
1169        cx: &mut Context<'_>,
1170    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
1171        match std::pin::pin!(self.inner.recv()).poll(cx) {
1172            Poll::Ready(Ok(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
1173            Poll::Ready(Ok(Err(e))) => Poll::Ready(Some(Err(e))),
1174            Poll::Ready(Err(_)) => Poll::Ready(None),
1175            Poll::Pending => Poll::Pending,
1176        }
1177    }
1178}
1179
1180/// Searches for the header/body separator in a given slice.
1181/// Returns the index of the separator and the length of the separator.
1182#[inline]
1183fn search_header_body_separator(slice: &[u8]) -> Option<(usize, usize)> {
1184    if slice.len() < 2 {
1185        // Slice too short
1186        return None;
1187    }
1188    for (i, b) in slice.iter().copied().enumerate() {
1189        if b == b'\r' {
1190            if slice[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) {
1191                return Some((i, 4));
1192            }
1193        } else if b == b'\n' && slice.get(i + 1) == Some(&b'\n') {
1194            return Some((i, 2));
1195        }
1196    }
1197    None
1198}
1199
1200/// Writes the chunk size to the given buffer in hexadecimal format, followed by `\r\n`.
1201#[inline]
1202fn write_chunk_size(dst: &mut [u8; 18], len: usize) -> &[u8] {
1203    let mut n = len;
1204    let mut pos = dst.len() - 2;
1205    loop {
1206        pos -= 1;
1207        dst[pos] = HEX_DIGITS[n & 0xF];
1208        n >>= 4;
1209        if n == 0 {
1210            break;
1211        }
1212    }
1213    dst[dst.len() - 2] = b'\r';
1214    dst[dst.len() - 1] = b'\n';
1215    &dst[pos..]
1216}