1use crate::config::{目标配置, 配置};
2use crate::interfaces::server::WebApi;
3use axum::extract::DefaultBodyLimit;
4use axum::http::Method;
5use axum::http::StatusCode;
6use axum::{
7 extract::State,
8 response::{Html, sse::{Event, KeepAlive, Sse}},
9 routing::{get, post},
10 Json, Router,
11};
12use crate::interfaces::{默认输入};
13use futures_util::stream::Stream;
14use serde::{Deserialize, Serialize};
15use std::sync::{Arc, Mutex};
16use std::time::Duration;
17use tokio::sync::{broadcast, mpsc};
18use tower_http::cors::{Any, CorsLayer};
19use tower_http::services::ServeDir;
20use tower_http::timeout::TimeoutLayer;
21use tracing::info;
22
23#[derive(Debug, Serialize, Deserialize)]
25#[serde(tag = "type")]
26pub enum ApiResponse<T> {
27 #[serde(rename = "success")]
28 Success { result: T },
29 #[serde(rename = "error")]
30 Error { error: String },
31}
32
33#[derive(Clone)]
35pub struct AppState {
36 pub api: Arc<Mutex<WebApi>>,
38 pub optimization_status: Arc<Mutex<OptimizationStatus>>,
40 pub status_broadcast: broadcast::Sender<OptimizationStatus>,
42 pub status_mpsc: mpsc::UnboundedSender<OptimizationStatus>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(tag = "status", rename_all = "snake_case")]
49pub enum OptimizationStatus {
50 Idle,
52 Running {
54 message: serde_json::Value,
55 },
56 Completed {
58 final_message: Option<serde_json::Value>,
59 },
60 Failed {
62 error: String,
63 },
64}
65
66pub async fn validate_config(Json(config): Json<serde_json::Value>) -> Json<ApiResponse<配置>> {
68 info!("POST /api/validate");
69
70 match serde_json::from_value::<配置>(config) {
72 Ok(config) => {
73 Json(ApiResponse::Success { result: config })
76 }
77 Err(e) => {
78 Json(ApiResponse::Error {
80 error: format!("配置解析错误: {e}"),
81 })
82 }
83 }
84}
85
86pub async fn sync_params(
88 State(state): State<AppState>,
89 Json(params): Json<serde_json::Value>,
90) -> Json<ApiResponse<()>> {
91 info!("POST /api/sync");
92
93 match serde_json::from_value::<默认输入>(params) {
95 Ok(图形界面参数) => {
96 let result = {
97 let mut api = state.api.lock().unwrap();
98 api.sync(图形界面参数)
99 }; match result {
102 Ok(_) => Json(ApiResponse::Success { result: () }),
103 Err(e) => Json(ApiResponse::Error { error: e.message }),
104 }
105 }
106 Err(e) => Json(ApiResponse::Error {
107 error: format!("参数解析错误: {e}"),
108 }),
109 }
110}
111
112pub async fn encode_evaluate(
114 State(state): State<AppState>,
115 Json(objective): Json<serde_json::Value>,
116) -> Json<ApiResponse<serde_json::Value>> {
117 info!("POST /api/encode");
118
119 match serde_json::from_value::<目标配置>(objective) {
121 Ok(目标函数配置) => {
122 let result = {
123 let api = state.api.lock().unwrap();
124 api.encode_evaluate(目标函数配置)
125 }; match result {
128 Ok(result) => Json(ApiResponse::Success {
129 result: serde_json::json!([result.0, result.1]),
130 }),
131 Err(e) => Json(ApiResponse::Error { error: e.message }),
132 }
133 }
134 Err(e) => Json(ApiResponse::Error {
135 error: format!("目标函数配置解析错误: {e}"),
136 }),
137 }
138}
139
140pub async fn start_optimize(State(state): State<AppState>) -> Json<ApiResponse<String>> {
142 info!("POST /api/optimize");
143
144 {
146 let status = state.optimization_status.lock().unwrap();
147 if matches!(*status, OptimizationStatus::Running { .. }) {
148 return Json(ApiResponse::Error {
149 error: "优化已在进行中".to_string(),
150 });
151 }
152 }
153
154 {
156 let mut status = state.optimization_status.lock().unwrap();
157 *status = OptimizationStatus::Running {
158 message: serde_json::json!({"info": "优化已启动"}),
159 };
160 match state.status_broadcast.send(status.clone()) {
161 Ok(count) => info!("[OPTIMIZE] 初始状态广播成功,{} 个接收者", count),
162 Err(_) => info!("[OPTIMIZE] 初始状态广播失败:没有接收者"),
163 }
164 }
165
166 let api = state.api.clone();
167 let status = state.optimization_status.clone();
168 let broadcast = state.status_broadcast.clone();
169
170 tokio::spawn(async move {
172 let result = tokio::task::spawn_blocking(move || {
174 let api_guard = api.lock().unwrap();
175 api_guard.optimize()
176 }).await;
177
178 let final_status = match result {
180 Ok(Ok(_)) => {
181 info!("优化完成");
182 OptimizationStatus::Completed {
183 final_message: None,
184 }
185 }
186 Ok(Err(e)) => {
187 info!("优化失败: {}", e.message);
188 OptimizationStatus::Failed {
189 error: e.message,
190 }
191 }
192 Err(e) => {
193 info!("优化任务崩溃: {:?}", e);
194 OptimizationStatus::Failed {
195 error: format!("任务崩溃: {:?}", e),
196 }
197 }
198 };
199
200 {
201 let mut status_guard = status.lock().unwrap();
202 *status_guard = final_status.clone();
203 }
204
205 match broadcast.send(final_status) {
207 Ok(count) => info!("[OPTIMIZE] 最终状态广播成功,{} 个接收者", count),
208 Err(_) => info!("[OPTIMIZE] 最终状态广播失败:没有接收者"),
209 }
210 });
211
212 Json(ApiResponse::Success {
213 result: "优化已启动".to_string(),
214 })
215}
216
217pub async fn sse_handler(
219 State(state): State<AppState>,
220) -> Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>> {
221 let initial_status = {
223 let status = state.optimization_status.lock().unwrap();
224 status.clone()
225 };
226
227 let mut broadcast_rx = state.status_broadcast.subscribe();
229
230 let stream = async_stream::stream! {
231 if let Ok(json) = serde_json::to_string(&initial_status) {
233 info!("[SSE] 连接建立,发送初始状态");
234 yield Ok(Event::default().data(json));
235 }
236
237 let mut msg_count = 0;
239 loop {
240 match broadcast_rx.recv().await {
241 Ok(status) => {
242 msg_count += 1;
243 if let Ok(json) = serde_json::to_string(&status) {
244 yield Ok(Event::default().data(json));
245 }
246 }
247 Err(_) => {
248 info!("[SSE] 连接关闭,共发送 {} 条消息", msg_count);
249 break;
250 }
251 }
252 }
253 };
254
255 Sse::new(stream).keep_alive(KeepAlive::default())
256}
257
258pub async fn index() -> Html<&'static str> {
260 Html(
261 r#"
262<!DOCTYPE html>
263<html>
264<head>
265 <title>libchai API 服务器</title>
266 <meta charset="utf-8">
267 <style>
268 body { font-family: Arial, sans-serif; margin: 40px; }
269 code { background: #f4f4f4; padding: 2px 4px; border-radius: 3px; }
270 pre { background: #f9f9f9; padding: 15px; border-radius: 5px; overflow-x: auto; }
271 button { padding: 10px 15px; margin: 5px; cursor: pointer; }
272 .status-panel { background: #f0f8ff; padding: 15px; border-radius: 5px; margin: 10px 0; }
273 .progress { color: #007acc; }
274 .error { color: #e74c3c; }
275 .success { color: #27ae60; }
276 </style>
277</head>
278<body>
279 <h1>libchai API 服务器</h1>
280
281 <h2>HTTP API 端点</h2>
282 <ul>
283 <li><code>POST /api/validate</code> - 验证配置</li>
284 <li><code>POST /api/sync</code> - 同步参数</li>
285 <li><code>POST /api/encode</code> - 编码评估</li>
286 <li><code>POST /api/optimize</code> - 开始优化</li>
287 <li><code>GET /sse/status</code> - SSE 实时状态推送</li>
288 </ul>
289
290 <h2>静态文件服务</h2>
291 <p><code>/*</code> - 提供 client 目录中的静态文件</p>
292
293 <h2>测试工具</h2>
294 <button onclick="testValidate()">测试验证</button>
295 <button onclick="testSync()">测试同步</button>
296 <button onclick="testEncode()">测试编码</button>
297 <button onclick="testOptimize()">开始优化</button>
298 <button onclick="reconnectWebSocket()">重新连接 SSE</button>
299
300 <div class="status-panel">
301 <h3>优化状态:</h3>
302 <div id="status">未知</div>
303 </div>
304
305 <h3>输出:</h3>
306 <div id="output"></div>
307
308 <script>
309 let eventSource = null;
310
311 function log(message) {
312 const now = new Date().toLocaleTimeString();
313 document.getElementById('output').innerHTML += `<p>[${now}] ${message}</p>`;
314 }
315
316 async function apiCall(endpoint, data) {
317 try {
318 const timeoutMs = 600000; // 10分钟,与服务器端一致
319
320 const controller = new AbortController();
321 const timeoutId = setTimeout(() => controller.abort(), timeoutMs);
322
323 const response = await fetch(`/api/${endpoint}`, {
324 method: data !== undefined ? 'POST' : 'GET',
325 headers: { 'Content-Type': 'application/json' },
326 body: data !== undefined ? JSON.stringify(data) : undefined,
327 signal: controller.signal
328 });
329
330 clearTimeout(timeoutId);
331 const result = await response.json();
332
333 if (result.type === 'success') {
334 log(`✅ ${endpoint}: ${JSON.stringify(result.result)}`);
335 return result.result;
336 } else {
337 log(`❌ ${endpoint} 错误: ${result.error}`);
338 throw new Error(result.error);
339 }
340 } catch (error) {
341 if (error.name === 'AbortError') {
342 log(`⏰ ${endpoint} 请求超时`);
343 throw new Error('请求超时,请稍后重试');
344 } else {
345 log(`❌ ${endpoint} 网络错误: ${error.message}`);
346 throw error;
347 }
348 }
349 }
350
351 function connectSSE() {
352 // 关闭现有连接
353 if (eventSource) {
354 eventSource.close();
355 eventSource = null;
356 }
357
358 const sseUrl = '/sse/status';
359 log(`🔌 连接 SSE: ${sseUrl}`);
360
361 eventSource = new EventSource(sseUrl);
362
363 eventSource.onopen = () => {
364 log('✅ SSE 已连接');
365 };
366
367 eventSource.onmessage = (event) => {
368 try {
369 log(`📨 收到 SSE 消息: ${event.data.substring(0, 100)}...`);
370 const status = JSON.parse(event.data);
371 updateStatusDisplay(status);
372 } catch (error) {
373 console.error('解析 SSE 消息失败:', error);
374 log(`❌ 消息解析失败: ${error.message}`);
375 }
376 };
377
378 eventSource.onerror = (error) => {
379 log('❌ SSE 错误,将自动重连...');
380 console.error('SSE error:', error);
381 // EventSource 会自动重连,不需要手动处理
382 };
383 }
384
385 function reconnectWebSocket() {
386 log('🔄 手动重新连接 SSE...');
387 connectSSE();
388 }
389
390 function updateStatusDisplay(status) {
391 const statusDiv = document.getElementById('status');
392
393 switch (status.status) {
394 case 'idle':
395 statusDiv.innerHTML = '⏸️ 空闲状态';
396 break;
397
398 case 'running':
399 statusDiv.innerHTML = '<span class="progress">🔄 优化进行中...</span>';
400 if (status.message) {
401 const msg = status.message;
402 let details = '';
403
404 if (msg.type === 'progress') {
405 details = `<br>步数: ${msg.steps}, 温度: ${msg.temperature.toFixed(4)}, 指标: ${msg.metric}`;
406 } else if (msg.type === 'better_solution') {
407 details = `<br>✨ 发现更优解!指标: ${msg.metric}`;
408 } else if (msg.type === 'parameters') {
409 details = `<br>参数: T_max=${msg.t_max.toFixed(2)}, T_min=${msg.t_min.toFixed(6)}`;
410 } else {
411 details = `<br>${JSON.stringify(msg)}`;
412 }
413
414 statusDiv.innerHTML += details;
415 }
416 break;
417
418 case 'completed':
419 statusDiv.innerHTML = '<span class="success">✅ 优化完成</span>';
420 if (status.final_message) {
421 statusDiv.innerHTML += `<br>结果: ${JSON.stringify(status.final_message)}`;
422 }
423 break;
424
425 case 'failed':
426 statusDiv.innerHTML = `<span class="error">❌ 优化失败</span><br>错误: ${status.error}`;
427 break;
428
429 default:
430 statusDiv.innerHTML = `未知状态: ${JSON.stringify(status)}`;
431 }
432 }
433
434 async function testValidate() {
435 await apiCall('validate', {"version": "1.0"});
436 }
437
438 async function testSync() {
439 await apiCall('sync', {
440 配置: { version: "1.0" },
441 词列表: [],
442 原始键位分布信息: {},
443 原始当量信息: {}
444 });
445 }
446
447 async function testEncode() {
448 log('🔄 开始编码评估(可能需要几分钟时间)...');
449 try {
450 await apiCall('encode', {});
451 } catch (error) {
452 // 错误已经在 apiCall 中处理了
453 }
454 }
455
456 async function testOptimize() {
457 await apiCall('optimize', null);
458 }
459
460 // 页面加载时连接 SSE
461 window.onload = () => {
462 connectSSE();
463 };
464
465 // 页面卸载时关闭 SSE
466 window.onbeforeunload = () => {
467 if (eventSource) {
468 eventSource.close();
469 }
470 };
471 </script>
472</body>
473</html>
474 "#,
475 )
476}
477
478pub fn create_app() -> Router {
480 let (tx, _rx) = broadcast::channel(100);
482
483 let (mpsc_tx, mut mpsc_rx) = mpsc::unbounded_channel::<OptimizationStatus>();
485
486 let state = AppState {
487 api: Arc::new(Mutex::new(WebApi::new())),
488 optimization_status: Arc::new(Mutex::new(OptimizationStatus::Idle)),
489 status_broadcast: tx.clone(),
490 status_mpsc: mpsc_tx.clone(),
491 };
492
493 let broadcast_clone = tx.clone();
495 tokio::spawn(async move {
496 while let Some(status) = mpsc_rx.recv().await {
497 let _ = broadcast_clone.send(status);
498 }
499 });
500
501 {
503 let mut api = state.api.lock().unwrap();
504 let status = state.optimization_status.clone();
505 let mpsc_sender = mpsc_tx.clone();
506
507 api.set_callback(move |消息| {
508 let progress_msg = serde_json::json!(消息);
510
511 let new_status = OptimizationStatus::Running {
513 message: progress_msg,
514 };
515
516 match 消息 {
518 crate::interfaces::消息::Progress { steps, .. } => {
519 if steps % 100 == 0 {
520 info!("[CALLBACK] 优化进度: {} 步", steps);
521 }
522 }
523 crate::interfaces::消息::BetterSolution { .. } => {
524 info!("[CALLBACK] 发现更优解");
525 }
526 crate::interfaces::消息::Parameters { .. } => {
527 info!("[CALLBACK] 设置优化参数");
528 }
529 _ => {}
530 }
531
532 {
534 let mut status_guard = status.lock().unwrap();
535 *status_guard = new_status.clone();
536 }
537
538 let _ = mpsc_sender.send(new_status);
540 });
541 }
542 let cors = CorsLayer::new()
544 .allow_origin(Any)
545 .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
546 .allow_headers(Any)
547 .allow_private_network(true)
548 .allow_credentials(false);
549
550 Router::new()
551 .route("/test", get(index))
552 .route("/api/validate", post(validate_config))
553 .route("/api/sync", post(sync_params))
554 .route("/api/encode", post(encode_evaluate))
555 .route("/api/optimize", post(start_optimize))
556 .route("/sse/status", get(sse_handler))
557 .fallback_service(ServeDir::new("client"))
558 .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) .layer(cors)
560 .layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(600))) .with_state(state)
562}
563
564async fn bind_available_port(
566 preferred_port: u16,
567) -> Result<(tokio::net::TcpListener, u16), Box<dyn std::error::Error>> {
568 let addr = format!("0.0.0.0:{}", preferred_port);
570 match tokio::net::TcpListener::bind(&addr).await {
571 Ok(listener) => {
572 info!("成功绑定到首选端口: {}", preferred_port);
573 return Ok((listener, preferred_port));
574 }
575 Err(e) => {
576 info!("端口 {} 已被占用: {}", preferred_port, e);
577 }
578 }
579
580 for offset in 1..=50 {
582 let port = preferred_port + offset;
583 if port < preferred_port {
584 break;
585 }
586
587 let addr = format!("0.0.0.0:{}", port);
588 match tokio::net::TcpListener::bind(&addr).await {
589 Ok(listener) => {
590 info!("成功绑定到替代端口: {}", port);
591 return Ok((listener, port));
592 }
593 Err(_) => {
594 }
596 }
597 }
598
599 for offset in 1..=50 {
601 if preferred_port < offset {
602 break;
603 }
604
605 let port = preferred_port - offset;
606 if port < 1024 {
607 break;
609 }
610
611 let addr = format!("0.0.0.0:{}", port);
612 match tokio::net::TcpListener::bind(&addr).await {
613 Ok(listener) => {
614 info!("成功绑定到替代端口: {}", port);
615 return Ok((listener, port));
616 }
617 Err(_) => {
618 }
620 }
621 }
622
623 match tokio::net::TcpListener::bind("0.0.0.0:0").await {
625 Ok(listener) => {
626 let actual_port = listener.local_addr()?.port();
627 info!("使用系统自动分配的端口: {}", actual_port);
628 Ok((listener, actual_port))
629 }
630 Err(e) => Err(format!("无法绑定到任何端口: {}", e).into()),
631 }
632}
633
634pub async fn start_server(port: u16) -> Result<(), Box<dyn std::error::Error>> {
636 tracing_subscriber::fmt::init();
637
638 let app = create_app();
639
640 let (listener, actual_port) = bind_available_port(port).await?;
642
643 info!("Listening on: http://127.0.0.1:{}", actual_port);
644 info!("API Endpoints:");
645 info!(" POST /api/validate - 验证配置");
646 info!(" POST /api/sync - 同步参数");
647 info!(" POST /api/encode - 编码评估");
648 info!(" POST /api/optimize - 开始优化");
649 info!(" GET /sse/status - SSE 实时状态推送");
650
651 axum::serve(listener, app).await?;
652
653 Ok(())
654}