axum_bootstrap/
lib.rs

1use std::{convert::Infallible, fmt::Display, net::SocketAddr, sync::Arc, time::Duration};
2
3pub mod init_log;
4pub mod util;
5type DynError = Box<dyn std::error::Error + Send + Sync>;
6use crate::util::{
7    io::{self, create_dual_stack_listener},
8    tls::{tls_config, TlsAcceptor},
9};
10use anyhow::anyhow;
11use axum::{
12    extract::Request,
13    response::{IntoResponse, Response},
14    Router,
15};
16
17use hyper::{body::Incoming, StatusCode};
18use hyper_util::rt::TokioExecutor;
19use log::{info, warn};
20use tokio::{pin, sync::broadcast, time};
21use tokio_rustls::rustls::ServerConfig;
22use tower::{Service, ServiceExt};
23
24const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
25const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
26
27pub struct Server<I: ReqInterceptor = DummyInterceptor> {
28    pub port: u16,
29    pub tls_param: Option<TlsParam>,
30    router: Router,
31    pub interceptor: Option<I>,
32    pub idle_timeout: Duration,
33}
34
35#[derive(Debug, Clone)]
36pub struct TlsParam {
37    pub tls: bool,
38    pub cert: String,
39    pub key: String,
40}
41
42pub enum InterceptResult {
43    Return(Response),
44    Continue(Request<Incoming>),
45    Error(AppError),
46}
47
48pub trait ReqInterceptor {
49    fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult> + Send;
50}
51
52#[derive(Clone)]
53pub struct DummyInterceptor;
54
55impl ReqInterceptor for DummyInterceptor {
56    async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult {
57        InterceptResult::Continue(req)
58    }
59}
60
61pub type DefaultServer = Server<DummyInterceptor>;
62
63pub fn new_server(port: u16, tls_param: Option<TlsParam>, router: Router) -> Server {
64    Server {
65        port,
66        tls_param,
67        router,
68        interceptor: None,
69        idle_timeout: Duration::from_secs(120),
70    }
71}
72
73pub fn new_server_with_interceptor<I>(port: u16, tls_param: Option<TlsParam>, interceptor: I, router: Router) -> Server<I>
74where
75    I: ReqInterceptor + Clone + Send + Sync + 'static,
76{
77    Server {
78        port,
79        tls_param,
80        router,
81        interceptor: Some(interceptor),
82        idle_timeout: Duration::from_secs(120),
83    }
84}
85
86impl<I> Server<I>
87where
88    I: ReqInterceptor + Clone + Send + Sync + 'static,
89{
90    pub fn with_timeout(mut self, timeout: Duration) -> Self {
91        self.idle_timeout = timeout;
92        self
93    }
94
95    pub async fn run(&self) -> Result<(), DynError> {
96        let use_tls = match self.tls_param.clone() {
97            Some(config) => config.tls,
98            None => false,
99        };
100        log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
101        let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
102        let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
103        match use_tls {
104            #[allow(clippy::expect_used)]
105            true => {
106                serve_tls(
107                    &self.router,
108                    server,
109                    graceful,
110                    self.port,
111                    self.tls_param.as_ref().expect("should be some"),
112                    self.interceptor.clone(),
113                    self.idle_timeout,
114                )
115                .await?
116            }
117            false => serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout).await?,
118        }
119        Ok(())
120    }
121}
122
123async fn handle<I>(
124    request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
125    interceptor: Option<I>,
126) -> std::result::Result<Response, std::convert::Infallible>
127where
128    I: ReqInterceptor + Clone + Send + Sync + 'static,
129{
130    if let Some(interceptor) = interceptor {
131        match interceptor.intercept(request, client_socket_addr).await {
132            InterceptResult::Continue(req) => app.clone().oneshot(req).await,
133            InterceptResult::Return(res) => Ok(res),
134            InterceptResult::Error(err) => {
135                let res = err.into_response();
136                Ok(res)
137            }
138        }
139    } else {
140        app.clone().oneshot(request).await
141    }
142}
143
144async fn handle_connection<C, I>(
145    conn: C, client_socket_addr: std::net::SocketAddr, mut app: axum::extract::connect_info::IntoMakeServiceWithConnectInfo<Router, SocketAddr>,
146    server: hyper_util::server::conn::auto::Builder<TokioExecutor>, interceptor: Option<I>,
147    graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
148) where
149    C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
150    I: ReqInterceptor + Clone + Send + Sync + 'static,
151{
152    let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
153    use hyper::Request;
154    use hyper_util::rt::TokioIo;
155    let stream = TokioIo::new(timeout_io);
156    let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
157    // https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs#L81
158    let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
159        handle(request, client_socket_addr, app.clone(), interceptor.clone())
160    });
161
162    let conn = server.serve_connection_with_upgrades(stream, hyper_service);
163    let conn = graceful.watch(conn.into_owned());
164
165    tokio::spawn(async move {
166        if let Err(err) = conn.await {
167            info!("connection error: {}", err);
168        }
169        log::debug!("connection dropped: {}", client_socket_addr);
170    });
171}
172
173async fn serve_plantext<I>(
174    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
175    port: u16, interceptor: Option<I>, timeout: Duration,
176) -> Result<(), DynError>
177where
178    I: ReqInterceptor + Clone + Send + Sync + 'static,
179{
180    let listener = create_dual_stack_listener(port).await?;
181    let signal = handle_signal();
182    pin!(signal);
183    loop {
184        tokio::select! {
185            _ = signal.as_mut() => {
186                info!("start graceful shutdown!");
187                drop(listener);
188                break;
189            }
190            conn = listener.accept() => {
191                match conn {
192                    Ok((conn, client_socket_addr)) => {
193                        let app: axum::extract::connect_info::IntoMakeServiceWithConnectInfo<Router, SocketAddr> = app.clone().into_make_service_with_connect_info::<SocketAddr>();
194                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
195                    Err(e) => {
196                        warn!("accept error:{}", e);
197                    }
198                }
199            }
200        }
201    }
202    tokio::select! {
203        _ = graceful.shutdown() => {
204            info!("Gracefully shutdown!");
205        },
206        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
207            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
208        }
209    }
210    Ok(())
211}
212
213async fn serve_tls<I>(
214    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
215    port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration,
216) -> Result<(), DynError>
217where
218    I: ReqInterceptor + Clone + Send + Sync + 'static,
219{
220    let (tx, _rx) = broadcast::channel::<Arc<ServerConfig>>(10);
221    let tx_clone = tx.clone();
222    let tls_param_clone = tls_param.clone();
223    tokio::spawn(async move {
224        info!("update tls config every {:?}", REFRESH_INTERVAL);
225        loop {
226            time::sleep(REFRESH_INTERVAL).await;
227            if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
228                info!("update tls config");
229                if let Err(e) = tx.send(new_acceptor) {
230                    warn!("send tls config error:{}", e);
231                }
232            }
233        }
234    });
235    let mut rx = tx_clone.subscribe();
236    let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
237    let signal = handle_signal();
238    pin!(signal);
239    loop {
240        tokio::select! {
241            _ = signal.as_mut() => {
242                info!("start graceful shutdown!");
243                drop(acceptor);
244                break;
245            }
246            message = rx.recv() => {
247                #[allow(clippy::expect_used)]
248                let new_config = message.expect("Channel should not be closed");
249                // Replace the acceptor with the new one
250                acceptor.replace_config(new_config);
251                info!("replaced tls config");
252            }
253            conn = acceptor.accept() => {
254                match conn {
255                    Ok((conn, client_socket_addr)) => {
256                        let app: axum::extract::connect_info::IntoMakeServiceWithConnectInfo<Router, SocketAddr> = app.clone().into_make_service_with_connect_info::<SocketAddr>();
257                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
258                    Err(e) => {
259                        warn!("accept error:{}", e);
260                    }
261                }
262            }
263        }
264    }
265    tokio::select! {
266        _ = graceful.shutdown() => {
267            info!("Gracefully shutdown!");
268        },
269        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
270            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
271        }
272    }
273    Ok(())
274}
275
276#[cfg(unix)]
277async fn handle_signal() -> Result<(), DynError> {
278    use log::info;
279    use tokio::signal::unix::{signal, SignalKind};
280    let mut terminate_signal = signal(SignalKind::terminate())?;
281    tokio::select! {
282        _ = terminate_signal.recv() => {
283            info!("receive terminate signal, shutdowning");
284        },
285        _ = tokio::signal::ctrl_c() => {
286            info!("ctrl_c => shutdowning");
287        },
288    };
289    Ok(())
290}
291
292#[cfg(windows)]
293async fn handle_signal() -> Result<(), DynError> {
294    let _ = tokio::signal::ctrl_c().await;
295    info!("ctrl_c => shutdowning");
296    Ok(())
297}
298
299// Make our own error that wraps `anyhow::Error`.
300#[derive(Debug)]
301pub struct AppError(anyhow::Error);
302
303impl Display for AppError {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        self.0.fmt(f)
306    }
307}
308
309// Tell axum how to convert `AppError` into a response.
310impl IntoResponse for AppError {
311    fn into_response(self) -> Response {
312        let err = self.0;
313        // Because `TraceLayer` wraps each request in a span that contains the request
314        // method, uri, etc we don't need to include those details here
315        tracing::error!(%err, "error");
316        (StatusCode::INTERNAL_SERVER_ERROR, format!("Something went wrong: {}", &err)).into_response()
317    }
318}
319
320// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
321// `Result<_, AppError>`. That way you don't need to do that manually.
322impl<E> From<E> for AppError
323where
324    E: Into<anyhow::Error>,
325{
326    fn from(err: E) -> Self {
327        Self(err.into())
328    }
329}
330
331impl AppError {
332    pub fn new<T: std::error::Error + Send + Sync + 'static>(err: T) -> Self {
333        Self(anyhow!(err))
334    }
335}
336
337fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
338    match result {
339        Ok(value) => value,
340        Err(err) => match err {},
341    }
342}