use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
pub mod error;
pub mod init_log;
#[cfg(feature = "jwt")]
pub mod jwt;
pub mod util;
type DynError = Box<dyn std::error::Error + Send + Sync>;
use crate::util::{
io::{self, create_dual_stack_listener},
tls::{TlsAcceptor, tls_config},
};
use axum::{
Router,
extract::Request,
response::{IntoResponse, Response},
};
use hyper::body::Incoming;
use hyper_util::rt::TokioExecutor;
use log::{info, warn};
use tokio::{
sync::broadcast::{self, Receiver, Sender, error::RecvError},
time,
};
use tokio_rustls::rustls::ServerConfig;
use tower::{Service, ServiceExt};
use util::format::SocketAddrFormat;
const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
pub struct Server<I: ReqInterceptor = DummyInterceptor> {
pub port: u16,
pub tls_param: Option<TlsParam>,
router: Router,
pub interceptor: Option<I>,
pub idle_timeout: Duration,
shutdown_rx: broadcast::Receiver<()>,
}
#[derive(Debug, Clone)]
pub struct TlsParam {
pub tls: bool,
pub cert: String,
pub key: String,
}
pub enum InterceptResult<T: IntoResponse> {
Return(Response),
Drop,
Continue(Request<Incoming>),
Error(T),
}
pub trait ReqInterceptor: Send {
type Error: IntoResponse + Send + Sync + 'static;
fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
}
#[derive(Clone)]
pub struct DummyInterceptor;
impl ReqInterceptor for DummyInterceptor {
type Error = error::AppError;
async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
InterceptResult::Continue(req)
}
}
pub type DefaultServer = Server<DummyInterceptor>;
pub fn new_server(port: u16, router: Router, shutdown_rx: broadcast::Receiver<()>) -> Server {
Server {
port,
tls_param: None, router,
interceptor: None,
idle_timeout: Duration::from_secs(120),
shutdown_rx,
}
}
impl<I> Server<I>
where
I: ReqInterceptor + Clone + Send + Sync + 'static,
{
pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
where
R: ReqInterceptor + Clone + Send + Sync + 'static,
{
Server::<R> {
port: self.port,
tls_param: self.tls_param,
router: self.router,
interceptor: Some(interceptor),
idle_timeout: self.idle_timeout, shutdown_rx: self.shutdown_rx,
}
}
pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
self.tls_param = tls_param;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = timeout;
self
}
pub async fn run(mut self) -> Result<(), std::io::Error> {
let use_tls = match self.tls_param.clone() {
Some(config) => config.tls,
None => false,
};
log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
match use_tls {
#[allow(clippy::expect_used)]
true => {
serve_tls(
&self.router,
server,
graceful,
self.port,
self.tls_param.as_ref().expect("should be some"),
self.interceptor.clone(),
self.idle_timeout,
&mut self.shutdown_rx,
)
.await?
}
false => {
serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout, &mut self.shutdown_rx).await?
}
}
Ok(())
}
}
async fn handle<I>(
request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
interceptor: Option<I>,
) -> std::result::Result<Response, std::io::Error>
where
I: ReqInterceptor + Clone + Send + Sync + 'static,
{
if let Some(interceptor) = interceptor {
match interceptor.intercept(request, client_socket_addr).await {
InterceptResult::Return(res) => Ok(res),
InterceptResult::Drop => Err(std::io::Error::other("Request dropped by interceptor")),
InterceptResult::Continue(req) => app
.oneshot(req)
.await
.map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
InterceptResult::Error(err) => {
let res = err.into_response();
Ok(res)
}
}
} else {
app.oneshot(request)
.await
.map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
}
}
async fn handle_connection<C, I>(
conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
) where
C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
I: ReqInterceptor + Clone + Send + Sync + 'static,
{
let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
use hyper::Request;
use hyper_util::rt::TokioIo;
let stream = TokioIo::new(timeout_io);
let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
handle(request, client_socket_addr, app.clone(), interceptor.clone())
});
let conn = server.serve_connection_with_upgrades(stream, hyper_service);
let conn = graceful.watch(conn.into_owned());
tokio::spawn(async move {
if let Err(err) = conn.await {
handle_hyper_error(client_socket_addr, err);
}
log::debug!("dropped: {client_socket_addr}");
});
}
fn handle_hyper_error(client_socket_addr: SocketAddr, http_err: DynError) {
use std::error::Error;
match http_err.downcast_ref::<hyper::Error>() {
Some(hyper_err) => {
let level = if hyper_err.is_user() { log::Level::Warn } else { log::Level::Debug };
let source = hyper_err.source().unwrap_or(hyper_err);
log::log!(
level,
"[hyper {}]: {:?} from {}",
if hyper_err.is_user() { "user" } else { "system" },
source,
SocketAddrFormat(&client_socket_addr)
);
}
None => match http_err.downcast_ref::<std::io::Error>() {
Some(io_err) => {
warn!("[hyper io]: [{}] {} from {}", io_err.kind(), io_err, SocketAddrFormat(&client_socket_addr));
}
None => {
warn!("[hyper]: {} from {}", http_err, SocketAddrFormat(&client_socket_addr));
}
},
}
}
async fn serve_plantext<I>(
app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
port: u16, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
) -> Result<(), std::io::Error>
where
I: ReqInterceptor + Clone + Send + Sync + 'static,
{
let listener = create_dual_stack_listener(port).await?;
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
info!("start graceful shutdown!");
drop(listener);
break;
}
conn = listener.accept() => {
match conn {
Ok((conn, client_socket_addr)) => {
handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
Err(e) => {
warn!("accept error:{e}");
}
}
}
}
}
match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
Ok(_) => info!("Gracefully shutdown!"),
Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn serve_tls<I>(
app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
) -> Result<(), std::io::Error>
where
I: ReqInterceptor + Clone + Send + Sync + 'static,
{
let (tx, mut rx) = broadcast::channel::<Arc<ServerConfig>>(1);
let tls_param_clone = tls_param.clone();
tokio::spawn(async move {
info!("update tls config every {REFRESH_INTERVAL:?}");
loop {
time::sleep(REFRESH_INTERVAL).await;
if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
info!("update tls config");
if let Err(e) = tx.send(new_acceptor) {
warn!("send tls config error:{e}");
}
}
}
});
let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
info!("start graceful shutdown!");
drop(acceptor);
break;
}
message = rx.recv() => {
match message {
Ok(new_config) => {
acceptor.replace_config(new_config);
info!("replaced tls config");
},
Err(e) => {
match e {
RecvError::Closed => {
warn!("this channel should not be closed!");
break;
},
RecvError::Lagged(n) => {
warn!("lagged {n} messages, this may cause tls config not updated in time");
}
}
}
}
}
conn = acceptor.accept() => {
match conn {
Ok((conn, client_socket_addr)) => {
handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
Err(e) => {
warn!("accept error:{e}");
}
}
}
}
}
match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
Ok(_) => info!("Gracefully shutdown!"),
Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
}
Ok(())
}
pub fn generate_shutdown_receiver() -> Receiver<()> {
let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
subscribe_shutdown_sender(shutdown_tx);
shutdown_rx
}
pub fn subscribe_shutdown_sender(shutdown_tx: Sender<()>) {
tokio::spawn(async move {
match wait_signal().await {
Ok(_) => {
let _ = shutdown_tx.send(());
}
Err(e) => {
log::error!("wait_signal error: {}", e);
panic!("wait_signal error: {}", e);
}
}
});
}
#[cfg(unix)]
pub(crate) async fn wait_signal() -> Result<(), DynError> {
use log::info;
use tokio::signal::unix::{SignalKind, signal};
let mut terminate_signal = signal(SignalKind::terminate())?;
tokio::select! {
_ = terminate_signal.recv() => {
info!("receive terminate signal");
},
_ = tokio::signal::ctrl_c() => {
info!("receive ctrl_c signal");
},
};
Ok(())
}
#[cfg(windows)]
pub(crate) async fn wait_signal() -> Result<(), DynError> {
let _ = tokio::signal::ctrl_c().await;
info!("receive ctrl_c signal");
Ok(())
}
fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
match result {
Ok(value) => value,
Err(err) => match err {},
}
}