use super::*;
use std::{io::BufRead, net::SocketAddr, ops};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
net::{TcpStream, ToSocketAddrs},
};
use tokio_tls_listener::{tls_config, tokio_rustls::server::TlsStream, TlsListener};
pub struct Server {
#[doc(hidden)]
pub listener: TlsListener,
}
impl Server {
pub async fn bind(
addr: impl ToSocketAddrs,
certs: &mut dyn BufRead,
key: &mut dyn BufRead,
) -> Result<Self, BoxErr> {
let mut conf = tls_config(certs, key)?;
conf.alpn_protocols = vec![b"h2".to_vec()];
#[cfg(debug_assertions)]
if std::env::var("SSLKEYLOGFILE").is_ok() {
conf.key_log = std::sync::Arc::new(tokio_tls_listener::rustls::KeyLogFile::new());
}
Ok(Self {
listener: TlsListener::bind(addr, conf).await?,
})
}
pub fn with_graceful_shutdown(self) -> GracefulShutdown<Self> {
GracefulShutdown::new(self)
}
#[inline]
pub async fn accept(&self) -> io::Result<(Conn<TlsStream<TcpStream>>, SocketAddr)> {
let (stream, addr) = self.listener.accept_tls().await?;
let conn = Conn::handshake(stream).await.map_err(io_err)?;
Ok((conn, addr))
}
}
#[derive(Debug)]
pub struct Conn<IO> {
inner: h2::server::Connection<IO, Bytes>,
}
impl<IO> Conn<IO>
where
IO: Unpin + AsyncRead + AsyncWrite,
{
#[inline]
pub async fn handshake(io: IO) -> Result<Conn<IO>> {
h2::server::handshake(io).await.map(|inner| Self { inner })
}
pub async fn accept(&mut self) -> Option<Result<(Request, Response)>> {
poll_fn(|cx| self.poll_accept(cx)).await
}
#[doc(hidden)]
pub fn poll_accept(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(Request, Response)>>> {
self.inner.poll_accept(cx).map(|event| {
event.map(|accept| {
accept.map(|(req, sender)| {
let (head, body) = req.into_parts();
let request = Request { head, body };
let response = Response {
status: http::StatusCode::default(),
headers: http::HeaderMap::default(),
sender,
};
(request, response)
})
})
})
}
pub fn incoming<State, Stream, Close>(
mut self,
state: State,
on_stream: fn(&mut Self, State, Request, Response) -> Stream,
on_close: fn(State) -> Close,
) where
IO: Send + 'static,
State: Clone + Send + 'static,
Stream: Future + Send + 'static,
Stream::Output: Send,
Close: Future + Send + 'static,
{
tokio::spawn(async move {
while let Some(Ok((req, res))) = self.accept().await {
let state = state.clone();
tokio::spawn(on_stream(&mut self, state, req, res));
}
on_close(state).await;
});
}
}
impl ops::Deref for Server {
type Target = TlsListener;
#[inline]
fn deref(&self) -> &Self::Target {
&self.listener
}
}
impl<IO> ops::Deref for Conn<IO> {
type Target = h2::server::Connection<IO, Bytes>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<IO> ops::DerefMut for Conn<IO> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}