use std::{convert::Infallible, pin::Pin, sync::Arc};
use futures::FutureExt;
use http_body_util::{BodyExt, combinators::BoxBody};
use hyper::{Request, Response, body::{Bytes, Incoming}, service::service_fn};
use hyper_util::rt::TokioIo;
use matchit::Router;
use tokio::{net::{TcpListener, TcpStream}, signal::unix::{signal, SignalKind}};
use tokio_rustls::TlsAcceptor;
use crate::http::{executor::Executor, http_request::HttpRequest, http_response::HttpResponse, server::http_server_config::HttpServerConfig};
type RouteCallback = Arc<dyn Fn(HttpRequest) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>> + Send + Sync>;
type BeforeCallback = Arc<dyn Fn(HttpRequest) -> Pin<Box<dyn Future<Output = BeforeResult> + Send>> + Send + Sync>;
type AfterCallback = Arc<dyn Fn(HttpResponse) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>> + Send + Sync>;
type ErrorCallback = Arc<dyn Fn(String) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>> + Send + Sync>;
pub struct HttpServer {
config: HttpServerConfig,
router: Router<RouteCallback>,
before: Vec<BeforeCallback>,
after: Vec<AfterCallback>,
on_error: Option<ErrorCallback>,
}
impl HttpServer {
pub fn builder(config: HttpServerConfig) -> HttpServerBuilder {
HttpServerBuilder {
config,
router: Router::new(),
before: Vec::new(),
after: Vec::new(),
on_error: None,
}
}
pub async fn run(self) {
let tls_acceptor = self.config.tls_config.map(|tls_config| {
TlsAcceptor::from(Arc::new(tls_config))
});
let host = format!("{}:{}", self.config.ip, self.config.port);
let listener = TcpListener::bind(&host).await.expect("Failed to start TCP Listener");
let mut sigterm = signal(SignalKind::terminate()).expect("Failed to start SIGTERM signal receiver");
let mut sigint = signal(SignalKind::interrupt()).expect("Failed to start SIGINT signal receiver");
let router = Arc::new(self.router);
let before: Arc<[BeforeCallback]> = self.before.into();
let after: Arc<[AfterCallback]> = self.after.into();
let on_error = self.on_error;
tracing::trace!("Started on {}", &host);
loop {
tokio::select! {
_ = sigterm.recv() => {
drop(listener);
break;
},
_ = sigint.recv() => {
drop(listener);
break;
},
result = listener.accept() => {
let tls_acceptor = tls_acceptor.clone();
let router = router.clone();
let before = before.clone();
let after = after.clone();
let on_error = on_error.clone();
let (tcp_stream, _client_addr) = match result {
Ok(pair) => pair,
Err(err) => {
tracing::error!("{:?}", err);
continue;
},
};
match tls_acceptor {
Some(acceptor) => {
tokio::spawn(Self::tls_connection(acceptor, tcp_stream, router, before, after, on_error));
},
None => {
tokio::spawn(Self::tcp_connection(tcp_stream, router, before, after, on_error));
},
}
}
}
}
tracing::trace!("Shut down complete");
}
async fn tcp_connection(tcp_stream: TcpStream, router: Arc<Router<RouteCallback>>, before: Arc<[BeforeCallback]>, after: Arc<[AfterCallback]>, on_error: Option<ErrorCallback>) {
let service = {
let router = router.clone();
service_fn(move |req| {
Self::incoming_request(req, router.clone(), before.clone(), after.clone(), on_error.clone())
})
};
let io = TokioIo::new(tcp_stream);
if let Err(err) = hyper::server::conn::http1::Builder::new().serve_connection(io, service).await {
tracing::error!("{:?}", err);
}
}
async fn tls_connection(tls_acceptor: TlsAcceptor, tcp_stream: TcpStream, router: Arc<Router<RouteCallback>>, before: Arc<[BeforeCallback]>, after: Arc<[AfterCallback]>, on_error: Option<ErrorCallback>) {
let tls_stream = match tls_acceptor.accept(tcp_stream).await {
Ok(stream) => stream,
Err(err) => {
tracing::error!("TLS handshake failed {:?}", err);
return;
},
};
let service = {
let router = router.clone();
service_fn(move |req| {
Self::incoming_request(req, router.clone(), before.clone(), after.clone(), on_error.clone())
})
};
let io = TokioIo::new(tls_stream);
let protocol = io.inner().get_ref().1.alpn_protocol();
match protocol {
Some(b"h2") => {
if let Err(err) = hyper::server::conn::http2::Builder::new(Executor).serve_connection(io, service).await {
tracing::error!("TLS handshake failed {:?}", err);
}
}
_ => {
if let Err(err) = hyper::server::conn::http1::Builder::new().serve_connection(io, service).await {
tracing::error!("{:?}", err);
}
}
}
}
async fn incoming_request(request: Request<Incoming>, router: Arc<Router<RouteCallback>>, before: Arc<[BeforeCallback]>, after: Arc<[AfterCallback]>, on_error: Option<ErrorCallback>) -> Result<Response<BoxBody<Bytes, anyhow::Error>>, Infallible> {
let result = std::panic::AssertUnwindSafe(Self::inner_request(request, router, before, after)).catch_unwind().await;
match result {
Ok(response) => response,
Err(err) => {
let error = if let Some(s) = err.downcast_ref::<String>() {
s.as_str()
} else if let Some(s) = err.downcast_ref::<&str>() {
s
} else {
"Unknown panic!"
};
let response = match on_error {
Some(handler) => {
handler(error.to_string()).await
},
None => {
tracing::error!("{:?}", error);
HttpResponse::builder().status(500).body_empty().unwrap()
},
};
Ok(Response::from(response))
}
}
}
async fn inner_request(request: Request<Incoming>, router: Arc<Router<RouteCallback>>, before: Arc<[BeforeCallback]>, after: Arc<[AfterCallback]>) -> Result<Response<BoxBody<Bytes, anyhow::Error>>, Infallible> {
let (parts, body) = request.into_parts();
let path = parts.uri.path().to_owned();
match router.at(&path) {
Ok(matched) => {
let params: Vec<(String, String)> = matched.params.iter().map(|(k, v)| (k.to_string(), v.to_string())).collect();
let body = body.map_err(anyhow::Error::from);
let mut req = HttpRequest::from_parts_with_params(body.boxed(), parts, params);
for handler in before.iter() {
match handler(req).await {
BeforeResult::Next(request) => {
req = request;
},
BeforeResult::Response(response) => {
let mut response = response;
for handler in after.iter() {
response = handler(response).await;
}
return Ok(Response::from(response))
},
}
}
let callback = matched.value;
let mut response = callback(req).await;
for handler in after.iter() {
response = handler(response).await;
}
Ok(Response::from(response))
},
Err(_) => {
let response = HttpResponse::builder().status(404).body_empty().unwrap();
Ok(Response::from(response))
},
}
}
}
pub enum BeforeResult {
Next(HttpRequest),
Response(HttpResponse),
}
impl From<HttpRequest> for BeforeResult {
fn from(req: HttpRequest) -> Self {
BeforeResult::Next(req)
}
}
impl From<HttpResponse> for BeforeResult {
fn from(res: HttpResponse) -> Self {
BeforeResult::Response(res)
}
}
pub struct HttpServerBuilder {
config: HttpServerConfig,
router: Router<RouteCallback>,
before: Vec<BeforeCallback>,
after: Vec<AfterCallback>,
on_error: Option<ErrorCallback>,
}
impl HttpServerBuilder {
pub fn before<T, Fut, R>(mut self, callback: T) -> Self
where
T: Fn(HttpRequest) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: Into<BeforeResult> + 'static,
{
let callback = Arc::new(callback);
self.before.push(Arc::new(move |request| {
let callback = Arc::clone(&callback);
Box::pin(async move {
callback(request).await.into()
})
}));
self
}
pub fn route<T, Fut>(mut self, path: impl Into<String>, callback: T) -> Self
where
T: Fn(HttpRequest) -> Fut + Send + Sync + 'static,
Fut: Future<Output = HttpResponse> + Send + 'static,
{
self.router.insert(path.into(), Arc::new(move |request| Box::pin(callback(request)))).unwrap();
self
}
pub fn after<T, Fut>(mut self, callback: T) -> Self
where
T: Fn(HttpResponse) -> Fut + Send + Sync + 'static,
Fut: Future<Output = HttpResponse> + Send + 'static,
{
self.after.push(Arc::new(move |response| Box::pin(callback(response))));
self
}
pub fn on_error<T, Fut>(mut self, callback: T) -> Self
where
T: Fn(String) -> Fut + Send + Sync + 'static,
Fut: Future<Output = HttpResponse> + Send + 'static,
{
self.on_error = Some(Arc::new(move |err| Box::pin(callback(err))));
self
}
pub fn build(self) -> HttpServer {
HttpServer {
config: self.config,
router: self.router,
before: self.before,
on_error: self.on_error,
after: self.after
}
}
}