Skip to main content

chai/
server.rs

1use crate::config::{目标配置, 配置};
2use crate::interfaces::server::WebApi;
3use axum::extract::DefaultBodyLimit;
4use axum::http::Method;
5use axum::http::StatusCode;
6use axum::{
7    extract::State,
8    response::{Html, sse::{Event, KeepAlive, Sse}},
9    routing::{get, post},
10    Json, Router,
11};
12use crate::interfaces::{默认输入};
13use futures_util::stream::Stream;
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::{broadcast, mpsc, RwLock};
18use tower_http::cors::{Any, CorsLayer};
19use tower_http::services::ServeDir;
20use tower_http::timeout::TimeoutLayer;
21use tracing::info;
22
23/// HTTP API 响应类型
24#[derive(Debug, Serialize, Deserialize)]
25#[serde(tag = "type")]
26pub enum ApiResponse<T> {
27    #[serde(rename = "success")]
28    Success { result: T },
29    #[serde(rename = "error")]
30    Error { error: String },
31}
32
33/// 应用状态
34#[derive(Clone)]
35pub struct AppState {
36    /// 全局 WebApi 实例(使用 RwLock 支持异步并发)
37    pub api: Arc<RwLock<WebApi>>,
38    /// 优化状态(使用 RwLock 支持异步并发)
39    pub optimization_status: Arc<RwLock<OptimizationStatus>>,
40    /// WebSocket 广播发送器
41    pub status_broadcast: broadcast::Sender<OptimizationStatus>,
42    /// MPSC 发送器(用于从同步回调发送)
43    pub status_mpsc: mpsc::UnboundedSender<OptimizationStatus>,
44}
45
46/// 优化状态
47#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(tag = "status", rename_all = "snake_case")]
49pub enum OptimizationStatus {
50    /// 空闲状态
51    Idle,
52    /// 运行中
53    Running {
54        message: serde_json::Value,
55    },
56    /// 已完成
57    Completed {
58        final_message: Option<serde_json::Value>,
59    },
60    /// 失败
61    Failed {
62        error: String,
63    },
64}
65
66/// HTTP API: 验证配置
67pub async fn validate_config(Json(config): Json<serde_json::Value>) -> Json<ApiResponse<配置>> {
68    info!("POST /api/validate");
69
70    // 直接在服务器中验证配置
71    match serde_json::from_value::<配置>(config) {
72        Ok(config) => {
73            // 配置解析成功,可以在这里添加额外的验证逻辑
74            // 例如:检查必填字段、验证数值范围等
75            Json(ApiResponse::Success { result: config })
76        }
77        Err(e) => {
78            // 配置解析失败
79            Json(ApiResponse::Error {
80                error: format!("配置解析错误: {e}"),
81            })
82        }
83    }
84}
85
86/// HTTP API: 同步参数
87pub async fn sync_params(
88    State(state): State<AppState>,
89    Json(params): Json<serde_json::Value>,
90) -> Json<ApiResponse<()>> {
91    info!("POST /api/sync");
92
93    // 直接转换为图形界面参数
94    match serde_json::from_value::<默认输入>(params) {
95        Ok(图形界面参数) => {
96            // 使用写锁,异步等待
97            let mut api = state.api.write().await;
98            let result = api.sync(图形界面参数);
99            drop(api); // 显式释放锁
100
101            match result {
102                Ok(_) => Json(ApiResponse::Success { result: () }),
103                Err(e) => Json(ApiResponse::Error { error: e.message }),
104            }
105        }
106        Err(e) => Json(ApiResponse::Error {
107            error: format!("参数解析错误: {e}"),
108        }),
109    }
110}
111
112/// HTTP API: 编码评估
113pub async fn encode_evaluate(
114    State(state): State<AppState>,
115    Json(objective): Json<serde_json::Value>,
116) -> Json<ApiResponse<serde_json::Value>> {
117    info!("POST /api/encode");
118
119    // 直接转换为目标函数配置
120    match serde_json::from_value::<目标配置>(objective) {
121        Ok(目标函数配置) => {
122            // 使用读锁,允许多个并发读取
123            let api = state.api.read().await;
124            let result = api.encode_evaluate(目标函数配置);
125            drop(api); // 显式释放锁
126
127            match result {
128                Ok(result) => Json(ApiResponse::Success {
129                    result: serde_json::json!([result.0, result.1]),
130                }),
131                Err(e) => Json(ApiResponse::Error { error: e.message }),
132            }
133        }
134        Err(e) => Json(ApiResponse::Error {
135            error: format!("目标函数配置解析错误: {e}"),
136        }),
137    }
138}
139
140/// HTTP API: 开始优化(异步)
141pub async fn start_optimize(State(state): State<AppState>) -> Json<ApiResponse<String>> {
142    info!("POST /api/optimize");
143
144    // 检查是否已经在运行
145    {
146        let status = state.optimization_status.read().await;
147        if matches!(*status, OptimizationStatus::Running { .. }) {
148            return Json(ApiResponse::Error {
149                error: "优化已在进行中".to_string(),
150            });
151        }
152    }
153
154    // 设置运行状态
155    {
156        let mut status = state.optimization_status.write().await;
157        *status = OptimizationStatus::Running {
158            message: serde_json::json!({"info": "优化已启动"}),
159        };
160        match state.status_broadcast.send(status.clone()) {
161            Ok(count) => info!("[OPTIMIZE] 初始状态广播成功,{} 个接收者", count),
162            Err(_) => info!("[OPTIMIZE] 初始状态广播失败:没有接收者"),
163        }
164    }
165
166    let api = state.api.clone();
167    let status = state.optimization_status.clone();
168    let broadcast = state.status_broadcast.clone();
169
170    // 在后台启动优化任务
171    tokio::spawn(async move {
172        // 使用 spawn_blocking 运行同步阻塞的优化任务
173        let api_clone = api.clone();
174        let result = tokio::task::spawn_blocking(move || {
175            // 在阻塞任务中使用 blocking 方式获取锁
176            let api_guard = api_clone.blocking_read();
177            api_guard.optimize()
178        }).await;
179
180        // 处理结果
181        let final_status = match result {
182            Ok(Ok(_)) => {
183                info!("优化完成");
184                OptimizationStatus::Completed {
185                    final_message: None,
186                }
187            }
188            Ok(Err(e)) => {
189                info!("优化失败: {}", e.message);
190                OptimizationStatus::Failed {
191                    error: e.message,
192                }
193            }
194            Err(e) => {
195                info!("优化任务崩溃: {:?}", e);
196                OptimizationStatus::Failed {
197                    error: format!("任务崩溃: {:?}", e),
198                }
199            }
200        };
201
202        {
203            let mut status_guard = status.write().await;
204            *status_guard = final_status.clone();
205        }
206
207        // 广播最终状态
208        match broadcast.send(final_status) {
209            Ok(count) => info!("[OPTIMIZE] 最终状态广播成功,{} 个接收者", count),
210            Err(_) => info!("[OPTIMIZE] 最终状态广播失败:没有接收者"),
211        }
212    });
213
214    Json(ApiResponse::Success {
215        result: "优化已启动".to_string(),
216    })
217}
218
219/// SSE 处理函数
220pub async fn sse_handler(
221    State(state): State<AppState>,
222) -> Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>> {
223    // 获取当前状态
224    let initial_status = {
225        let status = state.optimization_status.read().await;
226        status.clone()
227    };
228    
229    // 订阅广播通道
230    let mut broadcast_rx = state.status_broadcast.subscribe();
231    
232    let stream = async_stream::stream! {
233        // 发送初始状态
234        if let Ok(json) = serde_json::to_string(&initial_status) {
235            info!("[SSE] 连接建立,发送初始状态");
236            yield Ok(Event::default().data(json));
237        }
238        
239        // 持续接收广播消息
240        let mut msg_count = 0;
241        loop {
242            match broadcast_rx.recv().await {
243                Ok(status) => {
244                    msg_count += 1;
245                    if let Ok(json) = serde_json::to_string(&status) {
246                        yield Ok(Event::default().data(json));
247                    }
248                }
249                Err(_) => {
250                    info!("[SSE] 连接关闭,共发送 {} 条消息", msg_count);
251                    break;
252                }
253            }
254        }
255    };
256    
257    Sse::new(stream).keep_alive(KeepAlive::default())
258}
259
260/// 主页面
261pub async fn index() -> Html<&'static str> {
262    Html(
263        r#"
264<!DOCTYPE html>
265<html>
266<head>
267    <title>libchai API 服务器</title>
268    <meta charset="utf-8">
269    <style>
270        body { font-family: Arial, sans-serif; margin: 40px; }
271        code { background: #f4f4f4; padding: 2px 4px; border-radius: 3px; }
272        pre { background: #f9f9f9; padding: 15px; border-radius: 5px; overflow-x: auto; }
273        button { padding: 10px 15px; margin: 5px; cursor: pointer; }
274        .status-panel { background: #f0f8ff; padding: 15px; border-radius: 5px; margin: 10px 0; }
275        .progress { color: #007acc; }
276        .error { color: #e74c3c; }
277        .success { color: #27ae60; }
278    </style>
279</head>
280<body>
281    <h1>libchai API 服务器</h1>
282    
283    <h2>HTTP API 端点</h2>
284    <ul>
285        <li><code>POST /api/validate</code> - 验证配置</li>
286        <li><code>POST /api/sync</code> - 同步参数</li>
287        <li><code>POST /api/encode</code> - 编码评估</li>
288        <li><code>POST /api/optimize</code> - 开始优化</li>
289        <li><code>GET /sse/status</code> - SSE 实时状态推送</li>
290    </ul>
291    
292    <h2>静态文件服务</h2>
293    <p><code>/*</code> - 提供 client 目录中的静态文件</p>
294    
295    <h2>测试工具</h2>
296    <button onclick="testValidate()">测试验证</button>
297    <button onclick="testSync()">测试同步</button>
298    <button onclick="testEncode()">测试编码</button>
299    <button onclick="testOptimize()">开始优化</button>
300    <button onclick="reconnectWebSocket()">重新连接 SSE</button>
301    
302    <div class="status-panel">
303        <h3>优化状态:</h3>
304        <div id="status">未知</div>
305    </div>
306    
307    <h3>输出:</h3>
308    <div id="output"></div>
309
310    <script>
311        let eventSource = null;
312        
313        function log(message) {
314            const now = new Date().toLocaleTimeString();
315            document.getElementById('output').innerHTML += `<p>[${now}] ${message}</p>`;
316        }
317        
318        async function apiCall(endpoint, data) {
319            try {
320                const timeoutMs = 600000; // 10分钟,与服务器端一致
321                
322                const controller = new AbortController();
323                const timeoutId = setTimeout(() => controller.abort(), timeoutMs);
324                
325                const response = await fetch(`/api/${endpoint}`, {
326                    method: data !== undefined ? 'POST' : 'GET',
327                    headers: { 'Content-Type': 'application/json' },
328                    body: data !== undefined ? JSON.stringify(data) : undefined,
329                    signal: controller.signal
330                });
331                
332                clearTimeout(timeoutId);
333                const result = await response.json();
334                
335                if (result.type === 'success') {
336                    log(`✅ ${endpoint}: ${JSON.stringify(result.result)}`);
337                    return result.result;
338                } else {
339                    log(`❌ ${endpoint} 错误: ${result.error}`);
340                    throw new Error(result.error);
341                }
342            } catch (error) {
343                if (error.name === 'AbortError') {
344                    log(`⏰ ${endpoint} 请求超时`);
345                    throw new Error('请求超时,请稍后重试');
346                } else {
347                    log(`❌ ${endpoint} 网络错误: ${error.message}`);
348                    throw error;
349                }
350            }
351        }
352        
353        function connectSSE() {
354            // 关闭现有连接
355            if (eventSource) {
356                eventSource.close();
357                eventSource = null;
358            }
359            
360            const sseUrl = '/sse/status';
361            log(`🔌 连接 SSE: ${sseUrl}`);
362            
363            eventSource = new EventSource(sseUrl);
364            
365            eventSource.onopen = () => {
366                log('✅ SSE 已连接');
367            };
368            
369            eventSource.onmessage = (event) => {
370                try {
371                    log(`📨 收到 SSE 消息: ${event.data.substring(0, 100)}...`);
372                    const status = JSON.parse(event.data);
373                    updateStatusDisplay(status);
374                } catch (error) {
375                    console.error('解析 SSE 消息失败:', error);
376                    log(`❌ 消息解析失败: ${error.message}`);
377                }
378            };
379            
380            eventSource.onerror = (error) => {
381                log('❌ SSE 错误,将自动重连...');
382                console.error('SSE error:', error);
383                // EventSource 会自动重连,不需要手动处理
384            };
385        }
386        
387        function reconnectWebSocket() {
388            log('🔄 手动重新连接 SSE...');
389            connectSSE();
390        }
391        
392        function updateStatusDisplay(status) {
393            const statusDiv = document.getElementById('status');
394            
395            switch (status.status) {
396                case 'idle':
397                    statusDiv.innerHTML = '⏸️ 空闲状态';
398                    break;
399                    
400                case 'running':
401                    statusDiv.innerHTML = '<span class="progress">🔄 优化进行中...</span>';
402                    if (status.message) {
403                        const msg = status.message;
404                        let details = '';
405                        
406                        if (msg.type === 'progress') {
407                            details = `<br>步数: ${msg.steps}, 温度: ${msg.temperature.toFixed(4)}, 指标: ${msg.metric}`;
408                        } else if (msg.type === 'better_solution') {
409                            details = `<br>✨ 发现更优解!指标: ${msg.metric}`;
410                        } else if (msg.type === 'parameters') {
411                            details = `<br>参数: T_max=${msg.t_max.toFixed(2)}, T_min=${msg.t_min.toFixed(6)}`;
412                        } else {
413                            details = `<br>${JSON.stringify(msg)}`;
414                        }
415                        
416                        statusDiv.innerHTML += details;
417                    }
418                    break;
419                    
420                case 'completed':
421                    statusDiv.innerHTML = '<span class="success">✅ 优化完成</span>';
422                    if (status.final_message) {
423                        statusDiv.innerHTML += `<br>结果: ${JSON.stringify(status.final_message)}`;
424                    }
425                    break;
426                    
427                case 'failed':
428                    statusDiv.innerHTML = `<span class="error">❌ 优化失败</span><br>错误: ${status.error}`;
429                    break;
430                    
431                default:
432                    statusDiv.innerHTML = `未知状态: ${JSON.stringify(status)}`;
433            }
434        }
435        
436        async function testValidate() {
437            await apiCall('validate', {"version": "1.0"});
438        }
439        
440        async function testSync() {
441            await apiCall('sync', {
442                配置: { version: "1.0" },
443                词列表: [],
444                原始键位分布信息: {},
445                原始当量信息: {}
446            });
447        }
448        
449        async function testEncode() {
450            log('🔄 开始编码评估(可能需要几分钟时间)...');
451            try {
452                await apiCall('encode', {});
453            } catch (error) {
454                // 错误已经在 apiCall 中处理了
455            }
456        }
457        
458        async function testOptimize() {
459            await apiCall('optimize', null);
460        }
461        
462        // 页面加载时连接 SSE
463        window.onload = () => {
464            connectSSE();
465        };
466        
467        // 页面卸载时关闭 SSE
468        window.onbeforeunload = () => {
469            if (eventSource) {
470                eventSource.close();
471            }
472        };
473    </script>
474</body>
475</html>
476    "#,
477    )
478}
479
480/// 创建应用路由
481pub fn create_app() -> Router {
482    // 创建广播通道(容量设置为 100)
483    let (tx, _rx) = broadcast::channel(100);
484    
485    // 创建 MPSC 通道用于从同步回调发送
486    let (mpsc_tx, mut mpsc_rx) = mpsc::unbounded_channel::<OptimizationStatus>();
487    
488    let state = AppState {
489        api: Arc::new(RwLock::new(WebApi::new())),
490        optimization_status: Arc::new(RwLock::new(OptimizationStatus::Idle)),
491        status_broadcast: tx.clone(),
492        status_mpsc: mpsc_tx.clone(),
493    };
494    
495    // 启动转发任务:从 MPSC 转发到 broadcast
496    let broadcast_clone = tx.clone();
497    tokio::spawn(async move {
498        while let Some(status) = mpsc_rx.recv().await {
499            let _ = broadcast_clone.send(status);
500        }
501    });
502
503    // 设置全局回调函数
504    {
505        // 使用 try_write 因为在异步上下文中不能使用 blocking_write
506        // 这里是初始化阶段,不会有其他线程竞争,所以 unwrap 是安全的
507        let mut api = state.api.try_write().expect("初始化时获取 API 写锁失败");
508        let status = state.optimization_status.clone();
509        let mpsc_sender = mpsc_tx.clone();
510        
511        api.set_callback(move |消息| {
512            // 将消息转换为 JSON
513            let progress_msg = serde_json::json!(消息);
514            
515            // 更新状态
516            let new_status = OptimizationStatus::Running {
517                message: progress_msg,
518            };
519            
520            // 只在重要进度时记录
521            match 消息 {
522                crate::interfaces::消息::Progress { steps, .. } => {
523                    if steps % 100 == 0 {
524                        info!("[CALLBACK] 优化进度: {} 步", steps);
525                    }
526                }
527                crate::interfaces::消息::BetterSolution { .. } => {
528                    info!("[CALLBACK] 发现更优解");
529                }
530                crate::interfaces::消息::Parameters { .. } => {
531                    info!("[CALLBACK] 设置优化参数");
532                }
533                _ => {}
534            }
535            
536            // 更新共享状态(使用 blocking_write 因为回调在同步上下文中)
537            {
538                let mut status_guard = status.blocking_write();
539                *status_guard = new_status.clone();
540            }
541            
542            // 通过 MPSC 发送(可以从任何线程调用)
543            let _ = mpsc_sender.send(new_status);
544        });
545    }
546    // 配置更详细的 CORS 设置
547    let cors = CorsLayer::new()
548        .allow_origin(Any)
549        .allow_methods([Method::GET, Method::POST, Method::OPTIONS])
550        .allow_headers(Any)
551        .allow_private_network(true)
552        .allow_credentials(false);
553
554    Router::new()
555        .route("/test", get(index))
556        .route("/api/validate", post(validate_config))
557        .route("/api/sync", post(sync_params))
558        .route("/api/encode", post(encode_evaluate))
559        .route("/api/optimize", post(start_optimize))
560        .route("/sse/status", get(sse_handler))
561        .fallback_service(ServeDir::new("client"))
562        .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) // 100MB 请求体限制
563        .layer(cors)
564        .layer(TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, Duration::from_secs(600))) // 10分钟超时,与编码任务一致
565        .with_state(state)
566}
567
568/// 尝试绑定可用端口
569async fn bind_available_port(
570    preferred_port: u16,
571) -> Result<(tokio::net::TcpListener, u16), Box<dyn std::error::Error>> {
572    // 首先尝试首选端口
573    let addr = format!("0.0.0.0:{}", preferred_port);
574    match tokio::net::TcpListener::bind(&addr).await {
575        Ok(listener) => {
576            info!("成功绑定到首选端口: {}", preferred_port);
577            return Ok((listener, preferred_port));
578        }
579        Err(e) => {
580            info!("端口 {} 已被占用: {}", preferred_port, e);
581        }
582    }
583
584    // 如果首选端口被占用,尝试附近的端口
585    for offset in 1..=50 {
586        let port = preferred_port + offset;
587        if port < preferred_port {
588            break;
589        }
590
591        let addr = format!("0.0.0.0:{}", port);
592        match tokio::net::TcpListener::bind(&addr).await {
593            Ok(listener) => {
594                info!("成功绑定到替代端口: {}", port);
595                return Ok((listener, port));
596            }
597            Err(_) => {
598                // 继续尝试下一个端口
599            }
600        }
601    }
602
603    // 如果向上寻找失败,尝试向下寻找
604    for offset in 1..=50 {
605        if preferred_port < offset {
606            break;
607        }
608
609        let port = preferred_port - offset;
610        if port < 1024 {
611            // 避免使用系统保留端口
612            break;
613        }
614
615        let addr = format!("0.0.0.0:{}", port);
616        match tokio::net::TcpListener::bind(&addr).await {
617            Ok(listener) => {
618                info!("成功绑定到替代端口: {}", port);
619                return Ok((listener, port));
620            }
621            Err(_) => {
622                // 继续尝试下一个端口
623            }
624        }
625    }
626
627    // 最后尝试让系统自动分配端口
628    match tokio::net::TcpListener::bind("0.0.0.0:0").await {
629        Ok(listener) => {
630            let actual_port = listener.local_addr()?.port();
631            info!("使用系统自动分配的端口: {}", actual_port);
632            Ok((listener, actual_port))
633        }
634        Err(e) => Err(format!("无法绑定到任何端口: {}", e).into()),
635    }
636}
637
638/// 启动服务器
639pub async fn start_server(port: u16) -> Result<(), Box<dyn std::error::Error>> {
640    let app = create_app();
641
642    // 尝试绑定端口,如果失败则尝试其他端口
643    let (listener, actual_port) = bind_available_port(port).await?;
644
645    info!("Listening on: http://127.0.0.1:{}", actual_port);
646    info!("API Endpoints:");
647    info!("   POST /api/validate    - 验证配置");
648    info!("   POST /api/sync        - 同步参数");
649    info!("   POST /api/encode      - 编码评估");
650    info!("   POST /api/optimize    - 开始优化");
651    info!("   GET  /sse/status      - SSE 实时状态推送");
652
653    axum::serve(listener, app).await?;
654
655    Ok(())
656}