axum_bootstrap/
lib.rs

1use std::{convert::Infallible, fmt::Display, net::SocketAddr, sync::Arc, time::Duration};
2
3pub mod init_log;
4pub mod util;
5type DynError = Box<dyn std::error::Error + Send + Sync>;
6use crate::util::{
7    io::{self, create_dual_stack_listener},
8    tls::{tls_config, TlsAcceptor},
9};
10use anyhow::anyhow;
11use axum::{
12    extract::Request,
13    response::{IntoResponse, Response},
14    Router,
15};
16
17use hyper::{body::Incoming, StatusCode};
18use hyper_util::rt::TokioExecutor;
19use log::{info, warn};
20use tokio::{pin, sync::broadcast, time};
21use tokio_rustls::rustls::ServerConfig;
22use tower::{Service, ServiceExt};
23
24const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
25const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
26
27pub struct Server<I: ReqInterceptor = DummyInterceptor> {
28    pub port: u16,
29    pub tls_param: Option<TlsParam>,
30    router: Router,
31    pub interceptor: Option<I>,
32    pub idle_timeout: Duration,
33}
34
35#[derive(Debug, Clone)]
36pub struct TlsParam {
37    pub tls: bool,
38    pub cert: String,
39    pub key: String,
40}
41
42pub enum InterceptResult<T: IntoResponse> {
43    Return(Response),
44    Drop,
45    Continue(Request<Incoming>),
46    Error(T),
47}
48
49pub trait ReqInterceptor: Send {
50    type Error: IntoResponse + Send + Sync + 'static;
51    fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
52}
53
54#[derive(Clone)]
55pub struct DummyInterceptor;
56
57impl ReqInterceptor for DummyInterceptor {
58    type Error = AppError;
59
60    async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
61        InterceptResult::Continue(req)
62    }
63}
64
65pub type DefaultServer = Server<DummyInterceptor>;
66
67pub fn new_server(port: u16, router: Router) -> Server {
68    Server {
69        port,
70        tls_param: None, // No TLS by default
71        router,
72        interceptor: None,
73        idle_timeout: Duration::from_secs(120),
74    }
75}
76
77impl<I> Server<I>
78where
79    I: ReqInterceptor + Clone + Send + Sync + 'static,
80{
81    pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
82    where
83        R: ReqInterceptor + Clone + Send + Sync + 'static,
84    {
85        Server::<R> {
86            port: self.port,
87            tls_param: self.tls_param,
88            router: self.router,
89            interceptor: Some(interceptor),
90            idle_timeout: self.idle_timeout, // keep the same idle timeout
91        }
92    }
93    pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
94        // Enable TLS by setting the tls_param
95        self.tls_param = tls_param;
96        self
97    }
98
99    pub fn with_timeout(mut self, timeout: Duration) -> Self {
100        self.idle_timeout = timeout;
101        self
102    }
103
104    pub async fn run(&self) -> Result<(), DynError> {
105        let use_tls = match self.tls_param.clone() {
106            Some(config) => config.tls,
107            None => false,
108        };
109        log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
110        let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
111        let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
112        match use_tls {
113            #[allow(clippy::expect_used)]
114            true => {
115                serve_tls(
116                    &self.router,
117                    server,
118                    graceful,
119                    self.port,
120                    self.tls_param.as_ref().expect("should be some"),
121                    self.interceptor.clone(),
122                    self.idle_timeout,
123                )
124                .await?
125            }
126            false => serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout).await?,
127        }
128        Ok(())
129    }
130}
131
132async fn handle<I>(
133    request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
134    interceptor: Option<I>,
135) -> std::result::Result<Response, std::io::Error>
136where
137    I: ReqInterceptor + Clone + Send + Sync + 'static,
138{
139    if let Some(interceptor) = interceptor {
140        match interceptor.intercept(request, client_socket_addr).await {
141            InterceptResult::Return(res) => Ok(res),
142            InterceptResult::Drop => Err(std::io::Error::new(std::io::ErrorKind::Other, "Request dropped by interceptor")),
143            InterceptResult::Continue(req) => app
144                .oneshot(req)
145                .await
146                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
147            InterceptResult::Error(err) => {
148                let res = err.into_response();
149                Ok(res)
150            }
151        }
152    } else {
153        app.oneshot(request)
154            .await
155            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
156    }
157}
158
159async fn handle_connection<C, I>(
160    conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
161    interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
162) where
163    C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
164    I: ReqInterceptor + Clone + Send + Sync + 'static,
165{
166    let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
167    use hyper::Request;
168    use hyper_util::rt::TokioIo;
169    let stream = TokioIo::new(timeout_io);
170    let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
171    let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
172    // https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs#L81
173    let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
174        handle(request, client_socket_addr, app.clone(), interceptor.clone())
175    });
176
177    let conn = server.serve_connection_with_upgrades(stream, hyper_service);
178    let conn = graceful.watch(conn.into_owned());
179
180    tokio::spawn(async move {
181        if let Err(err) = conn.await {
182            info!("connection error: {}", err);
183        }
184        log::debug!("connection dropped: {}", client_socket_addr);
185    });
186}
187
188async fn serve_plantext<I>(
189    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
190    port: u16, interceptor: Option<I>, timeout: Duration,
191) -> Result<(), DynError>
192where
193    I: ReqInterceptor + Clone + Send + Sync + 'static,
194{
195    let listener = create_dual_stack_listener(port).await?;
196    let signal = handle_signal();
197    pin!(signal);
198    loop {
199        tokio::select! {
200            _ = signal.as_mut() => {
201                info!("start graceful shutdown!");
202                drop(listener);
203                break;
204            }
205            conn = listener.accept() => {
206                match conn {
207                    Ok((conn, client_socket_addr)) => {
208                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
209                    Err(e) => {
210                        warn!("accept error:{}", e);
211                    }
212                }
213            }
214        }
215    }
216    tokio::select! {
217        _ = graceful.shutdown() => {
218            info!("Gracefully shutdown!");
219        },
220        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
221            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
222        }
223    }
224    Ok(())
225}
226
227async fn serve_tls<I>(
228    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
229    port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration,
230) -> Result<(), DynError>
231where
232    I: ReqInterceptor + Clone + Send + Sync + 'static,
233{
234    let (tx, _rx) = broadcast::channel::<Arc<ServerConfig>>(10);
235    let tx_clone = tx.clone();
236    let tls_param_clone = tls_param.clone();
237    tokio::spawn(async move {
238        info!("update tls config every {:?}", REFRESH_INTERVAL);
239        loop {
240            time::sleep(REFRESH_INTERVAL).await;
241            if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
242                info!("update tls config");
243                if let Err(e) = tx.send(new_acceptor) {
244                    warn!("send tls config error:{}", e);
245                }
246            }
247        }
248    });
249    let mut rx = tx_clone.subscribe();
250    let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
251    let signal = handle_signal();
252    pin!(signal);
253    loop {
254        tokio::select! {
255            _ = signal.as_mut() => {
256                info!("start graceful shutdown!");
257                drop(acceptor);
258                break;
259            }
260            message = rx.recv() => {
261                #[allow(clippy::expect_used)]
262                let new_config = message.expect("Channel should not be closed");
263                // Replace the acceptor with the new one
264                acceptor.replace_config(new_config);
265                info!("replaced tls config");
266            }
267            conn = acceptor.accept() => {
268                match conn {
269                    Ok((conn, client_socket_addr)) => {
270                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
271                    Err(e) => {
272                        warn!("accept error:{}", e);
273                    }
274                }
275            }
276        }
277    }
278    tokio::select! {
279        _ = graceful.shutdown() => {
280            info!("Gracefully shutdown!");
281        },
282        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
283            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
284        }
285    }
286    Ok(())
287}
288
289#[cfg(unix)]
290async fn handle_signal() -> Result<(), DynError> {
291    use log::info;
292    use tokio::signal::unix::{signal, SignalKind};
293    let mut terminate_signal = signal(SignalKind::terminate())?;
294    tokio::select! {
295        _ = terminate_signal.recv() => {
296            info!("receive terminate signal, shutdowning");
297        },
298        _ = tokio::signal::ctrl_c() => {
299            info!("ctrl_c => shutdowning");
300        },
301    };
302    Ok(())
303}
304
305#[cfg(windows)]
306async fn handle_signal() -> Result<(), DynError> {
307    let _ = tokio::signal::ctrl_c().await;
308    info!("ctrl_c => shutdowning");
309    Ok(())
310}
311
312// Make our own error that wraps `anyhow::Error`.
313#[derive(Debug)]
314pub struct AppError(anyhow::Error);
315
316impl Display for AppError {
317    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        self.0.fmt(f)
319    }
320}
321
322// Tell axum how to convert `AppError` into a response.
323impl IntoResponse for AppError {
324    fn into_response(self) -> Response {
325        let err = self.0;
326        // Because `TraceLayer` wraps each request in a span that contains the request
327        // method, uri, etc we don't need to include those details here
328        tracing::error!(%err, "error");
329        (StatusCode::INTERNAL_SERVER_ERROR, format!("axum-bootstrap error: {}", &err)).into_response()
330    }
331}
332
333// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
334// `Result<_, AppError>`. That way you don't need to do that manually.
335impl<E> From<E> for AppError
336where
337    E: Into<anyhow::Error>,
338{
339    fn from(err: E) -> Self {
340        Self(err.into())
341    }
342}
343
344impl AppError {
345    pub fn new<T: std::error::Error + Send + Sync + 'static>(err: T) -> Self {
346        Self(anyhow!(err))
347    }
348}
349
350fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
351    match result {
352        Ok(value) => value,
353        Err(err) => match err {},
354    }
355}