Skip to main content

vibeio_http/h2/
mod.rs

1// TODO: add support for extended CONNECT
2
3mod date;
4mod options;
5mod send;
6mod upgrade;
7
8pub use options::*;
9use tokio_util::sync::CancellationToken;
10
11use std::{
12    future::Future,
13    pin::Pin,
14    task::{Context, Poll},
15};
16
17use bytes::Bytes;
18use http::{Request, Response};
19use http_body::{Body, Frame};
20
21use crate::{
22    h2::{date::DateCache, send::PipeToSendStream},
23    EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded,
24};
25
26static HTTP2_INVALID_HEADERS: [http::header::HeaderName; 5] = [
27    http::header::HeaderName::from_static("keep-alive"),
28    http::header::HeaderName::from_static("proxy-connection"),
29    http::header::CONNECTION,
30    http::header::TRANSFER_ENCODING,
31    http::header::UPGRADE,
32];
33
34pub(crate) struct H2Body {
35    recv: h2::RecvStream,
36    data_done: bool,
37}
38
39impl H2Body {
40    #[inline]
41    fn new(recv: h2::RecvStream) -> Self {
42        Self {
43            recv,
44            data_done: false,
45        }
46    }
47}
48
49impl Body for H2Body {
50    type Data = Bytes;
51    type Error = std::io::Error;
52
53    #[inline]
54    fn poll_frame(
55        mut self: Pin<&mut Self>,
56        cx: &mut Context<'_>,
57    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
58        if !self.data_done {
59            match self.recv.poll_data(cx) {
60                Poll::Ready(Some(Ok(data))) => {
61                    let _ = self.recv.flow_control().release_capacity(data.len());
62                    return Poll::Ready(Some(Ok(Frame::data(data))));
63                }
64                Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(h2_error_to_io(err)))),
65                Poll::Ready(None) => self.data_done = true,
66                Poll::Pending => return Poll::Pending,
67            }
68        }
69
70        match self.recv.poll_trailers(cx) {
71            Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
72            Poll::Ready(Ok(None)) => Poll::Ready(None),
73            Poll::Ready(Err(err)) => Poll::Ready(Some(Err(h2_error_to_io(err)))),
74            Poll::Pending => Poll::Pending,
75        }
76    }
77}
78
79#[inline]
80pub(super) fn h2_error_to_io(error: h2::Error) -> std::io::Error {
81    if error.is_io() {
82        error.into_io().unwrap_or(std::io::Error::other("io error"))
83    } else {
84        std::io::Error::other(error)
85    }
86}
87
88#[inline]
89pub(super) fn h2_reason_to_io(reason: h2::Reason) -> std::io::Error {
90    std::io::Error::other(h2::Error::from(reason))
91}
92
93/// An HTTP/2 connection handler.
94///
95/// `Http2` wraps an async I/O stream (`Io`) and drives the HTTP/2 server
96/// connection using the [`h2`] crate. It supports:
97///
98/// - Concurrent request stream handling
99/// - Streaming request/response bodies and trailers
100/// - Automatic `100 Continue` and `103 Early Hints` interim responses
101/// - Per-connection `Date` header caching
102/// - Graceful shutdown via a [`CancellationToken`]
103///
104/// # Construction
105///
106/// ```rust,ignore
107/// let http2 = Http2::new(tcp_stream, Http2Options::default());
108/// ```
109///
110/// # Serving requests
111///
112/// Use the [`HttpProtocol`] trait methods ([`handle`](HttpProtocol::handle) /
113/// [`handle_with_error_fn`](HttpProtocol::handle_with_error_fn)) to drive the
114/// connection to completion.
115pub struct Http2<Io> {
116    io_to_handshake: Option<Io>,
117    date_header_value_cached: DateCache,
118    options: Http2Options,
119    cancel_token: Option<CancellationToken>,
120}
121
122impl<Io> Http2<Io>
123where
124    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
125{
126    /// Creates a new `Http2` connection handler wrapping the given I/O stream.
127    ///
128    /// The `options` value controls HTTP/2 protocol configuration, handshake
129    /// and accept timeouts, and optional behaviour such as automatic
130    /// `100 Continue` responses; see [`Http2Options`] for details.
131    ///
132    /// # Example
133    ///
134    /// ```rust,ignore
135    /// let http2 = Http2::new(tcp_stream, Http2Options::default());
136    /// ```
137    #[inline]
138    pub fn new(io: Io, options: Http2Options) -> Self {
139        Self {
140            io_to_handshake: Some(io),
141            date_header_value_cached: DateCache::default(),
142            options,
143            cancel_token: None,
144        }
145    }
146
147    /// Attaches a [`CancellationToken`] for graceful shutdown.
148    ///
149    /// When the token is cancelled, the handler sends HTTP/2 graceful shutdown
150    /// signals (GOAWAY), stops accepting new streams, and exits cleanly.
151    #[inline]
152    pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
153        self.cancel_token = Some(token);
154        self
155    }
156}
157
158impl<Io> HttpProtocol for Http2<Io>
159where
160    Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
161{
162    #[allow(clippy::manual_async_fn)]
163    #[inline]
164    fn handle<F, Fut, ResB, ResBE, ResE>(
165        mut self,
166        request_fn: F,
167    ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
168    where
169        F: Fn(Request<super::Incoming>) -> Fut + 'static,
170        Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
171        ResB: http_body::Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
172        ResE: std::error::Error,
173        ResBE: std::error::Error,
174    {
175        async move {
176            let handshake_fut = self.options.h2.handshake(
177                self.io_to_handshake
178                    .take()
179                    .ok_or_else(|| std::io::Error::other("no io to handshake"))?,
180            );
181            let mut h2 = (if let Some(timeout) = self.options.handshake_timeout {
182                vibeio::time::timeout(timeout, handshake_fut).await
183            } else {
184                Ok(handshake_fut.await)
185            })
186            .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "handshake timeout"))?
187            .map_err(|e| {
188                if e.is_io() {
189                    e.into_io().unwrap_or(std::io::Error::other("io error"))
190                } else {
191                    std::io::Error::other(e)
192                }
193            })?;
194
195            while let Some(request) = {
196                let res = {
197                    let accept_fut_orig = h2.accept();
198                    let accept_fut_orig_pin = std::pin::pin!(accept_fut_orig);
199                    let cancel_token = self.cancel_token.clone();
200                    let cancel_fut = async move {
201                        if let Some(token) = cancel_token {
202                            token.cancelled().await
203                        } else {
204                            futures_util::future::pending().await
205                        }
206                    };
207                    let cancel_fut_pin = std::pin::pin!(cancel_fut);
208                    let accept_fut =
209                        futures_util::future::select(cancel_fut_pin, accept_fut_orig_pin);
210
211                    match if let Some(timeout) = self.options.accept_timeout {
212                        vibeio::time::timeout(timeout, accept_fut).await
213                    } else {
214                        Ok(accept_fut.await)
215                    } {
216                        Ok(futures_util::future::Either::Right((request, _))) => {
217                            (Some(request), false)
218                        }
219                        Ok(futures_util::future::Either::Left((_, _))) => {
220                            // Canceled
221                            (None, true)
222                        }
223                        Err(_) => {
224                            // Timeout
225                            (None, false)
226                        }
227                    }
228                };
229                match res {
230                    (Some(request), _) => request,
231                    (None, graceful) => {
232                        h2.graceful_shutdown();
233                        let _ = h2.accept().await;
234                        if graceful {
235                            return Ok(());
236                        }
237                        return Err(std::io::Error::new(
238                            std::io::ErrorKind::TimedOut,
239                            "accept timeout",
240                        ));
241                    }
242                }
243            } {
244                let (request, mut stream) = match request {
245                    Ok(d) => d,
246                    Err(e) if e.is_go_away() => {
247                        continue;
248                    }
249                    Err(e) if e.is_io() => {
250                        return Err(e.into_io().unwrap_or(std::io::Error::other("io error")));
251                    }
252                    Err(e) => {
253                        return Err(std::io::Error::other(e));
254                    }
255                };
256
257                let date_cache = self.date_header_value_cached.clone();
258                let (request_parts, recv_stream) = request.into_parts();
259                let (request_body, upgrade) = if request_parts.method == http::Method::CONNECT {
260                    (Incoming::Empty, Some(recv_stream))
261                } else {
262                    (Incoming::H2(H2Body::new(recv_stream)), None)
263                };
264                let mut request = Request::from_parts(request_parts, request_body);
265
266                // 100 Continue
267                let is_100_continue = self.options.send_continue_response
268                    && request
269                        .headers()
270                        .get(http::header::EXPECT)
271                        .and_then(|v| v.to_str().ok())
272                        .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
273
274                // Install early hints
275                let (early_hints_tx, early_hints_rx) = kanal::unbounded_async();
276                let early_hints = EarlyHints::new(early_hints_tx);
277                request.extensions_mut().insert(early_hints);
278
279                // Install HTTP upgrade
280                let upgrade = if let Some(recv_stream) = upgrade {
281                    let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
282                    let upgrade = Upgrade::new(upgrade_rx);
283                    let upgraded = upgrade.upgraded.clone();
284                    request.extensions_mut().insert(upgrade);
285                    Some((upgrade_tx, upgraded, recv_stream))
286                } else {
287                    None
288                };
289
290                let response_fut = Box::new(request_fn(request));
291
292                vibeio::spawn(async move {
293                    if is_100_continue {
294                        let mut response = Response::new(());
295                        *response.status_mut() = http::StatusCode::CONTINUE;
296                        let _ = stream.send_informational(response).map_err(h2_error_to_io);
297                    }
298
299                    let mut response_fut = Box::into_pin(response_fut);
300                    let early_hints_rx = early_hints_rx;
301                    let response_result = loop {
302                        let early_hints_recv_fut = early_hints_rx.recv();
303                        let mut early_hints_recv_fut = std::pin::pin!(early_hints_recv_fut);
304                        let next = std::future::poll_fn(|cx| {
305                            match stream.poll_reset(cx) {
306                                Poll::Ready(Ok(reason)) => {
307                                    return Poll::Ready(Err(h2_reason_to_io(reason)));
308                                }
309                                Poll::Ready(Err(err)) => {
310                                    return Poll::Ready(Err(h2_error_to_io(err)));
311                                }
312                                Poll::Pending => {}
313                            }
314
315                            if let Poll::Ready(res) = response_fut.as_mut().poll(cx) {
316                                return Poll::Ready(Ok(futures_util::future::Either::Left(res)));
317                            }
318
319                            match early_hints_recv_fut.as_mut().poll(cx) {
320                                Poll::Ready(Ok(msg)) => {
321                                    Poll::Ready(Ok(futures_util::future::Either::Right(msg)))
322                                }
323                                Poll::Ready(Err(_)) => Poll::Pending,
324                                Poll::Pending => Poll::Pending,
325                            }
326                        })
327                        .await;
328
329                        match next {
330                            Ok(futures_util::future::Either::Left(response_result)) => {
331                                break response_result;
332                            }
333                            Ok(futures_util::future::Either::Right((headers, sender))) => {
334                                let mut response = Response::new(());
335                                *response.status_mut() = http::StatusCode::EARLY_HINTS;
336                                *response.headers_mut() = headers;
337                                sender
338                                    .into_inner()
339                                    .send(
340                                        stream.send_informational(response).map_err(h2_error_to_io),
341                                    )
342                                    .ok();
343                            }
344                            Err(_) => {
345                                return;
346                            }
347                        }
348                    };
349                    let Ok(mut response) = response_result else {
350                        // Return early if the request handler returns an error
351                        return;
352                    };
353
354                    {
355                        let response_headers = response.headers_mut();
356                        if self.options.send_date_header {
357                            if let Some(http_date) = date_cache.get_date_header_value() {
358                                response_headers
359                                    .entry(http::header::DATE)
360                                    .or_insert(http_date);
361                            }
362                        }
363                        for header in &HTTP2_INVALID_HEADERS {
364                            if let http::header::Entry::Occupied(entry) =
365                                response_headers.entry(header)
366                            {
367                                entry.remove();
368                            }
369                        }
370                        if response_headers
371                            .get(http::header::TE)
372                            .is_some_and(|v| v != "trailers")
373                        {
374                            response_headers.remove(http::header::TE);
375                        }
376                    }
377
378                    let response_is_end_stream = response.body().is_end_stream();
379                    if !response_is_end_stream {
380                        if let Some(content_length) = response.body().size_hint().exact() {
381                            if !response
382                                .headers()
383                                .contains_key(http::header::CONTENT_LENGTH)
384                            {
385                                response
386                                    .headers_mut()
387                                    .insert(http::header::CONTENT_LENGTH, content_length.into());
388                            }
389                        }
390                    }
391
392                    let (response_parts, mut response_body) = response.into_parts();
393                    let Ok(send) = stream.send_response(
394                        Response::from_parts(response_parts, ()),
395                        response_is_end_stream && upgrade.is_none(),
396                    ) else {
397                        return;
398                    };
399
400                    if let Some((upgrade_tx, upgraded, recv_stream)) = upgrade {
401                        if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
402                            let (upgraded, task) = self::upgrade::pair(send, recv_stream);
403                            let _ = upgrade_tx.send(Upgraded::new(upgraded, None));
404                            task.await;
405                            return;
406                        }
407                    }
408
409                    if response_is_end_stream {
410                        return;
411                    }
412
413                    // No upgrade, send the body directly
414                    let _ = PipeToSendStream::new(send, &mut response_body).await;
415                });
416            }
417
418            Ok(())
419        }
420    }
421}