axum_bootstrap/
lib.rs

1use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
2
3pub mod error;
4pub mod init_log;
5pub mod util;
6type DynError = Box<dyn std::error::Error + Send + Sync>;
7use crate::util::{
8    io::{self, create_dual_stack_listener},
9    tls::{tls_config, TlsAcceptor},
10};
11
12use axum::{
13    extract::Request,
14    response::{IntoResponse, Response},
15    Router,
16};
17
18use hyper::body::Incoming;
19use hyper_util::rt::TokioExecutor;
20use log::{info, warn};
21use tokio::{sync::broadcast, time};
22use tokio_rustls::rustls::ServerConfig;
23use tower::{Service, ServiceExt};
24use util::format::SocketAddrFormat;
25
26const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
27const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
28
29pub struct Server<I: ReqInterceptor = DummyInterceptor> {
30    pub port: u16,
31    pub tls_param: Option<TlsParam>,
32    router: Router,
33    pub interceptor: Option<I>,
34    pub idle_timeout: Duration,
35    shutdown_rx: broadcast::Receiver<()>,
36}
37
38#[derive(Debug, Clone)]
39pub struct TlsParam {
40    pub tls: bool,
41    pub cert: String,
42    pub key: String,
43}
44
45pub enum InterceptResult<T: IntoResponse> {
46    Return(Response),
47    Drop,
48    Continue(Request<Incoming>),
49    Error(T),
50}
51
52pub trait ReqInterceptor: Send {
53    type Error: IntoResponse + Send + Sync + 'static;
54    fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
55}
56
57#[derive(Clone)]
58pub struct DummyInterceptor;
59
60impl ReqInterceptor for DummyInterceptor {
61    type Error = error::AppError;
62
63    async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
64        InterceptResult::Continue(req)
65    }
66}
67
68pub type DefaultServer = Server<DummyInterceptor>;
69
70pub fn new_server(port: u16, router: Router, shutdown_rx: broadcast::Receiver<()>) -> Server {
71    let server = Server {
72        port,
73        tls_param: None, // No TLS by default
74        router,
75        interceptor: None,
76        idle_timeout: Duration::from_secs(120),
77        shutdown_rx,
78    };
79    server
80}
81
82impl<I> Server<I>
83where
84    I: ReqInterceptor + Clone + Send + Sync + 'static,
85{
86    pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
87    where
88        R: ReqInterceptor + Clone + Send + Sync + 'static,
89    {
90        Server::<R> {
91            port: self.port,
92            tls_param: self.tls_param,
93            router: self.router,
94            interceptor: Some(interceptor),
95            idle_timeout: self.idle_timeout, // keep the same idle timeout
96            shutdown_rx: self.shutdown_rx,
97        }
98    }
99    pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
100        // Enable TLS by setting the tls_param
101        self.tls_param = tls_param;
102        self
103    }
104
105    pub fn with_timeout(mut self, timeout: Duration) -> Self {
106        self.idle_timeout = timeout;
107        self
108    }
109
110    pub async fn run(mut self) -> Result<(), std::io::Error> {
111        let use_tls = match self.tls_param.clone() {
112            Some(config) => config.tls,
113            None => false,
114        };
115        log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
116        let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
117        let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
118        match use_tls {
119            #[allow(clippy::expect_used)]
120            true => {
121                serve_tls(
122                    &self.router,
123                    server,
124                    graceful,
125                    self.port,
126                    self.tls_param.as_ref().expect("should be some"),
127                    self.interceptor.clone(),
128                    self.idle_timeout,
129                    &mut self.shutdown_rx,
130                )
131                .await?
132            }
133            false => {
134                serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout, &mut self.shutdown_rx).await?
135            }
136        }
137        Ok(())
138    }
139}
140
141async fn handle<I>(
142    request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
143    interceptor: Option<I>,
144) -> std::result::Result<Response, std::io::Error>
145where
146    I: ReqInterceptor + Clone + Send + Sync + 'static,
147{
148    if let Some(interceptor) = interceptor {
149        match interceptor.intercept(request, client_socket_addr).await {
150            InterceptResult::Return(res) => Ok(res),
151            InterceptResult::Drop => Err(std::io::Error::other("Request dropped by interceptor")),
152            InterceptResult::Continue(req) => app
153                .oneshot(req)
154                .await
155                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
156            InterceptResult::Error(err) => {
157                let res = err.into_response();
158                Ok(res)
159            }
160        }
161    } else {
162        app.oneshot(request)
163            .await
164            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
165    }
166}
167
168async fn handle_connection<C, I>(
169    conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
170    interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
171) where
172    C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
173    I: ReqInterceptor + Clone + Send + Sync + 'static,
174{
175    let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
176    use hyper::Request;
177    use hyper_util::rt::TokioIo;
178    let stream = TokioIo::new(timeout_io);
179    let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
180    let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
181    // https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs#L81
182    let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
183        handle(request, client_socket_addr, app.clone(), interceptor.clone())
184    });
185
186    let conn = server.serve_connection_with_upgrades(stream, hyper_service);
187    let conn = graceful.watch(conn.into_owned());
188
189    tokio::spawn(async move {
190        if let Err(err) = conn.await {
191            handle_hyper_error(client_socket_addr, err);
192        }
193        log::debug!("connection dropped: {client_socket_addr}");
194    });
195}
196
197fn handle_hyper_error(client_socket_addr: SocketAddr, http_err: DynError) {
198    use std::error::Error;
199    match http_err.downcast_ref::<hyper::Error>() {
200        Some(hyper_err) => {
201            let level = if hyper_err.is_user() { log::Level::Warn } else { log::Level::Debug };
202            let source = hyper_err.source().unwrap_or(hyper_err);
203            log::log!(
204                level,
205                "[hyper {}]: {:?} from {}",
206                if hyper_err.is_user() { "user" } else { "system" },
207                source,
208                SocketAddrFormat(&client_socket_addr)
209            );
210        }
211        None => match http_err.downcast_ref::<std::io::Error>() {
212            Some(io_err) => {
213                warn!("[hyper io]: [{}] {} from {}", io_err.kind(), io_err, SocketAddrFormat(&client_socket_addr));
214            }
215            None => {
216                warn!("[hyper]: {} from {}", http_err, SocketAddrFormat(&client_socket_addr));
217            }
218        },
219    }
220}
221
222async fn serve_plantext<I>(
223    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
224    port: u16, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
225) -> Result<(), std::io::Error>
226where
227    I: ReqInterceptor + Clone + Send + Sync + 'static,
228{
229    let listener = create_dual_stack_listener(port).await?;
230    loop {
231        tokio::select! {
232            _ = shutdown_rx.recv() => {
233                info!("start graceful shutdown!");
234                drop(listener);
235                break;
236            }
237            conn = listener.accept() => {
238                match conn {
239                    Ok((conn, client_socket_addr)) => {
240                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
241                    Err(e) => {
242                        warn!("accept error:{e}");
243                    }
244                }
245            }
246        }
247    }
248    tokio::select! {
249        _ = graceful.shutdown() => {
250            info!("Gracefully shutdown!");
251        },
252        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
253            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
254        }
255    }
256    Ok(())
257}
258
259#[allow(clippy::too_many_arguments)]
260async fn serve_tls<I>(
261    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
262    port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
263) -> Result<(), std::io::Error>
264where
265    I: ReqInterceptor + Clone + Send + Sync + 'static,
266{
267    let (tx, _rx) = broadcast::channel::<Arc<ServerConfig>>(10);
268    let tx_clone = tx.clone();
269    let tls_param_clone = tls_param.clone();
270    tokio::spawn(async move {
271        info!("update tls config every {REFRESH_INTERVAL:?}");
272        loop {
273            time::sleep(REFRESH_INTERVAL).await;
274            if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
275                info!("update tls config");
276                if let Err(e) = tx.send(new_acceptor) {
277                    warn!("send tls config error:{e}");
278                }
279            }
280        }
281    });
282    let mut rx = tx_clone.subscribe();
283    let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
284    loop {
285        tokio::select! {
286            _ = shutdown_rx.recv() => {
287                info!("start graceful shutdown!");
288                drop(acceptor);
289                break;
290            }
291            message = rx.recv() => {
292                #[allow(clippy::expect_used)]
293                let new_config = message.expect("Channel should not be closed");
294                // Replace the acceptor with the new one
295                acceptor.replace_config(new_config);
296                info!("replaced tls config");
297            }
298            conn = acceptor.accept() => {
299                match conn {
300                    Ok((conn, client_socket_addr)) => {
301                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
302                    Err(e) => {
303                        warn!("accept error:{e}");
304                    }
305                }
306            }
307        }
308    }
309    tokio::select! {
310        _ = graceful.shutdown() => {
311            info!("Gracefully shutdown!");
312        },
313        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
314            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
315        }
316    }
317    Ok(())
318}
319
320#[cfg(unix)]
321pub async fn wait_signal() -> Result<(), DynError> {
322    use log::info;
323    use tokio::signal::unix::{signal, SignalKind};
324    let mut terminate_signal = signal(SignalKind::terminate())?;
325    tokio::select! {
326        _ = terminate_signal.recv() => {
327            info!("receive terminate signal");
328        },
329        _ = tokio::signal::ctrl_c() => {
330            info!("receive ctrl_c signal");
331        },
332    };
333    Ok(())
334}
335
336#[cfg(windows)]
337pub async fn wait_signal() -> Result<(), DynError> {
338    let _ = tokio::signal::ctrl_c().await;
339    info!("receive ctrl_c signal");
340    Ok(())
341}
342
343fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
344    match result {
345        Ok(value) => value,
346        Err(err) => match err {},
347    }
348}