Skip to main content

aster/chrome_mcp/
socket_server.rs

1//! Socket Server - 运行在 Native Host 进程中
2//!
3//! 架构:
4//! Chrome 扩展 → Native Messaging → Native Host (包含此 Socket Server) ← Socket ← MCP Client
5//!
6//! 平台支持:
7//! - Unix: 使用 Unix Domain Socket
8//! - Windows: 使用 Named Pipe
9
10use std::collections::HashMap;
11use std::io::{Read, Write};
12use std::sync::atomic::{AtomicU32, Ordering};
13use std::sync::Arc;
14
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::sync::Mutex;
17
18use super::native_host::get_socket_path;
19
20/// Native Host 版本
21const NATIVE_HOST_VERSION: &str = "1.0.0";
22/// 最大消息大小 (1MB)
23const MAX_MESSAGE_SIZE: u32 = 1048576;
24
25/// MCP 客户端信息 (Unix)
26#[cfg(unix)]
27#[allow(dead_code)]
28struct McpClientInfo {
29    id: u32,
30    writer: tokio::net::unix::OwnedWriteHalf,
31}
32
33/// MCP 客户端信息 (Windows)
34#[cfg(windows)]
35#[allow(dead_code)]
36struct McpClientInfo {
37    id: u32,
38    pipe: Arc<Mutex<tokio::net::windows::named_pipe::NamedPipeServer>>,
39}
40
41/// Socket Server - 管理与 MCP 客户端的连接
42pub struct SocketServer {
43    mcp_clients: Arc<Mutex<HashMap<u32, McpClientInfo>>>,
44    next_client_id: AtomicU32,
45    running: Arc<Mutex<bool>>,
46}
47
48impl SocketServer {
49    /// 创建新的 Socket Server
50    pub fn new() -> Self {
51        Self {
52            mcp_clients: Arc::new(Mutex::new(HashMap::new())),
53            next_client_id: AtomicU32::new(1),
54            running: Arc::new(Mutex::new(false)),
55        }
56    }
57
58    /// 启动 Socket 服务器 (Unix)
59    #[cfg(unix)]
60    pub async fn start(&self) -> Result<(), String> {
61        let mut running = self.running.lock().await;
62        if *running {
63            return Ok(());
64        }
65
66        let socket_path = get_socket_path();
67        log_message(&format!("Creating socket listener: {}", socket_path));
68
69        // 清理旧的 socket 文件
70        let _ = std::fs::remove_file(&socket_path);
71
72        let listener = tokio::net::UnixListener::bind(&socket_path)
73            .map_err(|e| format!("Failed to bind socket: {}", e))?;
74
75        // 设置权限
76        {
77            use std::os::unix::fs::PermissionsExt;
78            let perms = std::fs::Permissions::from_mode(0o600);
79            let _ = std::fs::set_permissions(&socket_path, perms);
80        }
81
82        *running = true;
83        log_message("Socket server listening for connections");
84
85        let clients = Arc::clone(&self.mcp_clients);
86        let next_id = &self.next_client_id;
87
88        // 接受连接循环
89        loop {
90            match listener.accept().await {
91                Ok((stream, _)) => {
92                    let id = next_id.fetch_add(1, Ordering::SeqCst);
93                    self.handle_mcp_client(id, stream, Arc::clone(&clients))
94                        .await;
95                }
96                Err(e) => {
97                    log_message(&format!("Accept error: {}", e));
98                }
99            }
100        }
101    }
102
103    /// 启动 Socket 服务器 (Windows - Named Pipe)
104    #[cfg(windows)]
105    pub async fn start(&self) -> Result<(), String> {
106        use tokio::net::windows::named_pipe::ServerOptions;
107
108        let mut running = self.running.lock().await;
109        if *running {
110            return Ok(());
111        }
112
113        let pipe_path = get_socket_path();
114        log_message(&format!("Creating named pipe server: {}", pipe_path));
115
116        *running = true;
117        log_message("Named pipe server listening for connections");
118
119        let clients = Arc::clone(&self.mcp_clients);
120        let next_id = &self.next_client_id;
121
122        // 接受连接循环
123        loop {
124            // 创建新的 Named Pipe 实例
125            let server = ServerOptions::new()
126                .first_pipe_instance(false)
127                .create(&pipe_path)
128                .map_err(|e| format!("Failed to create named pipe: {}", e))?;
129
130            // 等待客户端连接
131            if let Err(e) = server.connect().await {
132                log_message(&format!("Named pipe connect error: {}", e));
133                continue;
134            }
135
136            let id = next_id.fetch_add(1, Ordering::SeqCst);
137            self.handle_mcp_client_windows(id, server, Arc::clone(&clients))
138                .await;
139        }
140    }
141
142    /// 处理 MCP 客户端连接 (Unix)
143    #[cfg(unix)]
144    async fn handle_mcp_client(
145        &self,
146        id: u32,
147        stream: tokio::net::UnixStream,
148        clients: Arc<Mutex<HashMap<u32, McpClientInfo>>>,
149    ) {
150        let (mut reader, writer) = stream.into_split();
151
152        {
153            let mut clients = clients.lock().await;
154            clients.insert(id, McpClientInfo { id, writer });
155            log_message(&format!(
156                "MCP client {} connected. Total: {}",
157                id,
158                clients.len()
159            ));
160        }
161
162        // 通知 Chrome 扩展
163        send_to_chrome(&serde_json::json!({ "type": "mcp_connected" }));
164
165        let clients_clone = Arc::clone(&clients);
166
167        // 读取循环
168        tokio::spawn(async move {
169            let mut buffer = Vec::new();
170            let mut read_buf = [0u8; 4096];
171
172            loop {
173                match reader.read(&mut read_buf).await {
174                    Ok(0) => break,
175                    Ok(n) => {
176                        buffer.extend_from_slice(&read_buf[..n]);
177                        Self::process_mcp_buffer(&mut buffer, id).await;
178                    }
179                    Err(_) => break,
180                }
181            }
182
183            let mut clients = clients_clone.lock().await;
184            clients.remove(&id);
185            log_message(&format!(
186                "MCP client {} disconnected. Total: {}",
187                id,
188                clients.len()
189            ));
190        });
191    }
192
193    /// 处理 MCP 客户端连接 (Windows)
194    #[cfg(windows)]
195    async fn handle_mcp_client_windows(
196        &self,
197        id: u32,
198        server: tokio::net::windows::named_pipe::NamedPipeServer,
199        clients: Arc<Mutex<HashMap<u32, McpClientInfo>>>,
200    ) {
201        let pipe = Arc::new(Mutex::new(server));
202
203        {
204            let mut clients = clients.lock().await;
205            clients.insert(
206                id,
207                McpClientInfo {
208                    id,
209                    pipe: Arc::clone(&pipe),
210                },
211            );
212            log_message(&format!(
213                "MCP client {} connected. Total: {}",
214                id,
215                clients.len()
216            ));
217        }
218
219        // 通知 Chrome 扩展
220        send_to_chrome(&serde_json::json!({ "type": "mcp_connected" }));
221
222        let clients_clone = Arc::clone(&clients);
223        let pipe_clone = Arc::clone(&pipe);
224
225        // 读取循环
226        tokio::spawn(async move {
227            let mut buffer = Vec::new();
228            let mut read_buf = [0u8; 4096];
229
230            loop {
231                let read_result = {
232                    let mut pipe = pipe_clone.lock().await;
233                    pipe.read(&mut read_buf).await
234                };
235
236                match read_result {
237                    Ok(0) => break,
238                    Ok(n) => {
239                        buffer.extend_from_slice(&read_buf[..n]);
240                        Self::process_mcp_buffer(&mut buffer, id).await;
241                    }
242                    Err(_) => break,
243                }
244            }
245
246            let mut clients = clients_clone.lock().await;
247            clients.remove(&id);
248            log_message(&format!(
249                "MCP client {} disconnected. Total: {}",
250                id,
251                clients.len()
252            ));
253        });
254    }
255
256    /// 处理 MCP 客户端缓冲区
257    async fn process_mcp_buffer(buffer: &mut Vec<u8>, client_id: u32) {
258        while buffer.len() >= 4 {
259            let msg_len = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
260
261            if msg_len == 0 || msg_len > MAX_MESSAGE_SIZE {
262                log_message(&format!(
263                    "Invalid message length from client {}: {}",
264                    client_id, msg_len
265                ));
266                buffer.clear();
267                return;
268            }
269
270            let total_len = 4 + msg_len as usize;
271            if buffer.len() < total_len {
272                return;
273            }
274
275            let msg_data = &buffer[4..total_len];
276            if let Ok(msg_str) = std::str::from_utf8(msg_data) {
277                if let Ok(message) = serde_json::from_str::<serde_json::Value>(msg_str) {
278                    log_message(&format!(
279                        "Received from MCP client {}: {}",
280                        client_id,
281                        msg_str.get(..msg_str.len().min(200)).unwrap_or(msg_str)
282                    ));
283                    // 转发到 Chrome 扩展
284                    send_to_chrome(&message);
285                }
286            }
287
288            buffer.drain(..total_len);
289        }
290    }
291
292    /// 处理来自 Chrome 扩展的消息
293    pub async fn handle_chrome_message(&self, message: &str) -> Result<(), String> {
294        log_message(&format!(
295            "Chrome message: {}",
296            message.get(..message.len().min(300)).unwrap_or(message)
297        ));
298
299        let data: serde_json::Value =
300            serde_json::from_str(message).map_err(|e| format!("Parse error: {}", e))?;
301
302        // 检查是否是工具响应
303        if data.get("result").is_some() || data.get("error").is_some() {
304            log_message("Received tool response, forwarding to MCP clients");
305            self.forward_to_mcp_clients(&data).await;
306            return Ok(());
307        }
308
309        // 处理其他消息类型
310        if let Some(msg_type) = data.get("type").and_then(|v| v.as_str()) {
311            match msg_type {
312                "ping" => {
313                    send_to_chrome(&serde_json::json!({
314                        "type": "pong",
315                        "timestamp": chrono::Utc::now().timestamp_millis()
316                    }));
317                }
318                "get_status" => {
319                    send_to_chrome(&serde_json::json!({
320                        "type": "status_response",
321                        "native_host_version": NATIVE_HOST_VERSION
322                    }));
323                }
324                _ => {
325                    self.forward_to_mcp_clients(&data).await;
326                }
327            }
328        } else {
329            self.forward_to_mcp_clients(&data).await;
330        }
331
332        Ok(())
333    }
334
335    /// 转发消息到所有 MCP 客户端 (Unix)
336    #[cfg(unix)]
337    async fn forward_to_mcp_clients(&self, data: &serde_json::Value) {
338        let mut clients = self.mcp_clients.lock().await;
339        if clients.is_empty() {
340            return;
341        }
342
343        log_message(&format!("Forwarding to {} MCP clients", clients.len()));
344
345        let json = serde_json::to_vec(data).unwrap_or_default();
346        let mut header = [0u8; 4];
347        header.copy_from_slice(&(json.len() as u32).to_le_bytes());
348
349        let mut failed_ids = Vec::new();
350
351        for (id, client) in clients.iter_mut() {
352            if client.writer.write_all(&header).await.is_err()
353                || client.writer.write_all(&json).await.is_err()
354            {
355                failed_ids.push(*id);
356            }
357        }
358
359        for id in failed_ids {
360            clients.remove(&id);
361        }
362    }
363
364    /// 转发消息到所有 MCP 客户端 (Windows)
365    #[cfg(windows)]
366    async fn forward_to_mcp_clients(&self, data: &serde_json::Value) {
367        let mut clients = self.mcp_clients.lock().await;
368        if clients.is_empty() {
369            return;
370        }
371
372        log_message(&format!("Forwarding to {} MCP clients", clients.len()));
373
374        let json = serde_json::to_vec(data).unwrap_or_default();
375        let mut header = [0u8; 4];
376        header.copy_from_slice(&(json.len() as u32).to_le_bytes());
377
378        let mut failed_ids = Vec::new();
379
380        for (id, client) in clients.iter_mut() {
381            let mut pipe = client.pipe.lock().await;
382            if pipe.write_all(&header).await.is_err() || pipe.write_all(&json).await.is_err() {
383                failed_ids.push(*id);
384            }
385        }
386
387        for id in failed_ids {
388            clients.remove(&id);
389        }
390    }
391
392    /// 停止服务器 (Unix)
393    #[cfg(unix)]
394    pub async fn stop(&self) {
395        let mut running = self.running.lock().await;
396        if !*running {
397            return;
398        }
399        *running = false;
400
401        // 清理 socket 文件
402        let socket_path = get_socket_path();
403        let _ = std::fs::remove_file(&socket_path);
404
405        // 关闭所有客户端
406        let mut clients = self.mcp_clients.lock().await;
407        clients.clear();
408
409        log_message("Socket server stopped");
410    }
411
412    /// 停止服务器 (Windows)
413    #[cfg(windows)]
414    pub async fn stop(&self) {
415        let mut running = self.running.lock().await;
416        if !*running {
417            return;
418        }
419        *running = false;
420
421        // 关闭所有客户端
422        let mut clients = self.mcp_clients.lock().await;
423        clients.clear();
424
425        log_message("Named pipe server stopped");
426    }
427}
428
429impl Default for SocketServer {
430    fn default() -> Self {
431        Self::new()
432    }
433}
434
435/// 日志输出到 stderr(Native Messaging 使用 stdout)
436fn log_message(message: &str) {
437    let timestamp = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ");
438    eprintln!("[{}] [Native Host] {}", timestamp, message);
439
440    // 同时写入日志文件
441    if let Some(home) = dirs::home_dir() {
442        let log_file = home.join(".aster").join("native-host.log");
443        if let Ok(mut file) = std::fs::OpenOptions::new()
444            .create(true)
445            .append(true)
446            .open(&log_file)
447        {
448            let _ = writeln!(file, "[{}] {}", timestamp, message);
449        }
450    }
451}
452
453/// 向 Chrome 扩展发送消息(Native Messaging 协议)
454fn send_to_chrome(message: &serde_json::Value) {
455    let json_str = serde_json::to_string(message).unwrap_or_default();
456    log_message(&format!(
457        "Sending to Chrome: {}",
458        json_str.get(..json_str.len().min(200)).unwrap_or(&json_str)
459    ));
460
461    let json = json_str.as_bytes();
462    let mut header = [0u8; 4];
463    header.copy_from_slice(&(json.len() as u32).to_le_bytes());
464
465    let mut stdout = std::io::stdout().lock();
466    let _ = stdout.write_all(&header);
467    let _ = stdout.write_all(json);
468    let _ = stdout.flush();
469}
470
471/// Native Message Reader - 从 stdin 读取 Native Messaging 消息
472#[allow(dead_code)]
473pub struct NativeMessageReader {
474    buffer: Vec<u8>,
475}
476
477impl NativeMessageReader {
478    pub fn new() -> Self {
479        Self { buffer: Vec::new() }
480    }
481
482    /// 读取下一条消息
483    pub fn read(&mut self) -> Option<String> {
484        let mut stdin = std::io::stdin().lock();
485        let mut header = [0u8; 4];
486
487        if stdin.read_exact(&mut header).is_err() {
488            return None;
489        }
490
491        let msg_len = u32::from_le_bytes(header);
492        if msg_len == 0 || msg_len > MAX_MESSAGE_SIZE {
493            log_message(&format!("Invalid message length: {}", msg_len));
494            return None;
495        }
496
497        let mut msg_buf = vec![0u8; msg_len as usize];
498        if stdin.read_exact(&mut msg_buf).is_err() {
499            return None;
500        }
501
502        String::from_utf8(msg_buf).ok()
503    }
504}
505
506impl Default for NativeMessageReader {
507    fn default() -> Self {
508        Self::new()
509    }
510}
511
512/// 运行 Native Host 主循环
513pub async fn run_native_host() -> Result<(), String> {
514    log_message("Initializing Native Host...");
515
516    let server = SocketServer::new();
517    let mut reader = NativeMessageReader::new();
518
519    // 启动 socket server(在后台)
520    tokio::spawn(async move {
521        let s = SocketServer::new();
522        if let Err(e) = s.start().await {
523            log_message(&format!("Socket server error: {}", e));
524        }
525    });
526
527    // 从 Chrome 扩展读取消息
528    log_message("Running in Native Messaging mode");
529    while let Some(message) = reader.read() {
530        if let Err(e) = server.handle_chrome_message(&message).await {
531            log_message(&format!("Handle message error: {}", e));
532        }
533    }
534
535    server.stop().await;
536    Ok(())
537}