1mod date;
2mod options;
3
4pub use options::*;
5use tokio_util::sync::CancellationToken;
6
7use std::{
8 future::Future,
9 pin::Pin,
10 rc::Rc,
11 task::{Context, Poll},
12};
13
14use bytes::Bytes;
15use http::{Request, Response};
16use http_body::{Body, Frame};
17use http_body_util::BodyExt;
18
19use crate::{h2::date::DateCache, EarlyHints, HttpProtocol, Incoming};
20
21static HTTP2_INVALID_HEADERS: [http::header::HeaderName; 5] = [
22 http::header::HeaderName::from_static("keep-alive"),
23 http::header::HeaderName::from_static("proxy-connection"),
24 http::header::TRANSFER_ENCODING,
25 http::header::TE,
26 http::header::UPGRADE,
27];
28
29struct H2Body {
30 recv: h2::RecvStream,
31 data_done: bool,
32}
33
34impl H2Body {
35 #[inline]
36 fn new(recv: h2::RecvStream) -> Self {
37 Self {
38 recv,
39 data_done: false,
40 }
41 }
42}
43
44impl Body for H2Body {
45 type Data = Bytes;
46 type Error = std::io::Error;
47
48 #[inline]
49 fn poll_frame(
50 mut self: Pin<&mut Self>,
51 cx: &mut Context<'_>,
52 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
53 if !self.data_done {
54 match self.recv.poll_data(cx) {
55 Poll::Ready(Some(Ok(data))) => {
56 let _ = self.recv.flow_control().release_capacity(data.len());
57 return Poll::Ready(Some(Ok(Frame::data(data))));
58 }
59 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(h2_error_to_io(err)))),
60 Poll::Ready(None) => self.data_done = true,
61 Poll::Pending => return Poll::Pending,
62 }
63 }
64
65 match self.recv.poll_trailers(cx) {
66 Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))),
67 Poll::Ready(Ok(None)) => Poll::Ready(None),
68 Poll::Ready(Err(err)) => Poll::Ready(Some(Err(h2_error_to_io(err)))),
69 Poll::Pending => Poll::Pending,
70 }
71 }
72}
73
74#[inline]
75fn h2_error_to_io(error: h2::Error) -> std::io::Error {
76 if error.is_io() {
77 error.into_io().unwrap_or(std::io::Error::other("io error"))
78 } else {
79 std::io::Error::other(error)
80 }
81}
82
83#[inline]
84fn h2_reason_to_io(reason: h2::Reason) -> std::io::Error {
85 std::io::Error::other(h2::Error::from(reason))
86}
87
88#[inline]
89async fn wait_for_send_capacity(
90 send: &mut h2::SendStream<Bytes>,
91 desired_capacity: usize,
92) -> Result<usize, std::io::Error> {
93 if desired_capacity == 0 {
94 return Ok(0);
95 }
96
97 send.reserve_capacity(desired_capacity);
98
99 if send.capacity() > 0 {
100 return Ok(send.capacity().min(desired_capacity));
101 }
102
103 std::future::poll_fn(|cx| loop {
104 match send.poll_reset(cx) {
105 Poll::Ready(Ok(reason)) => return Poll::Ready(Err(h2_reason_to_io(reason))),
106 Poll::Ready(Err(err)) => return Poll::Ready(Err(h2_error_to_io(err))),
107 Poll::Pending => {}
108 }
109
110 match send.poll_capacity(cx) {
111 Poll::Ready(Some(Ok(0))) => {}
112 Poll::Ready(Some(Ok(capacity))) => {
113 return Poll::Ready(Ok(capacity.min(desired_capacity)));
114 }
115 Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(h2_error_to_io(err))),
116 Poll::Ready(None) => {
117 return Poll::Ready(Err(std::io::Error::other(
118 "send stream capacity unexpectedly closed",
119 )))
120 }
121 Poll::Pending => return Poll::Pending,
122 }
123 })
124 .await
125}
126
127pub struct Http2<Io> {
150 io_to_handshake: Option<Io>,
151 date_header_value_cached: DateCache,
152 options: Http2Options,
153 cancel_token: Option<CancellationToken>,
154}
155
156impl<Io> Http2<Io>
157where
158 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
159{
160 #[inline]
172 pub fn new(io: Io, options: Http2Options) -> Self {
173 Self {
174 io_to_handshake: Some(io),
175 date_header_value_cached: DateCache::default(),
176 options,
177 cancel_token: None,
178 }
179 }
180
181 #[inline]
186 pub fn graceful_shutdown_token(mut self, token: CancellationToken) -> Self {
187 self.cancel_token = Some(token);
188 self
189 }
190}
191
192impl<Io> HttpProtocol for Http2<Io>
193where
194 Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static,
195{
196 #[allow(clippy::manual_async_fn)]
197 #[inline]
198 fn handle<F, Fut, ResB, ResBE, ResE>(
199 mut self,
200 request_fn: F,
201 ) -> impl std::future::Future<Output = Result<(), std::io::Error>>
202 where
203 F: Fn(Request<super::Incoming>) -> Fut + 'static,
204 Fut: std::future::Future<Output = Result<Response<ResB>, ResE>>,
205 ResB: http_body::Body<Data = bytes::Bytes, Error = ResBE> + Unpin,
206 ResE: std::error::Error,
207 ResBE: std::error::Error,
208 {
209 async move {
210 let request_fn = Rc::new(request_fn);
211 let handshake_fut = self.options.h2.handshake(
212 self.io_to_handshake
213 .take()
214 .ok_or_else(|| std::io::Error::other("no io to handshake"))?,
215 );
216 let mut h2 = (if let Some(timeout) = self.options.handshake_timeout {
217 vibeio::time::timeout(timeout, handshake_fut).await
218 } else {
219 Ok(handshake_fut.await)
220 })
221 .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "handshake timeout"))?
222 .map_err(|e| {
223 if e.is_io() {
224 e.into_io().unwrap_or(std::io::Error::other("io error"))
225 } else {
226 std::io::Error::other(e)
227 }
228 })?;
229
230 while let Some(request) = {
231 let res = {
232 let accept_fut_orig = h2.accept();
233 let accept_fut_orig_pin = std::pin::pin!(accept_fut_orig);
234 let cancel_token = self.cancel_token.clone();
235 let cancel_fut = async move {
236 if let Some(token) = cancel_token {
237 token.cancelled().await
238 } else {
239 futures_util::future::pending().await
240 }
241 };
242 let cancel_fut_pin = std::pin::pin!(cancel_fut);
243 let accept_fut =
244 futures_util::future::select(cancel_fut_pin, accept_fut_orig_pin);
245
246 match if let Some(timeout) = self.options.accept_timeout {
247 vibeio::time::timeout(timeout, accept_fut).await
248 } else {
249 Ok(accept_fut.await)
250 } {
251 Ok(futures_util::future::Either::Right((request, _))) => {
252 (Some(request), false)
253 }
254 Ok(futures_util::future::Either::Left((_, _))) => {
255 (None, true)
257 }
258 Err(_) => {
259 (None, false)
261 }
262 }
263 };
264 match res {
265 (Some(request), _) => request,
266 (None, graceful) => {
267 h2.graceful_shutdown();
268 let _ = h2.accept().await;
269 if graceful {
270 return Ok(());
271 }
272 return Err(std::io::Error::new(
273 std::io::ErrorKind::TimedOut,
274 "accept timeout",
275 ));
276 }
277 }
278 } {
279 let (request, mut stream) = match request {
280 Ok(d) => d,
281 Err(e) if e.is_go_away() => {
282 continue;
283 }
284 Err(e) if e.is_io() => {
285 return Err(e.into_io().unwrap_or(std::io::Error::other("io error")));
286 }
287 Err(e) => {
288 return Err(std::io::Error::other(e));
289 }
290 };
291
292 let date_cache = self.date_header_value_cached.clone();
293 let request_fn = request_fn.clone();
294 let send_continue_response = self.options.send_continue_response;
295 vibeio::spawn(async move {
296 let (request_parts, recv_stream) = request.into_parts();
297 let request_body = Incoming::new(H2Body::new(recv_stream));
298 let mut request = Request::from_parts(request_parts, request_body);
299
300 if send_continue_response {
302 let is_100_continue = request
303 .headers()
304 .get(http::header::EXPECT)
305 .and_then(|v| v.to_str().ok())
306 .is_some_and(|v| v.eq_ignore_ascii_case("100-continue"));
307 if is_100_continue {
308 let mut response = Response::new(());
309 *response.status_mut() = http::StatusCode::CONTINUE;
310 let _ = stream.send_informational(response).map_err(h2_error_to_io);
311 }
312 }
313
314 let (early_hints_tx, early_hints_rx) = async_channel::unbounded();
315 let early_hints = EarlyHints::new(early_hints_tx);
316 request.extensions_mut().insert(early_hints);
317
318 let mut response_fut = std::pin::pin!(request_fn(request));
319 let early_hints_rx = early_hints_rx;
320 let response_result = loop {
321 let early_hints_recv_fut = early_hints_rx.recv();
322 let mut early_hints_recv_fut = std::pin::pin!(early_hints_recv_fut);
323 let next = std::future::poll_fn(|cx| {
324 match stream.poll_reset(cx) {
325 Poll::Ready(Ok(reason)) => {
326 return Poll::Ready(Err(h2_reason_to_io(reason)));
327 }
328 Poll::Ready(Err(err)) => {
329 return Poll::Ready(Err(h2_error_to_io(err)));
330 }
331 Poll::Pending => {}
332 }
333
334 if let Poll::Ready(res) = response_fut.as_mut().poll(cx) {
335 return Poll::Ready(Ok(futures_util::future::Either::Left(res)));
336 }
337
338 match early_hints_recv_fut.as_mut().poll(cx) {
339 Poll::Ready(Ok(msg)) => {
340 Poll::Ready(Ok(futures_util::future::Either::Right(msg)))
341 }
342 Poll::Ready(Err(_)) => Poll::Pending,
343 Poll::Pending => Poll::Pending,
344 }
345 })
346 .await;
347
348 match next {
349 Ok(futures_util::future::Either::Left(response_result)) => {
350 break response_result;
351 }
352 Ok(futures_util::future::Either::Right((headers, sender))) => {
353 let mut response = Response::new(());
354 *response.status_mut() = http::StatusCode::EARLY_HINTS;
355 *response.headers_mut() = headers;
356 sender
357 .into_inner()
358 .send(
359 stream.send_informational(response).map_err(h2_error_to_io),
360 )
361 .ok();
362 }
363 Err(_) => {
364 return;
365 }
366 }
367 };
368 let Ok(mut response) = response_result else {
369 return;
371 };
372
373 {
374 let response_headers = response.headers_mut();
375 if let Some(http_date) = date_cache.get_date_header_value() {
376 response_headers
377 .entry(http::header::DATE)
378 .or_insert(http_date);
379 }
380 if let Some(connection_header) = response_headers
381 .remove(http::header::CONNECTION)
382 .as_ref()
383 .and_then(|v| v.to_str().ok())
384 {
385 for name in connection_header.split(',') {
386 response_headers.remove(name.trim());
387 }
388 }
389 while response_headers.remove(http::header::CONNECTION).is_some() {}
390 for header in &HTTP2_INVALID_HEADERS {
391 while response_headers.remove(header).is_some() {}
392 }
393 }
394
395 let response_is_end_stream = response.body().is_end_stream();
396 if !response_is_end_stream {
397 if let Some(content_length) = response.body().size_hint().exact() {
398 if !response
399 .headers()
400 .contains_key(http::header::CONTENT_LENGTH)
401 {
402 response
403 .headers_mut()
404 .insert(http::header::CONTENT_LENGTH, content_length.into());
405 }
406 }
407 }
408
409 let (response_parts, mut response_body) = response.into_parts();
410 let mut send = match stream.send_response(
411 Response::from_parts(response_parts, ()),
412 response_is_end_stream,
413 ) {
414 Ok(send) => send,
415 Err(_) => {
416 return;
417 }
418 };
419
420 if response_is_end_stream {
421 return;
422 }
423
424 while let Some(chunk) = {
425 let frame_fut = response_body.frame();
426 let mut frame_fut = std::pin::pin!(frame_fut);
427 match std::future::poll_fn(|cx| {
428 match send.poll_reset(cx) {
429 Poll::Ready(Ok(reason)) => {
430 return Poll::Ready(Err(h2_reason_to_io(reason)));
431 }
432 Poll::Ready(Err(err)) => {
433 return Poll::Ready(Err(h2_error_to_io(err)));
434 }
435 Poll::Pending => {}
436 }
437
438 match frame_fut.as_mut().poll(cx) {
439 Poll::Ready(frame) => Poll::Ready(Ok(frame)),
440 Poll::Pending => Poll::Pending,
441 }
442 })
443 .await
444 {
445 Ok(frame) => frame,
446 Err(_) => {
447 return;
448 }
449 }
450 } {
451 match chunk {
452 Ok(frame) => {
453 if frame.is_data() {
454 match frame.into_data() {
455 Ok(mut data) => {
456 let response_is_end_stream =
457 response_body.is_end_stream();
458 if data.is_empty() {
459 if send
460 .send_data(data, response_is_end_stream)
461 .is_err()
462 {
463 return;
464 }
465 if response_is_end_stream {
466 return;
467 }
468 continue;
469 }
470
471 while !data.is_empty() {
472 let capacity = match wait_for_send_capacity(
473 &mut send,
474 data.len(),
475 )
476 .await
477 {
478 Ok(capacity) => capacity,
479 Err(_) => return,
480 };
481 let chunk = data.split_to(capacity.min(data.len()));
482 let is_end_stream =
483 response_is_end_stream && data.is_empty();
484 if send.send_data(chunk, is_end_stream).is_err() {
485 return;
486 }
487 if is_end_stream {
488 return;
489 }
490 }
491 }
492 Err(_) => {
493 return;
494 }
495 }
496 } else if frame.is_trailers() {
497 match frame.into_trailers() {
498 Ok(trailers) => {
499 if send.send_trailers(trailers).is_err() {
500 return;
501 }
502 return;
503 }
504 Err(_) => {
505 return;
506 }
507 }
508 }
509 }
510 Err(_) => {
511 return;
512 }
513 }
514 }
515 let _ = send.send_data(Bytes::new(), true);
516 });
517 }
518
519 Ok(())
520 }
521 }
522}