1use crate::config::{目标配置, 配置};
2use crate::interfaces::server::WebApi;
3use axum::extract::DefaultBodyLimit;
4use axum::http::Method;
5use axum::http::StatusCode;
6use axum::{
7 extract::State,
8 response::{Html, sse::{Event, KeepAlive, Sse}},
9 routing::{get, post},
10 Json, Router,
11};
12use crate::interfaces::{默认输入};
13use futures_util::stream::Stream;
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::{broadcast, mpsc, RwLock};
18use tower_http::cors::{Any, CorsLayer};
19use tower_http::services::ServeDir;
20use tower_http::timeout::TimeoutLayer;
21use tracing::info;
22
23#[derive(Debug, Serialize, Deserialize)]
25#[serde(tag = "type")]
26pub enum ApiResponse<T> {
27 #[serde(rename = "success")]
28 Success { result: T },
29 #[serde(rename = "error")]
30 Error { error: String },
31}
32
33#[derive(Clone)]
35pub struct AppState {
36 pub api: Arc<RwLock<WebApi>>,
38 pub optimization_status: Arc<RwLock<OptimizationStatus>>,
40 pub status_broadcast: broadcast::Sender<OptimizationStatus>,
42 pub status_mpsc: mpsc::UnboundedSender<OptimizationStatus>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(tag = "status", rename_all = "snake_case")]
49pub enum OptimizationStatus {
50 Idle,
52 Running {
54 message: serde_json::Value,
55 },
56 Completed {
58 final_message: Option<serde_json::Value>,
59 },
60 Failed {
62 error: String,
63 },
64}
65
66pub async fn validate_config(Json(config): Json<serde_json::Value>) -> Json<ApiResponse<配置>> {
68 info!("POST /api/validate");
69
70 match serde_json::from_value::<配置>(config) {
72 Ok(config) => {
73 Json(ApiResponse::Success { result: config })
76 }
77 Err(e) => {
78 Json(ApiResponse::Error {
80 error: format!("配置解析错误: {e}"),
81 })
82 }
83 }
84}
85
86pub async fn sync_params(
88 State(state): State<AppState>,
89 Json(params): Json<serde_json::Value>,
90) -> Json<ApiResponse<()>> {
91 info!("POST /api/sync");
92
93 match serde_json::from_value::<默认输入>(params) {
95 Ok(图形界面参数) => {
96 let mut api = state.api.write().await;
98 let result = api.sync(图形界面参数);
99 drop(api); match result {
102 Ok(_) => Json(ApiResponse::Success { result: () }),
103 Err(e) => Json(ApiResponse::Error { error: e.message }),
104 }
105 }
106 Err(e) => Json(ApiResponse::Error {
107 error: format!("参数解析错误: {e}"),
108 }),
109 }
110}
111
112pub async fn encode_evaluate(
114 State(state): State<AppState>,
115 Json(objective): Json<serde_json::Value>,
116) -> Json<ApiResponse<serde_json::Value>> {
117 info!("POST /api/encode");
118
119 match serde_json::from_value::<目标配置>(objective) {
121 Ok(目标函数配置) => {
122 let api = state.api.read().await;
124 let result = api.encode_evaluate(目标函数配置);
125 drop(api); match result {
128 Ok(result) => Json(ApiResponse::Success {
129 result: serde_json::json!([result.0, result.1]),
130 }),
131 Err(e) => Json(ApiResponse::Error { error: e.message }),
132 }
133 }
134 Err(e) => Json(ApiResponse::Error {
135 error: format!("目标函数配置解析错误: {e}"),
136 }),
137 }
138}
139
140pub async fn start_optimize(State(state): State<AppState>) -> Json<ApiResponse<String>> {
142 info!("POST /api/optimize");
143
144 {
146 let status = state.optimization_status.read().await;
147 if matches!(*status, OptimizationStatus::Running { .. }) {
148 return Json(ApiResponse::Error {
149 error: "优化已在进行中".to_string(),
150 });
151 }
152 }
153
154 {
156 let mut status = state.optimization_status.write().await;
157 *status = OptimizationStatus::Running {
158 message: serde_json::json!({"info": "优化已启动"}),
159 };
160 match state.status_broadcast.send(status.clone()) {
161 Ok(count) => info!("[OPTIMIZE] 初始状态广播成功,{} 个接收者", count),
162 Err(_) => info!("[OPTIMIZE] 初始状态广播失败:没有接收者"),
163 }
164 }
165
166 let api = state.api.clone();
167 let status = state.optimization_status.clone();
168 let broadcast = state.status_broadcast.clone();
169
170 tokio::spawn(async move {
172 let api_clone = api.clone();
174 let result = tokio::task::spawn_blocking(move || {
175 let api_guard = api_clone.blocking_read();
177 api_guard.optimize()
178 }).await;
179
180 let final_status = match result {
182 Ok(Ok(_)) => {
183 info!("优化完成");
184 OptimizationStatus::Completed {
185 final_message: None,
186 }
187 }
188 Ok(Err(e)) => {
189 info!("优化失败: {}", e.message);
190 OptimizationStatus::Failed {
191 error: e.message,
192 }
193 }
194 Err(e) => {
195 info!("优化任务崩溃: {:?}", e);
196 OptimizationStatus::Failed {
197 error: format!("任务崩溃: {:?}", e),
198 }
199 }
200 };
201
202 {
203 let mut status_guard = status.write().await;
204 *status_guard = final_status.clone();
205 }
206
207 match broadcast.send(final_status) {
209 Ok(count) => info!("[OPTIMIZE] 最终状态广播成功,{} 个接收者", count),
210 Err(_) => info!("[OPTIMIZE] 最终状态广播失败:没有接收者"),
211 }
212 });
213
214 Json(ApiResponse::Success {
215 result: "优化已启动".to_string(),
216 })
217}
218
219pub async fn sse_handler(
221 State(state): State<AppState>,
222) -> Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>> {
223 let initial_status = {
225 let status = state.optimization_status.read().await;
226 status.clone()
227 };
228
229 let mut broadcast_rx = state.status_broadcast.subscribe();
231
232 let stream = async_stream::stream! {
233 if let Ok(json) = serde_json::to_string(&initial_status) {
235 info!("[SSE] 连接建立,发送初始状态");
236 yield Ok(Event::default().data(json));
237 }
238
239 let mut msg_count = 0;
241 loop {
242 match broadcast_rx.recv().await {
243 Ok(status) => {
244 msg_count += 1;
245 if let Ok(json) = serde_json::to_string(&status) {
246 yield Ok(Event::default().data(json));
247 }
248 }
249 Err(_) => {
250 info!("[SSE] 连接关闭,共发送 {} 条消息", msg_count);
251 break;
252 }
253 }
254 }
255 };
256
257 Sse::new(stream).keep_alive(KeepAlive::default())
258}
259
260pub async fn index() -> Html<&'static str> {
262 Html(
263 r#"
264<!DOCTYPE html>
265<html>
266<head>
267 <title>libchai API 服务器</title>
268 <meta charset="utf-8">
269 <style>
270 body { font-family: Arial, sans-serif; margin: 40px; }
271 code { background: #f4f4f4; padding: 2px 4px; border-radius: 3px; }
272 pre { background: #f9f9f9; padding: 15px; border-radius: 5px; overflow-x: auto; }
273 button { padding: 10px 15px; margin: 5px; cursor: pointer; }
274 .status-panel { background: #f0f8ff; padding: 15px; border-radius: 5px; margin: 10px 0; }
275 .progress { color: #007acc; }
276 .error { color: #e74c3c; }
277 .success { color: #27ae60; }
278 </style>
279</head>
280<body>
281 <h1>libchai API 服务器</h1>
282
283 <h2>HTTP API 端点</h2>
284 <ul>
285 <li><code>POST /api/validate</code> - 验证配置</li>
286 <li><code>POST /api/sync</code> - 同步参数</li>
287 <li><code>POST /api/encode</code> - 编码评估</li>
288 <li><code>POST /api/optimize</code> - 开始优化</li>
289 <li><code>GET /sse/status</code> - SSE 实时状态推送</li>
290 </ul>
291
292 <h2>静态文件服务</h2>
293 <p><code>/*</code> - 提供 client 目录中的静态文件</p>
294
295 <h2>测试工具</h2>
296 <button onclick="testValidate()">测试验证</button>
297 <button onclick="testSync()">测试同步</button>
298 <button onclick="testEncode()">测试编码</button>
299 <button onclick="testOptimize()">开始优化</button>
300 <button onclick="reconnectWebSocket()">重新连接 SSE</button>
301
302 <div class="status-panel">
303 <h3>优化状态:</h3>
304 <div id="status">未知</div>
305 </div>
306
307 <h3>输出:</h3>
308 <div id="output"></div>
309
310 <script>
311 let eventSource = null;
312
313 function log(message) {
314 const now = new Date().toLocaleTimeString();
315 document.getElementById('output').innerHTML += `<p>[${now}] ${message}</p>`;
316 }
317
318 async function apiCall(endpoint, data) {
319 try {
320 const timeoutMs = 600000; // 10分钟,与服务器端一致
321
322 const controller = new AbortController();
323 const timeoutId = setTimeout(() => controller.abort(), timeoutMs);
324
325 const response = await fetch(`/api/${endpoint}`, {
326 method: data !== undefined ? 'POST' : 'GET',
327 headers: { 'Content-Type': 'application/json' },
328 body: data !== undefined ? JSON.stringify(data) : undefined,
329 signal: controller.signal
330 });
331
332 clearTimeout(timeoutId);
333 const result = await response.json();
334
335 if (result.type === 'success') {
336 log(`✅ ${endpoint}: ${JSON.stringify(result.result)}`);
337 return result.result;
338 } else {
339 log(`❌ ${endpoint} 错误: ${result.error}`);
340 throw new Error(result.error);
341 }
342 } catch (error) {
343 if (error.name === 'AbortError') {
344 log(`⏰ ${endpoint} 请求超时`);
345 throw new Error('请求超时,请稍后重试');
346 } else {
347 log(`❌ ${endpoint} 网络错误: ${error.message}`);
348 throw error;
349 }
350 }
351 }
352
353 function connectSSE() {
354 // 关闭现有连接
355 if (eventSource) {
356 eventSource.close();
357 eventSource = null;
358 }
359
360 const sseUrl = '/sse/status';
361 log(`🔌 连接 SSE: ${sseUrl}`);
362
363 eventSource = new EventSource(sseUrl);
364
365 eventSource.onopen = () => {
366 log('✅ SSE 已连接');
367 };
368
369 eventSource.onmessage = (event) => {
370 try {
371 log(`📨 收到 SSE 消息: ${event.data.substring(0, 100)}...`);
372 const status = JSON.parse(event.data);
373 updateStatusDisplay(status);
374 } catch (error) {
375 console.error('解析 SSE 消息失败:', error);
376 log(`❌ 消息解析失败: ${error.message}`);
377 }
378 };
379
380 eventSource.onerror = (error) => {
381 log('❌ SSE 错误,将自动重连...');
382 console.error('SSE error:', error);
383 // EventSource 会自动重连,不需要手动处理
384 };
385 }
386
387 function reconnectWebSocket() {
388 log('🔄 手动重新连接 SSE...');
389 connectSSE();
390 }
391
392 function updateStatusDisplay(status) {
393 const statusDiv = document.getElementById('status');
394
395 switch (status.status) {
396 case 'idle':
397 statusDiv.innerHTML = '⏸️ 空闲状态';
398 break;
399
400 case 'running':
401 statusDiv.innerHTML = '<span class="progress">🔄 优化进行中...</span>';
402 if (status.message) {
403 const msg = status.message;
404 let details = '';
405
406 if (msg.type === 'progress') {
407 details = `<br>步数: ${msg.steps}, 温度: ${msg.temperature.toFixed(4)}, 指标: ${msg.metric}`;
408 } else if (msg.type === 'better_solution') {
409 details = `<br>✨ 发现更优解!指标: ${msg.metric}`;
410 } else if (msg.type === 'parameters') {
411 details = `<br>参数: T_max=${msg.t_max.toFixed(2)}, T_min=${msg.t_min.toFixed(6)}`;
412 } else {
413 details = `<br>${JSON.stringify(msg)}`;
414 }
415
416 statusDiv.innerHTML += details;
417 }
418 break;
419
420 case 'completed':
421 statusDiv.innerHTML = '<span class="success">✅ 优化完成</span>';
422 if (status.final_message) {
423 statusDiv.innerHTML += `<br>结果: ${JSON.stringify(status.final_message)}`;
424 }
425 break;
426
427 case 'failed':
428 statusDiv.innerHTML = `<span class="error">❌ 优化失败</span><br>错误: ${status.error}`;
429 break;
430
431 default:
432 statusDiv.innerHTML = `未知状态: ${JSON.stringify(status)}`;
433 }
434 }
435
436 async function testValidate() {
437 await apiCall('validate', {"version": "1.0"});
438 }
439
440 async function testSync() {
441 await apiCall('sync', {
442 配置: { version: "1.0" },
443 词列表: [],
444 原始键位分布信息: {},
445 原始当量信息: {}
446 });
447 }
448
449 async function testEncode() {
450 log('🔄 开始编码评估(可能需要几分钟时间)...');
451 try {
452 await apiCall('encode', {});
453 } catch (error) {
454 // 错误已经在 apiCall 中处理了
455 }
456 }
457
458 async function testOptimize() {
459 await apiCall('optimize', null);
460 }
461
462 // 页面加载时连接 SSE
463 window.onload = () => {
464 connectSSE();
465 };
466
467 // 页面卸载时关闭 SSE
468 window.onbeforeunload = () => {
469 if (eventSource) {
470 eventSource.close();
471 }
472 };
473 </script>
474</body>
475</html>
476 "#,
477 )
478}
479
480pub fn create_app() -> Router {
482 let (tx, _rx) = broadcast::channel(100);
484
485 let (mpsc_tx, mut mpsc_rx) = mpsc::unbounded_channel::<OptimizationStatus>();
487
488 let state = AppState {
489 api: Arc::new(RwLock::new(WebApi::new())),
490 optimization_status: Arc::new(RwLock::new(OptimizationStatus::Idle)),
491 status_broadcast: tx.clone(),
492 status_mpsc: mpsc_tx.clone(),
493 };
494
495 let broadcast_clone = tx.clone();
497 tokio::spawn(async move {
498 while let Some(status) = mpsc_rx.recv().await {
499 let _ = broadcast_clone.send(status);
500 }
501 });
502
503 {
505 let mut api = state.api.try_write().expect("初始化时获取 API 写锁失败");
508 let status = state.optimization_status.clone();
509 let mpsc_sender = mpsc_tx.clone();
510
511 api.set_callback(move |消息| {
512 let progress_msg = serde_json::json!(消息);
514
515 let new_status = OptimizationStatus::Running {
517 message: progress_msg,
518 };
519
520 match 消息 {
522 crate::interfaces::消息::Progress { steps, .. } => {
523 if steps % 100 == 0 {
524 info!("[CALLBACK] 优化进度: {} 步", steps);
525 }
526 }
527 crate::interfaces::消息::BetterSolution { .. } => {
528 info!("[CALLBACK] 发现更优解");
529 }
530 crate::interfaces::消息::Parameters { .. } => {
531 info!("[CALLBACK] 设置优化参数");
532 }
533 _ => {}
534 }
535
536 {
538 let mut status_guard = status.blocking_write();
539 *status_guard = new_status.clone();
540 }
541
542 let _ = mpsc_sender.send(new_status);
544 });
545 }
546 let cors = CorsLayer::new()
548 .allow_origin(Any)
549 .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
550 .allow_headers(Any)
551 .allow_private_network(true)
552 .allow_credentials(false);
553
554 Router::new()
555 .route("/test", get(index))
556 .route("/api/validate", post(validate_config))
557 .route("/api/sync", post(sync_params))
558 .route("/api/encode", post(encode_evaluate))
559 .route("/api/optimize", post(start_optimize))
560 .route("/sse/status", get(sse_handler))
561 .fallback_service(ServeDir::new("client"))
562 .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) .layer(cors)
564 .layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(600))) .with_state(state)
566}
567
568async fn bind_available_port(
570 preferred_port: u16,
571) -> Result<(tokio::net::TcpListener, u16), Box<dyn std::error::Error>> {
572 let addr = format!("0.0.0.0:{}", preferred_port);
574 match tokio::net::TcpListener::bind(&addr).await {
575 Ok(listener) => {
576 info!("成功绑定到首选端口: {}", preferred_port);
577 return Ok((listener, preferred_port));
578 }
579 Err(e) => {
580 info!("端口 {} 已被占用: {}", preferred_port, e);
581 }
582 }
583
584 for offset in 1..=50 {
586 let port = preferred_port + offset;
587 if port < preferred_port {
588 break;
589 }
590
591 let addr = format!("0.0.0.0:{}", port);
592 match tokio::net::TcpListener::bind(&addr).await {
593 Ok(listener) => {
594 info!("成功绑定到替代端口: {}", port);
595 return Ok((listener, port));
596 }
597 Err(_) => {
598 }
600 }
601 }
602
603 for offset in 1..=50 {
605 if preferred_port < offset {
606 break;
607 }
608
609 let port = preferred_port - offset;
610 if port < 1024 {
611 break;
613 }
614
615 let addr = format!("0.0.0.0:{}", port);
616 match tokio::net::TcpListener::bind(&addr).await {
617 Ok(listener) => {
618 info!("成功绑定到替代端口: {}", port);
619 return Ok((listener, port));
620 }
621 Err(_) => {
622 }
624 }
625 }
626
627 match tokio::net::TcpListener::bind("0.0.0.0:0").await {
629 Ok(listener) => {
630 let actual_port = listener.local_addr()?.port();
631 info!("使用系统自动分配的端口: {}", actual_port);
632 Ok((listener, actual_port))
633 }
634 Err(e) => Err(format!("无法绑定到任何端口: {}", e).into()),
635 }
636}
637
638pub async fn start_server(port: u16) -> Result<(), Box<dyn std::error::Error>> {
640 let app = create_app();
641
642 let (listener, actual_port) = bind_available_port(port).await?;
644
645 info!("Listening on: http://127.0.0.1:{}", actual_port);
646 info!("API Endpoints:");
647 info!(" POST /api/validate - 验证配置");
648 info!(" POST /api/sync - 同步参数");
649 info!(" POST /api/encode - 编码评估");
650 info!(" POST /api/optimize - 开始优化");
651 info!(" GET /sse/status - SSE 实时状态推送");
652
653 axum::serve(listener, app).await?;
654
655 Ok(())
656}