axum_bootstrap/
lib.rs

1use std::{fmt::Display, net::SocketAddr, sync::Arc, time::Duration};
2
3pub mod init_log;
4pub mod util;
5type DynError = Box<dyn std::error::Error + Send + Sync>;
6use crate::util::{
7    io::{self, create_dual_stack_listener},
8    tls::{tls_config, TlsAcceptor},
9};
10use anyhow::anyhow;
11use axum::{
12    response::{IntoResponse, Response},
13    Router,
14};
15
16use hyper::{body::Incoming, Request, StatusCode};
17use hyper_util::rt::TokioExecutor;
18use log::{info, warn};
19use tokio::{pin, sync::broadcast, time};
20use tokio_rustls::rustls::ServerConfig;
21use tower_service::Service;
22
23const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
24const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
25
26pub struct Server<I: ReqInterceptor = DummyInterceptor> {
27    pub port: u16,
28    pub tls_param: Option<TlsParam>,
29    router: Router,
30    pub interceptor: Option<I>,
31    pub idle_timeout: Duration,
32}
33
34#[derive(Debug, Clone)]
35pub struct TlsParam {
36    pub tls: bool,
37    pub cert: String,
38    pub key: String,
39}
40
41pub enum InterceptResult {
42    Return(Response),
43    Continue(Request<Incoming>),
44    Error(AppError),
45}
46
47pub trait ReqInterceptor {
48    fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult> + Send;
49}
50
51#[derive(Clone)]
52pub struct DummyInterceptor;
53
54impl ReqInterceptor for DummyInterceptor {
55    async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult {
56        InterceptResult::Continue(req)
57    }
58}
59
60pub type DefaultServer = Server<DummyInterceptor>;
61
62pub fn new_server(port: u16, tls_param: Option<TlsParam>, router: Router) -> Server {
63    Server {
64        port,
65        tls_param,
66        router,
67        interceptor: None,
68        idle_timeout: Duration::from_secs(120),
69    }
70}
71
72pub fn new_server_with_interceptor<I>(port: u16, tls_param: Option<TlsParam>, interceptor: I, router: Router) -> Server<I>
73where
74    I: ReqInterceptor + Clone + Send + Sync + 'static,
75{
76    Server {
77        port,
78        tls_param,
79        router,
80        interceptor: Some(interceptor),
81        idle_timeout: Duration::from_secs(120),
82    }
83}
84
85impl<I> Server<I>
86where
87    I: ReqInterceptor + Clone + Send + Sync + 'static,
88{
89    pub fn with_timeout(mut self, timeout: Duration) -> Self {
90        self.idle_timeout = timeout;
91        self
92    }
93
94    pub async fn run(&self) -> Result<(), DynError> {
95        let use_tls = match self.tls_param.clone() {
96            Some(config) => config.tls,
97            None => false,
98        };
99        log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
100        let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
101        let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
102        match use_tls {
103            #[allow(clippy::expect_used)]
104            true => {
105                serve_tls(
106                    &self.router,
107                    server,
108                    graceful,
109                    self.port,
110                    self.tls_param.as_ref().expect("should be some"),
111                    self.interceptor.clone(),
112                    self.idle_timeout,
113                )
114                .await?
115            }
116            false => serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout).await?,
117        }
118        Ok(())
119    }
120}
121
122async fn handle<I>(
123    request: Request<Incoming>, client_socket_addr: SocketAddr, mut app: Router, interceptor: Option<I>,
124) -> std::result::Result<Response, std::convert::Infallible>
125where
126    I: ReqInterceptor + Clone + Send + Sync + 'static,
127{
128    if let Some(interceptor) = interceptor {
129        match interceptor.intercept(request, client_socket_addr).await {
130            InterceptResult::Continue(req) => {
131                let res = app.call(req).await.into_response();
132                Ok(res)
133            }
134            InterceptResult::Return(res) => Ok(res),
135            InterceptResult::Error(err) => {
136                let res = err.into_response();
137                Ok(res)
138            }
139        }
140    } else {
141        let res = app.call(request).await.into_response();
142        Ok(res)
143    }
144}
145
146fn handle_connection<C, I>(
147    conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
148    interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
149) where
150    C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
151    I: ReqInterceptor + Clone + Send + Sync + 'static,
152{
153    let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
154    use hyper::Request;
155    use hyper_util::rt::TokioIo;
156    let stream = TokioIo::new(timeout_io);
157    let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
158        handle(request, client_socket_addr, app.clone(), interceptor.clone())
159    });
160
161    let conn = server.serve_connection_with_upgrades(stream, hyper_service);
162    let conn = graceful.watch(conn.into_owned());
163
164    tokio::spawn(async move {
165        if let Err(err) = conn.await {
166            info!("connection error: {}", err);
167        }
168        log::debug!("connection dropped: {}", client_socket_addr);
169    });
170}
171
172async fn serve_plantext<I>(
173    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
174    port: u16, interceptor: Option<I>, timeout: Duration,
175) -> Result<(), DynError>
176where
177    I: ReqInterceptor + Clone + Send + Sync + 'static,
178{
179    let listener = create_dual_stack_listener(port).await?;
180    let signal = handle_signal();
181    pin!(signal);
182    loop {
183        tokio::select! {
184            _ = signal.as_mut() => {
185                info!("start graceful shutdown!");
186                drop(listener);
187                break;
188            }
189            conn = listener.accept() => {
190                match conn {
191                    Ok((conn, client_socket_addr)) => {
192                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout);}
193                    Err(e) => {
194                        warn!("accept error:{}", e);
195                    }
196                }
197            }
198        }
199    }
200    tokio::select! {
201        _ = graceful.shutdown() => {
202            info!("Gracefully shutdown!");
203        },
204        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
205            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
206        }
207    }
208    Ok(())
209}
210
211async fn serve_tls<I>(
212    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
213    port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration,
214) -> Result<(), DynError>
215where
216    I: ReqInterceptor + Clone + Send + Sync + 'static,
217{
218    let (tx, _rx) = broadcast::channel::<Arc<ServerConfig>>(10);
219    let tx_clone = tx.clone();
220    let tls_param_clone = tls_param.clone();
221    tokio::spawn(async move {
222        info!("update tls config every {:?}", REFRESH_INTERVAL);
223        loop {
224            time::sleep(REFRESH_INTERVAL).await;
225            if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
226                info!("update tls config");
227                if let Err(e) = tx.send(new_acceptor) {
228                    warn!("send tls config error:{}", e);
229                }
230            }
231        }
232    });
233    let mut rx = tx_clone.subscribe();
234    let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
235    let signal = handle_signal();
236    pin!(signal);
237    loop {
238        tokio::select! {
239            _ = signal.as_mut() => {
240                info!("start graceful shutdown!");
241                drop(acceptor);
242                break;
243            }
244            message = rx.recv() => {
245                #[allow(clippy::expect_used)]
246                let new_config = message.expect("Channel should not be closed");
247                // Replace the acceptor with the new one
248                acceptor.replace_config(new_config);
249                info!("replaced tls config");
250            }
251            conn = acceptor.accept() => {
252                match conn {
253                    Ok((conn, client_socket_addr)) => {
254                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout);}
255                    Err(e) => {
256                        warn!("accept error:{}", e);
257                    }
258                }
259            }
260        }
261    }
262    tokio::select! {
263        _ = graceful.shutdown() => {
264            info!("Gracefully shutdown!");
265        },
266        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
267            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
268        }
269    }
270    Ok(())
271}
272
273#[cfg(unix)]
274async fn handle_signal() -> Result<(), DynError> {
275    use log::info;
276    use tokio::signal::unix::{signal, SignalKind};
277    let mut terminate_signal = signal(SignalKind::terminate())?;
278    tokio::select! {
279        _ = terminate_signal.recv() => {
280            info!("receive terminate signal, shutdowning");
281        },
282        _ = tokio::signal::ctrl_c() => {
283            info!("ctrl_c => shutdowning");
284        },
285    };
286    Ok(())
287}
288
289#[cfg(windows)]
290async fn handle_signal() -> Result<(), DynError> {
291    let _ = tokio::signal::ctrl_c().await;
292    info!("ctrl_c => shutdowning");
293    Ok(())
294}
295
296// Make our own error that wraps `anyhow::Error`.
297#[derive(Debug)]
298pub struct AppError(anyhow::Error);
299
300impl Display for AppError {
301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        self.0.fmt(f)
303    }
304}
305
306// Tell axum how to convert `AppError` into a response.
307impl IntoResponse for AppError {
308    fn into_response(self) -> Response {
309        let err = self.0;
310        // Because `TraceLayer` wraps each request in a span that contains the request
311        // method, uri, etc we don't need to include those details here
312        tracing::error!(%err, "error");
313        (StatusCode::INTERNAL_SERVER_ERROR, format!("Something went wrong: {}", &err)).into_response()
314    }
315}
316
317// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
318// `Result<_, AppError>`. That way you don't need to do that manually.
319impl<E> From<E> for AppError
320where
321    E: Into<anyhow::Error>,
322{
323    fn from(err: E) -> Self {
324        Self(err.into())
325    }
326}
327
328impl AppError {
329    pub fn new<T: std::error::Error + Send + Sync + 'static>(err: T) -> Self {
330        Self(anyhow!(err))
331    }
332}