axum_bootstrap/
lib.rs

1use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
2pub mod error;
3pub mod init_log;
4#[cfg(feature = "jwt")]
5pub mod jwt;
6pub mod util;
7type DynError = Box<dyn std::error::Error + Send + Sync>;
8use crate::util::{
9    io::{self, create_dual_stack_listener},
10    tls::{TlsAcceptor, tls_config},
11};
12
13use axum::{
14    Router,
15    extract::Request,
16    response::{IntoResponse, Response},
17};
18
19use hyper::body::Incoming;
20use hyper_util::rt::TokioExecutor;
21use log::{info, warn};
22use tokio::{
23    sync::broadcast::{self, Receiver, Sender, error::RecvError},
24    time,
25};
26use tokio_rustls::rustls::ServerConfig;
27use tower::{Service, ServiceExt};
28use util::format::SocketAddrFormat;
29
30const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
31const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
32
33pub struct Server<I: ReqInterceptor = DummyInterceptor> {
34    pub port: u16,
35    pub tls_param: Option<TlsParam>,
36    router: Router,
37    pub interceptor: Option<I>,
38    pub idle_timeout: Duration,
39    shutdown_rx: broadcast::Receiver<()>,
40}
41
42#[derive(Debug, Clone)]
43pub struct TlsParam {
44    pub tls: bool,
45    pub cert: String,
46    pub key: String,
47}
48
49pub enum InterceptResult<T: IntoResponse> {
50    Return(Response),
51    Drop,
52    Continue(Request<Incoming>),
53    Error(T),
54}
55
56pub trait ReqInterceptor: Send {
57    type Error: IntoResponse + Send + Sync + 'static;
58    fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
59}
60
61#[derive(Clone)]
62pub struct DummyInterceptor;
63
64impl ReqInterceptor for DummyInterceptor {
65    type Error = error::AppError;
66
67    async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
68        InterceptResult::Continue(req)
69    }
70}
71
72pub type DefaultServer = Server<DummyInterceptor>;
73
74pub fn new_server(port: u16, router: Router, shutdown_rx: broadcast::Receiver<()>) -> Server {
75    Server {
76        port,
77        tls_param: None, // No TLS by default
78        router,
79        interceptor: None,
80        idle_timeout: Duration::from_secs(120),
81        shutdown_rx,
82    }
83}
84
85impl<I> Server<I>
86where
87    I: ReqInterceptor + Clone + Send + Sync + 'static,
88{
89    pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
90    where
91        R: ReqInterceptor + Clone + Send + Sync + 'static,
92    {
93        Server::<R> {
94            port: self.port,
95            tls_param: self.tls_param,
96            router: self.router,
97            interceptor: Some(interceptor),
98            idle_timeout: self.idle_timeout, // keep the same idle timeout
99            shutdown_rx: self.shutdown_rx,
100        }
101    }
102    pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
103        // Enable TLS by setting the tls_param
104        self.tls_param = tls_param;
105        self
106    }
107
108    pub fn with_timeout(mut self, timeout: Duration) -> Self {
109        self.idle_timeout = timeout;
110        self
111    }
112
113    pub async fn run(mut self) -> Result<(), std::io::Error> {
114        let use_tls = match self.tls_param.clone() {
115            Some(config) => config.tls,
116            None => false,
117        };
118        log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
119        let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
120        let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
121        match use_tls {
122            #[allow(clippy::expect_used)]
123            true => {
124                serve_tls(
125                    &self.router,
126                    server,
127                    graceful,
128                    self.port,
129                    self.tls_param.as_ref().expect("should be some"),
130                    self.interceptor.clone(),
131                    self.idle_timeout,
132                    &mut self.shutdown_rx,
133                )
134                .await?
135            }
136            false => {
137                serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout, &mut self.shutdown_rx).await?
138            }
139        }
140        Ok(())
141    }
142}
143
144async fn handle<I>(
145    request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
146    interceptor: Option<I>,
147) -> std::result::Result<Response, std::io::Error>
148where
149    I: ReqInterceptor + Clone + Send + Sync + 'static,
150{
151    if let Some(interceptor) = interceptor {
152        match interceptor.intercept(request, client_socket_addr).await {
153            InterceptResult::Return(res) => Ok(res),
154            InterceptResult::Drop => Err(std::io::Error::other("Request dropped by interceptor")),
155            InterceptResult::Continue(req) => app
156                .oneshot(req)
157                .await
158                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
159            InterceptResult::Error(err) => {
160                let res = err.into_response();
161                Ok(res)
162            }
163        }
164    } else {
165        app.oneshot(request)
166            .await
167            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
168    }
169}
170
171async fn handle_connection<C, I>(
172    conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
173    interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
174) where
175    C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
176    I: ReqInterceptor + Clone + Send + Sync + 'static,
177{
178    let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
179    use hyper::Request;
180    use hyper_util::rt::TokioIo;
181    let stream = TokioIo::new(timeout_io);
182    let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
183    let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
184    // https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs#L81
185    let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
186        handle(request, client_socket_addr, app.clone(), interceptor.clone())
187    });
188
189    let conn = server.serve_connection_with_upgrades(stream, hyper_service);
190    let conn = graceful.watch(conn.into_owned());
191
192    tokio::spawn(async move {
193        if let Err(err) = conn.await {
194            handle_hyper_error(client_socket_addr, err);
195        }
196        log::debug!("connection dropped: {client_socket_addr}");
197    });
198}
199
200fn handle_hyper_error(client_socket_addr: SocketAddr, http_err: DynError) {
201    use std::error::Error;
202    match http_err.downcast_ref::<hyper::Error>() {
203        Some(hyper_err) => {
204            let level = if hyper_err.is_user() { log::Level::Warn } else { log::Level::Debug };
205            let source = hyper_err.source().unwrap_or(hyper_err);
206            log::log!(
207                level,
208                "[hyper {}]: {:?} from {}",
209                if hyper_err.is_user() { "user" } else { "system" },
210                source,
211                SocketAddrFormat(&client_socket_addr)
212            );
213        }
214        None => match http_err.downcast_ref::<std::io::Error>() {
215            Some(io_err) => {
216                warn!("[hyper io]: [{}] {} from {}", io_err.kind(), io_err, SocketAddrFormat(&client_socket_addr));
217            }
218            None => {
219                warn!("[hyper]: {} from {}", http_err, SocketAddrFormat(&client_socket_addr));
220            }
221        },
222    }
223}
224
225async fn serve_plantext<I>(
226    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
227    port: u16, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
228) -> Result<(), std::io::Error>
229where
230    I: ReqInterceptor + Clone + Send + Sync + 'static,
231{
232    let listener = create_dual_stack_listener(port).await?;
233    loop {
234        tokio::select! {
235            _ = shutdown_rx.recv() => {
236                info!("start graceful shutdown!");
237                drop(listener);
238                break;
239            }
240            conn = listener.accept() => {
241                match conn {
242                    Ok((conn, client_socket_addr)) => {
243                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
244                    Err(e) => {
245                        warn!("accept error:{e}");
246                    }
247                }
248            }
249        }
250    }
251    match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
252        Ok(_) => info!("Gracefully shutdown!"),
253        Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
254    }
255    Ok(())
256}
257
258#[allow(clippy::too_many_arguments)]
259async fn serve_tls<I>(
260    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
261    port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
262) -> Result<(), std::io::Error>
263where
264    I: ReqInterceptor + Clone + Send + Sync + 'static,
265{
266    let (tx, mut rx) = broadcast::channel::<Arc<ServerConfig>>(1);
267    let tls_param_clone = tls_param.clone();
268    tokio::spawn(async move {
269        info!("update tls config every {REFRESH_INTERVAL:?}");
270        loop {
271            time::sleep(REFRESH_INTERVAL).await;
272            if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
273                info!("update tls config");
274                if let Err(e) = tx.send(new_acceptor) {
275                    warn!("send tls config error:{e}");
276                }
277            }
278        }
279    });
280    let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
281    loop {
282        tokio::select! {
283            _ = shutdown_rx.recv() => {
284                info!("start graceful shutdown!");
285                drop(acceptor);
286                break;
287            }
288            message = rx.recv() => {
289                match message {
290                    Ok(new_config) => {
291                        acceptor.replace_config(new_config);
292                        info!("replaced tls config");
293                    },
294                    Err(e) => {
295                        match e {
296                            RecvError::Closed => {
297                                warn!("this channel should not be closed!");
298                                break;
299                            },
300                            RecvError::Lagged(n) => {
301                                warn!("lagged {n} messages, this may cause tls config not updated in time");
302                            }
303                        }
304                    }
305                }
306            }
307            conn = acceptor.accept() => {
308                match conn {
309                    Ok((conn, client_socket_addr)) => {
310                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
311                    Err(e) => {
312                        warn!("accept error:{e}");
313                    }
314                }
315            }
316        }
317    }
318    match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
319        Ok(_) => info!("Gracefully shutdown!"),
320        Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
321    }
322    Ok(())
323}
324
325pub fn generate_shutdown_receiver() -> Receiver<()> {
326    let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
327    subscribe_shutdown_sender(shutdown_tx);
328    shutdown_rx
329}
330
331pub fn subscribe_shutdown_sender(shutdown_tx: Sender<()>) {
332    tokio::spawn(async move {
333        match wait_signal().await {
334            Ok(_) => {
335                let _ = shutdown_tx.send(());
336            }
337            Err(e) => {
338                log::error!("wait_signal error: {}", e);
339                panic!("wait_signal error: {}", e);
340            }
341        }
342    });
343}
344
345#[cfg(unix)]
346pub(crate) async fn wait_signal() -> Result<(), DynError> {
347    use log::info;
348    use tokio::signal::unix::{SignalKind, signal};
349    let mut terminate_signal = signal(SignalKind::terminate())?;
350    tokio::select! {
351        _ = terminate_signal.recv() => {
352            info!("receive terminate signal");
353        },
354        _ = tokio::signal::ctrl_c() => {
355            info!("receive ctrl_c signal");
356        },
357    };
358    Ok(())
359}
360
361#[cfg(windows)]
362pub(crate) async fn wait_signal() -> Result<(), DynError> {
363    let _ = tokio::signal::ctrl_c().await;
364    info!("receive ctrl_c signal");
365    Ok(())
366}
367
368fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
369    match result {
370        Ok(value) => value,
371        Err(err) => match err {},
372    }
373}