axum_bootstrap/
lib.rs

1use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
2
3pub mod error;
4pub mod init_log;
5pub mod util;
6type DynError = Box<dyn std::error::Error + Send + Sync>;
7use crate::util::{
8    io::{self, create_dual_stack_listener},
9    tls::{tls_config, TlsAcceptor},
10};
11
12use axum::{
13    extract::Request,
14    response::{IntoResponse, Response},
15    Router,
16};
17
18use hyper::body::Incoming;
19use hyper_util::rt::TokioExecutor;
20use log::{info, warn};
21use tokio::{
22    sync::{broadcast, mpsc},
23    time,
24};
25use tokio_rustls::rustls::ServerConfig;
26use tower::{Service, ServiceExt};
27use util::format::SocketAddrFormat;
28
29const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
30const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
31
32pub struct Server<I: ReqInterceptor = DummyInterceptor> {
33    pub port: u16,
34    pub tls_param: Option<TlsParam>,
35    router: Router,
36    pub interceptor: Option<I>,
37    pub idle_timeout: Duration,
38    shutdown_rx: mpsc::Receiver<()>,
39}
40
41#[derive(Debug, Clone)]
42pub struct TlsParam {
43    pub tls: bool,
44    pub cert: String,
45    pub key: String,
46}
47
48pub enum InterceptResult<T: IntoResponse> {
49    Return(Response),
50    Drop,
51    Continue(Request<Incoming>),
52    Error(T),
53}
54
55pub trait ReqInterceptor: Send {
56    type Error: IntoResponse + Send + Sync + 'static;
57    fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
58}
59
60#[derive(Clone)]
61pub struct DummyInterceptor;
62
63impl ReqInterceptor for DummyInterceptor {
64    type Error = error::AppError;
65
66    async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
67        InterceptResult::Continue(req)
68    }
69}
70
71pub type DefaultServer = Server<DummyInterceptor>;
72
73pub fn new_server(port: u16, router: Router) -> (Server, mpsc::Sender<()>) {
74    let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
75    let server = Server {
76        port,
77        tls_param: None, // No TLS by default
78        router,
79        interceptor: None,
80        idle_timeout: Duration::from_secs(120),
81        shutdown_rx,
82    };
83    (server, shutdown_tx)
84}
85
86impl<I> Server<I>
87where
88    I: ReqInterceptor + Clone + Send + Sync + 'static,
89{
90    pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
91    where
92        R: ReqInterceptor + Clone + Send + Sync + 'static,
93    {
94        Server::<R> {
95            port: self.port,
96            tls_param: self.tls_param,
97            router: self.router,
98            interceptor: Some(interceptor),
99            idle_timeout: self.idle_timeout, // keep the same idle timeout
100            shutdown_rx: self.shutdown_rx,
101        }
102    }
103    pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
104        // Enable TLS by setting the tls_param
105        self.tls_param = tls_param;
106        self
107    }
108
109    pub fn with_timeout(mut self, timeout: Duration) -> Self {
110        self.idle_timeout = timeout;
111        self
112    }
113
114    pub async fn run(mut self) -> Result<(), DynError> {
115        let use_tls = match self.tls_param.clone() {
116            Some(config) => config.tls,
117            None => false,
118        };
119        log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
120        let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
121        let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
122        match use_tls {
123            #[allow(clippy::expect_used)]
124            true => {
125                serve_tls(
126                    &self.router,
127                    server,
128                    graceful,
129                    self.port,
130                    self.tls_param.as_ref().expect("should be some"),
131                    self.interceptor.clone(),
132                    self.idle_timeout,
133                    &mut self.shutdown_rx,
134                )
135                .await?
136            }
137            false => {
138                serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout, &mut self.shutdown_rx).await?
139            }
140        }
141        Ok(())
142    }
143}
144
145async fn handle<I>(
146    request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
147    interceptor: Option<I>,
148) -> std::result::Result<Response, std::io::Error>
149where
150    I: ReqInterceptor + Clone + Send + Sync + 'static,
151{
152    if let Some(interceptor) = interceptor {
153        match interceptor.intercept(request, client_socket_addr).await {
154            InterceptResult::Return(res) => Ok(res),
155            InterceptResult::Drop => Err(std::io::Error::other("Request dropped by interceptor")),
156            InterceptResult::Continue(req) => app
157                .oneshot(req)
158                .await
159                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
160            InterceptResult::Error(err) => {
161                let res = err.into_response();
162                Ok(res)
163            }
164        }
165    } else {
166        app.oneshot(request)
167            .await
168            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
169    }
170}
171
172async fn handle_connection<C, I>(
173    conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
174    interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
175) where
176    C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
177    I: ReqInterceptor + Clone + Send + Sync + 'static,
178{
179    let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
180    use hyper::Request;
181    use hyper_util::rt::TokioIo;
182    let stream = TokioIo::new(timeout_io);
183    let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
184    let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
185    // https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs#L81
186    let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
187        handle(request, client_socket_addr, app.clone(), interceptor.clone())
188    });
189
190    let conn = server.serve_connection_with_upgrades(stream, hyper_service);
191    let conn = graceful.watch(conn.into_owned());
192
193    tokio::spawn(async move {
194        if let Err(err) = conn.await {
195            handle_hyper_error(client_socket_addr, err);
196        }
197        log::debug!("connection dropped: {client_socket_addr}");
198    });
199}
200
201fn handle_hyper_error(client_socket_addr: SocketAddr, http_err: DynError) {
202    use std::error::Error;
203    match http_err.downcast_ref::<hyper::Error>() {
204        Some(hyper_err) => {
205            let level = if hyper_err.is_user() { log::Level::Warn } else { log::Level::Debug };
206            let source = hyper_err.source().unwrap_or(hyper_err);
207            log::log!(
208                level,
209                "[hyper {}]: {:?} from {}",
210                if hyper_err.is_user() { "user" } else { "system" },
211                source,
212                SocketAddrFormat(&client_socket_addr)
213            );
214        }
215        None => match http_err.downcast_ref::<std::io::Error>() {
216            Some(io_err) => {
217                warn!("[hyper io]: [{}] {} from {}", io_err.kind(), io_err, SocketAddrFormat(&client_socket_addr));
218            }
219            None => {
220                warn!("[hyper]: {} from {}", http_err, SocketAddrFormat(&client_socket_addr));
221            }
222        },
223    }
224}
225
226async fn serve_plantext<I>(
227    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
228    port: u16, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut mpsc::Receiver<()>,
229) -> Result<(), DynError>
230where
231    I: ReqInterceptor + Clone + Send + Sync + 'static,
232{
233    let listener = create_dual_stack_listener(port).await?;
234    loop {
235        tokio::select! {
236            _ = shutdown_rx.recv() => {
237                info!("start graceful shutdown!");
238                drop(listener);
239                break;
240            }
241            conn = listener.accept() => {
242                match conn {
243                    Ok((conn, client_socket_addr)) => {
244                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
245                    Err(e) => {
246                        warn!("accept error:{e}");
247                    }
248                }
249            }
250        }
251    }
252    tokio::select! {
253        _ = graceful.shutdown() => {
254            info!("Gracefully shutdown!");
255        },
256        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
257            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
258        }
259    }
260    Ok(())
261}
262
263#[allow(clippy::too_many_arguments)]
264async fn serve_tls<I>(
265    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
266    port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut mpsc::Receiver<()>,
267) -> Result<(), DynError>
268where
269    I: ReqInterceptor + Clone + Send + Sync + 'static,
270{
271    let (tx, _rx) = broadcast::channel::<Arc<ServerConfig>>(10);
272    let tx_clone = tx.clone();
273    let tls_param_clone = tls_param.clone();
274    tokio::spawn(async move {
275        info!("update tls config every {REFRESH_INTERVAL:?}");
276        loop {
277            time::sleep(REFRESH_INTERVAL).await;
278            if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
279                info!("update tls config");
280                if let Err(e) = tx.send(new_acceptor) {
281                    warn!("send tls config error:{e}");
282                }
283            }
284        }
285    });
286    let mut rx = tx_clone.subscribe();
287    let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
288    loop {
289        tokio::select! {
290            _ = shutdown_rx.recv() => {
291                info!("start graceful shutdown!");
292                drop(acceptor);
293                break;
294            }
295            message = rx.recv() => {
296                #[allow(clippy::expect_used)]
297                let new_config = message.expect("Channel should not be closed");
298                // Replace the acceptor with the new one
299                acceptor.replace_config(new_config);
300                info!("replaced tls config");
301            }
302            conn = acceptor.accept() => {
303                match conn {
304                    Ok((conn, client_socket_addr)) => {
305                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
306                    Err(e) => {
307                        warn!("accept error:{e}");
308                    }
309                }
310            }
311        }
312    }
313    tokio::select! {
314        _ = graceful.shutdown() => {
315            info!("Gracefully shutdown!");
316        },
317        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
318            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
319        }
320    }
321    Ok(())
322}
323
324#[cfg(unix)]
325pub async fn wait_signal() -> Result<(), DynError> {
326    use log::info;
327    use tokio::signal::unix::{signal, SignalKind};
328    let mut terminate_signal = signal(SignalKind::terminate())?;
329    tokio::select! {
330        _ = terminate_signal.recv() => {
331            info!("receive terminate signal");
332        },
333        _ = tokio::signal::ctrl_c() => {
334            info!("receive ctrl_c signal");
335        },
336    };
337    Ok(())
338}
339
340#[cfg(windows)]
341pub async fn wait_signal() -> Result<(), DynError> {
342    let _ = tokio::signal::ctrl_c().await;
343    info!("receive ctrl_c signal");
344    Ok(())
345}
346
347fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
348    match result {
349        Ok(value) => value,
350        Err(err) => match err {},
351    }
352}