axum_bootstrap/
lib.rs

1use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
2
3pub mod error;
4pub mod init_log;
5pub mod util;
6type DynError = Box<dyn std::error::Error + Send + Sync>;
7use crate::util::{
8    io::{self, create_dual_stack_listener},
9    tls::{TlsAcceptor, tls_config},
10};
11
12use axum::{
13    Router,
14    extract::Request,
15    response::{IntoResponse, Response},
16};
17
18use hyper::body::Incoming;
19use hyper_util::rt::TokioExecutor;
20use log::{info, warn};
21use tokio::{
22    sync::broadcast::{self, Receiver, Sender, error::RecvError},
23    time,
24};
25use tokio_rustls::rustls::ServerConfig;
26use tower::{Service, ServiceExt};
27use util::format::SocketAddrFormat;
28
29const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
30const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
31
32pub struct Server<I: ReqInterceptor = DummyInterceptor> {
33    pub port: u16,
34    pub tls_param: Option<TlsParam>,
35    router: Router,
36    pub interceptor: Option<I>,
37    pub idle_timeout: Duration,
38    shutdown_rx: broadcast::Receiver<()>,
39}
40
41#[derive(Debug, Clone)]
42pub struct TlsParam {
43    pub tls: bool,
44    pub cert: String,
45    pub key: String,
46}
47
48pub enum InterceptResult<T: IntoResponse> {
49    Return(Response),
50    Drop,
51    Continue(Request<Incoming>),
52    Error(T),
53}
54
55pub trait ReqInterceptor: Send {
56    type Error: IntoResponse + Send + Sync + 'static;
57    fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
58}
59
60#[derive(Clone)]
61pub struct DummyInterceptor;
62
63impl ReqInterceptor for DummyInterceptor {
64    type Error = error::AppError;
65
66    async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
67        InterceptResult::Continue(req)
68    }
69}
70
71pub type DefaultServer = Server<DummyInterceptor>;
72
73pub fn new_server(port: u16, router: Router, shutdown_rx: broadcast::Receiver<()>) -> Server {
74    Server {
75        port,
76        tls_param: None, // No TLS by default
77        router,
78        interceptor: None,
79        idle_timeout: Duration::from_secs(120),
80        shutdown_rx,
81    }
82}
83
84impl<I> Server<I>
85where
86    I: ReqInterceptor + Clone + Send + Sync + 'static,
87{
88    pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
89    where
90        R: ReqInterceptor + Clone + Send + Sync + 'static,
91    {
92        Server::<R> {
93            port: self.port,
94            tls_param: self.tls_param,
95            router: self.router,
96            interceptor: Some(interceptor),
97            idle_timeout: self.idle_timeout, // keep the same idle timeout
98            shutdown_rx: self.shutdown_rx,
99        }
100    }
101    pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
102        // Enable TLS by setting the tls_param
103        self.tls_param = tls_param;
104        self
105    }
106
107    pub fn with_timeout(mut self, timeout: Duration) -> Self {
108        self.idle_timeout = timeout;
109        self
110    }
111
112    pub async fn run(mut self) -> Result<(), std::io::Error> {
113        let use_tls = match self.tls_param.clone() {
114            Some(config) => config.tls,
115            None => false,
116        };
117        log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
118        let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
119        let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
120        match use_tls {
121            #[allow(clippy::expect_used)]
122            true => {
123                serve_tls(
124                    &self.router,
125                    server,
126                    graceful,
127                    self.port,
128                    self.tls_param.as_ref().expect("should be some"),
129                    self.interceptor.clone(),
130                    self.idle_timeout,
131                    &mut self.shutdown_rx,
132                )
133                .await?
134            }
135            false => {
136                serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout, &mut self.shutdown_rx).await?
137            }
138        }
139        Ok(())
140    }
141}
142
143async fn handle<I>(
144    request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
145    interceptor: Option<I>,
146) -> std::result::Result<Response, std::io::Error>
147where
148    I: ReqInterceptor + Clone + Send + Sync + 'static,
149{
150    if let Some(interceptor) = interceptor {
151        match interceptor.intercept(request, client_socket_addr).await {
152            InterceptResult::Return(res) => Ok(res),
153            InterceptResult::Drop => Err(std::io::Error::other("Request dropped by interceptor")),
154            InterceptResult::Continue(req) => app
155                .oneshot(req)
156                .await
157                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
158            InterceptResult::Error(err) => {
159                let res = err.into_response();
160                Ok(res)
161            }
162        }
163    } else {
164        app.oneshot(request)
165            .await
166            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
167    }
168}
169
170async fn handle_connection<C, I>(
171    conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
172    interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
173) where
174    C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
175    I: ReqInterceptor + Clone + Send + Sync + 'static,
176{
177    let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
178    use hyper::Request;
179    use hyper_util::rt::TokioIo;
180    let stream = TokioIo::new(timeout_io);
181    let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
182    let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
183    // https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs#L81
184    let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
185        handle(request, client_socket_addr, app.clone(), interceptor.clone())
186    });
187
188    let conn = server.serve_connection_with_upgrades(stream, hyper_service);
189    let conn = graceful.watch(conn.into_owned());
190
191    tokio::spawn(async move {
192        if let Err(err) = conn.await {
193            handle_hyper_error(client_socket_addr, err);
194        }
195        log::debug!("connection dropped: {client_socket_addr}");
196    });
197}
198
199fn handle_hyper_error(client_socket_addr: SocketAddr, http_err: DynError) {
200    use std::error::Error;
201    match http_err.downcast_ref::<hyper::Error>() {
202        Some(hyper_err) => {
203            let level = if hyper_err.is_user() { log::Level::Warn } else { log::Level::Debug };
204            let source = hyper_err.source().unwrap_or(hyper_err);
205            log::log!(
206                level,
207                "[hyper {}]: {:?} from {}",
208                if hyper_err.is_user() { "user" } else { "system" },
209                source,
210                SocketAddrFormat(&client_socket_addr)
211            );
212        }
213        None => match http_err.downcast_ref::<std::io::Error>() {
214            Some(io_err) => {
215                warn!("[hyper io]: [{}] {} from {}", io_err.kind(), io_err, SocketAddrFormat(&client_socket_addr));
216            }
217            None => {
218                warn!("[hyper]: {} from {}", http_err, SocketAddrFormat(&client_socket_addr));
219            }
220        },
221    }
222}
223
224async fn serve_plantext<I>(
225    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
226    port: u16, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
227) -> Result<(), std::io::Error>
228where
229    I: ReqInterceptor + Clone + Send + Sync + 'static,
230{
231    let listener = create_dual_stack_listener(port).await?;
232    loop {
233        tokio::select! {
234            _ = shutdown_rx.recv() => {
235                info!("start graceful shutdown!");
236                drop(listener);
237                break;
238            }
239            conn = listener.accept() => {
240                match conn {
241                    Ok((conn, client_socket_addr)) => {
242                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
243                    Err(e) => {
244                        warn!("accept error:{e}");
245                    }
246                }
247            }
248        }
249    }
250    match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
251        Ok(_) => info!("Gracefully shutdown!"),
252        Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
253    }
254    Ok(())
255}
256
257#[allow(clippy::too_many_arguments)]
258async fn serve_tls<I>(
259    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
260    port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
261) -> Result<(), std::io::Error>
262where
263    I: ReqInterceptor + Clone + Send + Sync + 'static,
264{
265    let (tx, mut rx) = broadcast::channel::<Arc<ServerConfig>>(1);
266    let tls_param_clone = tls_param.clone();
267    tokio::spawn(async move {
268        info!("update tls config every {REFRESH_INTERVAL:?}");
269        loop {
270            time::sleep(REFRESH_INTERVAL).await;
271            if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
272                info!("update tls config");
273                if let Err(e) = tx.send(new_acceptor) {
274                    warn!("send tls config error:{e}");
275                }
276            }
277        }
278    });
279    let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
280    loop {
281        tokio::select! {
282            _ = shutdown_rx.recv() => {
283                info!("start graceful shutdown!");
284                drop(acceptor);
285                break;
286            }
287            message = rx.recv() => {
288                match message {
289                    Ok(new_config) => {
290                        acceptor.replace_config(new_config);
291                        info!("replaced tls config");
292                    },
293                    Err(e) => {
294                        match e {
295                            RecvError::Closed => {
296                                warn!("this channel should not be closed!");
297                                break;
298                            },
299                            RecvError::Lagged(n) => {
300                                warn!("lagged {n} messages, this may cause tls config not updated in time");
301                            }
302                        }
303                    }
304                }
305            }
306            conn = acceptor.accept() => {
307                match conn {
308                    Ok((conn, client_socket_addr)) => {
309                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
310                    Err(e) => {
311                        warn!("accept error:{e}");
312                    }
313                }
314            }
315        }
316    }
317    match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
318        Ok(_) => info!("Gracefully shutdown!"),
319        Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
320    }
321    Ok(())
322}
323
324pub fn generate_shutdown_receiver() -> Receiver<()> {
325    let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
326    subscribe_shutdown_sender(shutdown_tx);
327    shutdown_rx
328}
329
330pub fn subscribe_shutdown_sender(shutdown_tx: Sender<()>) {
331    tokio::spawn(async move {
332        match wait_signal().await {
333            Ok(_) => {
334                let _ = shutdown_tx.send(());
335            }
336            Err(e) => {
337                log::error!("wait_signal error: {}", e);
338                panic!("wait_signal error: {}", e);
339            }
340        }
341    });
342}
343
344#[cfg(unix)]
345pub(crate) async fn wait_signal() -> Result<(), DynError> {
346    use log::info;
347    use tokio::signal::unix::{SignalKind, signal};
348    let mut terminate_signal = signal(SignalKind::terminate())?;
349    tokio::select! {
350        _ = terminate_signal.recv() => {
351            info!("receive terminate signal");
352        },
353        _ = tokio::signal::ctrl_c() => {
354            info!("receive ctrl_c signal");
355        },
356    };
357    Ok(())
358}
359
360#[cfg(windows)]
361pub(crate) async fn wait_signal() -> Result<(), DynError> {
362    let _ = tokio::signal::ctrl_c().await;
363    info!("receive ctrl_c signal");
364    Ok(())
365}
366
367fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
368    match result {
369        Ok(value) => value,
370        Err(err) => match err {},
371    }
372}