Skip to main content

chai/
server.rs

1use crate::config::{目标配置, 配置};
2use crate::interfaces::server::WebApi;
3use axum::extract::DefaultBodyLimit;
4use axum::http::Method;
5use axum::http::StatusCode;
6use axum::{
7    extract::State,
8    response::{Html, sse::{Event, KeepAlive, Sse}},
9    routing::{get, post},
10    Json, Router,
11};
12use crate::interfaces::{默认输入};
13use futures_util::stream::Stream;
14use serde::{Deserialize, Serialize};
15use std::sync::{Arc, Mutex};
16use std::time::Duration;
17use tokio::sync::{broadcast, mpsc};
18use tower_http::cors::{Any, CorsLayer};
19use tower_http::services::ServeDir;
20use tower_http::timeout::TimeoutLayer;
21use tracing::info;
22
23/// HTTP API 响应类型
24#[derive(Debug, Serialize, Deserialize)]
25#[serde(tag = "type")]
26pub enum ApiResponse<T> {
27    #[serde(rename = "success")]
28    Success { result: T },
29    #[serde(rename = "error")]
30    Error { error: String },
31}
32
33/// 应用状态
34#[derive(Clone)]
35pub struct AppState {
36    /// 全局 WebApi 实例
37    pub api: Arc<Mutex<WebApi>>,
38    /// 优化状态
39    pub optimization_status: Arc<Mutex<OptimizationStatus>>,
40    /// WebSocket 广播发送器
41    pub status_broadcast: broadcast::Sender<OptimizationStatus>,
42    /// MPSC 发送器(用于从同步回调发送)
43    pub status_mpsc: mpsc::UnboundedSender<OptimizationStatus>,
44}
45
46/// 优化状态
47#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(tag = "status", rename_all = "snake_case")]
49pub enum OptimizationStatus {
50    /// 空闲状态
51    Idle,
52    /// 运行中
53    Running {
54        message: serde_json::Value,
55    },
56    /// 已完成
57    Completed {
58        final_message: Option<serde_json::Value>,
59    },
60    /// 失败
61    Failed {
62        error: String,
63    },
64}
65
66/// HTTP API: 验证配置
67pub async fn validate_config(Json(config): Json<serde_json::Value>) -> Json<ApiResponse<配置>> {
68    info!("POST /api/validate");
69
70    // 直接在服务器中验证配置
71    match serde_json::from_value::<配置>(config) {
72        Ok(config) => {
73            // 配置解析成功,可以在这里添加额外的验证逻辑
74            // 例如:检查必填字段、验证数值范围等
75            Json(ApiResponse::Success { result: config })
76        }
77        Err(e) => {
78            // 配置解析失败
79            Json(ApiResponse::Error {
80                error: format!("配置解析错误: {e}"),
81            })
82        }
83    }
84}
85
86/// HTTP API: 同步参数
87pub async fn sync_params(
88    State(state): State<AppState>,
89    Json(params): Json<serde_json::Value>,
90) -> Json<ApiResponse<()>> {
91    info!("POST /api/sync");
92
93    // 直接转换为图形界面参数
94    match serde_json::from_value::<默认输入>(params) {
95        Ok(图形界面参数) => {
96            let result = {
97                let mut api = state.api.lock().unwrap();
98                api.sync(图形界面参数)
99            }; // 锁在这里被释放
100
101            match result {
102                Ok(_) => Json(ApiResponse::Success { result: () }),
103                Err(e) => Json(ApiResponse::Error { error: e.message }),
104            }
105        }
106        Err(e) => Json(ApiResponse::Error {
107            error: format!("参数解析错误: {e}"),
108        }),
109    }
110}
111
112/// HTTP API: 编码评估
113pub async fn encode_evaluate(
114    State(state): State<AppState>,
115    Json(objective): Json<serde_json::Value>,
116) -> Json<ApiResponse<serde_json::Value>> {
117    info!("POST /api/encode");
118
119    // 直接转换为目标函数配置
120    match serde_json::from_value::<目标配置>(objective) {
121        Ok(目标函数配置) => {
122            let result = {
123                let api = state.api.lock().unwrap();
124                api.encode_evaluate(目标函数配置)
125            }; // 锁在这里被释放
126
127            match result {
128                Ok(result) => Json(ApiResponse::Success {
129                    result: serde_json::json!([result.0, result.1]),
130                }),
131                Err(e) => Json(ApiResponse::Error { error: e.message }),
132            }
133        }
134        Err(e) => Json(ApiResponse::Error {
135            error: format!("目标函数配置解析错误: {e}"),
136        }),
137    }
138}
139
140/// HTTP API: 开始优化(异步)
141pub async fn start_optimize(State(state): State<AppState>) -> Json<ApiResponse<String>> {
142    info!("POST /api/optimize");
143
144    // 检查是否已经在运行
145    {
146        let status = state.optimization_status.lock().unwrap();
147        if matches!(*status, OptimizationStatus::Running { .. }) {
148            return Json(ApiResponse::Error {
149                error: "优化已在进行中".to_string(),
150            });
151        }
152    }
153
154    // 设置运行状态
155    {
156        let mut status = state.optimization_status.lock().unwrap();
157        *status = OptimizationStatus::Running {
158            message: serde_json::json!({"info": "优化已启动"}),
159        };
160        match state.status_broadcast.send(status.clone()) {
161            Ok(count) => info!("[OPTIMIZE] 初始状态广播成功,{} 个接收者", count),
162            Err(_) => info!("[OPTIMIZE] 初始状态广播失败:没有接收者"),
163        }
164    }
165
166    let api = state.api.clone();
167    let status = state.optimization_status.clone();
168    let broadcast = state.status_broadcast.clone();
169
170    // 在后台启动优化任务
171    tokio::spawn(async move {
172        // 使用 spawn_blocking 运行同步阻塞的优化任务
173        let result = tokio::task::spawn_blocking(move || {
174            let api_guard = api.lock().unwrap();
175            api_guard.optimize()
176        }).await;
177
178        // 处理结果
179        let final_status = match result {
180            Ok(Ok(_)) => {
181                info!("优化完成");
182                OptimizationStatus::Completed {
183                    final_message: None,
184                }
185            }
186            Ok(Err(e)) => {
187                info!("优化失败: {}", e.message);
188                OptimizationStatus::Failed {
189                    error: e.message,
190                }
191            }
192            Err(e) => {
193                info!("优化任务崩溃: {:?}", e);
194                OptimizationStatus::Failed {
195                    error: format!("任务崩溃: {:?}", e),
196                }
197            }
198        };
199
200        {
201            let mut status_guard = status.lock().unwrap();
202            *status_guard = final_status.clone();
203        }
204
205        // 广播最终状态
206        match broadcast.send(final_status) {
207            Ok(count) => info!("[OPTIMIZE] 最终状态广播成功,{} 个接收者", count),
208            Err(_) => info!("[OPTIMIZE] 最终状态广播失败:没有接收者"),
209        }
210    });
211
212    Json(ApiResponse::Success {
213        result: "优化已启动".to_string(),
214    })
215}
216
217/// SSE 处理函数
218pub async fn sse_handler(
219    State(state): State<AppState>,
220) -> Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>> {
221    // 获取当前状态
222    let initial_status = {
223        let status = state.optimization_status.lock().unwrap();
224        status.clone()
225    };
226    
227    // 订阅广播通道
228    let mut broadcast_rx = state.status_broadcast.subscribe();
229    
230    let stream = async_stream::stream! {
231        // 发送初始状态
232        if let Ok(json) = serde_json::to_string(&initial_status) {
233            info!("[SSE] 连接建立,发送初始状态");
234            yield Ok(Event::default().data(json));
235        }
236        
237        // 持续接收广播消息
238        let mut msg_count = 0;
239        loop {
240            match broadcast_rx.recv().await {
241                Ok(status) => {
242                    msg_count += 1;
243                    if let Ok(json) = serde_json::to_string(&status) {
244                        yield Ok(Event::default().data(json));
245                    }
246                }
247                Err(_) => {
248                    info!("[SSE] 连接关闭,共发送 {} 条消息", msg_count);
249                    break;
250                }
251            }
252        }
253    };
254    
255    Sse::new(stream).keep_alive(KeepAlive::default())
256}
257
258/// 主页面
259pub async fn index() -> Html<&'static str> {
260    Html(
261        r#"
262<!DOCTYPE html>
263<html>
264<head>
265    <title>libchai API 服务器</title>
266    <meta charset="utf-8">
267    <style>
268        body { font-family: Arial, sans-serif; margin: 40px; }
269        code { background: #f4f4f4; padding: 2px 4px; border-radius: 3px; }
270        pre { background: #f9f9f9; padding: 15px; border-radius: 5px; overflow-x: auto; }
271        button { padding: 10px 15px; margin: 5px; cursor: pointer; }
272        .status-panel { background: #f0f8ff; padding: 15px; border-radius: 5px; margin: 10px 0; }
273        .progress { color: #007acc; }
274        .error { color: #e74c3c; }
275        .success { color: #27ae60; }
276    </style>
277</head>
278<body>
279    <h1>libchai API 服务器</h1>
280    
281    <h2>HTTP API 端点</h2>
282    <ul>
283        <li><code>POST /api/validate</code> - 验证配置</li>
284        <li><code>POST /api/sync</code> - 同步参数</li>
285        <li><code>POST /api/encode</code> - 编码评估</li>
286        <li><code>POST /api/optimize</code> - 开始优化</li>
287        <li><code>GET /sse/status</code> - SSE 实时状态推送</li>
288    </ul>
289    
290    <h2>静态文件服务</h2>
291    <p><code>/*</code> - 提供 client 目录中的静态文件</p>
292    
293    <h2>测试工具</h2>
294    <button onclick="testValidate()">测试验证</button>
295    <button onclick="testSync()">测试同步</button>
296    <button onclick="testEncode()">测试编码</button>
297    <button onclick="testOptimize()">开始优化</button>
298    <button onclick="reconnectWebSocket()">重新连接 SSE</button>
299    
300    <div class="status-panel">
301        <h3>优化状态:</h3>
302        <div id="status">未知</div>
303    </div>
304    
305    <h3>输出:</h3>
306    <div id="output"></div>
307
308    <script>
309        let eventSource = null;
310        
311        function log(message) {
312            const now = new Date().toLocaleTimeString();
313            document.getElementById('output').innerHTML += `<p>[${now}] ${message}</p>`;
314        }
315        
316        async function apiCall(endpoint, data) {
317            try {
318                const timeoutMs = 600000; // 10分钟,与服务器端一致
319                
320                const controller = new AbortController();
321                const timeoutId = setTimeout(() => controller.abort(), timeoutMs);
322                
323                const response = await fetch(`/api/${endpoint}`, {
324                    method: data !== undefined ? 'POST' : 'GET',
325                    headers: { 'Content-Type': 'application/json' },
326                    body: data !== undefined ? JSON.stringify(data) : undefined,
327                    signal: controller.signal
328                });
329                
330                clearTimeout(timeoutId);
331                const result = await response.json();
332                
333                if (result.type === 'success') {
334                    log(`✅ ${endpoint}: ${JSON.stringify(result.result)}`);
335                    return result.result;
336                } else {
337                    log(`❌ ${endpoint} 错误: ${result.error}`);
338                    throw new Error(result.error);
339                }
340            } catch (error) {
341                if (error.name === 'AbortError') {
342                    log(`⏰ ${endpoint} 请求超时`);
343                    throw new Error('请求超时,请稍后重试');
344                } else {
345                    log(`❌ ${endpoint} 网络错误: ${error.message}`);
346                    throw error;
347                }
348            }
349        }
350        
351        function connectSSE() {
352            // 关闭现有连接
353            if (eventSource) {
354                eventSource.close();
355                eventSource = null;
356            }
357            
358            const sseUrl = '/sse/status';
359            log(`🔌 连接 SSE: ${sseUrl}`);
360            
361            eventSource = new EventSource(sseUrl);
362            
363            eventSource.onopen = () => {
364                log('✅ SSE 已连接');
365            };
366            
367            eventSource.onmessage = (event) => {
368                try {
369                    log(`📨 收到 SSE 消息: ${event.data.substring(0, 100)}...`);
370                    const status = JSON.parse(event.data);
371                    updateStatusDisplay(status);
372                } catch (error) {
373                    console.error('解析 SSE 消息失败:', error);
374                    log(`❌ 消息解析失败: ${error.message}`);
375                }
376            };
377            
378            eventSource.onerror = (error) => {
379                log('❌ SSE 错误,将自动重连...');
380                console.error('SSE error:', error);
381                // EventSource 会自动重连,不需要手动处理
382            };
383        }
384        
385        function reconnectWebSocket() {
386            log('🔄 手动重新连接 SSE...');
387            connectSSE();
388        }
389        
390        function updateStatusDisplay(status) {
391            const statusDiv = document.getElementById('status');
392            
393            switch (status.status) {
394                case 'idle':
395                    statusDiv.innerHTML = '⏸️ 空闲状态';
396                    break;
397                    
398                case 'running':
399                    statusDiv.innerHTML = '<span class="progress">🔄 优化进行中...</span>';
400                    if (status.message) {
401                        const msg = status.message;
402                        let details = '';
403                        
404                        if (msg.type === 'progress') {
405                            details = `<br>步数: ${msg.steps}, 温度: ${msg.temperature.toFixed(4)}, 指标: ${msg.metric}`;
406                        } else if (msg.type === 'better_solution') {
407                            details = `<br>✨ 发现更优解!指标: ${msg.metric}`;
408                        } else if (msg.type === 'parameters') {
409                            details = `<br>参数: T_max=${msg.t_max.toFixed(2)}, T_min=${msg.t_min.toFixed(6)}`;
410                        } else {
411                            details = `<br>${JSON.stringify(msg)}`;
412                        }
413                        
414                        statusDiv.innerHTML += details;
415                    }
416                    break;
417                    
418                case 'completed':
419                    statusDiv.innerHTML = '<span class="success">✅ 优化完成</span>';
420                    if (status.final_message) {
421                        statusDiv.innerHTML += `<br>结果: ${JSON.stringify(status.final_message)}`;
422                    }
423                    break;
424                    
425                case 'failed':
426                    statusDiv.innerHTML = `<span class="error">❌ 优化失败</span><br>错误: ${status.error}`;
427                    break;
428                    
429                default:
430                    statusDiv.innerHTML = `未知状态: ${JSON.stringify(status)}`;
431            }
432        }
433        
434        async function testValidate() {
435            await apiCall('validate', {"version": "1.0"});
436        }
437        
438        async function testSync() {
439            await apiCall('sync', {
440                配置: { version: "1.0" },
441                词列表: [],
442                原始键位分布信息: {},
443                原始当量信息: {}
444            });
445        }
446        
447        async function testEncode() {
448            log('🔄 开始编码评估(可能需要几分钟时间)...');
449            try {
450                await apiCall('encode', {});
451            } catch (error) {
452                // 错误已经在 apiCall 中处理了
453            }
454        }
455        
456        async function testOptimize() {
457            await apiCall('optimize', null);
458        }
459        
460        // 页面加载时连接 SSE
461        window.onload = () => {
462            connectSSE();
463        };
464        
465        // 页面卸载时关闭 SSE
466        window.onbeforeunload = () => {
467            if (eventSource) {
468                eventSource.close();
469            }
470        };
471    </script>
472</body>
473</html>
474    "#,
475    )
476}
477
478/// 创建应用路由
479pub fn create_app() -> Router {
480    // 创建广播通道(容量设置为 100)
481    let (tx, _rx) = broadcast::channel(100);
482    
483    // 创建 MPSC 通道用于从同步回调发送
484    let (mpsc_tx, mut mpsc_rx) = mpsc::unbounded_channel::<OptimizationStatus>();
485    
486    let state = AppState {
487        api: Arc::new(Mutex::new(WebApi::new())),
488        optimization_status: Arc::new(Mutex::new(OptimizationStatus::Idle)),
489        status_broadcast: tx.clone(),
490        status_mpsc: mpsc_tx.clone(),
491    };
492    
493    // 启动转发任务:从 MPSC 转发到 broadcast
494    let broadcast_clone = tx.clone();
495    tokio::spawn(async move {
496        while let Some(status) = mpsc_rx.recv().await {
497            let _ = broadcast_clone.send(status);
498        }
499    });
500
501    // 设置全局回调函数
502    {
503        let mut api = state.api.lock().unwrap();
504        let status = state.optimization_status.clone();
505        let mpsc_sender = mpsc_tx.clone();
506        
507        api.set_callback(move |消息| {
508            // 将消息转换为 JSON
509            let progress_msg = serde_json::json!(消息);
510            
511            // 更新状态
512            let new_status = OptimizationStatus::Running {
513                message: progress_msg,
514            };
515            
516            // 只在重要进度时记录
517            match 消息 {
518                crate::interfaces::消息::Progress { steps, .. } => {
519                    if steps % 100 == 0 {
520                        info!("[CALLBACK] 优化进度: {} 步", steps);
521                    }
522                }
523                crate::interfaces::消息::BetterSolution { .. } => {
524                    info!("[CALLBACK] 发现更优解");
525                }
526                crate::interfaces::消息::Parameters { .. } => {
527                    info!("[CALLBACK] 设置优化参数");
528                }
529                _ => {}
530            }
531            
532            // 更新共享状态
533            {
534                let mut status_guard = status.lock().unwrap();
535                *status_guard = new_status.clone();
536            }
537            
538            // 通过 MPSC 发送(可以从任何线程调用)
539            let _ = mpsc_sender.send(new_status);
540        });
541    }
542    // 配置更详细的 CORS 设置
543    let cors = CorsLayer::new()
544        .allow_origin(Any)
545        .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
546        .allow_headers(Any)
547        .allow_private_network(true)
548        .allow_credentials(false);
549
550    Router::new()
551        .route("/test", get(index))
552        .route("/api/validate", post(validate_config))
553        .route("/api/sync", post(sync_params))
554        .route("/api/encode", post(encode_evaluate))
555        .route("/api/optimize", post(start_optimize))
556        .route("/sse/status", get(sse_handler))
557        .fallback_service(ServeDir::new("client"))
558        .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) // 100MB 请求体限制
559        .layer(cors)
560        .layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(600))) // 10分钟超时,与编码任务一致
561        .with_state(state)
562}
563
564/// 尝试绑定可用端口
565async fn bind_available_port(
566    preferred_port: u16,
567) -> Result<(tokio::net::TcpListener, u16), Box<dyn std::error::Error>> {
568    // 首先尝试首选端口
569    let addr = format!("0.0.0.0:{}", preferred_port);
570    match tokio::net::TcpListener::bind(&addr).await {
571        Ok(listener) => {
572            info!("成功绑定到首选端口: {}", preferred_port);
573            return Ok((listener, preferred_port));
574        }
575        Err(e) => {
576            info!("端口 {} 已被占用: {}", preferred_port, e);
577        }
578    }
579
580    // 如果首选端口被占用,尝试附近的端口
581    for offset in 1..=50 {
582        let port = preferred_port + offset;
583        if port < preferred_port {
584            break;
585        }
586
587        let addr = format!("0.0.0.0:{}", port);
588        match tokio::net::TcpListener::bind(&addr).await {
589            Ok(listener) => {
590                info!("成功绑定到替代端口: {}", port);
591                return Ok((listener, port));
592            }
593            Err(_) => {
594                // 继续尝试下一个端口
595            }
596        }
597    }
598
599    // 如果向上寻找失败,尝试向下寻找
600    for offset in 1..=50 {
601        if preferred_port < offset {
602            break;
603        }
604
605        let port = preferred_port - offset;
606        if port < 1024 {
607            // 避免使用系统保留端口
608            break;
609        }
610
611        let addr = format!("0.0.0.0:{}", port);
612        match tokio::net::TcpListener::bind(&addr).await {
613            Ok(listener) => {
614                info!("成功绑定到替代端口: {}", port);
615                return Ok((listener, port));
616            }
617            Err(_) => {
618                // 继续尝试下一个端口
619            }
620        }
621    }
622
623    // 最后尝试让系统自动分配端口
624    match tokio::net::TcpListener::bind("0.0.0.0:0").await {
625        Ok(listener) => {
626            let actual_port = listener.local_addr()?.port();
627            info!("使用系统自动分配的端口: {}", actual_port);
628            Ok((listener, actual_port))
629        }
630        Err(e) => Err(format!("无法绑定到任何端口: {}", e).into()),
631    }
632}
633
634/// 启动服务器
635pub async fn start_server(port: u16) -> Result<(), Box<dyn std::error::Error>> {
636    tracing_subscriber::fmt::init();
637
638    let app = create_app();
639
640    // 尝试绑定端口,如果失败则尝试其他端口
641    let (listener, actual_port) = bind_available_port(port).await?;
642
643    info!("Listening on: http://127.0.0.1:{}", actual_port);
644    info!("API Endpoints:");
645    info!("   POST /api/validate    - 验证配置");
646    info!("   POST /api/sync        - 同步参数");
647    info!("   POST /api/encode      - 编码评估");
648    info!("   POST /api/optimize    - 开始优化");
649    info!("   GET  /sse/status      - SSE 实时状态推送");
650
651    axum::serve(listener, app).await?;
652
653    Ok(())
654}