use crate::{
body::SimpleBody, graceful_shutdown::GracefulShutdown, https_redirect::HttpsRedirectService,
};
use anyhow::Result;
use http::HeaderValue;
use hyper::{body::Incoming, service::Service, Request, Response};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder as ServerBuilder,
};
use rustls::{server::ResolvesServerCert, ServerConfig};
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
time::Duration,
};
use tokio::{net::TcpListener, select};
use tokio_rustls::TlsAcceptor;
const X_FORWARDED_FOR: &str = "x-forwarded-for";
const X_FORWARDED_PROTO: &str = "x-forwarded-proto";
pub struct SimpleHttpServer {
handle: tokio::task::JoinHandle<()>,
graceful_shutdown: Option<GracefulShutdown>,
}
async fn listen_loop<S>(listener: TcpListener, service: S, graceful_shutdown: GracefulShutdown)
where
S: Service<Request<Incoming>, Response = Response<SimpleBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let mut recv = graceful_shutdown.subscribe();
loop {
let stream = select! {
stream = listener.accept() => stream,
_ = recv.changed() => break,
};
let (stream, remote_addr) = match stream {
Ok((stream, remote_addr)) => (stream, remote_addr),
Err(e) => {
tracing::warn!(?e, "Failed to accept connection.");
continue;
}
};
let remote_ip = remote_addr.ip();
let service = WrappedService::new(service.clone(), remote_ip, "http");
let server = ServerBuilder::new(TokioExecutor::new());
let io = TokioIo::new(stream);
let conn = server.serve_connection_with_upgrades(io, service);
let conn = graceful_shutdown.watch(conn.into_owned());
tokio::spawn(async {
if let Err(e) = conn.await {
tracing::warn!(?e, "Failed to serve connection.");
}
});
}
}
async fn listen_loop_tls<S>(
listener: TcpListener,
service: S,
resolver: Arc<dyn ResolvesServerCert>,
graceful_shutdown: GracefulShutdown,
) where
S: Service<Request<Incoming>, Response = Response<SimpleBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let server_config = ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(resolver);
let tls_acceptor = TlsAcceptor::from(Arc::new(server_config));
let mut recv = graceful_shutdown.subscribe();
loop {
let stream = select! {
stream = listener.accept() => stream,
_ = recv.changed() => break,
};
let (stream, remote_addr) = match stream {
Ok((stream, remote_addr)) => (stream, remote_addr),
Err(e) => {
tracing::warn!(?e, "Failed to accept connection.");
continue;
}
};
let remote_ip = remote_addr.ip();
let service = WrappedService::new(service.clone(), remote_ip, "https");
let tls_acceptor = tls_acceptor.clone();
let graceful_shutdown = graceful_shutdown.clone();
tokio::spawn(async move {
let server = ServerBuilder::new(TokioExecutor::new());
let stream = match tls_acceptor.accept(stream).await {
Ok(stream) => stream,
Err(e) => {
tracing::warn!(?e, "Failed to accept TLS connection.");
return;
}
};
let io = TokioIo::new(stream);
let conn = server.serve_connection_with_upgrades(io, service);
let conn = graceful_shutdown.watch(conn.into_owned());
if let Err(e) = conn.await {
tracing::warn!(?e, "Failed to serve connection.");
}
});
}
}
pub enum HttpsConfig {
Http,
Https {
resolver: Arc<dyn ResolvesServerCert>,
},
}
impl HttpsConfig {
pub fn from_resolver<R: ResolvesServerCert + 'static>(resolver: R) -> Self {
Self::Https {
resolver: Arc::new(resolver),
}
}
pub fn http() -> Self {
Self::Http
}
}
impl SimpleHttpServer {
pub fn new<S>(service: S, listener: TcpListener, https_config: HttpsConfig) -> Result<Self>
where
S: Service<Request<Incoming>, Response = Response<SimpleBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let graceful_shutdown = GracefulShutdown::new();
let handle = match https_config {
HttpsConfig::Http => {
tokio::spawn(listen_loop(listener, service, graceful_shutdown.clone()))
}
HttpsConfig::Https { resolver } => {
if rustls::crypto::ring::default_provider()
.install_default()
.is_err()
{
tracing::info!("Using already-installed crypto provider.")
}
tokio::spawn(listen_loop_tls(
listener,
service,
resolver,
graceful_shutdown.clone(),
))
}
};
Ok(Self {
handle,
graceful_shutdown: Some(graceful_shutdown),
})
}
pub async fn graceful_shutdown(mut self) {
println!("Shutting down");
let graceful_shutdown = self
.graceful_shutdown
.take()
.expect("self.graceful_shutdown is always set");
graceful_shutdown.shutdown().await;
}
pub async fn graceful_shutdown_with_timeout(mut self, timeout: Duration) {
let graceful_shutdown = self
.graceful_shutdown
.take()
.expect("self.graceful_shutdown is always set");
let result = tokio::time::timeout(timeout, graceful_shutdown.shutdown()).await;
if let Err(e) = result {
tracing::warn!(?e, "Timed out waiting for graceful shutdown, aborting.");
}
}
}
impl Drop for SimpleHttpServer {
fn drop(&mut self) {
if self.graceful_shutdown.is_some() {
tracing::warn!("Shutting down SimpleHttpServer without a call to graceful_shutdown. Connections will be dropped abruptly!");
}
self.handle.abort();
}
}
pub struct ServerWithHttpRedirect {
http_server: SimpleHttpServer,
https_server: Option<SimpleHttpServer>,
}
pub struct ServerWithHttpRedirectHttpsConfig {
pub https_port: u16,
pub resolver: Arc<dyn ResolvesServerCert>,
}
pub struct ServerWithHttpRedirectConfig {
pub http_port: u16,
pub https_config: Option<ServerWithHttpRedirectHttpsConfig>,
}
impl ServerWithHttpRedirect {
pub async fn new<S>(service: S, server_config: ServerWithHttpRedirectConfig) -> Result<Self>
where
S: Service<Request<Incoming>, Response = Response<SimpleBody>>
+ Clone
+ Send
+ Sync
+ 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
if let Some(https_config) = server_config.https_config {
let https_listener =
TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], https_config.https_port)))
.await?;
let https_server = SimpleHttpServer::new(
service,
https_listener,
HttpsConfig::Https {
resolver: https_config.resolver,
},
)?;
let http_listener =
TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], server_config.http_port)))
.await?;
let http_server =
SimpleHttpServer::new(HttpsRedirectService, http_listener, HttpsConfig::Http)?;
Ok(Self {
http_server,
https_server: Some(https_server),
})
} else {
let listener =
TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], server_config.http_port)))
.await?;
let http_server = SimpleHttpServer::new(service, listener, HttpsConfig::Http)?;
Ok(Self {
http_server,
https_server: None,
})
}
}
pub async fn graceful_shutdown_with_timeout(self, timeout: Duration) {
if let Some(https_server) = self.https_server {
tokio::join!(
self.http_server.graceful_shutdown_with_timeout(timeout),
https_server.graceful_shutdown_with_timeout(timeout)
);
} else {
self.http_server
.graceful_shutdown_with_timeout(timeout)
.await;
}
}
}
struct WrappedService<S> {
inner: S,
forwarded_for: IpAddr,
forwarded_proto: &'static str,
}
impl<S> WrappedService<S> {
pub fn new(inner: S, forwarded_for: IpAddr, forwarded_proto: &'static str) -> Self {
Self {
inner,
forwarded_for,
forwarded_proto,
}
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for WrappedService<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn call(&self, request: Request<ReqBody>) -> Self::Future {
let mut request = request;
request.headers_mut().insert(
X_FORWARDED_FOR,
HeaderValue::from_str(&format!("{}", self.forwarded_for))
.expect("X-Forwarded-For is always valid"),
);
request.headers_mut().insert(
X_FORWARDED_PROTO,
HeaderValue::from_str(self.forwarded_proto).expect("X-Forwarded-Proto is always valid"),
);
self.inner.call(request)
}
}