1use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
28
29pub mod error;
31pub mod init_log;
33#[cfg(feature = "jwt")]
35pub mod jwt;
36pub mod util;
38
39type DynError = Box<dyn std::error::Error + Send + Sync>;
41
42use crate::util::{
43 io::{self, create_dual_stack_listener},
44 tls::{TlsAcceptor, tls_config},
45};
46
47use axum::{
48 Router,
49 extract::Request,
50 response::{IntoResponse, Response},
51};
52
53use hyper::body::Incoming;
54use hyper_util::rt::TokioExecutor;
55use log::{info, warn};
56use tokio::{
57 sync::broadcast::{self, Receiver, Sender, error::RecvError},
58 time,
59};
60use tokio_rustls::rustls::ServerConfig;
61use tower::{Service, ServiceExt};
62use util::format::SocketAddrFormat;
63
64const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
66
67const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
69
70pub struct Server<I: ReqInterceptor = DummyInterceptor> {
83 pub port: u16,
84 pub tls_param: Option<TlsParam>,
85 router: Router,
86 pub interceptor: Option<I>,
87 pub idle_timeout: Duration,
88 shutdown_rx: broadcast::Receiver<()>,
89}
90
91#[derive(Debug, Clone)]
98pub struct TlsParam {
99 pub tls: bool,
100 pub cert: String,
101 pub key: String,
102}
103
104pub enum InterceptResult<T: IntoResponse> {
114 Return(Response),
115 Drop,
116 Continue(Request<Incoming>),
117 Error(T),
118}
119
120pub trait ReqInterceptor: Send {
151 type Error: IntoResponse + Send + Sync + 'static;
152 fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
153}
154
155#[derive(Clone)]
159pub struct DummyInterceptor;
160
161impl ReqInterceptor for DummyInterceptor {
162 type Error = error::AppError;
163
164 async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
165 InterceptResult::Continue(req)
166 }
167}
168
169pub type DefaultServer = Server<DummyInterceptor>;
171
172pub fn new_server(port: u16, router: Router, shutdown_rx: broadcast::Receiver<()>) -> Server {
197 Server {
198 port,
199 tls_param: None, router,
201 interceptor: None,
202 idle_timeout: Duration::from_secs(120),
203 shutdown_rx,
204 }
205}
206
207impl<I> Server<I>
208where
209 I: ReqInterceptor + Clone + Send + Sync + 'static,
210{
211 pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
224 where
225 R: ReqInterceptor + Clone + Send + Sync + 'static,
226 {
227 Server::<R> {
228 port: self.port,
229 tls_param: self.tls_param,
230 router: self.router,
231 interceptor: Some(interceptor),
232 idle_timeout: self.idle_timeout, shutdown_rx: self.shutdown_rx,
234 }
235 }
236
237 pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
245 self.tls_param = tls_param;
246 self
247 }
248
249 pub fn with_timeout(mut self, timeout: Duration) -> Self {
257 self.idle_timeout = timeout;
258 self
259 }
260
261 pub async fn run(mut self) -> Result<(), std::io::Error> {
274 let use_tls = match self.tls_param.clone() {
275 Some(config) => config.tls,
276 None => false,
277 };
278 log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
279 let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
280 let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
281 match use_tls {
282 #[allow(clippy::expect_used)]
283 true => {
284 serve_tls(
285 &self.router,
286 server,
287 graceful,
288 self.port,
289 self.tls_param.as_ref().expect("should be some"),
290 self.interceptor.clone(),
291 self.idle_timeout,
292 &mut self.shutdown_rx,
293 )
294 .await?
295 }
296 false => {
297 serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout, &mut self.shutdown_rx).await?
298 }
299 }
300 Ok(())
301 }
302}
303
304async fn handle<I>(
318 request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
319 interceptor: Option<I>,
320) -> std::result::Result<Response, std::io::Error>
321where
322 I: ReqInterceptor + Clone + Send + Sync + 'static,
323{
324 if let Some(interceptor) = interceptor {
325 match interceptor.intercept(request, client_socket_addr).await {
326 InterceptResult::Return(res) => Ok(res),
327 InterceptResult::Drop => Err(std::io::Error::other("Request dropped by interceptor")),
328 InterceptResult::Continue(req) => app
329 .oneshot(req)
330 .await
331 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
332 InterceptResult::Error(err) => {
333 let res = err.into_response();
334 Ok(res)
335 }
336 }
337 } else {
338 app.oneshot(request)
339 .await
340 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
341 }
342}
343
344async fn handle_connection<C, I>(
361 conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
362 interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
363) where
364 C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
365 I: ReqInterceptor + Clone + Send + Sync + 'static,
366{
367 let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
368 use hyper::Request;
369 use hyper_util::rt::TokioIo;
370 let stream = TokioIo::new(timeout_io);
371 let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
372 let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
373 let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
375 handle(request, client_socket_addr, app.clone(), interceptor.clone())
376 });
377
378 let conn = server.serve_connection_with_upgrades(stream, hyper_service);
379 let conn = graceful.watch(conn.into_owned());
380
381 tokio::spawn(async move {
382 if let Err(err) = conn.await {
383 handle_hyper_error(client_socket_addr, err);
384 }
385 log::debug!("dropped: {client_socket_addr}");
386 });
387}
388
389fn handle_hyper_error(client_socket_addr: SocketAddr, http_err: DynError) {
397 use std::error::Error;
398 match http_err.downcast_ref::<hyper::Error>() {
399 Some(hyper_err) => {
400 let level = if hyper_err.is_user() { log::Level::Warn } else { log::Level::Debug };
401 let source = hyper_err.source().unwrap_or(hyper_err);
402 log::log!(
403 level,
404 "[hyper {}]: {:?} from {}",
405 if hyper_err.is_user() { "user" } else { "system" },
406 source,
407 SocketAddrFormat(&client_socket_addr)
408 );
409 }
410 None => match http_err.downcast_ref::<std::io::Error>() {
411 Some(io_err) => {
412 warn!("[hyper io]: [{}] {} from {}", io_err.kind(), io_err, SocketAddrFormat(&client_socket_addr));
413 }
414 None => {
415 warn!("[hyper]: {} from {}", http_err, SocketAddrFormat(&client_socket_addr));
416 }
417 },
418 }
419}
420
421async fn serve_plantext<I>(
438 app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
439 port: u16, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
440) -> Result<(), std::io::Error>
441where
442 I: ReqInterceptor + Clone + Send + Sync + 'static,
443{
444 let listener = create_dual_stack_listener(port).await?;
445 loop {
446 tokio::select! {
447 _ = shutdown_rx.recv() => {
448 info!("start graceful shutdown!");
449 drop(listener);
450 break;
451 }
452 conn = listener.accept() => {
453 match conn {
454 Ok((conn, client_socket_addr)) => {
455 handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
456 Err(e) => {
457 warn!("accept error:{e}");
458 }
459 }
460 }
461 }
462 }
463 match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
464 Ok(_) => info!("Gracefully shutdown!"),
465 Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
466 }
467 Ok(())
468}
469
470#[allow(clippy::too_many_arguments)]
491async fn serve_tls<I>(
492 app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
493 port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
494) -> Result<(), std::io::Error>
495where
496 I: ReqInterceptor + Clone + Send + Sync + 'static,
497{
498 let (tx, mut rx) = broadcast::channel::<Arc<ServerConfig>>(1);
499 let tls_param_clone = tls_param.clone();
500 tokio::spawn(async move {
501 info!("update tls config every {REFRESH_INTERVAL:?}");
502 loop {
503 time::sleep(REFRESH_INTERVAL).await;
504 if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
505 info!("update tls config");
506 if let Err(e) = tx.send(new_acceptor) {
507 warn!("send tls config error:{e}");
508 }
509 }
510 }
511 });
512 let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
513 loop {
514 tokio::select! {
515 _ = shutdown_rx.recv() => {
516 info!("start graceful shutdown!");
517 drop(acceptor);
518 break;
519 }
520 message = rx.recv() => {
521 match message {
522 Ok(new_config) => {
523 acceptor.replace_config(new_config);
524 info!("replaced tls config");
525 },
526 Err(e) => {
527 match e {
528 RecvError::Closed => {
529 warn!("this channel should not be closed!");
530 break;
531 },
532 RecvError::Lagged(n) => {
533 warn!("lagged {n} messages, this may cause tls config not updated in time");
534 }
535 }
536 }
537 }
538 }
539 conn = acceptor.accept() => {
540 match conn {
541 Ok((conn, client_socket_addr)) => {
542 handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
543 Err(e) => {
544 warn!("accept error:{e}");
545 }
546 }
547 }
548 }
549 }
550 match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
551 Ok(_) => info!("Gracefully shutdown!"),
552 Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
553 }
554 Ok(())
555}
556
557pub fn generate_shutdown_receiver() -> Receiver<()> {
573 let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
574 subscribe_shutdown_sender(shutdown_tx);
575 shutdown_rx
576}
577
578pub fn subscribe_shutdown_sender(shutdown_tx: Sender<()>) {
585 tokio::spawn(async move {
586 match wait_signal().await {
587 Ok(_) => {
588 let _ = shutdown_tx.send(());
589 }
590 Err(e) => {
591 log::error!("wait_signal error: {}", e);
592 panic!("wait_signal error: {}", e);
593 }
594 }
595 });
596}
597
598#[cfg(unix)]
606pub(crate) async fn wait_signal() -> Result<(), DynError> {
607 use log::info;
608 use tokio::signal::unix::{SignalKind, signal};
609 let mut terminate_signal = signal(SignalKind::terminate())?;
610 tokio::select! {
611 _ = terminate_signal.recv() => {
612 info!("receive terminate signal");
613 },
614 _ = tokio::signal::ctrl_c() => {
615 info!("receive ctrl_c signal");
616 },
617 };
618 Ok(())
619}
620
621#[cfg(windows)]
629pub(crate) async fn wait_signal() -> Result<(), DynError> {
630 let _ = tokio::signal::ctrl_c().await;
631 info!("receive ctrl_c signal");
632 Ok(())
633}
634
635fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
645 match result {
646 Ok(value) => value,
647 Err(err) => match err {},
648 }
649}