Skip to main content

vibeio_http/h1/
mod.rs

1mod options;
2mod tests;
3mod writebuf;
4mod zerocopy;
5
6pub use options::*;
7pub use zerocopy::*;
8
9#[cfg(unix)]
10pub(crate) type RawHandle = std::os::fd::RawFd;
11#[cfg(windows)]
12pub(crate) type RawHandle = std::os::windows::io::RawHandle;
13
14use std::{
15    future::Future,
16    io::IoSlice,
17    mem::MaybeUninit,
18    pin::Pin,
19    str::FromStr,
20    sync::{
21        atomic::{AtomicBool, Ordering},
22        Arc,
23    },
24    task::{Context, Poll},
25    time::UNIX_EPOCH,
26};
27
28use bytes::{Buf, Bytes, BytesMut};
29use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, Uri, Version};
30use http_body::Body;
31use http_body_util::{BodyExt, Empty};
32use kanal::AsyncReceiver;
33use memchr::{memchr3_iter, memmem};
34use tokio::io::{AsyncReadExt, AsyncWriteExt};
35use tokio_util::sync::CancellationToken;
36
37use crate::{h1::writebuf::WriteBuf, EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded};
38
39const HEX_DIGITS: &[u8; 16] = b"0123456789ABCDEF";
40const WRITE_BUF_BATCH_THRESHOLD: usize = 16384;
41
42/// An HTTP/1.x connection handler.
43///
44/// `Http1` wraps an async I/O stream (`Io`) and provides a complete
45/// HTTP/1.0 and HTTP/1.1 server implementation, including:
46///
47/// - Request head parsing (via [`httparse`])
48/// - Streaming request bodies (content-length and chunked transfer-encoding)
49/// - Chunked response encoding and trailer support
50/// - `100 Continue` and `103 Early Hints` interim responses
51/// - HTTP connection upgrades (e.g. WebSocket)
52/// - Optional zero-copy response sending on Linux (see `Http1::zerocopy`)
53/// - Keep-alive connection reuse
54/// - Graceful shutdown via a [`CancellationToken`]
55///
56/// # Construction
57///
58/// ```rust,ignore
59/// let http1 = Http1::new(tcp_stream, Http1Options::default());
60/// ```
61///
62/// # Serving requests
63///
64/// Use the [`HttpProtocol`] trait methods ([`handle`](HttpProtocol::handle) /
65/// [`handle_with_error_fn`](HttpProtocol::handle_with_error_fn)) to drive the
66/// connection to completion:
67///
68/// ```rust,ignore
69/// http1.handle(|req| async move {
70///     Ok::<_, Infallible>(Response::new(Full::new(Bytes::from("Hello!"))))
71/// }).await?;
72/// ```
73pub struct Http1<Io> {
74    io: Io,
75    options: options::Http1Options,
76    cancel_token: Option<CancellationToken>,
77    parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]>,
78    date_header_value_cached: Option<(String, std::time::SystemTime)>,
79    cached_headers: Option<HeaderMap>,
80    read_buf: BytesMut,
81    response_head_buf: Vec<u8>,
82    write_buf: WriteBuf,
83}
84
85#[cfg(all(target_os = "linux", feature = "h1-zerocopy"))]
86impl<Io> Http1<Io>
87where
88    for<'a> Io: tokio::io::AsyncRead
89        + tokio::io::AsyncWrite
90        + vibeio::io::AsInnerRawHandle<'a>
91        + Unpin
92        + 'static,
93{
94    /// Converts this `Http1` into an [`Http1Zerocopy`] that uses emulated
95    /// sendfile (Linux only) to send response bodies without copying data
96    /// through user space.
97    ///
98    /// The response body must have a `ZerocopyResponse` extension installed
99    /// (via [`install_zerocopy`]) containing the file descriptor to send from.
100    /// Responses without that extension are sent normally.
101    ///
102    /// Only available on Linux (`target_os = "linux"`), and only when `Io`
103    /// implements [`vibeio::io::AsInnerRawHandle`].
104    #[inline]
105    pub fn zerocopy(self) -> Http1Zerocopy<Io> {
106        Http1Zerocopy { inner: self }
107    }
108}
109
110impl<Io> Http1<Io>
111where
112    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
113{
114    /// Creates a new `Http1` connection handler wrapping the given I/O stream.
115    ///
116    /// The `options` value controls limits, timeouts, and optional features;
117    /// see [`Http1Options`] for details.
118    ///
119    /// # Example
120    ///
121    /// ```rust,ignore
122    /// let http1 = Http1::new(tcp_stream, Http1Options::default());
123    /// ```
124    #[inline]
125    pub fn new(io: Io, options: options::Http1Options) -> Self {
126        // Safety: u8 is a primitive type, so we can safely assume initialization
127        let read_buf = BytesMut::with_capacity(options.max_header_size);
128        let parsed_headers: Box<[MaybeUninit<httparse::Header<'static>>]> =
129            Box::new_uninit_slice(options.max_header_count);
130        Self {
131            io,
132            options,
133            cancel_token: None,
134            parsed_headers,
135            date_header_value_cached: None,
136            cached_headers: None,
137            read_buf,
138            response_head_buf: Vec::with_capacity(1024),
139            write_buf: WriteBuf::new(),
140        }
141    }
142
143    #[inline]
144    fn get_date_header_value(&mut self) -> &str {
145        let now = std::time::SystemTime::now();
146        if self.date_header_value_cached.as_ref().is_none_or(|v| {
147            v.1.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
148                != now.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())
149        }) {
150            let value = httpdate::fmt_http_date(now).to_string();
151            self.date_header_value_cached = Some((value, now));
152        }
153        self.date_header_value_cached
154            .as_ref()
155            .map(|v| v.0.as_str())
156            .unwrap_or("")
157    }
158
159    /// Attaches a [`CancellationToken`] for graceful shutdown.
160    ///
161    /// After the current in-flight request has been fully handled and its
162    /// response written, the connection loop checks whether the token has been
163    /// cancelled. If it has, the loop exits cleanly instead of waiting for the
164    /// next request.
165    ///
166    /// This allows the server to drain active connections without abruptly
167    /// closing them mid-response.
168    #[inline]
169    pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
170        self.cancel_token = Some(token);
171        self
172    }
173
174    #[inline]
175    async fn fill_buf(&mut self) -> Result<usize, std::io::Error> {
176        if self.read_buf.remaining() < 1024 {
177            self.read_buf.reserve(1024);
178        }
179        let spare_capacity = self.read_buf.spare_capacity_mut();
180        // Safety: The buffer is are read only after the request head has been parsed
181        let n = self
182            .io
183            .read(unsafe {
184                &mut *std::ptr::slice_from_raw_parts_mut(
185                    spare_capacity.as_mut_ptr() as *mut u8,
186                    spare_capacity.len(),
187                )
188            })
189            .await?;
190        if n == 0 {
191            return Ok(0);
192        }
193        unsafe { self.read_buf.set_len(self.read_buf.len() + n) };
194        Ok(n)
195    }
196
197    #[inline]
198    async fn read_body_fn(
199        &mut self,
200        body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
201        content_length: u64,
202        send_continue_body: &Option<Arc<AtomicBool>>,
203        continue_sent: &mut bool,
204        version: Version,
205    ) -> Result<(), std::io::Error> {
206        let mut remaining = content_length;
207        let mut just_started = true;
208        while remaining > 0 {
209            if !*continue_sent
210                && send_continue_body
211                    .as_ref()
212                    .is_some_and(|b| b.load(Ordering::Relaxed))
213            {
214                *continue_sent = true;
215                self.write_100_continue(version).await?;
216            }
217
218            let have_to_read_buf = !just_started || self.read_buf.is_empty();
219            just_started = false;
220            if have_to_read_buf {
221                let n = self.fill_buf().await?;
222                if n == 0 {
223                    break;
224                }
225            }
226            let chunk = self
227                .read_buf
228                .split_to(
229                    self.read_buf
230                        .len()
231                        .min(remaining.min(usize::MAX as u64) as usize),
232                )
233                .freeze();
234            remaining -= chunk.len() as u64;
235
236            let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
237        }
238        Ok(())
239    }
240
241    #[inline]
242    async fn read_body_chunk(
243        &mut self,
244        would_have_trailers: bool,
245        send_continue_body: &Option<Arc<AtomicBool>>,
246        continue_sent: &mut bool,
247        version: Version,
248    ) -> Result<bytes::Bytes, std::io::Error> {
249        let len = {
250            // Safety: u8 is a primitive type, so we can safely assume initialization
251            let mut len_buf_pos: usize = 0;
252            let mut just_started = true;
253            loop {
254                if len_buf_pos >= 48 {
255                    return Err(std::io::Error::new(
256                        std::io::ErrorKind::InvalidData,
257                        "chunk length buffer overflow",
258                    ));
259                }
260
261                let begin_search = len_buf_pos.saturating_sub(1);
262
263                let have_to_read_buf = !just_started || self.read_buf.is_empty();
264                just_started = false;
265                if have_to_read_buf {
266                    if !*continue_sent
267                        && send_continue_body
268                            .as_ref()
269                            .is_some_and(|b| b.load(Ordering::Relaxed))
270                    {
271                        *continue_sent = true;
272                        self.write_100_continue(version).await?;
273                    }
274                    let n = self.fill_buf().await?;
275                    if n == 0 {
276                        return Err(std::io::Error::new(
277                            std::io::ErrorKind::UnexpectedEof,
278                            "unexpected EOF",
279                        ));
280                    }
281                    len_buf_pos += n;
282                } else {
283                    len_buf_pos += self.read_buf.len();
284                }
285
286                if let Some(pos) =
287                    memmem::find(&self.read_buf[begin_search..len_buf_pos.min(48)], b"\r\n")
288                {
289                    let numbers = std::str::from_utf8(&self.read_buf[..begin_search + pos])
290                        .map_err(|_| {
291                            std::io::Error::new(
292                                std::io::ErrorKind::InvalidData,
293                                "invalid chunk length",
294                            )
295                        })?;
296                    let len = usize::from_str_radix(numbers, 16).map_err(|_| {
297                        std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid chunk length")
298                    })?;
299                    // Ignore the trailing CRLF
300                    self.read_buf.advance(begin_search + pos + 2);
301                    break len;
302                }
303            }
304        };
305        // Safety: u8 is a primitive type, so we can safely assume initialization
306        let mut read = 0;
307        if len == 0 && would_have_trailers {
308            return Ok(bytes::Bytes::new()); // Empty terminating chunk
309        }
310        let mut just_started = true;
311        // + 2, because we need to read the trailing CRLF
312        let Some(len_plus_two) = len.checked_add(2) else {
313            return Err(std::io::Error::new(
314                std::io::ErrorKind::InvalidData,
315                "chunk length too large",
316            ));
317        };
318        while read < len_plus_two {
319            let have_to_read_buf = !just_started || self.read_buf.is_empty();
320            just_started = false;
321            if have_to_read_buf {
322                if !*continue_sent
323                    && send_continue_body
324                        .as_ref()
325                        .is_some_and(|b| b.load(Ordering::Relaxed))
326                {
327                    *continue_sent = true;
328                    self.write_100_continue(version).await?;
329                }
330                let n = self.fill_buf().await?;
331                if n == 0 {
332                    return Err(std::io::Error::new(
333                        std::io::ErrorKind::UnexpectedEof,
334                        "unexpected EOF",
335                    ));
336                }
337                read += n;
338            } else {
339                read += self.read_buf.len();
340            }
341        }
342        let chunk = self.read_buf.split_to(len).freeze();
343        self.read_buf.advance(2); // Ignore the trailing CRLF
344        Ok(chunk)
345    }
346
347    #[inline]
348    async fn read_trailers(&mut self) -> Result<Option<HeaderMap>, std::io::Error> {
349        // Safety: u8 is a primitive type, so we can safely assume initialization
350        let mut bytes_read: usize = 0;
351        let mut just_started = true;
352        while bytes_read < self.options.max_header_size {
353            let old_bytes_read = bytes_read;
354            let begin_search = old_bytes_read.saturating_sub(3);
355
356            let have_to_read_buf = !just_started || self.read_buf.is_empty();
357            just_started = false;
358            if have_to_read_buf {
359                let n = self.fill_buf().await?;
360                if n == 0 {
361                    return Err(std::io::Error::new(
362                        std::io::ErrorKind::UnexpectedEof,
363                        "unexpected EOF",
364                    ));
365                }
366                bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
367            } else {
368                bytes_read =
369                    (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
370            }
371
372            if bytes_read >= 2 && self.read_buf[0] == b'\r' && self.read_buf[1] == b'\n' {
373                // No trailers, return None
374                return Ok(None);
375            }
376
377            if let Some(separator_index) =
378                memmem::find(&self.read_buf[begin_search..bytes_read], b"\r\n\r\n")
379            {
380                let to_parse_length = begin_search + separator_index + 4;
381                let buf_ro = self.read_buf.split_to(to_parse_length).freeze();
382
383                // Parse trailers using `httparse` crate's header parsing
384                let mut httparse_trailers =
385                    vec![httparse::EMPTY_HEADER; self.options.max_header_count].into_boxed_slice();
386                let status = httparse::parse_headers(&buf_ro, &mut httparse_trailers)
387                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
388                if let httparse::Status::Complete((_, trailers)) = status {
389                    let mut trailers_constructed = HeaderMap::new();
390                    for header in trailers {
391                        if header == &httparse::EMPTY_HEADER {
392                            // No more headers...
393                            break;
394                        }
395                        let name = HeaderName::from_bytes(header.name.as_bytes())
396                            .map_err(|e| std::io::Error::other(e.to_string()))?;
397                        let value_start = header.value.as_ptr() as usize - buf_ro.as_ptr() as usize;
398                        let value_len = header.value.len();
399                        // Safety: the header value is already validated by httparse
400                        let value = unsafe {
401                            HeaderValue::from_maybe_shared_unchecked(
402                                buf_ro.slice(value_start..(value_start + value_len)),
403                            )
404                        };
405                        trailers_constructed.append(name, value);
406                    }
407
408                    return Ok(Some(trailers_constructed));
409                } else {
410                    return Err(std::io::Error::new(
411                        std::io::ErrorKind::InvalidInput,
412                        "trailer headers incomplete",
413                    ));
414                }
415            }
416        }
417        Err(std::io::Error::new(
418            std::io::ErrorKind::InvalidData,
419            "request too large",
420        ))
421    }
422
423    #[inline]
424    async fn read_chunked_body_fn(
425        &mut self,
426        body_tx: kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
427        would_have_trailers: bool,
428        send_continue_body: &Option<Arc<AtomicBool>>,
429        continue_sent: &mut bool,
430        version: Version,
431    ) -> Result<(), std::io::Error> {
432        loop {
433            let chunk = self
434                .read_body_chunk(
435                    would_have_trailers,
436                    send_continue_body,
437                    continue_sent,
438                    version,
439                )
440                .await?;
441            if chunk.is_empty() {
442                break;
443            }
444
445            let _ = body_tx.send(Ok(http_body::Frame::data(chunk))).await;
446        }
447        if would_have_trailers {
448            // Trailers
449            let trailers = self.read_trailers().await?;
450            if let Some(trailers) = trailers {
451                let _ = body_tx.send(Ok(http_body::Frame::trailers(trailers))).await;
452            }
453        }
454        Ok(())
455    }
456
457    #[inline]
458    async fn read_request(
459        &mut self,
460    ) -> Result<
461        Option<(
462            Request<Incoming>,
463            kanal::AsyncSender<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>,
464            Option<Arc<AtomicBool>>,
465        )>,
466        std::io::Error,
467    > {
468        // Parse HTTP request using httparse
469        let (request, body_tx, send_continue_body) = {
470            let Some((head, headers)) = self.get_head().await? else {
471                return Ok(None);
472            };
473            // Safety: The headers are read only after the request head has been parsed
474            let headers = unsafe {
475                std::mem::transmute::<
476                    &mut [MaybeUninit<httparse::Header<'static>>],
477                    &mut [MaybeUninit<httparse::Header<'_>>],
478                >(headers)
479            };
480            let mut req = httparse::Request::new(&mut []);
481            let status = req
482                .parse_with_uninit_headers(&head, headers)
483                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
484            if status.is_partial() {
485                return Err(std::io::Error::new(
486                    std::io::ErrorKind::InvalidData,
487                    "partial request head",
488                ));
489            }
490
491            // Convert httparse HTTP request to `http` one
492            let (body_tx, body_rx) = kanal::bounded_async(2);
493
494            // Detect 100-continue and create flag before building the body
495            let is_100_continue = self.options.send_continue_response
496                && req.headers.iter().any(|h| {
497                    h.name.eq_ignore_ascii_case("expect")
498                        && h.value.eq_ignore_ascii_case(b"100-continue")
499                });
500            let send_continue_body = is_100_continue.then(|| Arc::new(AtomicBool::new(false)));
501
502            let request_body = Http1Body {
503                inner: Box::pin(body_rx),
504                send_continue_body: send_continue_body.clone(),
505            };
506            let mut request = Request::new(Incoming::H1(request_body));
507            match req.version {
508                Some(0) => *request.version_mut() = http::Version::HTTP_10,
509                Some(1) => *request.version_mut() = http::Version::HTTP_11,
510                _ => *request.version_mut() = http::Version::HTTP_11,
511            };
512            if let Some(method) = req.method {
513                *request.method_mut() = Method::from_bytes(method.as_bytes())
514                    .map_err(|e| std::io::Error::other(e.to_string()))?;
515            }
516            if let Some(path) = req.path {
517                *request.uri_mut() =
518                    Uri::from_str(path).map_err(|e| std::io::Error::other(e.to_string()))?;
519            }
520            let mut header_map = self.cached_headers.take().unwrap_or_default();
521            header_map.clear();
522            let additional_capacity = req.headers.len().saturating_sub(header_map.capacity());
523            if additional_capacity > 0 {
524                header_map.reserve(additional_capacity);
525            }
526            for header in req.headers {
527                if header == &httparse::EMPTY_HEADER {
528                    // No more headers...
529                    break;
530                }
531                let name = HeaderName::from_bytes(header.name.as_bytes())
532                    .map_err(|e| std::io::Error::other(e.to_string()))?;
533                let value_start = header.value.as_ptr() as usize - head.as_ptr() as usize;
534                let value_len = header.value.len();
535                // Safety: the header value is already validated by httparse
536                let value = unsafe {
537                    HeaderValue::from_maybe_shared_unchecked(
538                        head.slice(value_start..(value_start + value_len)),
539                    )
540                };
541                header_map.append(name, value);
542            }
543            *request.headers_mut() = header_map;
544
545            (request, body_tx, send_continue_body)
546        };
547        Ok(Some((request, body_tx, send_continue_body)))
548    }
549
550    #[inline]
551    async fn get_head(
552        &mut self,
553    ) -> Result<Option<(Bytes, &mut [MaybeUninit<httparse::Header<'static>>])>, std::io::Error>
554    {
555        let mut request_line_read = false;
556        let mut bytes_read: usize = 0;
557        let mut whitespace_trimmed = None;
558        let mut just_started = true;
559        while bytes_read < self.options.max_header_size {
560            let old_bytes_read = bytes_read;
561            let begin_search = old_bytes_read.saturating_sub(3);
562
563            let have_to_read_buf = !just_started || self.read_buf.is_empty();
564            just_started = false;
565            if have_to_read_buf {
566                let n = self.fill_buf().await?;
567                if n == 0 {
568                    if whitespace_trimmed.is_none() {
569                        return Ok(None);
570                    }
571                    return Err(std::io::Error::new(
572                        std::io::ErrorKind::UnexpectedEof,
573                        "unexpected EOF",
574                    ));
575                }
576                bytes_read = (old_bytes_read + n).min(self.options.max_header_size);
577            } else {
578                bytes_read =
579                    (old_bytes_read + self.read_buf.len()).min(self.options.max_header_size)
580            }
581
582            if whitespace_trimmed.is_none() {
583                whitespace_trimmed = self.read_buf[old_bytes_read..bytes_read]
584                    .iter()
585                    .position(|b| !b.is_ascii_whitespace());
586            }
587
588            if let Some(whitespace_trimmed) = whitespace_trimmed {
589                // Validate first line (request line) before checking for header/body separator
590                if !request_line_read {
591                    let memchr = memchr3_iter(
592                        b' ',
593                        b'\r',
594                        b'\n',
595                        &self.read_buf[whitespace_trimmed..bytes_read],
596                    );
597                    let mut spaces = 0;
598                    for separator_index in memchr {
599                        if self.read_buf[whitespace_trimmed + separator_index] == b' ' {
600                            if spaces >= 2 {
601                                return Err(std::io::Error::new(
602                                    std::io::ErrorKind::InvalidInput,
603                                    "bad request first line",
604                                ));
605                            }
606                            spaces += 1;
607                        } else if spaces == 2 {
608                            request_line_read = true;
609                            break;
610                        } else {
611                            return Err(std::io::Error::new(
612                                std::io::ErrorKind::InvalidInput,
613                                "bad request first line",
614                            ));
615                        }
616                    }
617                }
618
619                if request_line_read {
620                    let begin_search = begin_search.max(whitespace_trimmed);
621                    if let Some((separator_index, separator_len)) =
622                        search_header_body_separator(&self.read_buf[begin_search..bytes_read])
623                    {
624                        let to_parse_length =
625                            begin_search + separator_index + separator_len - whitespace_trimmed;
626                        self.read_buf.advance(whitespace_trimmed);
627                        let head = self.read_buf.split_to(to_parse_length);
628                        return Ok(Some((head.freeze(), &mut self.parsed_headers)));
629                    }
630                }
631            }
632        }
633        Err(std::io::Error::new(
634            std::io::ErrorKind::InvalidData,
635            "request too large",
636        ))
637    }
638
639    #[inline]
640    async fn write_response<Z, ZFut>(
641        &mut self,
642        mut response: Response<
643            impl Body<Data = bytes::Bytes, Error = impl std::error::Error> + Unpin,
644        >,
645        version: Version,
646        write_trailers: bool,
647        zerocopy_fn: Option<Z>,
648    ) -> Result<(), std::io::Error>
649    where
650        Z: FnMut(RawHandle, &'static Io, u64) -> ZFut,
651        ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
652    {
653        // Date header
654        if self.options.send_date_header {
655            response.headers_mut().insert(
656                header::DATE,
657                HeaderValue::from_str(self.get_date_header_value())
658                    .map_err(|e| std::io::Error::other(e.to_string()))?,
659            );
660        }
661
662        // If the body has a size hint, set the Content-Length header if it's not already set
663        if let Some(suggested_content_length) = response.body().size_hint().exact() {
664            let headers = response.headers_mut();
665            if !headers.contains_key(header::CONTENT_LENGTH) {
666                headers.insert(header::CONTENT_LENGTH, suggested_content_length.into());
667            }
668        }
669
670        let chunked = response
671            .headers()
672            .get(header::TRANSFER_ENCODING)
673            .map(|v| {
674                v.to_str().ok().is_some_and(|s| {
675                    s.split(',')
676                        .any(|s| s.trim().eq_ignore_ascii_case("chunked"))
677                })
678            })
679            .unwrap_or_else(|| {
680                response
681                    .headers()
682                    .get(header::CONTENT_LENGTH)
683                    .and_then(|v| v.to_str().ok())
684                    .is_none_or(|s| s.parse::<u64>().is_err())
685            });
686
687        if chunked {
688            response.headers_mut().insert(
689                header::TRANSFER_ENCODING,
690                HeaderValue::from_static("chunked"),
691            );
692            while response
693                .headers_mut()
694                .remove(header::CONTENT_LENGTH)
695                .is_some()
696            {}
697        }
698
699        let (parts, mut body) = response.into_parts();
700
701        self.response_head_buf.clear();
702        let estimated_head_len = 30 + parts.headers.len() * 30; // Similar to Hyper's heuristic
703        if self.response_head_buf.capacity() < estimated_head_len {
704            self.response_head_buf
705                .reserve(estimated_head_len - self.response_head_buf.capacity());
706        }
707        let head = &mut self.response_head_buf;
708        if version == Version::HTTP_10 {
709            head.extend_from_slice(b"HTTP/1.0 ");
710        } else {
711            head.extend_from_slice(b"HTTP/1.1 ");
712        }
713        let status = parts.status;
714        head.extend_from_slice(status.as_str().as_bytes());
715        if let Some(canonical_reason) = status.canonical_reason() {
716            head.extend_from_slice(b" ");
717            head.extend_from_slice(canonical_reason.as_bytes());
718        }
719        head.extend_from_slice(b"\r\n");
720        for (name, value) in &parts.headers {
721            head.extend_from_slice(name.as_str().as_bytes());
722            head.extend_from_slice(b": ");
723            head.extend_from_slice(value.as_bytes());
724            head.extend_from_slice(b"\r\n");
725        }
726        head.extend_from_slice(b"\r\n");
727        unsafe {
728            self.write_buf.push(IoSlice::new(head));
729        }
730
731        if !chunked {
732            if let Some(content_length) = parts
733                .headers
734                .get(header::CONTENT_LENGTH)
735                .and_then(|v| v.to_str().ok())
736                .and_then(|s| s.parse::<u64>().ok())
737            {
738                if let Some(zero_copy) = parts.extensions.get::<ZerocopyResponse>() {
739                    if let Some(mut zerocopy_fn) = zerocopy_fn {
740                        // Zerocopy
741                        unsafe {
742                            self.write_buf
743                                .flush(&mut self.io, self.options.enable_vectored_write)
744                                .await?
745                        };
746                        zerocopy_fn(
747                            zero_copy.handle,
748                            // Safety: the lifetime of the static reference is bound by the lifetime of the Io struct
749                            unsafe { std::mem::transmute::<&Io, &'static Io>(&self.io) },
750                            content_length,
751                        )
752                        .await?;
753                        self.io.flush().await?;
754                        let reclaimed_headers = parts.headers;
755                        self.cached_headers = Some(reclaimed_headers);
756                        return Ok(());
757                    }
758                }
759            }
760        }
761
762        let mut trailers_written = false;
763        while let Some(chunk) = body.frame().await {
764            let chunk = chunk.map_err(|e| std::io::Error::other(e.to_string()))?;
765            match chunk.into_data() {
766                Ok(data) => {
767                    if chunked {
768                        let mut chunk_size_buf = [0u8; 18];
769                        let chunk_size = write_chunk_size(&mut chunk_size_buf, data.len());
770                        self.write_buf.push_copy(chunk_size);
771                        self.write_buf.push_bytes(data);
772                        unsafe {
773                            self.write_buf.push(IoSlice::new(b"\r\n"));
774                        }
775                    } else {
776                        self.write_buf.push_bytes(data);
777                    }
778                    while self.write_buf.len() >= WRITE_BUF_BATCH_THRESHOLD {
779                        let bytes_written = unsafe {
780                            self.write_buf
781                                .write(&mut self.io, self.options.enable_vectored_write)
782                                .await?
783                        };
784                        if bytes_written == 0 {
785                            return Err(std::io::ErrorKind::WriteZero.into());
786                        }
787                    }
788                }
789                Err(chunk) => {
790                    if let Ok(trailers) = chunk.into_trailers() {
791                        if write_trailers {
792                            unsafe {
793                                self.write_buf.push(IoSlice::new(b"0\r\n"));
794                                for (name, value) in &trailers {
795                                    self.write_buf.push_copy(name.as_str().as_bytes());
796                                    self.write_buf.push(IoSlice::new(b": "));
797                                    self.write_buf.push_copy(value.as_bytes());
798                                    self.write_buf.push(IoSlice::new(b"\r\n"));
799                                }
800                                self.write_buf.push(IoSlice::new(b"\r\n"));
801                            }
802                            trailers_written = true;
803                        }
804                        break;
805                    }
806                }
807            };
808        }
809        if chunked && !trailers_written {
810            // Terminating chunk
811            unsafe {
812                self.write_buf.push(IoSlice::new(b"0\r\n\r\n"));
813            }
814        }
815        unsafe {
816            self.write_buf
817                .flush(&mut self.io, self.options.enable_vectored_write)
818                .await?;
819        }
820        self.io.flush().await?;
821        let reclaimed_headers = parts.headers;
822        self.cached_headers = Some(reclaimed_headers);
823
824        Ok(())
825    }
826
827    #[inline]
828    async fn write_100_continue(&mut self, version: Version) -> Result<(), std::io::Error> {
829        if version == Version::HTTP_10 {
830            self.io.write_all(b"HTTP/1.0 100 Continue\r\n\r\n").await?;
831        } else {
832            self.io.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
833        }
834        self.io.flush().await?;
835
836        Ok(())
837    }
838
839    #[inline]
840    async fn write_early_hints(
841        &mut self,
842        version: Version,
843        headers: http::HeaderMap,
844    ) -> Result<(), std::io::Error> {
845        let mut head = Vec::new();
846        if version == Version::HTTP_10 {
847            head.extend_from_slice(b"HTTP/1.0 103 Early Hints\r\n");
848        } else {
849            head.extend_from_slice(b"HTTP/1.1 103 Early Hints\r\n");
850        }
851        let mut current_header_name = None;
852        for (name, value) in headers {
853            if let Some(name) = name {
854                current_header_name = Some(name);
855            };
856            if let Some(current_header_name) = &current_header_name {
857                head.extend_from_slice(current_header_name.as_str().as_bytes());
858                if value.is_empty() {
859                    head.extend_from_slice(b":\r\n");
860                    continue;
861                }
862                head.extend_from_slice(b": ");
863                head.extend_from_slice(value.as_bytes());
864                head.extend_from_slice(b"\r\n");
865            }
866        }
867        head.extend_from_slice(b"\r\n");
868
869        self.io.write_all(&head).await?;
870
871        Ok(())
872    }
873
874    #[inline]
875    pub(crate) async fn handle_with_error_fn_and_zerocopy<
876        F,
877        Fut,
878        ResB,
879        ResBE,
880        ResE,
881        EF,
882        EFut,
883        EResB,
884        EResBE,
885        EResE,
886        ZF,
887        ZFut,
888    >(
889        mut self,
890        request_fn: F,
891        error_fn: EF,
892        mut zerocopy_fn: Option<ZF>,
893    ) -> Result<(), std::io::Error>
894    where
895        F: Fn(Request<Incoming>) -> Fut + 'static,
896        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
897        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
898        ResE: std::error::Error,
899        ResBE: std::error::Error,
900        EF: FnOnce(bool) -> EFut,
901        EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
902        EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
903        EResE: std::error::Error,
904        EResBE: std::error::Error,
905        ZF: FnMut(RawHandle, &'static Io, u64) -> ZFut,
906        ZFut: std::future::Future<Output = Result<(), std::io::Error>>,
907    {
908        let mut keep_alive = true;
909
910        while keep_alive {
911            let (mut request, body_tx, send_continue_body) = match if let Some(timeout) =
912                self.options.header_read_timeout
913            {
914                vibeio::time::timeout(timeout, async {
915                    if let Some(token) = self.cancel_token.clone() {
916                        token.run_until_cancelled(self.read_request()).await
917                    } else {
918                        Some(self.read_request().await)
919                    }
920                })
921                .await
922            } else {
923                Ok(Some(self.read_request().await))
924            } {
925                Ok(Some(Ok(Some(d)))) => d,
926                Ok(Some(Ok(None))) => {
927                    return Ok(());
928                }
929                Ok(Some(Err(e))) => {
930                    // Parse error
931                    if let Ok(mut response) = error_fn(false).await {
932                        response
933                            .headers_mut()
934                            .insert(header::CONNECTION, HeaderValue::from_static("close"));
935
936                        let _ = self
937                            .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
938                            .await;
939                    }
940                    return Err(e);
941                }
942                Ok(None) => {
943                    // Graceful shutdown
944                    return Ok(());
945                }
946                Err(_) => {
947                    // Timeout error
948                    if let Ok(mut response) = error_fn(true).await {
949                        response
950                            .headers_mut()
951                            .insert(header::CONNECTION, HeaderValue::from_static("close"));
952
953                        let _ = self
954                            .write_response(response, Version::HTTP_11, false, zerocopy_fn.as_mut())
955                            .await;
956                    }
957                    return Err(std::io::Error::new(
958                        std::io::ErrorKind::TimedOut,
959                        "header read timeout",
960                    ));
961                }
962            };
963
964            // Connection header detection
965            let connection_header_split = request
966                .headers()
967                .get(header::CONNECTION)
968                .and_then(|v| v.to_str().ok())
969                .map(|v| v.split(",").map(|v| v.trim()));
970            let is_connection_close = connection_header_split
971                .clone()
972                .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("close")));
973            let is_connection_keep_alive = connection_header_split
974                .is_some_and(|mut split| split.any(|v| v.eq_ignore_ascii_case("keep-alive")));
975            keep_alive = !is_connection_close
976                && (is_connection_keep_alive || request.version() == http::Version::HTTP_11);
977
978            let version = request.version();
979            let is_100_continue = send_continue_body.is_some();
980
981            // 103 Early Hints
982            let early_hints_fut = if self.options.enable_early_hints {
983                let (early_hints, mut early_hints_rx) = EarlyHints::new_lazy();
984                request.extensions_mut().insert(early_hints);
985                // Safety: the function below is used only in futures_util::future::select
986                // Also, another function that would borrow self would read data,
987                // while this function would write data
988                let mut_self = unsafe { std::mem::transmute::<&mut Self, &mut Self>(&mut self) };
989                futures_util::future::Either::Left(async move {
990                    while let Some((headers, sender)) =
991                        std::future::poll_fn(|cx| early_hints_rx.poll_recv(cx)).await
992                    {
993                        sender
994                            .into_inner()
995                            .send(mut_self.write_early_hints(version, headers).await)
996                            .ok();
997                    }
998                    futures_util::future::pending::<Result<(), std::io::Error>>().await
999                })
1000            } else {
1001                futures_util::future::Either::Right(futures_util::future::pending::<
1002                    Result<(), std::io::Error>,
1003                >())
1004            };
1005
1006            // Content-Length header
1007            let content_length = request
1008                .headers()
1009                .get(header::CONTENT_LENGTH)
1010                .and_then(|v| v.to_str().ok())
1011                .and_then(|v| v.parse::<u64>().ok())
1012                .unwrap_or(0);
1013            let chunked = request
1014                .headers()
1015                .get(header::TRANSFER_ENCODING)
1016                .and_then(|v| v.to_str().ok())
1017                .is_some_and(|v| {
1018                    v.split(',')
1019                        .any(|v| v.trim().eq_ignore_ascii_case("chunked"))
1020                });
1021            let has_trailers = request
1022                .headers()
1023                .get(header::TRAILER)
1024                .map(|v| v.to_str().ok().is_some_and(|s| !s.is_empty()))
1025                .unwrap_or(false);
1026            let write_trailers = request
1027                .headers()
1028                .get(header::TE)
1029                .and_then(|v| v.to_str().ok())
1030                .map(|v| {
1031                    v.split(',')
1032                        .any(|v| v.trim().eq_ignore_ascii_case("trailers"))
1033                })
1034                .unwrap_or(false);
1035
1036            // Install HTTP upgrade
1037            let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
1038            let upgrade = Upgrade::new(upgrade_rx);
1039            let upgraded = upgrade.upgraded.clone();
1040            request.extensions_mut().insert(upgrade);
1041
1042            // Get HTTP response
1043            let mut continue_sent = false;
1044            let mut response = {
1045                let read_body_fut = async {
1046                    if chunked {
1047                        self.read_chunked_body_fn(
1048                            body_tx,
1049                            has_trailers,
1050                            &send_continue_body,
1051                            &mut continue_sent,
1052                            version,
1053                        )
1054                        .await
1055                    } else {
1056                        self.read_body_fn(
1057                            body_tx,
1058                            content_length,
1059                            &send_continue_body,
1060                            &mut continue_sent,
1061                            version,
1062                        )
1063                        .await
1064                    }
1065                };
1066                let read_body_fut_pin = std::pin::pin!(read_body_fut);
1067                let request_fut = request_fn(request);
1068                let request_fut_pin = std::pin::pin!(request_fut);
1069                let early_hints_fut_pin = std::pin::pin!(early_hints_fut);
1070
1071                let select_read_body_either =
1072                    futures_util::future::select(request_fut_pin, early_hints_fut_pin);
1073                let select_either =
1074                    futures_util::future::select(read_body_fut_pin, select_read_body_either).await;
1075
1076                let (response, body_fut) = match select_either {
1077                    futures_util::future::Either::Left((result, request_fut)) => {
1078                        result?;
1079                        (
1080                            match request_fut.await {
1081                                futures_util::future::Either::Left((response, _)) => response,
1082                                futures_util::future::Either::Right((_, _)) => unreachable!(),
1083                            },
1084                            None,
1085                        )
1086                    }
1087                    futures_util::future::Either::Right((response, read_body_fut)) => (
1088                        match response {
1089                            futures_util::future::Either::Left((response, _)) => response,
1090                            futures_util::future::Either::Right((_, _)) => unreachable!(),
1091                        },
1092                        Some(read_body_fut),
1093                    ),
1094                };
1095
1096                // Drain away remaining body
1097                if let Some(body_fut) = body_fut {
1098                    body_fut.await?;
1099                }
1100
1101                response.map_err(|e| std::io::Error::other(e.to_string()))?
1102            };
1103
1104            // Response-triggered 100 Continue
1105            if !continue_sent
1106                && is_100_continue
1107                && !response.status().is_client_error()
1108                && !response.status().is_server_error()
1109            {
1110                self.write_100_continue(version).await?;
1111            }
1112
1113            let mut was_upgraded = false;
1114            if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
1115                was_upgraded = true;
1116                response
1117                    .headers_mut()
1118                    .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
1119            } else if keep_alive {
1120                if version == Version::HTTP_10
1121                    || response.headers().contains_key(header::CONNECTION)
1122                {
1123                    response
1124                        .headers_mut()
1125                        .insert(header::CONNECTION, HeaderValue::from_static("keep-alive"));
1126                }
1127            } else if version == Version::HTTP_11
1128                || response.headers().contains_key(header::CONNECTION)
1129            {
1130                response
1131                    .headers_mut()
1132                    .insert(header::CONNECTION, HeaderValue::from_static("close"));
1133            }
1134
1135            // Write response to IO
1136            self.write_response(response, version, write_trailers, zerocopy_fn.as_mut())
1137                .await?;
1138
1139            if was_upgraded {
1140                // HTTP upgrade
1141                let frozen_buf = self.read_buf.freeze();
1142                let _ = upgrade_tx.send(Upgraded::new(
1143                    self.io,
1144                    if frozen_buf.is_empty() {
1145                        None
1146                    } else {
1147                        Some(frozen_buf)
1148                    },
1149                ));
1150                return Ok(());
1151            }
1152
1153            if self.cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
1154                // Graceful shutdown requested, break out of loop
1155                break;
1156            }
1157        }
1158        Ok(())
1159    }
1160}
1161
1162impl<Io> HttpProtocol for Http1<Io>
1163where
1164    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
1165{
1166    #[inline]
1167    fn handle_with_error_fn<F, Fut, ResB, ResBE, ResE, EF, EFut, EResB, EResBE, EResE>(
1168        self,
1169        request_fn: F,
1170        error_fn: EF,
1171    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1172    where
1173        F: Fn(Request<Incoming>) -> Fut + 'static,
1174        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
1175        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
1176        ResE: std::error::Error,
1177        ResBE: std::error::Error,
1178        EF: FnOnce(bool) -> EFut,
1179        EFut: std::future::Future<Output = Result<Response<EResB>, EResE>>,
1180        EResB: Body<Data = bytes::Bytes, Error = EResBE> + Unpin + 'static,
1181        EResE: std::error::Error,
1182        EResBE: std::error::Error,
1183    {
1184        #[allow(clippy::type_complexity)]
1185        let no_zerocopy: Option<
1186            Box<
1187                dyn FnMut(
1188                    RawHandle,
1189                    &Io,
1190                    u64,
1191                ) -> Box<
1192                    dyn std::future::Future<Output = Result<(), std::io::Error>>
1193                        + Unpin
1194                        + Send
1195                        + Sync,
1196                >,
1197            >,
1198        > = None;
1199        self.handle_with_error_fn_and_zerocopy(request_fn, error_fn, no_zerocopy)
1200    }
1201
1202    #[inline]
1203    fn handle<F, Fut, ResB, ResBE, ResE>(
1204        self,
1205        request_fn: F,
1206    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
1207    where
1208        F: Fn(Request<Incoming>) -> Fut + 'static,
1209        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
1210        ResB: Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
1211        ResE: std::error::Error,
1212        ResBE: std::error::Error,
1213    {
1214        self.handle_with_error_fn(request_fn, |is_timeout| async move {
1215            let mut response = Response::builder();
1216            if is_timeout {
1217                response = response.status(http::StatusCode::REQUEST_TIMEOUT);
1218            } else {
1219                response = response.status(http::StatusCode::BAD_REQUEST);
1220            }
1221            response.body(Empty::new())
1222        })
1223    }
1224}
1225
1226pub(crate) struct Http1Body {
1227    #[allow(clippy::type_complexity)]
1228    inner: Pin<Box<AsyncReceiver<Result<http_body::Frame<bytes::Bytes>, std::io::Error>>>>,
1229    send_continue_body: Option<Arc<AtomicBool>>,
1230}
1231
1232impl Body for Http1Body {
1233    type Data = bytes::Bytes;
1234    type Error = std::io::Error;
1235
1236    #[inline]
1237    fn poll_frame(
1238        self: Pin<&mut Self>,
1239        cx: &mut Context<'_>,
1240    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
1241        match std::pin::pin!(self.inner.recv()).poll(cx) {
1242            Poll::Ready(Ok(Ok(frame))) => Poll::Ready(Some(Ok(frame))),
1243            Poll::Ready(Ok(Err(e))) => Poll::Ready(Some(Err(e))),
1244            Poll::Ready(Err(_)) => Poll::Ready(None),
1245            Poll::Pending => {
1246                if let Some(scb) = self.send_continue_body.as_ref() {
1247                    scb.store(true, Ordering::Relaxed);
1248                }
1249                Poll::Pending
1250            }
1251        }
1252    }
1253}
1254
1255/// Searches for the header/body separator in a given slice.
1256/// Returns the index of the separator and the length of the separator.
1257#[inline]
1258fn search_header_body_separator(slice: &[u8]) -> Option<(usize, usize)> {
1259    if slice.len() < 2 {
1260        // Slice too short
1261        return None;
1262    }
1263    for (i, b) in slice.iter().copied().enumerate() {
1264        if b == b'\r' {
1265            if slice[i + 1..].chunks(3).next() == Some(&b"\n\r\n"[..]) {
1266                return Some((i, 4));
1267            }
1268        } else if b == b'\n' && slice.get(i + 1) == Some(&b'\n') {
1269            return Some((i, 2));
1270        }
1271    }
1272    None
1273}
1274
1275/// Writes the chunk size to the given buffer in hexadecimal format, followed by `\r\n`.
1276#[inline]
1277fn write_chunk_size(dst: &mut [u8; 18], len: usize) -> &[u8] {
1278    let mut n = len;
1279    let mut pos = dst.len() - 2;
1280    loop {
1281        pos -= 1;
1282        dst[pos] = HEX_DIGITS[n & 0xF];
1283        n >>= 4;
1284        if n == 0 {
1285            break;
1286        }
1287    }
1288    dst[dst.len() - 2] = b'\r';
1289    dst[dst.len() - 1] = b'\n';
1290    &dst[pos..]
1291}