axum_bootstrap/
lib.rs

1use std::{convert::Infallible, fmt::Display, net::SocketAddr, sync::Arc, time::Duration};
2
3pub mod init_log;
4pub mod util;
5type DynError = Box<dyn std::error::Error + Send + Sync>;
6use crate::util::{
7    io::{self, create_dual_stack_listener},
8    tls::{tls_config, TlsAcceptor},
9};
10use anyhow::anyhow;
11use axum::{
12    extract::Request,
13    response::{IntoResponse, Response},
14    Router,
15};
16
17use hyper::{body::Incoming, StatusCode};
18use hyper_util::rt::TokioExecutor;
19use log::{info, warn};
20use tokio::{pin, sync::broadcast, time};
21use tokio_rustls::rustls::ServerConfig;
22use tower::{Service, ServiceExt};
23
24const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
25const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
26
27pub struct Server<I: ReqInterceptor = DummyInterceptor> {
28    pub port: u16,
29    pub tls_param: Option<TlsParam>,
30    router: Router,
31    pub interceptor: Option<I>,
32    pub idle_timeout: Duration,
33}
34
35#[derive(Debug, Clone)]
36pub struct TlsParam {
37    pub tls: bool,
38    pub cert: String,
39    pub key: String,
40}
41
42pub enum InterceptResult<T: IntoResponse> {
43    Return(Response),
44    Continue(Request<Incoming>),
45    Error(T),
46}
47
48pub trait ReqInterceptor {
49    type Error: IntoResponse + Send + Sync + 'static;
50    fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
51}
52
53#[derive(Clone)]
54pub struct DummyInterceptor;
55
56impl ReqInterceptor for DummyInterceptor {
57    type Error = AppError;
58
59    async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
60        InterceptResult::Continue(req)
61    }
62}
63
64pub type DefaultServer = Server<DummyInterceptor>;
65
66pub fn new_server(port: u16, router: Router) -> Server {
67    Server {
68        port,
69        tls_param: None, // No TLS by default
70        router,
71        interceptor: None,
72        idle_timeout: Duration::from_secs(120),
73    }
74}
75
76impl<I> Server<I>
77where
78    I: ReqInterceptor + Clone + Send + Sync + 'static,
79{
80    pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
81    where
82        R: ReqInterceptor + Clone + Send + Sync + 'static,
83    {
84        Server::<R> {
85            port: self.port,
86            tls_param: self.tls_param,
87            router: self.router,
88            interceptor: Some(interceptor),
89            idle_timeout: self.idle_timeout, // keep the same idle timeout
90        }
91    }
92    pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
93        // Enable TLS by setting the tls_param
94        self.tls_param = tls_param;
95        self
96    }
97
98    pub fn with_timeout(mut self, timeout: Duration) -> Self {
99        self.idle_timeout = timeout;
100        self
101    }
102
103    pub async fn run(&self) -> Result<(), DynError> {
104        let use_tls = match self.tls_param.clone() {
105            Some(config) => config.tls,
106            None => false,
107        };
108        log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
109        let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
110        let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
111        match use_tls {
112            #[allow(clippy::expect_used)]
113            true => {
114                serve_tls(
115                    &self.router,
116                    server,
117                    graceful,
118                    self.port,
119                    self.tls_param.as_ref().expect("should be some"),
120                    self.interceptor.clone(),
121                    self.idle_timeout,
122                )
123                .await?
124            }
125            false => serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout).await?,
126        }
127        Ok(())
128    }
129}
130
131async fn handle<I>(
132    request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
133    interceptor: Option<I>,
134) -> std::result::Result<Response, std::convert::Infallible>
135where
136    I: ReqInterceptor + Clone + Send + Sync + 'static,
137{
138    if let Some(interceptor) = interceptor {
139        match interceptor.intercept(request, client_socket_addr).await {
140            InterceptResult::Continue(req) => app.oneshot(req).await,
141            InterceptResult::Return(res) => Ok(res),
142            InterceptResult::Error(err) => {
143                let res = err.into_response();
144                Ok(res)
145            }
146        }
147    } else {
148        app.oneshot(request).await
149    }
150}
151
152async fn handle_connection<C, I>(
153    conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
154    interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
155) where
156    C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
157    I: ReqInterceptor + Clone + Send + Sync + 'static,
158{
159    let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
160    use hyper::Request;
161    use hyper_util::rt::TokioIo;
162    let stream = TokioIo::new(timeout_io);
163    let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
164    let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
165    // https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs#L81
166    let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
167        handle(request, client_socket_addr, app.clone(), interceptor.clone())
168    });
169
170    let conn = server.serve_connection_with_upgrades(stream, hyper_service);
171    let conn = graceful.watch(conn.into_owned());
172
173    tokio::spawn(async move {
174        if let Err(err) = conn.await {
175            info!("connection error: {}", err);
176        }
177        log::debug!("connection dropped: {}", client_socket_addr);
178    });
179}
180
181async fn serve_plantext<I>(
182    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
183    port: u16, interceptor: Option<I>, timeout: Duration,
184) -> Result<(), DynError>
185where
186    I: ReqInterceptor + Clone + Send + Sync + 'static,
187{
188    let listener = create_dual_stack_listener(port).await?;
189    let signal = handle_signal();
190    pin!(signal);
191    loop {
192        tokio::select! {
193            _ = signal.as_mut() => {
194                info!("start graceful shutdown!");
195                drop(listener);
196                break;
197            }
198            conn = listener.accept() => {
199                match conn {
200                    Ok((conn, client_socket_addr)) => {
201                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
202                    Err(e) => {
203                        warn!("accept error:{}", e);
204                    }
205                }
206            }
207        }
208    }
209    tokio::select! {
210        _ = graceful.shutdown() => {
211            info!("Gracefully shutdown!");
212        },
213        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
214            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
215        }
216    }
217    Ok(())
218}
219
220async fn serve_tls<I>(
221    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
222    port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration,
223) -> Result<(), DynError>
224where
225    I: ReqInterceptor + Clone + Send + Sync + 'static,
226{
227    let (tx, _rx) = broadcast::channel::<Arc<ServerConfig>>(10);
228    let tx_clone = tx.clone();
229    let tls_param_clone = tls_param.clone();
230    tokio::spawn(async move {
231        info!("update tls config every {:?}", REFRESH_INTERVAL);
232        loop {
233            time::sleep(REFRESH_INTERVAL).await;
234            if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
235                info!("update tls config");
236                if let Err(e) = tx.send(new_acceptor) {
237                    warn!("send tls config error:{}", e);
238                }
239            }
240        }
241    });
242    let mut rx = tx_clone.subscribe();
243    let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
244    let signal = handle_signal();
245    pin!(signal);
246    loop {
247        tokio::select! {
248            _ = signal.as_mut() => {
249                info!("start graceful shutdown!");
250                drop(acceptor);
251                break;
252            }
253            message = rx.recv() => {
254                #[allow(clippy::expect_used)]
255                let new_config = message.expect("Channel should not be closed");
256                // Replace the acceptor with the new one
257                acceptor.replace_config(new_config);
258                info!("replaced tls config");
259            }
260            conn = acceptor.accept() => {
261                match conn {
262                    Ok((conn, client_socket_addr)) => {
263                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
264                    Err(e) => {
265                        warn!("accept error:{}", e);
266                    }
267                }
268            }
269        }
270    }
271    tokio::select! {
272        _ = graceful.shutdown() => {
273            info!("Gracefully shutdown!");
274        },
275        _ = tokio::time::sleep(GRACEFUL_SHUTDOWN_TIMEOUT) => {
276            info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting...");
277        }
278    }
279    Ok(())
280}
281
282#[cfg(unix)]
283async fn handle_signal() -> Result<(), DynError> {
284    use log::info;
285    use tokio::signal::unix::{signal, SignalKind};
286    let mut terminate_signal = signal(SignalKind::terminate())?;
287    tokio::select! {
288        _ = terminate_signal.recv() => {
289            info!("receive terminate signal, shutdowning");
290        },
291        _ = tokio::signal::ctrl_c() => {
292            info!("ctrl_c => shutdowning");
293        },
294    };
295    Ok(())
296}
297
298#[cfg(windows)]
299async fn handle_signal() -> Result<(), DynError> {
300    let _ = tokio::signal::ctrl_c().await;
301    info!("ctrl_c => shutdowning");
302    Ok(())
303}
304
305// Make our own error that wraps `anyhow::Error`.
306#[derive(Debug)]
307pub struct AppError(anyhow::Error);
308
309impl Display for AppError {
310    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        self.0.fmt(f)
312    }
313}
314
315// Tell axum how to convert `AppError` into a response.
316impl IntoResponse for AppError {
317    fn into_response(self) -> Response {
318        let err = self.0;
319        // Because `TraceLayer` wraps each request in a span that contains the request
320        // method, uri, etc we don't need to include those details here
321        tracing::error!(%err, "error");
322        (StatusCode::INTERNAL_SERVER_ERROR, format!("axum-bootstrap error: {}", &err)).into_response()
323    }
324}
325
326// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
327// `Result<_, AppError>`. That way you don't need to do that manually.
328impl<E> From<E> for AppError
329where
330    E: Into<anyhow::Error>,
331{
332    fn from(err: E) -> Self {
333        Self(err.into())
334    }
335}
336
337impl AppError {
338    pub fn new<T: std::error::Error + Send + Sync + 'static>(err: T) -> Self {
339        Self(anyhow!(err))
340    }
341}
342
343fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
344    match result {
345        Ok(value) => value,
346        Err(err) => match err {},
347    }
348}