Skip to main content

aster/chrome_mcp/
socket_client.rs

1//! Socket Client - 连接到 Native Host Socket Server
2//!
3//! 架构:
4//! MCP Server (包含此 Socket Client) → Socket → Native Host → Native Messaging → Chrome 扩展
5
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10
11#[cfg(unix)]
12use tokio::io::AsyncReadExt;
13use tokio::io::AsyncWriteExt;
14use tokio::sync::{mpsc, oneshot, Mutex};
15use tokio::time::timeout;
16
17use super::native_host::get_socket_path;
18use super::types::ToolCallResult;
19
20/// 最大消息大小 (1MB)
21const MAX_MESSAGE_SIZE: u32 = 1048576;
22/// 连接超时 (5秒)
23const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
24/// 工具调用超时 (60秒)
25const TOOL_CALL_TIMEOUT: Duration = Duration::from_secs(60);
26/// 重连延迟 (1秒)
27#[allow(dead_code)]
28const RECONNECT_DELAY: Duration = Duration::from_secs(1);
29/// 最大重连次数
30#[allow(dead_code)]
31const MAX_RECONNECT_ATTEMPTS: u32 = 10;
32
33/// Socket 连接错误
34#[derive(Debug, Clone)]
35pub struct SocketConnectionError {
36    pub message: String,
37}
38
39impl std::fmt::Display for SocketConnectionError {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "SocketConnectionError: {}", self.message)
42    }
43}
44
45impl std::error::Error for SocketConnectionError {}
46
47impl SocketConnectionError {
48    pub fn new(message: impl Into<String>) -> Self {
49        Self {
50            message: message.into(),
51        }
52    }
53}
54
55/// 等待中的工具调用
56struct PendingCall {
57    sender: oneshot::Sender<Result<ToolCallResult, SocketConnectionError>>,
58}
59
60/// Socket Client 内部状态
61struct ClientState {
62    connected: bool,
63    connecting: bool,
64    pending_calls: HashMap<String, PendingCall>,
65    reconnect_attempts: u32,
66}
67
68/// Socket Client - 连接到 Native Host Socket Server
69pub struct SocketClient {
70    state: Arc<Mutex<ClientState>>,
71    call_id: AtomicU64,
72    #[cfg(unix)]
73    writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
74    #[cfg(windows)]
75    writer: Arc<Mutex<Option<tokio::net::windows::named_pipe::NamedPipeClient>>>,
76    shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
77}
78
79impl SocketClient {
80    /// 创建新的 Socket Client
81    pub fn new() -> Self {
82        Self {
83            state: Arc::new(Mutex::new(ClientState {
84                connected: false,
85                connecting: false,
86                pending_calls: HashMap::new(),
87                reconnect_attempts: 0,
88            })),
89            call_id: AtomicU64::new(0),
90            writer: Arc::new(Mutex::new(None)),
91            shutdown_tx: Arc::new(Mutex::new(None)),
92        }
93    }
94
95    /// 检查是否已连接
96    pub async fn is_connected(&self) -> bool {
97        self.state.lock().await.connected
98    }
99
100    /// 确保已连接
101    pub async fn ensure_connected(&self) -> bool {
102        {
103            let state = self.state.lock().await;
104            if state.connected {
105                return true;
106            }
107            if state.connecting {
108                drop(state);
109                // 等待连接完成
110                for _ in 0..50 {
111                    tokio::time::sleep(Duration::from_millis(100)).await;
112                    let state = self.state.lock().await;
113                    if state.connected {
114                        return true;
115                    }
116                    if !state.connecting {
117                        return false;
118                    }
119                }
120                return false;
121            }
122        }
123
124        match self.connect().await {
125            Ok(_) => self.state.lock().await.connected,
126            Err(e) => {
127                tracing::warn!("Failed to connect to socket: {}", e);
128                false
129            }
130        }
131    }
132
133    /// 连接到 Socket Server (Unix)
134    #[cfg(unix)]
135    async fn connect(&self) -> Result<(), SocketConnectionError> {
136        {
137            let mut state = self.state.lock().await;
138            if state.connected || state.connecting {
139                return Ok(());
140            }
141            state.connecting = true;
142        }
143
144        let socket_path = get_socket_path();
145
146        let connect_result = timeout(
147            CONNECT_TIMEOUT,
148            tokio::net::UnixStream::connect(&socket_path),
149        )
150        .await;
151
152        match connect_result {
153            Ok(Ok(stream)) => {
154                let (reader, writer) = stream.into_split();
155                *self.writer.lock().await = Some(writer);
156
157                let state_clone = Arc::clone(&self.state);
158                let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
159                *self.shutdown_tx.lock().await = Some(shutdown_tx);
160
161                // 启动读取任务
162                tokio::spawn(async move {
163                    Self::read_loop(reader, state_clone, shutdown_rx).await;
164                });
165
166                let mut state = self.state.lock().await;
167                state.connected = true;
168                state.connecting = false;
169                state.reconnect_attempts = 0;
170                tracing::info!("Connected to socket server");
171                Ok(())
172            }
173            Ok(Err(e)) => {
174                let mut state = self.state.lock().await;
175                state.connecting = false;
176                Err(SocketConnectionError::new(format!(
177                    "Connection failed: {}",
178                    e
179                )))
180            }
181            Err(_) => {
182                let mut state = self.state.lock().await;
183                state.connecting = false;
184                Err(SocketConnectionError::new("Connection timeout"))
185            }
186        }
187    }
188
189    /// 连接到 Socket Server (Windows)
190    #[cfg(windows)]
191    async fn connect(&self) -> Result<(), SocketConnectionError> {
192        {
193            let mut state = self.state.lock().await;
194            if state.connected || state.connecting {
195                return Ok(());
196            }
197            state.connecting = true;
198        }
199
200        let socket_path = get_socket_path();
201
202        // Windows named pipe 连接
203        let connect_result = timeout(CONNECT_TIMEOUT, async {
204            tokio::net::windows::named_pipe::ClientOptions::new().open(&socket_path)
205        })
206        .await;
207
208        match connect_result {
209            Ok(Ok(pipe)) => {
210                *self.writer.lock().await = Some(pipe);
211
212                let mut state = self.state.lock().await;
213                state.connected = true;
214                state.connecting = false;
215                state.reconnect_attempts = 0;
216                tracing::info!("Connected to socket server");
217                Ok(())
218            }
219            Ok(Err(e)) => {
220                let mut state = self.state.lock().await;
221                state.connecting = false;
222                Err(SocketConnectionError::new(format!(
223                    "Connection failed: {}",
224                    e
225                )))
226            }
227            Err(_) => {
228                let mut state = self.state.lock().await;
229                state.connecting = false;
230                Err(SocketConnectionError::new("Connection timeout"))
231            }
232        }
233    }
234
235    /// Unix 读取循环
236    #[cfg(unix)]
237    async fn read_loop(
238        mut reader: tokio::net::unix::OwnedReadHalf,
239        state: Arc<Mutex<ClientState>>,
240        mut shutdown_rx: mpsc::Receiver<()>,
241    ) {
242        let mut buffer = Vec::new();
243        let mut read_buf = [0u8; 4096];
244
245        loop {
246            tokio::select! {
247                _ = shutdown_rx.recv() => {
248                    tracing::info!("Socket read loop shutdown");
249                    break;
250                }
251                result = reader.read(&mut read_buf) => {
252                    match result {
253                        Ok(0) => {
254                            tracing::info!("Socket connection closed");
255                            Self::handle_disconnect(state).await;
256                            break;
257                        }
258                        Ok(n) => {
259                            buffer.extend_from_slice(&read_buf[..n]);
260                            Self::process_buffer(&mut buffer, &state).await;
261                        }
262                        Err(e) => {
263                            tracing::error!("Socket read error: {}", e);
264                            Self::handle_disconnect(state).await;
265                            break;
266                        }
267                    }
268                }
269            }
270        }
271    }
272
273    /// 处理断开连接
274    async fn handle_disconnect(state: Arc<Mutex<ClientState>>) {
275        let mut state = state.lock().await;
276        state.connected = false;
277        state.connecting = false;
278
279        // 拒绝所有等待中的调用
280        for (_, pending) in state.pending_calls.drain() {
281            let _ = pending
282                .sender
283                .send(Err(SocketConnectionError::new("Connection closed")));
284        }
285    }
286
287    /// 处理缓冲区中的消息
288    async fn process_buffer(buffer: &mut Vec<u8>, state: &Arc<Mutex<ClientState>>) {
289        while buffer.len() >= 4 {
290            let msg_len = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
291
292            if msg_len == 0 || msg_len > MAX_MESSAGE_SIZE {
293                tracing::error!("Invalid message length: {}", msg_len);
294                buffer.clear();
295                return;
296            }
297
298            let total_len = 4 + msg_len as usize;
299            if buffer.len() < total_len {
300                return; // 消息不完整
301            }
302
303            let msg_data = &buffer[4..total_len];
304            if let Ok(msg_str) = std::str::from_utf8(msg_data) {
305                Self::handle_message(msg_str, state).await;
306            }
307
308            buffer.drain(..total_len);
309        }
310    }
311
312    /// 处理接收到的消息
313    async fn handle_message(msg_str: &str, state: &Arc<Mutex<ClientState>>) {
314        let msg: serde_json::Value = match serde_json::from_str(msg_str) {
315            Ok(v) => v,
316            Err(e) => {
317                tracing::error!("Failed to parse message: {}", e);
318                return;
319            }
320        };
321
322        tracing::debug!(
323            "Received message: {}",
324            msg_str.get(..msg_str.len().min(300)).unwrap_or(msg_str)
325        );
326
327        // 检查是否是工具调用响应
328        if msg.get("result").is_some() || msg.get("error").is_some() {
329            let result = super::types::ToolCallResult {
330                result: msg.get("result").and_then(|r| {
331                    r.get("content").map(|c| super::types::ToolResultContent {
332                        content: c.as_array().cloned().unwrap_or_default(),
333                    })
334                }),
335                error: msg.get("error").and_then(|e| {
336                    e.get("content").map(|c| super::types::ToolErrorContent {
337                        content: c.as_array().cloned().unwrap_or_default(),
338                    })
339                }),
340            };
341
342            let mut state = state.lock().await;
343            // 处理第一个等待中的请求
344            if let Some(call_id) = state.pending_calls.keys().next().cloned() {
345                if let Some(pending) = state.pending_calls.remove(&call_id) {
346                    let _ = pending.sender.send(Ok(result));
347                }
348            }
349        }
350    }
351
352    /// 调用工具
353    pub async fn call_tool(
354        &self,
355        tool_name: &str,
356        args: serde_json::Value,
357    ) -> Result<ToolCallResult, SocketConnectionError> {
358        if !self.is_connected().await {
359            return Err(SocketConnectionError::new("Not connected"));
360        }
361
362        let call_id = format!(
363            "call_{}_{}",
364            self.call_id.fetch_add(1, Ordering::SeqCst),
365            chrono::Utc::now().timestamp_millis()
366        );
367
368        let (tx, rx) = oneshot::channel();
369
370        // 注册等待中的调用
371        {
372            let mut state = self.state.lock().await;
373            state
374                .pending_calls
375                .insert(call_id.clone(), PendingCall { sender: tx });
376        }
377
378        // 构造消息
379        let message = serde_json::json!({
380            "type": "tool_request",
381            "method": "execute_tool",
382            "params": {
383                "tool": tool_name,
384                "client_id": "aster",
385                "args": args
386            }
387        });
388
389        // 发送消息
390        if let Err(e) = self.send_message(&message).await {
391            let mut state = self.state.lock().await;
392            state.pending_calls.remove(&call_id);
393            return Err(e);
394        }
395
396        // 等待响应
397        match timeout(TOOL_CALL_TIMEOUT, rx).await {
398            Ok(Ok(result)) => result,
399            Ok(Err(_)) => Err(SocketConnectionError::new("Response channel closed")),
400            Err(_) => {
401                let mut state = self.state.lock().await;
402                state.pending_calls.remove(&call_id);
403                Err(SocketConnectionError::new("Tool call timeout"))
404            }
405        }
406    }
407
408    /// 发送消息 (Unix)
409    #[cfg(unix)]
410    async fn send_message(&self, message: &serde_json::Value) -> Result<(), SocketConnectionError> {
411        let json = serde_json::to_vec(message)
412            .map_err(|e| SocketConnectionError::new(format!("Serialize error: {}", e)))?;
413
414        let mut header = [0u8; 4];
415        header.copy_from_slice(&(json.len() as u32).to_le_bytes());
416
417        let mut writer = self.writer.lock().await;
418        if let Some(ref mut w) = *writer {
419            w.write_all(&header)
420                .await
421                .map_err(|e| SocketConnectionError::new(format!("Write error: {}", e)))?;
422            w.write_all(&json)
423                .await
424                .map_err(|e| SocketConnectionError::new(format!("Write error: {}", e)))?;
425            Ok(())
426        } else {
427            Err(SocketConnectionError::new("Not connected"))
428        }
429    }
430
431    /// 发送消息 (Windows)
432    #[cfg(windows)]
433    async fn send_message(&self, message: &serde_json::Value) -> Result<(), SocketConnectionError> {
434        let json = serde_json::to_vec(message)
435            .map_err(|e| SocketConnectionError::new(format!("Serialize error: {}", e)))?;
436
437        let mut header = [0u8; 4];
438        header.copy_from_slice(&(json.len() as u32).to_le_bytes());
439
440        let mut writer = self.writer.lock().await;
441        if let Some(ref mut w) = *writer {
442            w.write_all(&header)
443                .await
444                .map_err(|e| SocketConnectionError::new(format!("Write error: {}", e)))?;
445            w.write_all(&json)
446                .await
447                .map_err(|e| SocketConnectionError::new(format!("Write error: {}", e)))?;
448            Ok(())
449        } else {
450            Err(SocketConnectionError::new("Not connected"))
451        }
452    }
453
454    /// 断开连接
455    pub async fn disconnect(&self) {
456        // 发送关闭信号
457        if let Some(tx) = self.shutdown_tx.lock().await.take() {
458            let _ = tx.send(()).await;
459        }
460
461        // 清理 writer
462        *self.writer.lock().await = None;
463
464        // 更新状态
465        let mut state = self.state.lock().await;
466        state.connected = false;
467        state.connecting = false;
468
469        // 拒绝所有等待中的调用
470        for (_, pending) in state.pending_calls.drain() {
471            let _ = pending
472                .sender
473                .send(Err(SocketConnectionError::new("Disconnected")));
474        }
475    }
476}
477
478impl Default for SocketClient {
479    fn default() -> Self {
480        Self::new()
481    }
482}
483
484/// 创建 Socket Client 实例
485pub fn create_socket_client() -> SocketClient {
486    SocketClient::new()
487}