1use axum::{
2 http::StatusCode,
3 response::Json,
4 routing::{get, Router},
5 Router as AxumRouter,
6};
7use serde_json::json;
8use std::sync::Arc;
9use tokio::signal;
10use tower::ServiceBuilder;
11use tower_http::compression::CompressionLayer;
12use tracing::{info, warn};
13
14use crate::config::Config;
15use crate::handlers::{health_check, root, AppState};
16use crate::middleware::{create_cors_layer, create_trace_layer};
17
18pub struct Server {
20 config: Config,
21 app_state: Arc<AppState>,
22}
23
24impl Server {
25 pub async fn new(config: Config) -> Result<Self, anyhow::Error> {
27 let app_state = Arc::new(AppState::new());
28
29 Ok(Self { config, app_state })
30 }
31
32 fn create_base_router(&self) -> AxumRouter {
34 let state = self.app_state.clone();
35
36 Router::new()
37 .route("/", get(root))
38 .route("/health", get(health_check))
39 .fallback(|| async {
40 (
41 StatusCode::NOT_FOUND,
42 Json(json!({
43 "error": "Not found",
44 "status": 404
45 })),
46 )
47 })
48 .layer(
49 ServiceBuilder::new()
50 .layer(CompressionLayer::new())
51 .into_inner(),
52 )
53 .layer(create_cors_layer(&self.config))
54 .layer(create_trace_layer())
55 .with_state(state)
56 }
57
58 fn create_router(&self, custom_routes: Option<AxumRouter>) -> AxumRouter {
60 let base_router = self.create_base_router();
61
62 let router = if let Some(custom) = custom_routes {
64 base_router.merge(custom)
65 } else {
66 base_router
67 };
68
69 if let Some(ref context_path) = self.config.server.context_path {
71 use std::borrow::Cow;
72 let path: Cow<'_, str> = if context_path.starts_with('/') {
74 Cow::Borrowed(context_path.as_str())
75 } else {
76 tracing::warn!("上下文路径 '{}' 应该以 '/' 开头,已自动修正", context_path);
78 Cow::Owned(format!("/{}", context_path))
79 };
80 Router::new().nest(path.as_ref(), router)
81 } else {
82 router
83 }
84 }
85
86 pub async fn start(&self) -> Result<(), anyhow::Error> {
88 self.start_with_routes(None).await
89 }
90
91 pub async fn start_with_routes(
96 &self,
97 custom_routes: Option<AxumRouter>,
98 ) -> Result<(), anyhow::Error> {
99 let addr = self.config.server.socket_addr()?;
100 let app = self.create_router(custom_routes);
101
102 let base_url = if let Some(ref context_path) = self.config.server.context_path {
103 format!("http://{}{}", addr, context_path)
104 } else {
105 format!("http://{}", addr)
106 };
107
108 info!("服务器启动在 {}", base_url);
109 info!("健康检查: {}/health", base_url);
110
111 let listener = tokio::net::TcpListener::bind(&addr).await?;
112
113 axum::serve(listener, app)
114 .with_graceful_shutdown(shutdown_signal())
115 .await?;
116
117 Ok(())
118 }
119}
120
121async fn shutdown_signal() {
123 let ctrl_c = async {
124 signal::ctrl_c().await.expect("无法安装 Ctrl+C 信号处理器");
125 info!("收到 Ctrl+C 信号,开始优雅关闭...");
126 };
127
128 #[cfg(unix)]
129 let terminate = async {
130 signal::unix::signal(signal::unix::SignalKind::terminate())
131 .expect("无法安装 SIGTERM 信号处理器")
132 .recv()
133 .await;
134 warn!("收到 SIGTERM 信号,开始优雅关闭...");
135 };
136
137 #[cfg(not(unix))]
138 let terminate = std::future::pending::<()>();
139
140 tokio::select! {
141 _ = ctrl_c => {},
142 _ = terminate => {},
143 }
144
145 info!("服务器正在关闭...");
146}