Skip to main content

axum_bootstrap/
lib.rs

1//! # Axum Bootstrap 服务器核心模块
2//!
3//! 这个模块提供了基于 Axum 和 Hyper 的高性能 HTTP/HTTPS 服务器实现。
4//! 主要特性包括:
5//! - 支持 HTTP 和 HTTPS (TLS)
6//! - 请求拦截机制
7//! - 优雅关闭 (Graceful Shutdown)
8//! - IPv4/IPv6 双栈支持
9//! - 连接超时控制
10//! - TLS 证书动态更新
11//!
12//! # 示例
13//!
14//! ```no_run
15//! use axum::Router;
16//! use axum_bootstrap::{new_server, generate_shutdown_receiver};
17//!
18//! #[tokio::main]
19//! async fn main() {
20//!     let router = Router::new();
21//!     let shutdown_rx = generate_shutdown_receiver();
22//!     let server = new_server(8080, router, shutdown_rx);
23//!     server.run().await.unwrap();
24//! }
25//! ```
26
27use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
28
29/// 错误处理模块
30pub mod error;
31/// 日志初始化模块
32pub mod init_log;
33/// JWT 认证模块 (需要启用 jwt feature)
34#[cfg(feature = "jwt")]
35pub mod jwt;
36/// 工具函数模块
37pub mod util;
38
39/// 动态错误类型别名
40type DynError = Box<dyn std::error::Error + Send + Sync>;
41
42use crate::util::{
43    io::{self, create_dual_stack_listener},
44    tls::{TlsAcceptor, tls_config},
45};
46
47use axum::{
48    Router,
49    extract::Request,
50    response::{IntoResponse, Response},
51};
52
53use hyper::body::Incoming;
54use hyper_util::rt::TokioExecutor;
55use log::{info, warn};
56use tokio::{
57    sync::broadcast::{self, Receiver, Sender, error::RecvError},
58    time,
59};
60use tokio_rustls::rustls::ServerConfig;
61use tower::{Service, ServiceExt};
62use util::format::SocketAddrFormat;
63
64/// TLS 配置刷新间隔 (24小时)
65const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
66
67/// 优雅关闭等待超时时间 (10秒)
68const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
69
70/// HTTP/HTTPS 服务器核心结构
71///
72/// # 泛型参数
73/// - `I`: 请求拦截器类型,必须实现 `ReqInterceptor` trait
74///
75/// # 字段
76/// - `port`: 监听端口
77/// - `tls_param`: TLS 配置参数 (可选)
78/// - `router`: Axum 路由
79/// - `interceptor`: 请求拦截器实例 (可选)
80/// - `idle_timeout`: 连接空闲超时时间
81/// - `shutdown_rx`: 关闭信号接收器
82pub struct Server<I: ReqInterceptor = DummyInterceptor> {
83    pub port: u16,
84    pub tls_param: Option<TlsParam>,
85    router: Router,
86    pub interceptor: Option<I>,
87    pub idle_timeout: Duration,
88    shutdown_rx: broadcast::Receiver<()>,
89}
90
91/// TLS 配置参数
92///
93/// # 字段
94/// - `tls`: 是否启用 TLS
95/// - `cert`: TLS 证书文件路径
96/// - `key`: TLS 私钥文件路径
97#[derive(Debug, Clone)]
98pub struct TlsParam {
99    pub tls: bool,
100    pub cert: String,
101    pub key: String,
102}
103
104/// 请求拦截结果
105///
106/// 用于控制请求的处理流程
107///
108/// # 变体
109/// - `Return(Response)`: 直接返回响应,不继续处理
110/// - `Drop`: 丢弃请求,不返回响应
111/// - `Continue(Request)`: 继续处理请求
112/// - `Error(T)`: 返回错误响应
113pub enum InterceptResult<T: IntoResponse> {
114    Return(Response),
115    Drop,
116    Continue(Request<Incoming>),
117    Error(T),
118}
119
120/// 请求拦截器 trait
121///
122/// 实现此 trait 可以在请求到达路由处理器之前进行拦截和处理
123///
124/// # 关联类型
125/// - `Error`: 拦截器可能返回的错误类型
126///
127/// # 方法
128/// - `intercept`: 拦截请求的方法
129///
130/// # 示例
131///
132/// ```no_run
133/// use axum_bootstrap::{ReqInterceptor, InterceptResult};
134/// use axum::extract::Request;
135/// use hyper::body::Incoming;
136/// use std::net::SocketAddr;
137///
138/// #[derive(Clone)]
139/// struct MyInterceptor;
140///
141/// impl ReqInterceptor for MyInterceptor {
142///     type Error = axum_bootstrap::error::AppError;
143///
144///     async fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> InterceptResult<Self::Error> {
145///         // 自定义拦截逻辑
146///         InterceptResult::Continue(req)
147///     }
148/// }
149/// ```
150pub trait ReqInterceptor: Send {
151    type Error: IntoResponse + Send + Sync + 'static;
152    fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
153}
154
155/// 空实现的请求拦截器
156///
157/// 默认不执行任何拦截操作,直接继续处理请求
158#[derive(Clone)]
159pub struct DummyInterceptor;
160
161impl ReqInterceptor for DummyInterceptor {
162    type Error = error::AppError;
163
164    async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
165        InterceptResult::Continue(req)
166    }
167}
168
169/// 默认服务器类型 (使用 DummyInterceptor)
170pub type DefaultServer = Server<DummyInterceptor>;
171
172/// 创建默认服务器实例
173///
174/// # 参数
175/// - `port`: 监听端口
176/// - `router`: Axum 路由
177/// - `shutdown_rx`: 关闭信号接收器
178///
179/// # 返回
180/// 返回配置好的服务器实例,默认不启用 TLS,空闲超时为 120 秒
181///
182/// # 示例
183///
184/// ```no_run
185/// use axum::Router;
186/// use axum_bootstrap::{new_server, generate_shutdown_receiver};
187///
188/// #[tokio::main]
189/// async fn main() {
190///     let router = Router::new();
191///     let shutdown_rx = generate_shutdown_receiver();
192///     let server = new_server(8080, router, shutdown_rx);
193///     server.run().await.unwrap();
194/// }
195/// ```
196pub fn new_server(port: u16, router: Router, shutdown_rx: broadcast::Receiver<()>) -> Server {
197    Server {
198        port,
199        tls_param: None, // 默认不启用 TLS
200        router,
201        interceptor: None,
202        idle_timeout: Duration::from_secs(120),
203        shutdown_rx,
204    }
205}
206
207impl<I> Server<I>
208where
209    I: ReqInterceptor + Clone + Send + Sync + 'static,
210{
211    /// 设置请求拦截器
212    ///
213    /// 用于将服务器的拦截器类型更改为新的类型
214    ///
215    /// # 类型参数
216    /// - `R`: 新的拦截器类型
217    ///
218    /// # 参数
219    /// - `interceptor`: 新的拦截器实例
220    ///
221    /// # 返回
222    /// 返回配置了新拦截器的服务器实例
223    pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
224    where
225        R: ReqInterceptor + Clone + Send + Sync + 'static,
226    {
227        Server::<R> {
228            port: self.port,
229            tls_param: self.tls_param,
230            router: self.router,
231            interceptor: Some(interceptor),
232            idle_timeout: self.idle_timeout, // 保持相同的空闲超时
233            shutdown_rx: self.shutdown_rx,
234        }
235    }
236
237    /// 设置 TLS 参数
238    ///
239    /// # 参数
240    /// - `tls_param`: TLS 配置参数,为 None 时禁用 TLS
241    ///
242    /// # 返回
243    /// 返回配置了 TLS 的服务器实例
244    pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
245        self.tls_param = tls_param;
246        self
247    }
248
249    /// 设置连接空闲超时时间
250    ///
251    /// # 参数
252    /// - `timeout`: 超时时长
253    ///
254    /// # 返回
255    /// 返回配置了超时的服务器实例
256    pub fn with_timeout(mut self, timeout: Duration) -> Self {
257        self.idle_timeout = timeout;
258        self
259    }
260
261    /// 启动服务器
262    ///
263    /// 根据 TLS 配置启动 HTTP 或 HTTPS 服务器,并监听关闭信号
264    ///
265    /// # 返回
266    /// - `Ok(())`: 服务器成功启动并正常关闭
267    /// - `Err(std::io::Error)`: 启动或运行过程中出现 I/O 错误
268    ///
269    /// # 错误
270    /// - 端口绑定失败
271    /// - TLS 证书加载失败
272    /// - 网络 I/O 错误
273    pub async fn run(mut self) -> Result<(), std::io::Error> {
274        let use_tls = match self.tls_param.clone() {
275            Some(config) => config.tls,
276            None => false,
277        };
278        log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
279        let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
280        let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
281        match use_tls {
282            #[allow(clippy::expect_used)]
283            true => {
284                serve_tls(
285                    &self.router,
286                    server,
287                    graceful,
288                    self.port,
289                    self.tls_param.as_ref().expect("should be some"),
290                    self.interceptor.clone(),
291                    self.idle_timeout,
292                    &mut self.shutdown_rx,
293                )
294                .await?
295            }
296            false => {
297                serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout, &mut self.shutdown_rx).await?
298            }
299        }
300        Ok(())
301    }
302}
303
304/// 处理单个 HTTP 请求
305///
306/// 如果配置了拦截器,会先调用拦截器处理请求,否则直接路由到应用
307///
308/// # 参数
309/// - `request`: HTTP 请求
310/// - `client_socket_addr`: 客户端地址
311/// - `app`: Axum 应用实例
312/// - `interceptor`: 可选的请求拦截器
313///
314/// # 返回
315/// - `Ok(Response)`: 成功生成的 HTTP 响应
316/// - `Err(std::io::Error)`: 处理过程中的 I/O 错误
317async fn handle<I>(
318    request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
319    interceptor: Option<I>,
320) -> std::result::Result<Response, std::io::Error>
321where
322    I: ReqInterceptor + Clone + Send + Sync + 'static,
323{
324    if let Some(interceptor) = interceptor {
325        match interceptor.intercept(request, client_socket_addr).await {
326            InterceptResult::Return(res) => Ok(res),
327            InterceptResult::Drop => Err(std::io::Error::other("Request dropped by interceptor")),
328            InterceptResult::Continue(req) => app
329                .oneshot(req)
330                .await
331                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
332            InterceptResult::Error(err) => {
333                let res = err.into_response();
334                Ok(res)
335            }
336        }
337    } else {
338        app.oneshot(request)
339            .await
340            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
341    }
342}
343
344/// 处理单个连接
345///
346/// 为每个连接创建超时包装器和 Hyper 服务,并在新的 tokio 任务中处理
347///
348/// # 类型参数
349/// - `C`: 连接类型,必须实现 AsyncRead + AsyncWrite
350/// - `I`: 请求拦截器类型
351///
352/// # 参数
353/// - `conn`: 网络连接
354/// - `client_socket_addr`: 客户端地址
355/// - `app`: Axum 路由
356/// - `server`: Hyper 服务器构建器
357/// - `interceptor`: 可选的请求拦截器
358/// - `graceful`: 优雅关闭句柄
359/// - `timeout`: 连接空闲超时时间
360async fn handle_connection<C, I>(
361    conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
362    interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
363) where
364    C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
365    I: ReqInterceptor + Clone + Send + Sync + 'static,
366{
367    let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
368    use hyper::Request;
369    use hyper_util::rt::TokioIo;
370    let stream = TokioIo::new(timeout_io);
371    let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
372    let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
373    // https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs#L81
374    let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
375        handle(request, client_socket_addr, app.clone(), interceptor.clone())
376    });
377
378    let conn = server.serve_connection_with_upgrades(stream, hyper_service);
379    let conn = graceful.watch(conn.into_owned());
380
381    tokio::spawn(async move {
382        if let Err(err) = conn.await {
383            handle_hyper_error(client_socket_addr, err);
384        }
385        log::debug!("dropped: {client_socket_addr}");
386    });
387}
388
389/// 处理 Hyper 错误并记录日志
390///
391/// 根据错误类型输出不同级别的日志
392///
393/// # 参数
394/// - `client_socket_addr`: 客户端地址
395/// - `http_err`: HTTP 错误
396fn handle_hyper_error(client_socket_addr: SocketAddr, http_err: DynError) {
397    use std::error::Error;
398    match http_err.downcast_ref::<hyper::Error>() {
399        Some(hyper_err) => {
400            let level = if hyper_err.is_user() { log::Level::Warn } else { log::Level::Debug };
401            let source = hyper_err.source().unwrap_or(hyper_err);
402            log::log!(
403                level,
404                "[hyper {}]: {:?} from {}",
405                if hyper_err.is_user() { "user" } else { "system" },
406                source,
407                SocketAddrFormat(&client_socket_addr)
408            );
409        }
410        None => match http_err.downcast_ref::<std::io::Error>() {
411            Some(io_err) => {
412                warn!("[hyper io]: [{}] {} from {}", io_err.kind(), io_err, SocketAddrFormat(&client_socket_addr));
413            }
414            None => {
415                warn!("[hyper]: {} from {}", http_err, SocketAddrFormat(&client_socket_addr));
416            }
417        },
418    }
419}
420
421/// 启动纯文本 HTTP 服务器
422///
423/// 监听指定端口并处理 HTTP 连接,支持优雅关闭
424///
425/// # 参数
426/// - `app`: Axum 路由
427/// - `server`: Hyper 服务器构建器
428/// - `graceful`: 优雅关闭句柄
429/// - `port`: 监听端口
430/// - `interceptor`: 可选的请求拦截器
431/// - `timeout`: 连接空闲超时时间
432/// - `shutdown_rx`: 关闭信号接收器
433///
434/// # 返回
435/// - `Ok(())`: 服务器成功启动并正常关闭
436/// - `Err(std::io::Error)`: 启动或运行过程中出现错误
437async fn serve_plantext<I>(
438    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
439    port: u16, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
440) -> Result<(), std::io::Error>
441where
442    I: ReqInterceptor + Clone + Send + Sync + 'static,
443{
444    let listener = create_dual_stack_listener(port).await?;
445    loop {
446        tokio::select! {
447            _ = shutdown_rx.recv() => {
448                info!("start graceful shutdown!");
449                drop(listener);
450                break;
451            }
452            conn = listener.accept() => {
453                match conn {
454                    Ok((conn, client_socket_addr)) => {
455                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
456                    Err(e) => {
457                        warn!("accept error:{e}");
458                    }
459                }
460            }
461        }
462    }
463    match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
464        Ok(_) => info!("Gracefully shutdown!"),
465        Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
466    }
467    Ok(())
468}
469
470/// 启动 TLS HTTPS 服务器
471///
472/// 监听指定端口并处理 HTTPS 连接,支持 TLS 证书动态更新和优雅关闭
473///
474/// # 参数
475/// - `app`: Axum 路由
476/// - `server`: Hyper 服务器构建器
477/// - `graceful`: 优雅关闭句柄
478/// - `port`: 监听端口
479/// - `tls_param`: TLS 配置参数
480/// - `interceptor`: 可选的请求拦截器
481/// - `timeout`: 连接空闲超时时间
482/// - `shutdown_rx`: 关闭信号接收器
483///
484/// # 返回
485/// - `Ok(())`: 服务器成功启动并正常关闭
486/// - `Err(std::io::Error)`: 启动或运行过程中出现错误
487///
488/// # 说明
489/// 服务器会在后台启动一个定时任务,每隔 REFRESH_INTERVAL (24小时) 刷新一次 TLS 配置
490#[allow(clippy::too_many_arguments)]
491async fn serve_tls<I>(
492    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
493    port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
494) -> Result<(), std::io::Error>
495where
496    I: ReqInterceptor + Clone + Send + Sync + 'static,
497{
498    let (tx, mut rx) = broadcast::channel::<Arc<ServerConfig>>(1);
499    let tls_param_clone = tls_param.clone();
500    tokio::spawn(async move {
501        info!("update tls config every {REFRESH_INTERVAL:?}");
502        loop {
503            time::sleep(REFRESH_INTERVAL).await;
504            if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
505                info!("update tls config");
506                if let Err(e) = tx.send(new_acceptor) {
507                    warn!("send tls config error:{e}");
508                }
509            }
510        }
511    });
512    let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
513    loop {
514        tokio::select! {
515            _ = shutdown_rx.recv() => {
516                info!("start graceful shutdown!");
517                drop(acceptor);
518                break;
519            }
520            message = rx.recv() => {
521                match message {
522                    Ok(new_config) => {
523                        acceptor.replace_config(new_config);
524                        info!("replaced tls config");
525                    },
526                    Err(e) => {
527                        match e {
528                            RecvError::Closed => {
529                                warn!("this channel should not be closed!");
530                                break;
531                            },
532                            RecvError::Lagged(n) => {
533                                warn!("lagged {n} messages, this may cause tls config not updated in time");
534                            }
535                        }
536                    }
537                }
538            }
539            conn = acceptor.accept() => {
540                match conn {
541                    Ok((conn, client_socket_addr)) => {
542                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
543                    Err(e) => {
544                        warn!("accept error:{e}");
545                    }
546                }
547            }
548        }
549    }
550    match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
551        Ok(_) => info!("Gracefully shutdown!"),
552        Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
553    }
554    Ok(())
555}
556
557/// 生成关闭信号接收器
558///
559/// 创建一个广播通道并订阅系统信号,返回接收器用于监听关闭信号
560///
561/// # 返回
562/// 关闭信号接收器,当收到 SIGTERM 或 Ctrl+C 信号时会收到通知
563///
564/// # 示例
565///
566/// ```no_run
567/// use axum_bootstrap::generate_shutdown_receiver;
568///
569/// let shutdown_rx = generate_shutdown_receiver();
570/// // 传递给服务器
571/// ```
572pub fn generate_shutdown_receiver() -> Receiver<()> {
573    let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
574    subscribe_shutdown_sender(shutdown_tx);
575    shutdown_rx
576}
577
578/// 订阅关闭信号发送器
579///
580/// 在后台任务中监听系统信号,当收到信号时通过发送器通知所有接收器
581///
582/// # 参数
583/// - `shutdown_tx`: 关闭信号发送器
584pub fn subscribe_shutdown_sender(shutdown_tx: Sender<()>) {
585    tokio::spawn(async move {
586        match wait_signal().await {
587            Ok(_) => {
588                let _ = shutdown_tx.send(());
589            }
590            Err(e) => {
591                log::error!("wait_signal error: {}", e);
592                panic!("wait_signal error: {}", e);
593            }
594        }
595    });
596}
597
598/// 等待系统关闭信号 (Unix 平台)
599///
600/// 监听 SIGTERM 和 Ctrl+C 信号
601///
602/// # 返回
603/// - `Ok(())`: 成功接收到信号
604/// - `Err(DynError)`: 信号处理出错
605#[cfg(unix)]
606pub(crate) async fn wait_signal() -> Result<(), DynError> {
607    use log::info;
608    use tokio::signal::unix::{SignalKind, signal};
609    let mut terminate_signal = signal(SignalKind::terminate())?;
610    tokio::select! {
611        _ = terminate_signal.recv() => {
612            info!("receive terminate signal");
613        },
614        _ = tokio::signal::ctrl_c() => {
615            info!("receive ctrl_c signal");
616        },
617    };
618    Ok(())
619}
620
621/// 等待系统关闭信号 (Windows 平台)
622///
623/// 监听 Ctrl+C 信号
624///
625/// # 返回
626/// - `Ok(())`: 成功接收到信号
627/// - `Err(DynError)`: 信号处理出错
628#[cfg(windows)]
629pub(crate) async fn wait_signal() -> Result<(), DynError> {
630    let _ = tokio::signal::ctrl_c().await;
631    info!("receive ctrl_c signal");
632    Ok(())
633}
634
635/// 解包 Infallible 结果类型
636///
637/// 因为 Infallible 类型永远不会发生,所以这个函数总是返回 Ok 值
638///
639/// # 参数
640/// - `result`: 包含 Infallible 错误的 Result
641///
642/// # 返回
643/// Result 中的成功值
644fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
645    match result {
646        Ok(value) => value,
647        Err(err) => match err {},
648    }
649}