use core::{
cmp, fmt,
future::{poll_fn, Future},
marker::PhantomData,
pin::{pin, Pin},
task::{ready, Context, Poll},
time::Duration,
};
use std::net::SocketAddr;
use ::h2::{
server::{Connection, SendResponse},
Ping, PingPong,
};
use futures_core::stream::Stream;
use tracing::trace;
use xitca_io::io::{AsyncRead, AsyncWrite};
use xitca_service::Service;
use xitca_unsafe_collection::futures::{Select as _, SelectOutput};
use crate::{
body::BodySize,
bytes::Bytes,
date::{DateTime, DateTimeHandle},
error::HttpServiceError,
h2::{body::RequestBody, error::Error},
http::{
header::{HeaderMap, HeaderName, HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRAILER},
Extension, Request, RequestExt, Response, Version,
},
util::{futures::Queue, timer::KeepAlive},
};
pub(crate) struct Dispatcher<'a, TlsSt, S, ReqB> {
io: &'a mut Connection<TlsSt, Bytes>,
addr: SocketAddr,
keep_alive: Pin<&'a mut KeepAlive>,
ka_dur: Duration,
service: &'a S,
date: &'a DateTimeHandle,
_req_body: PhantomData<ReqB>,
}
impl<'a, TlsSt, S, ReqB, ResB, BE> Dispatcher<'a, TlsSt, S, ReqB>
where
S: Service<Request<RequestExt<ReqB>>, Response = Response<ResB>>,
S::Error: fmt::Debug,
ResB: Stream<Item = Result<Bytes, BE>>,
BE: fmt::Debug,
TlsSt: AsyncRead + AsyncWrite + Unpin,
ReqB: From<RequestBody>,
{
pub(crate) fn new(
io: &'a mut Connection<TlsSt, Bytes>,
addr: SocketAddr,
keep_alive: Pin<&'a mut KeepAlive>,
ka_dur: Duration,
service: &'a S,
date: &'a DateTimeHandle,
) -> Self {
Self {
io,
addr,
keep_alive,
ka_dur,
service,
date,
_req_body: PhantomData,
}
}
pub(crate) async fn run(self) -> Result<(), Error<S::Error, BE>> {
let Self {
io,
addr,
mut keep_alive,
ka_dur,
service,
date,
..
} = self;
let ping_pong = io.ping_pong().expect("first call to ping_pong should never fail");
let deadline = date.now() + ka_dur;
keep_alive.as_mut().update(deadline);
let mut ping_pong = H2PingPong {
on_flight: false,
keep_alive: keep_alive.as_mut(),
ping_pong,
date,
ka_dur,
};
let mut queue = Queue::new();
loop {
match io.accept().select(try_poll_queue(&mut queue, &mut ping_pong)).await {
SelectOutput::A(Some(Ok((req, tx)))) => {
let req = req.map(|body| {
let body = ReqB::from(RequestBody::from(body));
RequestExt::from_parts(body, Extension::new(addr))
});
queue.push(async move {
let fut = service.call(req);
h2_handler(fut, tx, date).await
});
}
SelectOutput::B(SelectOutput::A(_)) => io.graceful_shutdown(),
SelectOutput::B(SelectOutput::B(Ok(_))) => {
trace!("Connection keep-alive timeout. Shutting down");
return Ok(());
}
SelectOutput::A(None) => {
trace!("Connection closed by remote. Shutting down");
break;
}
SelectOutput::A(Some(Err(e))) | SelectOutput::B(SelectOutput::B(Err(e))) => return Err(From::from(e)),
}
}
queue.drain().await;
Ok(())
}
}
async fn try_poll_queue<F, E, S, B>(
queue: &mut Queue<F>,
ping_ping: &mut H2PingPong<'_>,
) -> SelectOutput<(), Result<(), ::h2::Error>>
where
F: Future<Output = Result<ConnectionState, E>>,
HttpServiceError<S, B>: From<E>,
S: fmt::Debug,
B: fmt::Debug,
{
loop {
if queue.is_empty() {
return SelectOutput::B(ping_ping.await);
}
match queue.next2().await {
Ok(ConnectionState::KeepAlive) => {}
Ok(ConnectionState::Close) => return SelectOutput::A(()),
Err(e) => HttpServiceError::from(e).log("h2_dispatcher"),
}
}
}
struct H2PingPong<'a> {
on_flight: bool,
keep_alive: Pin<&'a mut KeepAlive>,
ping_pong: PingPong,
date: &'a DateTimeHandle,
ka_dur: Duration,
}
impl Future for H2PingPong<'_> {
type Output = Result<(), ::h2::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
if this.on_flight {
match this.ping_pong.poll_pong(cx)? {
Poll::Ready(_) => {
this.on_flight = false;
let deadline = this.date.now() + this.ka_dur;
this.keep_alive.as_mut().update(deadline);
this.keep_alive.as_mut().reset();
}
Poll::Pending => return this.keep_alive.as_mut().poll(cx).map(|_| Ok(())),
}
} else {
ready!(this.keep_alive.as_mut().poll(cx));
this.ping_pong.send_ping(Ping::opaque())?;
let deadline = this.date.now() + (this.ka_dur * 10);
this.keep_alive.as_mut().update(deadline);
this.on_flight = true;
}
}
}
}
enum ConnectionState {
KeepAlive,
Close,
}
async fn h2_handler<Fut, B, SE, BE>(
fut: Fut,
mut tx: SendResponse<Bytes>,
date: &DateTimeHandle,
) -> Result<ConnectionState, Error<SE, BE>>
where
Fut: Future<Output = Result<Response<B>, SE>>,
B: Stream<Item = Result<Bytes, BE>>,
BE: fmt::Debug,
{
let (res, body) = fut.await.map_err(Error::Service)?.into_parts();
let mut res = Response::from_parts(res, ());
*res.version_mut() = Version::HTTP_2;
let is_eof = match BodySize::from_stream(&body) {
BodySize::None => {
debug_assert!(!res.headers().contains_key(CONTENT_LENGTH));
true
}
BodySize::Stream => {
debug_assert!(!res.headers().contains_key(CONTENT_LENGTH));
false
}
BodySize::Sized(n) => {
if !res.headers().contains_key(CONTENT_LENGTH) {
res.headers_mut().insert(CONTENT_LENGTH, HeaderValue::from(n));
}
n == 0
}
};
let mut trailers = HeaderMap::with_capacity(0);
while let Some(value) = res.headers_mut().remove(TRAILER) {
let name = HeaderName::from_bytes(value.as_bytes()).unwrap();
let value = res.headers_mut().remove(name.clone()).unwrap();
trailers.append(name, value);
}
if !res.headers().contains_key(DATE) {
let date = date.with_date(HeaderValue::from_bytes).unwrap();
res.headers_mut().insert(DATE, date);
}
let state = res
.headers_mut()
.remove(CONNECTION)
.and_then(|v| {
v.as_bytes()
.eq_ignore_ascii_case(b"close")
.then_some(ConnectionState::Close)
})
.unwrap_or(ConnectionState::KeepAlive);
let mut stream = tx.send_response(res, is_eof)?;
if !is_eof {
let mut body = pin!(body);
while let Some(res) = poll_fn(|cx| body.as_mut().poll_next(cx)).await {
let mut chunk = res.map_err(Error::Body)?;
while !chunk.is_empty() {
let len = chunk.len();
stream.reserve_capacity(cmp::min(len, CHUNK_SIZE));
let cap = poll_fn(|cx| stream.poll_capacity(cx))
.await
.expect("No capacity left. http2 response is dropped")?;
let bytes = chunk.split_to(cmp::min(cap, len));
stream.send_data(bytes, false)?;
}
}
}
stream.send_trailers(trailers)?;
Ok(state)
}
const CHUNK_SIZE: usize = 16_384;