use crate::web::middleware::{
ForbiddenUrnsState, IpBanState, LocalOnlyUrnsState, forbidden_urns_middleware,
ip_ban_middleware, local_only_middleware, local_only_urns_middleware,
};
use crate::web::{HttpsConfig, WebServerConfig, WebServerError, build_cors, build_https};
use axum::{Router, debug_handler, middleware, routing::get};
use linkme::distributed_slice;
use log::{debug, error, info};
use robotech_macros::log_call;
use socket2::{Domain, Socket, Type};
use std::net::{IpAddr, SocketAddr, TcpListener};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tower_http::trace::TraceLayer;
use utoipa::openapi::OpenApi;
use utoipa_swagger_ui::{SwaggerUi, Url};
use wheel_rs::process::terminate_process;
#[distributed_slice]
pub static ROUTER_SLICE: [fn() -> Router];
#[distributed_slice]
pub static API_DOC_SLICE: [fn() -> (Url<'static>, OpenApi)];
static WEB_SERVICE_HANDLES: RwLock<Option<Vec<JoinHandle<()>>>> = RwLock::new(None);
static STOP_WEB_SERVICE_SENDER: RwLock<Option<broadcast::Sender<()>>> = RwLock::new(None);
fn set_web_service_handles(value: Vec<JoinHandle<()>>) -> Result<(), WebServerError> {
let mut write_lock = WEB_SERVICE_HANDLES
.write()
.map_err(|e| WebServerError::SetWebServiceHandles(e.to_string()))?;
*write_lock = Some(value);
Ok(())
}
fn take_web_service_handles() -> Result<Option<Vec<JoinHandle<()>>>, WebServerError> {
let mut write_lock = WEB_SERVICE_HANDLES
.write()
.map_err(|e| WebServerError::TakeWebServiceHandles(e.to_string()))?;
Ok(write_lock.take())
}
fn set_stop_web_service_sender(value: broadcast::Sender<()>) -> Result<(), WebServerError> {
let mut write_lock = STOP_WEB_SERVICE_SENDER
.write()
.map_err(|e| WebServerError::SetWebServiceHandles(e.to_string()))?;
*write_lock = Some(value);
Ok(())
}
fn take_stop_web_service_sender() -> Result<Option<broadcast::Sender<()>>, WebServerError> {
let mut write_lock = STOP_WEB_SERVICE_SENDER
.write()
.map_err(|e| WebServerError::TakeWebServiceHandles(e.to_string()))?;
Ok(write_lock.take())
}
#[debug_handler]
#[log_call]
pub async fn health() -> &'static str {
"Ok"
}
#[log_call]
pub async fn start_web_server(
web_server_config: WebServerConfig,
port_of_args: Option<u16>,
old_pid: Option<u32>,
) -> Result<(), WebServerError> {
let WebServerConfig {
bind: binds,
port: port_option,
listen: listens,
mut reuse_port,
https: https_config,
forbidden_urns,
local_only_urns,
ip_white_list,
ip_black_list,
log_enabled,
cors: cors_config,
health_check,
start_wait_timeout,
start_retry_interval,
terminate_old_app_wait_timeout,
terminate_old_app_retry_interval,
} = web_server_config;
let health_check_uri = &health_check.uri;
let (is_random_port, listen_binds) =
get_listen_binds(port_of_args, binds, port_option, listens)?;
if listen_binds.is_empty() {
Err(WebServerError::ParseListenBinds(
"没有配置监听绑定".to_string(),
))?;
}
let mut old_web_service_handles = take_web_service_handles()?;
let stop_old_web_service_sender = take_stop_web_service_sender()?;
if is_random_port {
reuse_port = false;
} else if !reuse_port {
if let Some(old_pid) = old_pid {
terminate_old_app(
old_pid,
terminate_old_app_wait_timeout,
terminate_old_app_retry_interval,
)
.await?;
} else {
if let Some(web_service_handles) = old_web_service_handles.take() {
stop_old_web_service(stop_old_web_service_sender.clone(), web_service_handles)
.await?;
}
}
}
let mut router = Router::new();
for build_router in ROUTER_SLICE.iter() {
router = router.merge(build_router());
}
if health_check.exposed {
router = router.route(health_check_uri, get(health));
} else {
router = router.route(
health_check_uri,
get(health).layer(axum::middleware::from_fn(local_only_middleware)),
);
}
let mut api_docs = vec![];
for init_api_doc in API_DOC_SLICE.iter() {
api_docs.push(init_api_doc());
}
if !api_docs.is_empty() {
router = router.merge(SwaggerUi::new("/swagger-ui").urls(api_docs));
}
if log_enabled {
router = router.layer(TraceLayer::new_for_http());
}
if !ip_white_list.is_empty() || !ip_black_list.is_empty() {
let ip_ban_state = IpBanState {
ip_white_list: Arc::new(ip_white_list.clone()),
ip_black_list: Arc::new(ip_black_list.clone()),
};
router = router.layer(middleware::from_fn_with_state(
ip_ban_state.clone(),
ip_ban_middleware,
));
}
if !forbidden_urns.is_empty() {
let forbidden_urns_state = ForbiddenUrnsState {
forbidden_urns: Arc::new(forbidden_urns.clone()),
};
router = router.layer(middleware::from_fn_with_state(
forbidden_urns_state.clone(),
forbidden_urns_middleware,
));
}
if !local_only_urns.is_empty() {
let local_only_urns_state = LocalOnlyUrnsState {
local_only_urns: Arc::new(local_only_urns.clone()),
};
router = router.layer(middleware::from_fn_with_state(
local_only_urns_state.clone(),
local_only_urns_middleware,
));
}
if let Some(cors_layer) = build_cors(&cors_config)? {
router = router.layer(cors_layer);
}
let http_protocol = if let Some(https_config) = https_config.clone()
&& https_config.enabled
{
"https"
} else {
"http"
};
let (stop_web_service_sender, stop_web_service_receiver) = broadcast::channel::<()>(1);
let (health_check_url_prefix, web_service_handles) = bind_and_start(
router,
reuse_port,
listen_binds,
http_protocol,
https_config,
stop_web_service_receiver,
)?;
if old_web_service_handles.is_none() {
let heath_check_url = format!("{health_check_url_prefix}{health_check_uri}");
wait_for_web_server_ready(
heath_check_url.as_str(),
start_wait_timeout,
start_retry_interval,
)
.await?;
}
if is_random_port || reuse_port {
if let Some(old_pid) = old_pid {
terminate_old_app(
old_pid,
terminate_old_app_wait_timeout,
terminate_old_app_retry_interval,
)
.await?;
} else {
if let Some(web_service_handles) = old_web_service_handles.take() {
tokio::spawn({
let stop_old_web_service_sender = stop_old_web_service_sender.clone();
async move {
tokio::time::sleep(Duration::from_secs(5)).await;
stop_old_web_service(stop_old_web_service_sender, web_service_handles).await
}
});
}
}
}
set_web_service_handles(web_service_handles)?;
set_stop_web_service_sender(stop_web_service_sender)?;
Ok(())
}
#[log_call]
pub fn create_listener(
mut bind: String,
port: u16,
reuse_port: bool,
) -> Result<TcpListener, WebServerError> {
if bind.starts_with('[') && bind.ends_with(']') {
bind = bind[1..bind.len() - 1].to_string();
}
let ip_addr: IpAddr = bind
.parse()
.map_err(|_| WebServerError::Socket("无效的 IP 地址格式".to_string()))?;
let addr: &SocketAddr = &SocketAddr::new(ip_addr, port);
let socket = Socket::new(
Domain::for_address(*addr),
Type::STREAM,
Some(socket2::Protocol::TCP),
)
.map_err(|e| WebServerError::Socket(format!("创建 socket 失败: {e}")))?;
socket
.set_reuse_address(true)
.map_err(|e| WebServerError::Socket(format!("设置地址复用选项失败: {e}")))?;
socket
.set_reuse_port(reuse_port)
.map_err(|e| WebServerError::Socket(format!("设置端口复用选项失败: {e}")))?;
socket
.set_nonblocking(true)
.map_err(|e| WebServerError::Socket(format!("设置非阻塞模式失败: {e}")))?;
socket
.bind(&(*addr).into())
.map_err(|e| WebServerError::Socket(format!("绑定{addr}失败: {e}")))?;
socket
.listen(1024)
.map_err(|e| WebServerError::Socket(format!("开始监听{addr}失败: {e}",)))?;
Ok(TcpListener::from(socket))
}
async fn wait_for_web_server_ready(
health_check_url: &str,
wait_timeout: Duration,
retry_interval: Duration,
) -> Result<(), WebServerError> {
let client = if health_check_url.starts_with("https://") {
reqwest::Client::builder()
.danger_accept_invalid_certs(true) .build()
.map_err(|e| WebServerError::BuildReqwestClient(e.to_string()))?
} else {
reqwest::Client::new()
};
timeout(wait_timeout, async move {
Ok(loop {
tokio::time::sleep(retry_interval).await;
if let Ok(response) = client.get(health_check_url).send().await {
if response.status().is_success() {
info!("Web服务器通过健康检查,启动完成.");
break;
}
}
})
})
.await
.map_err(|_| WebServerError::StartWebServerTimeout(health_check_url.to_string()))?
}
pub async fn stop_web_service() -> Result<(), WebServerError> {
if let Some(stop_web_service_sender) = take_stop_web_service_sender()? {
stop_web_service_sender
.send(())
.map_err(|e| WebServerError::StopService(e.to_string()))?;
}
if let Some(web_service_handles) = take_web_service_handles()? {
for web_service_handle in web_service_handles {
let _ = web_service_handle
.await
.map_err(|e| WebServerError::StopService(e.to_string()))?;
}
}
Ok(())
}
pub async fn stop_old_web_service(
old_sender: Option<broadcast::Sender<()>>,
old_handles: Vec<JoinHandle<()>>,
) -> Result<(), WebServerError> {
if let Some(old_sender) = old_sender {
old_sender
.send(())
.map_err(|e| WebServerError::StopService(e.to_string()))?;
}
for web_service_handle in old_handles {
let _ = web_service_handle
.await
.map_err(|e| WebServerError::StopService(e.to_string()))?;
}
Ok(())
}
async fn terminate_old_app(
old_pid: u32,
wait_timeout: Duration,
retry_interval: Duration,
) -> Result<(), WebServerError> {
debug!("停止运行旧的Web服务器...");
terminate_process(old_pid, wait_timeout, retry_interval).await?;
Ok(())
}
fn get_listen_binds(
port_of_args: Option<u16>,
binds: Vec<String>,
mut port_option: Option<u16>,
listens: Vec<String>,
) -> Result<(bool, Vec<(String, u16)>), WebServerError> {
if port_of_args.is_some() {
port_option = port_of_args;
}
let mut is_random_port = true;
let port = port_option.unwrap_or(0);
if port != 0 {
is_random_port = false;
}
let mut listen_binds = vec![];
if !binds.is_empty() {
for bind in binds {
listen_binds.push((bind, port));
}
} else if listens.is_empty() {
listen_binds.push(("0.0.0.0".to_string(), port));
}
for listen in &listens {
let parts: Vec<&str> = listen.rsplitn(2, ':').collect();
match parts.len() {
1 => {
let port: u16 = listen
.parse()
.map_err(|_| WebServerError::ParsePort(listen.to_string()))?;
if port != 0 {
is_random_port = false;
}
listen_binds.push(("0.0.0.0".to_string(), port));
}
2 => {
let port: u16 = parts[0]
.parse()
.map_err(|_| WebServerError::ParsePort(listen.to_string()))?;
if port != 0 {
is_random_port = false;
}
let bind = parts[1].to_string();
listen_binds.push((bind, port));
}
_ => Err(WebServerError::ParsePort(listen.to_string()))?,
}
}
Ok((is_random_port, listen_binds))
}
#[log_call]
fn bind_and_start(
router: Router,
reuse_port: bool,
listen_binds: Vec<(String, u16)>,
http_protocol: &str,
https_config: Option<HttpsConfig>,
stop_web_service_receiver: broadcast::Receiver<()>,
) -> Result<(String, Vec<JoinHandle<()>>), WebServerError> {
let mut web_service_handles = Vec::new();
let mut health_check_url_prefix = None;
for (bind, port) in listen_binds {
let tcp_listener = create_listener(bind.to_string(), port, reuse_port)?;
let actual_addr = tcp_listener.local_addr()?;
let tokio_listener = tokio::net::TcpListener::from_std(tcp_listener)
.map_err(|e| WebServerError::Socket(format!("转换为tokio listener失败: {:#}", e)))?;
let mut stop_web_service_receiver = stop_web_service_receiver.resubscribe();
if let Some(https_config) = https_config.clone()
&& https_config.enabled
{
let handle = build_https(
router.clone(),
tokio_listener,
stop_web_service_receiver,
https_config,
)?;
web_service_handles.push(handle);
} else {
let server = axum::serve(
tokio_listener,
router
.clone()
.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(async move {
let _ = stop_web_service_receiver.recv().await;
info!("停止Axum Web服务");
});
let handle = tokio::spawn(async move {
if let Err(e) = server.await {
error!("Axum Web服务运行异常: {:#}", e);
}
});
web_service_handles.push(handle);
}
let ip = if bind == "0.0.0.0" {
health_check_url_prefix = Some(format!("{http_protocol}://localhost:{port}"));
"127.0.0.1"
} else if bind == r"[::]" {
health_check_url_prefix = Some(format!("{http_protocol}://localhost:{port}"));
r"[::1]"
} else {
if health_check_url_prefix.is_none() {
health_check_url_prefix = Some(format!("{http_protocol}://{bind}:{port}"));
}
&bind
};
info!("监听 <{actual_addr}> 成功✅ -> 🌐 {http_protocol}://{ip}:{port}");
}
Ok((health_check_url_prefix.unwrap(), web_service_handles))
}