1mod date;
4mod options;
5mod send;
6mod upgrade;
7
8pub use options::*;
9use pin_project_lite::pin_project;
10use tokio_util::sync::CancellationToken;
11
12use std::{
13 future::Future,
14 pin::Pin,
15 sync::{atomic::AtomicBool, Arc},
16 task::{Context, Poll},
17};
18
19use bytes::Bytes;
20use http::{Request, Response};
21use http_body::{Body, Frame};
22
23use crate::{
24 early_hints::EarlyHintsReceiver,
25 h2::{
26 date::DateCache,
27 send::{PipeToSendStream, SendBuf},
28 },
29 EarlyHints, HttpProtocol, Incoming, Upgrade, Upgraded,
30};
31
32static HTTP2_INVALID_HEADERS: [http::header::HeaderName; 5] = [
33 http::header::HeaderName::from_static("keep-alive"),
34 http::header::HeaderName::from_static("proxy-connection"),
35 http::header::CONNECTION,
36 http::header::TRANSFER_ENCODING,
37 http::header::UPGRADE,
38];
39
40pub(crate) struct H2Body {
41 recv: h2::RecvStream,
42 data_done: bool,
43 send_continue_body: Option<Arc<AtomicBool>>,
44}
45
46impl H2Body {
47 #[inline]
48 fn new(recv: h2::RecvStream, send_continue_body: Option<Arc<AtomicBool>>) -> Self {
49 Self {
50 recv,
51 data_done: false,
52 send_continue_body,
53 }
54 }
55}
56
57impl Body for H2Body {
58 type Data = Bytes;
59 type Error = std::io::Error;
60
61 #[inline]
62 fn poll_frame(
63 mut self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
66 if !self.data_done {
67 match self.recv.poll_data(cx) {
68 Poll::Ready(Some(Ok(data))) => {
69 let _ = self.recv.flow_control().release_capacity(data.len());
70 return Poll::Ready(Some(Ok(Frame::data(data))));
71 }
72 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(h2_error_to_io(err)))),
73 Poll::Ready(None) => self.data_done = true,
74 Poll::Pending => {
75 if let Some(scb) = self.send_continue_body.as_ref() {
76 scb.store(true, std::sync::atomic::Ordering::Relaxed);
77 }
78 return Poll::Pending;
79 }
80 }
81 }
82
83 match self.recv.poll_trailers(cx) {
84 Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
85 Poll::Ready(Ok(None)) => Poll::Ready(None),
86 Poll::Ready(Err(err)) => Poll::Ready(Some(Err(h2_error_to_io(err)))),
87 Poll::Pending => {
88 if let Some(scb) = self.send_continue_body.as_ref() {
89 scb.store(true, std::sync::atomic::Ordering::Relaxed);
90 }
91 Poll::Pending
92 }
93 }
94 }
95}
96
97#[inline]
98pub(super) fn h2_error_to_io(error: h2::Error) -> std::io::Error {
99 if error.is_io() {
100 error.into_io().unwrap_or(std::io::Error::other("io error"))
101 } else {
102 std::io::Error::other(error)
103 }
104}
105
106#[inline]
107pub(super) fn h2_reason_to_io(reason: h2::Reason) -> std::io::Error {
108 std::io::Error::other(h2::Error::from(reason))
109}
110
111#[inline]
112fn sanitize_response<ResB>(
113 response: &mut Response<ResB>,
114 send_date_header: bool,
115 date_cache: &DateCache,
116) where
117 ResB: Body<Data = bytes::Bytes>,
118{
119 let response_headers = response.headers_mut();
120 if send_date_header {
121 if let Some(http_date) = date_cache.get_date_header_value() {
122 response_headers
123 .entry(http::header::DATE)
124 .or_insert(http_date);
125 }
126 }
127 for header in &HTTP2_INVALID_HEADERS {
128 if let http::header::Entry::Occupied(entry) = response_headers.entry(header) {
129 entry.remove();
130 }
131 }
132 if response_headers
133 .get(http::header::TE)
134 .is_some_and(|v| v != "trailers")
135 {
136 response_headers.remove(http::header::TE);
137 }
138}
139
140struct PendingUpgrade {
141 tx: oneshot::Sender<Upgraded>,
142 upgraded: std::sync::Arc<std::sync::atomic::AtomicBool>,
143 recv_stream: h2::RecvStream,
144}
145
146pin_project! {
147 struct H2Stream<Fut, ResB>
148 where
149 Fut: Future,
150 ResB: Body<Data = bytes::Bytes>,
151 {
152 stream: h2::server::SendResponse<SendBuf<ResB::Data>>,
153 #[pin]
154 state: H2StreamState<Fut, ResB>,
155 }
156}
157
158pin_project! {
159 #[project = H2StreamStateProj]
160 enum H2StreamState<Fut, ResB>
161 where
162 Fut: Future,
163 ResB: Body<Data = bytes::Bytes>,
164 {
165 Service {
166 #[pin]
167 response_fut: Fut,
168 early_hints_rx: EarlyHintsReceiver,
169 date_cache: DateCache,
170 send_date_header: bool,
171 upgrade: Option<PendingUpgrade>,
172 send_continue: bool,
173 early_hints_open: bool,
174 send_continue_body: Option<Arc<AtomicBool>>,
175 continue_sent: bool
176 },
177 Body {
178 #[pin]
179 pipe: PipeToSendStream<ResB>,
180 },
181 }
182}
183
184impl<Fut, ResB> H2Stream<Fut, ResB>
185where
186 Fut: Future,
187 ResB: Body<Data = bytes::Bytes>,
188{
189 #[allow(clippy::too_many_arguments)]
190 #[inline]
191 const fn new(
192 stream: h2::server::SendResponse<SendBuf<ResB::Data>>,
193 response_fut: Fut,
194 early_hints_rx: EarlyHintsReceiver,
195 date_cache: DateCache,
196 send_date_header: bool,
197 upgrade: Option<PendingUpgrade>,
198 send_continue: bool,
199 send_continue_body: Option<Arc<AtomicBool>>,
200 ) -> Self {
201 Self {
202 stream,
203 state: H2StreamState::Service {
204 response_fut,
205 early_hints_rx,
206 date_cache,
207 send_date_header,
208 upgrade,
209 send_continue,
210 early_hints_open: true,
211 send_continue_body,
212 continue_sent: false,
213 },
214 }
215 }
216}
217
218impl<Fut, ResB, ResBE, ResE> Future for H2Stream<Fut, ResB>
219where
220 Fut: Future<Output = Result<Response<ResB>, ResE>>,
221 ResB: Body<Data = bytes::Bytes, Error = ResBE>,
222 ResE: std::error::Error,
223 ResBE: std::error::Error,
224{
225 type Output = ();
226
227 #[inline]
228 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
229 let mut this = self.project();
230
231 loop {
232 match this.state.as_mut().project() {
233 H2StreamStateProj::Service {
234 response_fut,
235 early_hints_rx,
236 date_cache,
237 send_date_header,
238 upgrade,
239 send_continue,
240 early_hints_open,
241 send_continue_body,
242 continue_sent,
243 } => {
244 if let Poll::Ready(response_result) = response_fut.poll(cx) {
245 let Ok(mut response) = response_result else {
246 return Poll::Ready(());
247 };
248
249 sanitize_response(&mut response, *send_date_header, date_cache);
250
251 let response_is_end_stream = response.body().is_end_stream();
252 if !response_is_end_stream {
253 if let Some(content_length) = response.body().size_hint().exact() {
254 if !response
255 .headers()
256 .contains_key(http::header::CONTENT_LENGTH)
257 {
258 response.headers_mut().insert(
259 http::header::CONTENT_LENGTH,
260 content_length.into(),
261 );
262 }
263 }
264 }
265
266 if *send_continue && !*continue_sent {
267 if !response.status().is_client_error()
268 && !response.status().is_server_error()
269 {
270 let mut response = Response::new(());
271 *response.status_mut() = http::StatusCode::CONTINUE;
272 let _ = this
273 .stream
274 .send_informational(response)
275 .map_err(h2_error_to_io);
276 }
277 *continue_sent = true;
278 }
279
280 let (response_parts, response_body) = response.into_parts();
281 let Ok(send) = this.stream.send_response(
282 Response::from_parts(response_parts, ()),
283 response_is_end_stream && upgrade.is_none(),
284 ) else {
285 return Poll::Ready(());
286 };
287
288 if let Some(PendingUpgrade {
289 tx,
290 upgraded,
291 recv_stream,
292 }) = upgrade.take()
293 {
294 if upgraded.load(std::sync::atomic::Ordering::Relaxed) {
295 let (upgraded, task) = self::upgrade::pair(send, recv_stream);
296 let _ = tx.send(Upgraded::new(upgraded, None));
297 vibeio::spawn(task);
298 return Poll::Ready(());
299 }
300 }
301
302 if response_is_end_stream {
303 return Poll::Ready(());
304 }
305
306 this.state.set(H2StreamState::Body {
307 pipe: PipeToSendStream::new(send, response_body),
308 });
309 continue;
310 }
311
312 match this.stream.poll_reset(cx) {
313 Poll::Ready(Ok(_)) | Poll::Ready(Err(_)) => return Poll::Ready(()),
314 Poll::Pending => {}
315 }
316
317 if *send_continue
318 && !*continue_sent
319 && send_continue_body
320 .as_ref()
321 .is_some_and(|scb| scb.load(std::sync::atomic::Ordering::Relaxed))
322 {
323 let mut response = Response::new(());
324 *response.status_mut() = http::StatusCode::CONTINUE;
325 let _ = this
326 .stream
327 .send_informational(response)
328 .map_err(h2_error_to_io);
329 *continue_sent = true;
330 }
331
332 if *early_hints_open {
333 match early_hints_rx.poll_recv(cx) {
334 Poll::Ready(Some((headers, sender))) => {
335 let mut response = Response::new(());
336 *response.status_mut() = http::StatusCode::EARLY_HINTS;
337 *response.headers_mut() = headers;
338 sender
339 .into_inner()
340 .send(
341 this.stream
342 .send_informational(response)
343 .map_err(h2_error_to_io),
344 )
345 .ok();
346 continue;
347 }
348 Poll::Ready(None) => {
349 *early_hints_open = false;
350 continue;
351 }
352 Poll::Pending => {}
353 }
354 }
355
356 return Poll::Pending;
357 }
358 H2StreamStateProj::Body { pipe } => {
359 return pipe.poll(cx).map(|_| ());
360 }
361 }
362 }
363 }
364}
365
366pub struct Http2<Io> {
389 io_to_handshake: Option<Io>,
390 date_header_value_cached: DateCache,
391 options: Http2Options,
392 cancel_token: Option<CancellationToken>,
393}
394
395impl<Io> Http2<Io>
396where
397 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
398{
399 #[inline]
411 pub fn new(io: Io, options: Http2Options) -> Self {
412 Self {
413 io_to_handshake: Some(io),
414 date_header_value_cached: DateCache::default(),
415 options,
416 cancel_token: None,
417 }
418 }
419
420 #[inline]
425 pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
426 self.cancel_token = Some(token);
427 self
428 }
429}
430
431impl<Io> HttpProtocol for Http2<Io>
432where
433 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
434{
435 #[allow(clippy::manual_async_fn)]
436 #[inline]
437 fn handle<F, Fut, ResB, ResBE, ResE>(
438 mut self,
439 request_fn: F,
440 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
441 where
442 F: Fn(Request<super::Incoming>) -> Fut + 'static,
443 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>> + 'static,
444 ResB: http_body::Body<Data = bytes::Bytes, Error = ResBE> + Unpin + 'static,
445 ResE: std::error::Error,
446 ResBE: std::error::Error,
447 {
448 async move {
449 let handshake_fut = self.options.h2.handshake(
450 self.io_to_handshake
451 .take()
452 .ok_or_else(|| std::io::Error::other("no io to handshake"))?,
453 );
454 let mut h2 = (if let Some(timeout) = self.options.handshake_timeout {
455 vibeio::time::timeout(timeout, handshake_fut).await
456 } else {
457 Ok(handshake_fut.await)
458 })
459 .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "handshake timeout"))?
460 .map_err(|e| {
461 if e.is_io() {
462 e.into_io().unwrap_or(std::io::Error::other("io error"))
463 } else {
464 std::io::Error::other(e)
465 }
466 })?;
467
468 while let Some(request) = {
469 let res = {
470 let accept_fut_orig = h2.accept();
471 let accept_fut_orig_pin = std::pin::pin!(accept_fut_orig);
472 let cancel_token = self.cancel_token.clone();
473 let cancel_fut = async move {
474 if let Some(token) = cancel_token {
475 token.cancelled().await
476 } else {
477 futures_util::future::pending().await
478 }
479 };
480 let cancel_fut_pin = std::pin::pin!(cancel_fut);
481 let accept_fut =
482 futures_util::future::select(cancel_fut_pin, accept_fut_orig_pin);
483
484 match if let Some(timeout) = self.options.accept_timeout {
485 vibeio::time::timeout(timeout, accept_fut).await
486 } else {
487 Ok(accept_fut.await)
488 } {
489 Ok(futures_util::future::Either::Right((request, _))) => {
490 (Some(request), false)
491 }
492 Ok(futures_util::future::Either::Left((_, _))) => {
493 (None, true)
495 }
496 Err(_) => {
497 (None, false)
499 }
500 }
501 };
502 match res {
503 (Some(request), _) => request,
504 (None, graceful) => {
505 h2.graceful_shutdown();
506 let _ = h2.accept().await;
507 if graceful {
508 return Ok(());
509 }
510 return Err(std::io::Error::new(
511 std::io::ErrorKind::TimedOut,
512 "accept timeout",
513 ));
514 }
515 }
516 } {
517 let (request, stream) = match request {
518 Ok(d) => d,
519 Err(e) if e.is_go_away() => {
520 continue;
521 }
522 Err(e) if e.is_io() => {
523 return Err(e.into_io().unwrap_or(std::io::Error::other("io error")));
524 }
525 Err(e) => {
526 return Err(std::io::Error::other(e));
527 }
528 };
529
530 let is_100_continue = self.options.send_continue_response
532 && request
533 .headers()
534 .get(http::header::EXPECT)
535 .and_then(|v| v.to_str().ok())
536 .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
537
538 let date_cache = self.date_header_value_cached.clone();
539 let send_continue_body = is_100_continue.then(|| Arc::new(AtomicBool::new(false)));
540 let (request_parts, recv_stream) = request.into_parts();
541 let (request_body, upgrade) = if request_parts.method == http::Method::CONNECT {
542 (Incoming::Empty, Some(recv_stream))
543 } else {
544 (
545 Incoming::H2(H2Body::new(recv_stream, send_continue_body.clone())),
546 None,
547 )
548 };
549 let mut request = Request::from_parts(request_parts, request_body);
550
551 let (early_hints, early_hints_rx) = EarlyHints::new_lazy();
553 request.extensions_mut().insert(early_hints);
554
555 let upgrade = if let Some(recv_stream) = upgrade {
557 let (upgrade_tx, upgrade_rx) = oneshot::async_channel();
558 let upgrade = Upgrade::new(upgrade_rx);
559 let upgraded = upgrade.upgraded.clone();
560 request.extensions_mut().insert(upgrade);
561 Some(PendingUpgrade {
562 tx: upgrade_tx,
563 upgraded,
564 recv_stream,
565 })
566 } else {
567 None
568 };
569
570 vibeio::spawn(H2Stream::new(
571 stream,
572 request_fn(request),
573 early_hints_rx,
574 date_cache,
575 self.options.send_date_header,
576 upgrade,
577 is_100_continue,
578 send_continue_body,
579 ));
580 }
581
582 Ok(())
583 }
584 }
585}