axum_bootstrap/
lib.rs

1use std::{convert::Infallible, fmt::Display, net::SocketAddr, sync::Arc, time::Duration};
2
3pub mod init_log;
4pub mod util;
5type DynError = Box<dyn std::error::Error + Send + Sync>;
6use crate::util::{
7    io::{self, create_dual_stack_listener},
8    tls::{tls_config, TlsAcceptor},
9};
10use anyhow::anyhow;
11use axum::{
12    extract::Request,
13    response::{IntoResponse, Response},
14    Router,
15};
16
17use hyper::{body::Incoming, StatusCode};
18use hyper_util::rt::TokioExecutor;
19use log::{info, warn};
20use tokio::{pin, sync::broadcast, time};
21use tokio_rustls::rustls::ServerConfig;
22use tower::{Service, ServiceExt};
23use util::format::SocketAddrFormat;
24
25const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
26const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
27
28pub struct Server<I: ReqInterceptor = DummyInterceptor> {
29    pub port: u16,
30    pub tls_param: Option<TlsParam>,
31    router: Router,
32    pub interceptor: Option<I>,
33    pub idle_timeout: Duration,
34}
35
36#[derive(Debug, Clone)]
37pub struct TlsParam {
38    pub tls: bool,
39    pub cert: String,
40    pub key: String,
41}
42
43pub enum InterceptResult<T: IntoResponse> {
44    Return(Response),
45    Drop,
46    Continue(Request<Incoming>),
47    Error(T),
48}
49
50pub trait ReqInterceptor: Send {
51    type Error: IntoResponse + Send + Sync + 'static;
52    fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
53}
54
55#[derive(Clone)]
56pub struct DummyInterceptor;
57
58impl ReqInterceptor for DummyInterceptor {
59    type Error = AppError;
60
61    async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
62        InterceptResult::Continue(req)
63    }
64}
65
66pub type DefaultServer = Server<DummyInterceptor>;
67
68pub fn new_server(port: u16, router: Router) -> Server {
69    Server {
70        port,
71        tls_param: None, // No TLS by default
72        router,
73        interceptor: None,
74        idle_timeout: Duration::from_secs(120),
75    }
76}
77
78impl<I> Server<I>
79where
80    I: ReqInterceptor + Clone + Send + Sync + 'static,
81{
82    pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
83    where
84        R: ReqInterceptor + Clone + Send + Sync + 'static,
85    {
86        Server::<R> {
87            port: self.port,
88            tls_param: self.tls_param,
89            router: self.router,
90            interceptor: Some(interceptor),
91            idle_timeout: self.idle_timeout, // keep the same idle timeout
92        }
93    }
94    pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
95        // Enable TLS by setting the tls_param
96        self.tls_param = tls_param;
97        self
98    }
99
100    pub fn with_timeout(mut self, timeout: Duration) -> Self {
101        self.idle_timeout = timeout;
102        self
103    }
104
105    pub async fn run(&self) -> Result<(), DynError> {
106        let use_tls = match self.tls_param.clone() {
107            Some(config) => config.tls,
108            None => false,
109        };
110        log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
111        let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
112        let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
113        match use_tls {
114            #[allow(clippy::expect_used)]
115            true => {
116                serve_tls(
117                    &self.router,
118                    server,
119                    graceful,
120                    self.port,
121                    self.tls_param.as_ref().expect("should be some"),
122                    self.interceptor.clone(),
123                    self.idle_timeout,
124                )
125                .await?
126            }
127            false => serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout).await?,
128        }
129        Ok(())
130    }
131}
132
133async fn handle<I>(
134    request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
135    interceptor: Option<I>,
136) -> std::result::Result<Response, std::io::Error>
137where
138    I: ReqInterceptor + Clone + Send + Sync + 'static,
139{
140    if let Some(interceptor) = interceptor {
141        match interceptor.intercept(request, client_socket_addr).await {
142            InterceptResult::Return(res) => Ok(res),
143            InterceptResult::Drop => Err(std::io::Error::other("Request dropped by interceptor")),
144            InterceptResult::Continue(req) => app
145                .oneshot(req)
146                .await
147                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
148            InterceptResult::Error(err) => {
149                let res = err.into_response();
150                Ok(res)
151            }
152        }
153    } else {
154        app.oneshot(request)
155            .await
156            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
157    }
158}
159
160async fn handle_connection<C, I>(
161    conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
162    interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
163) where
164    C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
165    I: ReqInterceptor + Clone + Send + Sync + 'static,
166{
167    let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
168    use hyper::Request;
169    use hyper_util::rt::TokioIo;
170    let stream = TokioIo::new(timeout_io);
171    let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
172    let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
173    // https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs#L81
174    let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
175        handle(request, client_socket_addr, app.clone(), interceptor.clone())
176    });
177
178    let conn = server.serve_connection_with_upgrades(stream, hyper_service);
179    let conn = graceful.watch(conn.into_owned());
180
181    tokio::spawn(async move {
182        if let Err(err) = conn.await {
183            handle_hyper_error(client_socket_addr, err);
184        }
185        log::debug!("connection dropped: {client_socket_addr}");
186    });
187}
188
189fn handle_hyper_error(client_socket_addr: SocketAddr, http_err: DynError) {
190    use std::error::Error;
191    match http_err.downcast_ref::<hyper::Error>() {
192        Some(hyper_err) => {
193            let level = if hyper_err.is_user() { log::Level::Warn } else { log::Level::Debug };
194            let source = hyper_err.source().unwrap_or(hyper_err);
195            log::log!(
196                level,
197                "[hyper {}]: {:?} from {}",
198                if hyper_err.is_user() { "user" } else { "system" },
199                source,
200                SocketAddrFormat(&client_socket_addr)
201            );
202        }
203        None => match http_err.downcast_ref::<std::io::Error>() {
204            Some(io_err) => {
205                warn!("[hyper io]: [{}] {} from {}", io_err.kind(), io_err, SocketAddrFormat(&client_socket_addr));
206            }
207            None => {
208                warn!("[hyper]: {} from {}", http_err, SocketAddrFormat(&client_socket_addr));
209            }
210        },
211    }
212}
213
214async fn serve_plantext<I>(
215    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
216    port: u16, interceptor: Option<I>, timeout: Duration,
217) -> Result<(), DynError>
218where
219    I: ReqInterceptor + Clone + Send + Sync + 'static,
220{
221    let listener = create_dual_stack_listener(port).await?;
222    let signal = handle_signal();
223    pin!(signal);
224    loop {
225        tokio::select! {
226            _ = signal.as_mut() => {
227                info!("start graceful shutdown!");
228                drop(listener);
229                break;
230            }
231            conn = listener.accept() => {
232                match conn {
233                    Ok((conn, client_socket_addr)) => {
234                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
235                    Err(e) => {
236                        warn!("accept error:{e}");
237                    }
238                }
239            }
240        }
241    }
242    tokio::select! {
243        _ = graceful.shutdown() => {
244            info!("Gracefully shutdown!");
245        },
246        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
247            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
248        }
249    }
250    Ok(())
251}
252
253async fn serve_tls<I>(
254    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
255    port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration,
256) -> Result<(), DynError>
257where
258    I: ReqInterceptor + Clone + Send + Sync + 'static,
259{
260    let (tx, _rx) = broadcast::channel::<Arc<ServerConfig>>(10);
261    let tx_clone = tx.clone();
262    let tls_param_clone = tls_param.clone();
263    tokio::spawn(async move {
264        info!("update tls config every {REFRESH_INTERVAL:?}");
265        loop {
266            time::sleep(REFRESH_INTERVAL).await;
267            if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
268                info!("update tls config");
269                if let Err(e) = tx.send(new_acceptor) {
270                    warn!("send tls config error:{e}");
271                }
272            }
273        }
274    });
275    let mut rx = tx_clone.subscribe();
276    let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
277    let signal = handle_signal();
278    pin!(signal);
279    loop {
280        tokio::select! {
281            _ = signal.as_mut() => {
282                info!("start graceful shutdown!");
283                drop(acceptor);
284                break;
285            }
286            message = rx.recv() => {
287                #[allow(clippy::expect_used)]
288                let new_config = message.expect("Channel should not be closed");
289                // Replace the acceptor with the new one
290                acceptor.replace_config(new_config);
291                info!("replaced tls config");
292            }
293            conn = acceptor.accept() => {
294                match conn {
295                    Ok((conn, client_socket_addr)) => {
296                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
297                    Err(e) => {
298                        warn!("accept error:{e}");
299                    }
300                }
301            }
302        }
303    }
304    tokio::select! {
305        _ = graceful.shutdown() => {
306            info!("Gracefully shutdown!");
307        },
308        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
309            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
310        }
311    }
312    Ok(())
313}
314
315#[cfg(unix)]
316async fn handle_signal() -> Result<(), DynError> {
317    use log::info;
318    use tokio::signal::unix::{signal, SignalKind};
319    let mut terminate_signal = signal(SignalKind::terminate())?;
320    tokio::select! {
321        _ = terminate_signal.recv() => {
322            info!("receive terminate signal, shutdowning");
323        },
324        _ = tokio::signal::ctrl_c() => {
325            info!("ctrl_c => shutdowning");
326        },
327    };
328    Ok(())
329}
330
331#[cfg(windows)]
332async fn handle_signal() -> Result<(), DynError> {
333    let _ = tokio::signal::ctrl_c().await;
334    info!("ctrl_c => shutdowning");
335    Ok(())
336}
337
338// Make our own error that wraps `anyhow::Error`.
339#[derive(Debug)]
340pub struct AppError(anyhow::Error);
341
342impl Display for AppError {
343    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        self.0.fmt(f)
345    }
346}
347
348// Tell axum how to convert `AppError` into a response.
349impl IntoResponse for AppError {
350    fn into_response(self) -> Response {
351        let err = self.0;
352        // Because `TraceLayer` wraps each request in a span that contains the request
353        // method, uri, etc we don't need to include those details here
354        tracing::error!(%err, "error");
355        (StatusCode::INTERNAL_SERVER_ERROR, format!("ERROR: {}", &err)).into_response()
356    }
357}
358
359// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
360// `Result<_, AppError>`. That way you don't need to do that manually.
361impl<E> From<E> for AppError
362where
363    E: Into<anyhow::Error>,
364{
365    fn from(err: E) -> Self {
366        Self(err.into())
367    }
368}
369
370impl AppError {
371    pub fn new<T: std::error::Error + Send + Sync + 'static>(err: T) -> Self {
372        Self(anyhow!(err))
373    }
374}
375
376fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
377    match result {
378        Ok(value) => value,
379        Err(err) => match err {},
380    }
381}