Skip to main content

vibeio_http/h2/
mod.rs

1mod date;
2mod options;
3mod send;
4
5pub use options::*;
6use tokio_util::sync::CancellationToken;
7
8use std::{
9    future::Future,
10    pin::Pin,
11    rc::Rc,
12    task::{Context, Poll},
13};
14
15use bytes::Bytes;
16use http::{Request, Response};
17use http_body::{Body, Frame};
18
19use crate::{
20    h2::{date::DateCache, send::PipeToSendStream},
21    EarlyHints, HttpProtocol, Incoming,
22};
23
24static HTTP2_INVALID_HEADERS: [http::header::HeaderName; 5] = [
25    http::header::HeaderName::from_static("keep-alive"),
26    http::header::HeaderName::from_static("proxy-connection"),
27    http::header::CONNECTION,
28    http::header::TRANSFER_ENCODING,
29    http::header::UPGRADE,
30];
31
32pub(crate) struct H2Body {
33    recv: h2::RecvStream,
34    data_done: bool,
35}
36
37impl H2Body {
38    #[inline]
39    fn new(recv: h2::RecvStream) -> Self {
40        Self {
41            recv,
42            data_done: false,
43        }
44    }
45}
46
47impl Body for H2Body {
48    type Data = Bytes;
49    type Error = std::io::Error;
50
51    #[inline]
52    fn poll_frame(
53        mut self: Pin<&mut Self>,
54        cx: &mut Context<'_>,
55    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
56        if !self.data_done {
57            match self.recv.poll_data(cx) {
58                Poll::Ready(Some(Ok(data))) => {
59                    let _ = self.recv.flow_control().release_capacity(data.len());
60                    return Poll::Ready(Some(Ok(Frame::data(data))));
61                }
62                Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(h2_error_to_io(err)))),
63                Poll::Ready(None) => self.data_done = true,
64                Poll::Pending => return Poll::Pending,
65            }
66        }
67
68        match self.recv.poll_trailers(cx) {
69            Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
70            Poll::Ready(Ok(None)) => Poll::Ready(None),
71            Poll::Ready(Err(err)) => Poll::Ready(Some(Err(h2_error_to_io(err)))),
72            Poll::Pending => Poll::Pending,
73        }
74    }
75}
76
77#[inline]
78pub(super) fn h2_error_to_io(error: h2::Error) -> std::io::Error {
79    if error.is_io() {
80        error.into_io().unwrap_or(std::io::Error::other("io error"))
81    } else {
82        std::io::Error::other(error)
83    }
84}
85
86#[inline]
87pub(super) fn h2_reason_to_io(reason: h2::Reason) -> std::io::Error {
88    std::io::Error::other(h2::Error::from(reason))
89}
90
91/// An HTTP/2 connection handler.
92///
93/// `Http2` wraps an async I/O stream (`Io`) and drives the HTTP/2 server
94/// connection using the [`h2`] crate. It supports:
95///
96/// - Concurrent request stream handling
97/// - Streaming request/response bodies and trailers
98/// - Automatic `100 Continue` and `103 Early Hints` interim responses
99/// - Per-connection `Date` header caching
100/// - Graceful shutdown via a [`CancellationToken`]
101///
102/// # Construction
103///
104/// ```rust,ignore
105/// let http2 = Http2::new(tcp_stream, Http2Options::default());
106/// ```
107///
108/// # Serving requests
109///
110/// Use the [`HttpProtocol`] trait methods ([`handle`](HttpProtocol::handle) /
111/// [`handle_with_error_fn`](HttpProtocol::handle_with_error_fn)) to drive the
112/// connection to completion.
113pub struct Http2<Io> {
114    io_to_handshake: Option<Io>,
115    date_header_value_cached: DateCache,
116    options: Http2Options,
117    cancel_token: Option<CancellationToken>,
118}
119
120impl<Io> Http2<Io>
121where
122    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
123{
124    /// Creates a new `Http2` connection handler wrapping the given I/O stream.
125    ///
126    /// The `options` value controls HTTP/2 protocol configuration, handshake
127    /// and accept timeouts, and optional behaviour such as automatic
128    /// `100 Continue` responses; see [`Http2Options`] for details.
129    ///
130    /// # Example
131    ///
132    /// ```rust,ignore
133    /// let http2 = Http2::new(tcp_stream, Http2Options::default());
134    /// ```
135    #[inline]
136    pub fn new(io: Io, options: Http2Options) -> Self {
137        Self {
138            io_to_handshake: Some(io),
139            date_header_value_cached: DateCache::default(),
140            options,
141            cancel_token: None,
142        }
143    }
144
145    /// Attaches a [`CancellationToken`] for graceful shutdown.
146    ///
147    /// When the token is cancelled, the handler sends HTTP/2 graceful shutdown
148    /// signals (GOAWAY), stops accepting new streams, and exits cleanly.
149    #[inline]
150    pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
151        self.cancel_token = Some(token);
152        self
153    }
154}
155
156impl<Io> HttpProtocol for Http2<Io>
157where
158    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
159{
160    #[allow(clippy::manual_async_fn)]
161    #[inline]
162    fn handle<F, Fut, ResB, ResBE, ResE>(
163        mut self,
164        request_fn: F,
165    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
166    where
167        F: Fn(Request<super::Incoming>) -> Fut + 'static,
168        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
169        ResB: http_body::Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
170        ResE: std::error::Error,
171        ResBE: std::error::Error,
172    {
173        async move {
174            let request_fn = Rc::new(request_fn);
175            let handshake_fut = self.options.h2.handshake(
176                self.io_to_handshake
177                    .take()
178                    .ok_or_else(|| std::io::Error::other("no io to handshake"))?,
179            );
180            let mut h2 = (if let Some(timeout) = self.options.handshake_timeout {
181                vibeio::time::timeout(timeout, handshake_fut).await
182            } else {
183                Ok(handshake_fut.await)
184            })
185            .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "handshake timeout"))?
186            .map_err(|e| {
187                if e.is_io() {
188                    e.into_io().unwrap_or(std::io::Error::other("io error"))
189                } else {
190                    std::io::Error::other(e)
191                }
192            })?;
193
194            while let Some(request) = {
195                let res = {
196                    let accept_fut_orig = h2.accept();
197                    let accept_fut_orig_pin = std::pin::pin!(accept_fut_orig);
198                    let cancel_token = self.cancel_token.clone();
199                    let cancel_fut = async move {
200                        if let Some(token) = cancel_token {
201                            token.cancelled().await
202                        } else {
203                            futures_util::future::pending().await
204                        }
205                    };
206                    let cancel_fut_pin = std::pin::pin!(cancel_fut);
207                    let accept_fut =
208                        futures_util::future::select(cancel_fut_pin, accept_fut_orig_pin);
209
210                    match if let Some(timeout) = self.options.accept_timeout {
211                        vibeio::time::timeout(timeout, accept_fut).await
212                    } else {
213                        Ok(accept_fut.await)
214                    } {
215                        Ok(futures_util::future::Either::Right((request, _))) => {
216                            (Some(request), false)
217                        }
218                        Ok(futures_util::future::Either::Left((_, _))) => {
219                            // Canceled
220                            (None, true)
221                        }
222                        Err(_) => {
223                            // Timeout
224                            (None, false)
225                        }
226                    }
227                };
228                match res {
229                    (Some(request), _) => request,
230                    (None, graceful) => {
231                        h2.graceful_shutdown();
232                        let _ = h2.accept().await;
233                        if graceful {
234                            return Ok(());
235                        }
236                        return Err(std::io::Error::new(
237                            std::io::ErrorKind::TimedOut,
238                            "accept timeout",
239                        ));
240                    }
241                }
242            } {
243                let (request, mut stream) = match request {
244                    Ok(d) => d,
245                    Err(e) if e.is_go_away() => {
246                        continue;
247                    }
248                    Err(e) if e.is_io() => {
249                        return Err(e.into_io().unwrap_or(std::io::Error::other("io error")));
250                    }
251                    Err(e) => {
252                        return Err(std::io::Error::other(e));
253                    }
254                };
255
256                let date_cache = self.date_header_value_cached.clone();
257                let request_fn = request_fn.clone();
258                vibeio::spawn(async move {
259                    let (request_parts, recv_stream) = request.into_parts();
260                    let request_body = Incoming::H2(H2Body::new(recv_stream));
261                    let mut request = Request::from_parts(request_parts, request_body);
262
263                    // 100 Continue
264                    if self.options.send_continue_response {
265                        let is_100_continue = request
266                            .headers()
267                            .get(http::header::EXPECT)
268                            .and_then(|v| v.to_str().ok())
269                            .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
270                        if is_100_continue {
271                            let mut response = Response::new(());
272                            *response.status_mut() = http::StatusCode::CONTINUE;
273                            let _ = stream.send_informational(response).map_err(h2_error_to_io);
274                        }
275                    }
276
277                    let (early_hints_tx, early_hints_rx) = async_channel::unbounded();
278                    let early_hints = EarlyHints::new(early_hints_tx);
279                    request.extensions_mut().insert(early_hints);
280
281                    let mut response_fut = std::pin::pin!(request_fn(request));
282                    let early_hints_rx = early_hints_rx;
283                    let response_result = loop {
284                        let early_hints_recv_fut = early_hints_rx.recv();
285                        let mut early_hints_recv_fut = std::pin::pin!(early_hints_recv_fut);
286                        let next = std::future::poll_fn(|cx| {
287                            match stream.poll_reset(cx) {
288                                Poll::Ready(Ok(reason)) => {
289                                    return Poll::Ready(Err(h2_reason_to_io(reason)));
290                                }
291                                Poll::Ready(Err(err)) => {
292                                    return Poll::Ready(Err(h2_error_to_io(err)));
293                                }
294                                Poll::Pending => {}
295                            }
296
297                            if let Poll::Ready(res) = response_fut.as_mut().poll(cx) {
298                                return Poll::Ready(Ok(futures_util::future::Either::Left(res)));
299                            }
300
301                            match early_hints_recv_fut.as_mut().poll(cx) {
302                                Poll::Ready(Ok(msg)) => {
303                                    Poll::Ready(Ok(futures_util::future::Either::Right(msg)))
304                                }
305                                Poll::Ready(Err(_)) => Poll::Pending,
306                                Poll::Pending => Poll::Pending,
307                            }
308                        })
309                        .await;
310
311                        match next {
312                            Ok(futures_util::future::Either::Left(response_result)) => {
313                                break response_result;
314                            }
315                            Ok(futures_util::future::Either::Right((headers, sender))) => {
316                                let mut response = Response::new(());
317                                *response.status_mut() = http::StatusCode::EARLY_HINTS;
318                                *response.headers_mut() = headers;
319                                sender
320                                    .into_inner()
321                                    .send(
322                                        stream.send_informational(response).map_err(h2_error_to_io),
323                                    )
324                                    .ok();
325                            }
326                            Err(_) => {
327                                return;
328                            }
329                        }
330                    };
331                    let Ok(mut response) = response_result else {
332                        // Return early if the request handler returns an error
333                        return;
334                    };
335
336                    {
337                        let response_headers = response.headers_mut();
338                        if self.options.send_date_header {
339                            if let Some(http_date) = date_cache.get_date_header_value() {
340                                response_headers
341                                    .entry(http::header::DATE)
342                                    .or_insert(http_date);
343                            }
344                        }
345                        for header in &HTTP2_INVALID_HEADERS {
346                            if let http::header::Entry::Occupied(entry) =
347                                response_headers.entry(header)
348                            {
349                                entry.remove();
350                            }
351                        }
352                        if response_headers
353                            .get(http::header::TE)
354                            .is_some_and(|v| v != "trailers")
355                        {
356                            response_headers.remove(http::header::TE);
357                        }
358                    }
359
360                    let response_is_end_stream = response.body().is_end_stream();
361                    if !response_is_end_stream {
362                        if let Some(content_length) = response.body().size_hint().exact() {
363                            if !response
364                                .headers()
365                                .contains_key(http::header::CONTENT_LENGTH)
366                            {
367                                response
368                                    .headers_mut()
369                                    .insert(http::header::CONTENT_LENGTH, content_length.into());
370                            }
371                        }
372                    }
373
374                    let (response_parts, mut response_body) = response.into_parts();
375                    let Ok(send) = stream.send_response(
376                        Response::from_parts(response_parts, ()),
377                        response_is_end_stream,
378                    ) else {
379                        return;
380                    };
381
382                    if response_is_end_stream {
383                        return;
384                    }
385
386                    let _ = PipeToSendStream::new(send, &mut response_body).await;
387                });
388            }
389
390            Ok(())
391        }
392    }
393}