1mod date;
4mod options;
5mod send;
6mod upgrade;
7
8pub use options::*;
9use tokio_util::sync::CancellationToken;
10
11use std::{
12 future::Future,
13 pin::Pin,
14 task::{Context, Poll},
15};
16
17use bytes::Bytes;
18use http::{Request, Response};
19use http_body::{Body, Frame};
20
21use crate::{
22 h2::{date::DateCache, send::PipeToSendStream},
23 EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded,
24};
25
26static HTTP2_INVALID_HEADERS: [http::header::HeaderName; 5] = [
27 http::header::HeaderName::from_static("keep-alive"),
28 http::header::HeaderName::from_static("proxy-connection"),
29 http::header::CONNECTION,
30 http::header::TRANSFER_ENCODING,
31 http::header::UPGRADE,
32];
33
34pub(crate) struct H2Body {
35 recv: h2::RecvStream,
36 data_done: bool,
37}
38
39impl H2Body {
40 #[inline]
41 fn new(recv: h2::RecvStream) -> Self {
42 Self {
43 recv,
44 data_done: false,
45 }
46 }
47}
48
49impl Body for H2Body {
50 type Data = Bytes;
51 type Error = std::io::Error;
52
53 #[inline]
54 fn poll_frame(
55 mut self: Pin<&mut Self>,
56 cx: &mut Context<'_>,
57 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
58 if !self.data_done {
59 match self.recv.poll_data(cx) {
60 Poll::Ready(Some(Ok(data))) => {
61 let _ = self.recv.flow_control().release_capacity(data.len());
62 return Poll::Ready(Some(Ok(Frame::data(data))));
63 }
64 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(h2_error_to_io(err)))),
65 Poll::Ready(None) => self.data_done = true,
66 Poll::Pending => return Poll::Pending,
67 }
68 }
69
70 match self.recv.poll_trailers(cx) {
71 Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
72 Poll::Ready(Ok(None)) => Poll::Ready(None),
73 Poll::Ready(Err(err)) => Poll::Ready(Some(Err(h2_error_to_io(err)))),
74 Poll::Pending => Poll::Pending,
75 }
76 }
77}
78
79#[inline]
80pub(super) fn h2_error_to_io(error: h2::Error) -> std::io::Error {
81 if error.is_io() {
82 error.into_io().unwrap_or(std::io::Error::other("io error"))
83 } else {
84 std::io::Error::other(error)
85 }
86}
87
88#[inline]
89pub(super) fn h2_reason_to_io(reason: h2::Reason) -> std::io::Error {
90 std::io::Error::other(h2::Error::from(reason))
91}
92
93pub struct Http2<Io> {
116 io_to_handshake: Option<Io>,
117 date_header_value_cached: DateCache,
118 options: Http2Options,
119 cancel_token: Option<CancellationToken>,
120}
121
122impl<Io> Http2<Io>
123where
124 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
125{
126 #[inline]
138 pub fn new(io: Io, options: Http2Options) -> Self {
139 Self {
140 io_to_handshake: Some(io),
141 date_header_value_cached: DateCache::default(),
142 options,
143 cancel_token: None,
144 }
145 }
146
147 #[inline]
152 pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
153 self.cancel_token = Some(token);
154 self
155 }
156}
157
158impl<Io> HttpProtocol for Http2<Io>
159where
160 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
161{
162 #[allow(clippy::manual_async_fn)]
163 #[inline]
164 fn handle<F, Fut, ResB, ResBE, ResE>(
165 mut self,
166 request_fn: F,
167 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
168 where
169 F: Fn(Request<super::Incoming>) -> Fut + 'static,
170 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
171 ResB: http_body::Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
172 ResE: std::error::Error,
173 ResBE: std::error::Error,
174 {
175 async move {
176 let handshake_fut = self.options.h2.handshake(
177 self.io_to_handshake
178 .take()
179 .ok_or_else(|| std::io::Error::other("no io to handshake"))?,
180 );
181 let mut h2 = (if let Some(timeout) = self.options.handshake_timeout {
182 vibeio::time::timeout(timeout, handshake_fut).await
183 } else {
184 Ok(handshake_fut.await)
185 })
186 .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "handshake timeout"))?
187 .map_err(|e| {
188 if e.is_io() {
189 e.into_io().unwrap_or(std::io::Error::other("io error"))
190 } else {
191 std::io::Error::other(e)
192 }
193 })?;
194
195 while let Some(request) = {
196 let res = {
197 let accept_fut_orig = h2.accept();
198 let accept_fut_orig_pin = std::pin::pin!(accept_fut_orig);
199 let cancel_token = self.cancel_token.clone();
200 let cancel_fut = async move {
201 if let Some(token) = cancel_token {
202 token.cancelled().await
203 } else {
204 futures_util::future::pending().await
205 }
206 };
207 let cancel_fut_pin = std::pin::pin!(cancel_fut);
208 let accept_fut =
209 futures_util::future::select(cancel_fut_pin, accept_fut_orig_pin);
210
211 match if let Some(timeout) = self.options.accept_timeout {
212 vibeio::time::timeout(timeout, accept_fut).await
213 } else {
214 Ok(accept_fut.await)
215 } {
216 Ok(futures_util::future::Either::Right((request, _))) => {
217 (Some(request), false)
218 }
219 Ok(futures_util::future::Either::Left((_, _))) => {
220 (None, true)
222 }
223 Err(_) => {
224 (None, false)
226 }
227 }
228 };
229 match res {
230 (Some(request), _) => request,
231 (None, graceful) => {
232 h2.graceful_shutdown();
233 let _ = h2.accept().await;
234 if graceful {
235 return Ok(());
236 }
237 return Err(std::io::Error::new(
238 std::io::ErrorKind::TimedOut,
239 "accept timeout",
240 ));
241 }
242 }
243 } {
244 let (request, mut stream) = match request {
245 Ok(d) => d,
246 Err(e) if e.is_go_away() => {
247 continue;
248 }
249 Err(e) if e.is_io() => {
250 return Err(e.into_io().unwrap_or(std::io::Error::other("io error")));
251 }
252 Err(e) => {
253 return Err(std::io::Error::other(e));
254 }
255 };
256
257 let date_cache = self.date_header_value_cached.clone();
258 let (request_parts, recv_stream) = request.into_parts();
259 let (request_body, upgrade) = if request_parts.method == http::Method::CONNECT {
260 (Incoming::Empty, Some(recv_stream))
261 } else {
262 (Incoming::H2(H2Body::new(recv_stream)), None)
263 };
264 let mut request = Request::from_parts(request_parts, request_body);
265
266 let is_100_continue = self.options.send_continue_response
268 && request
269 .headers()
270 .get(http::header::EXPECT)
271 .and_then(|v| v.to_str().ok())
272 .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
273
274 let (early_hints_tx, early_hints_rx) = kanal::unbounded_async();
276 let early_hints = EarlyHints::new(early_hints_tx);
277 request.extensions_mut().insert(early_hints);
278
279 let upgrade = if let Some(recv_stream) = upgrade {
281 let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
282 let upgrade = Upgrade::new(upgrade_rx);
283 let upgraded = upgrade.upgraded.clone();
284 request.extensions_mut().insert(upgrade);
285 Some((upgrade_tx, upgraded, recv_stream))
286 } else {
287 None
288 };
289
290 let response_fut = Box::new(request_fn(request));
291
292 vibeio::spawn(async move {
293 if is_100_continue {
294 let mut response = Response::new(());
295 *response.status_mut() = http::StatusCode::CONTINUE;
296 let _ = stream.send_informational(response).map_err(h2_error_to_io);
297 }
298
299 let mut response_fut = Box::into_pin(response_fut);
300 let early_hints_rx = early_hints_rx;
301 let response_result = loop {
302 let early_hints_recv_fut = early_hints_rx.recv();
303 let mut early_hints_recv_fut = std::pin::pin!(early_hints_recv_fut);
304 let next = std::future::poll_fn(|cx| {
305 match stream.poll_reset(cx) {
306 Poll::Ready(Ok(reason)) => {
307 return Poll::Ready(Err(h2_reason_to_io(reason)));
308 }
309 Poll::Ready(Err(err)) => {
310 return Poll::Ready(Err(h2_error_to_io(err)));
311 }
312 Poll::Pending => {}
313 }
314
315 if let Poll::Ready(res) = response_fut.as_mut().poll(cx) {
316 return Poll::Ready(Ok(futures_util::future::Either::Left(res)));
317 }
318
319 match early_hints_recv_fut.as_mut().poll(cx) {
320 Poll::Ready(Ok(msg)) => {
321 Poll::Ready(Ok(futures_util::future::Either::Right(msg)))
322 }
323 Poll::Ready(Err(_)) => Poll::Pending,
324 Poll::Pending => Poll::Pending,
325 }
326 })
327 .await;
328
329 match next {
330 Ok(futures_util::future::Either::Left(response_result)) => {
331 break response_result;
332 }
333 Ok(futures_util::future::Either::Right((headers, sender))) => {
334 let mut response = Response::new(());
335 *response.status_mut() = http::StatusCode::EARLY_HINTS;
336 *response.headers_mut() = headers;
337 sender
338 .into_inner()
339 .send(
340 stream.send_informational(response).map_err(h2_error_to_io),
341 )
342 .ok();
343 }
344 Err(_) => {
345 return;
346 }
347 }
348 };
349 let Ok(mut response) = response_result else {
350 return;
352 };
353
354 {
355 let response_headers = response.headers_mut();
356 if self.options.send_date_header {
357 if let Some(http_date) = date_cache.get_date_header_value() {
358 response_headers
359 .entry(http::header::DATE)
360 .or_insert(http_date);
361 }
362 }
363 for header in &HTTP2_INVALID_HEADERS {
364 if let http::header::Entry::Occupied(entry) =
365 response_headers.entry(header)
366 {
367 entry.remove();
368 }
369 }
370 if response_headers
371 .get(http::header::TE)
372 .is_some_and(|v| v != "trailers")
373 {
374 response_headers.remove(http::header::TE);
375 }
376 }
377
378 let response_is_end_stream = response.body().is_end_stream();
379 if !response_is_end_stream {
380 if let Some(content_length) = response.body().size_hint().exact() {
381 if !response
382 .headers()
383 .contains_key(http::header::CONTENT_LENGTH)
384 {
385 response
386 .headers_mut()
387 .insert(http::header::CONTENT_LENGTH, content_length.into());
388 }
389 }
390 }
391
392 let (response_parts, mut response_body) = response.into_parts();
393 let Ok(send) = stream.send_response(
394 Response::from_parts(response_parts, ()),
395 response_is_end_stream && upgrade.is_none(),
396 ) else {
397 return;
398 };
399
400 if let Some((upgrade_tx, upgraded, recv_stream)) = upgrade {
401 if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
402 let (upgraded, task) = self::upgrade::pair(send, recv_stream);
403 let _ = upgrade_tx.send(Upgraded::new(upgraded, None));
404 task.await;
405 return;
406 }
407 }
408
409 if response_is_end_stream {
410 return;
411 }
412
413 let _ = PipeToSendStream::new(send, &mut response_body).await;
415 });
416 }
417
418 Ok(())
419 }
420 }
421}