use crate::{
body::HttpBody,
error::Error,
handler::Handler,
listener::Listener,
request::{AnyClone, Request},
};
use futures::FutureExt;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder,
};
use std::{
collections::HashMap,
fmt::Debug,
future::{pending, Future, IntoFuture, Pending},
io,
pin::Pin,
sync::{Arc, RwLock},
};
use tokio::{
io::{AsyncRead, AsyncWrite},
pin, select,
sync::watch,
};
struct HyperHandler<A> {
handler: Arc<dyn Handler>,
depot: Arc<RwLock<HashMap<String, Box<dyn AnyClone + Send + Sync>>>>,
addr: A,
}
impl<A> HyperHandler<A>
where
A: Clone + Send + Sync + 'static,
{
#[inline]
fn new(
handler: Arc<dyn Handler>,
addr: A,
depot: Arc<RwLock<HashMap<String, Box<dyn AnyClone + Send + Sync>>>>,
) -> Self {
Self {
handler,
addr,
depot,
}
}
#[inline]
async fn serve(
req: hyper::Request<hyper::body::Incoming>,
handler: Arc<dyn Handler>,
addr: A,
depot: Arc<RwLock<HashMap<String, Box<dyn AnyClone + Send + Sync>>>>,
) -> crate::Result<hyper::Response<HttpBody>> {
let mut request = Request::new(req.map(HttpBody::Incoming), depot);
request.extensions_mut().insert(addr);
match handler.handle(request).await {
Ok(res) => res.into_raw(),
Err(e) => {
let builder = crate::response::Response::default();
e.response_builder(builder).into_raw()
}
}
}
}
impl<A> hyper::service::Service<hyper::Request<hyper::body::Incoming>> for HyperHandler<A>
where
A: Clone + Send + Sync + 'static,
{
type Error = Error;
type Response = hyper::Response<HttpBody>;
type Future =
Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
#[inline]
fn call(&self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
Box::pin(Self::serve(
req,
self.handler.clone(),
self.addr.clone(),
self.depot.clone(),
))
}
}
pub struct Server<L, S, H> {
listener: L,
signal: S,
handler: H,
build: Builder<TokioExecutor>,
}
impl<L, S, H> Server<L, S, H> {
#[inline]
pub fn signal<X>(self, signal: X) -> Server<L, X, H> {
Server {
signal,
build: self.build,
listener: self.listener,
handler: self.handler,
}
}
}
#[inline]
pub fn listen<L, H>(handler: H, listener: L) -> Server<L, Pending<()>, H> {
Server {
handler,
listener,
signal: pending(),
build: Builder::new(TokioExecutor::new()),
}
}
impl<L, S, H> IntoFuture for Server<L, S, H>
where
L: Listener + Send + 'static,
L::Io: AsyncRead + AsyncWrite + Send + Unpin,
L::Addr: Send + Sync + Debug,
S: Future + Send + 'static,
H: Handler,
{
type Output = crate::Result<()>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
#[inline]
fn into_future(self) -> Self::IntoFuture {
let Self {
build,
signal,
listener,
handler,
} = self;
let (shutdown_tx, shutdown_rx) = watch::channel(());
let shutdown_tx = Arc::new(shutdown_tx);
tokio::spawn(async move {
signal.await;
drop(shutdown_rx);
});
let (close_tx, close_rx) = watch::channel(());
let build = Arc::new(build);
let handler = Arc::new(handler);
Box::pin(async move {
loop {
let (stream, addr) = select! {
res = listener.accept() => {
match res {
Ok(conn) => conn,
Err(e) => {
if matches!(e.kind(),io::ErrorKind::ConnectionRefused | io::ErrorKind::ConnectionAborted | io::ErrorKind::ConnectionReset) == false
{
eprintln!("listener accept error: {e}");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
continue
}
}
}
() = shutdown_tx.closed() => {
break;
}
};
let stream = TokioIo::new(stream);
let addr = Arc::new(addr);
let hyper_handler = HyperHandler::<Arc<L::Addr>>::new(
handler.clone(),
addr.clone(),
Arc::new(RwLock::new(HashMap::default())),
);
let shutdown_tx = Arc::clone(&shutdown_tx);
let close_rx = close_rx.clone();
let build = build.clone();
tokio::spawn(async move {
let conn = build.serve_connection_with_upgrades(stream, hyper_handler);
pin!(conn);
let shutdown = shutdown_tx.closed().fuse();
pin!(shutdown);
loop {
select! {
res = conn.as_mut() => {
if let Err(e) = res {
eprintln!("connection failed: {e}");
}
break;
}
() = &mut shutdown => {
conn.as_mut().graceful_shutdown();
}
}
}
drop(close_rx);
});
}
drop(close_rx);
drop(listener);
eprintln!(
"waiting for {} task(s) to finish",
close_tx.receiver_count()
);
close_tx.closed().await;
eprintln!("server shutdown complete");
Ok(())
})
}
}