mcp_stdio_proxy/proxy/
proxy_handler.rs

1use log::{debug, info};
2/**
3 * Create a local SSE server that proxies requests to a stdio MCP server.
4 */
5use rmcp::{
6    ErrorData, RoleClient, RoleServer, ServerHandler,
7    model::{
8        CallToolRequestParam, CallToolResult, ClientInfo, Content, Implementation, ListToolsResult,
9        PaginatedRequestParam, ServerInfo,
10    },
11    service::{NotificationContext, RequestContext, RunningService},
12};
13use std::sync::{Arc, RwLock};
14use tokio::sync::Mutex;
15
16/// A proxy handler that forwards requests to a client based on the server's capabilities
17#[derive(Clone, Debug)]
18pub struct ProxyHandler {
19    client: Arc<Mutex<RunningService<RoleClient, ClientInfo>>>,
20    // Store the server's capabilities to avoid locking the client on every get_info call
21    cached_info: Arc<RwLock<Option<ServerInfo>>>,
22    // MCP ID 用于日志记录
23    mcp_id: String,
24}
25
26impl ServerHandler for ProxyHandler {
27    fn get_info(&self) -> ServerInfo {
28        // 首先检查缓存的信息
29        if let Ok(cached_read) = self.cached_info.read() {
30            if let Some(ref cached) = *cached_read {
31                return cached.clone();
32            }
33        }
34
35        // 如果缓存为空,尝试动态获取
36        // 使用 try_lock 而不是 lock,避免阻塞
37        // peer_info() 是同步方法,可以安全调用
38        let client = self.client.clone();
39        if let Ok(guard) = client.try_lock() {
40            if let Some(peer_info) = guard.peer_info() {
41                let server_info = ServerInfo {
42                    protocol_version: peer_info.protocol_version.clone(),
43                    server_info: Implementation {
44                        name: peer_info.server_info.name.clone(),
45                        version: peer_info.server_info.version.clone(),
46                        title: None,
47                        website_url: None,
48                        icons: None,
49                    },
50                    instructions: peer_info.instructions.clone(),
51                    capabilities: peer_info.capabilities.clone(),
52                };
53
54                // 将动态获取的信息缓存起来
55                if let Ok(mut cached_write) = self.cached_info.write() {
56                    *cached_write = Some(server_info.clone());
57                    debug!("Successfully cached server info from peer_info");
58                }
59
60                return server_info;
61            }
62        }
63
64        // 如果都获取不到,返回错误状态信息
65        ServerInfo {
66            protocol_version: Default::default(),
67            server_info: Implementation {
68                name: "MCP Proxy - Service Unavailable".to_string(),
69                version: "0.1.0".to_string(),
70                title: None,
71                website_url: None,
72                icons: None,
73            },
74            instructions: Some("ERROR: MCP service is not available or still initializing. Please try again later.".to_string()),
75            capabilities: Default::default(), // 空的能力列表,表示服务不可用
76        }
77    }
78
79    #[tracing::instrument(skip(self, request, _context), fields(
80        mcp_id = %self.mcp_id,
81        request = ?request,
82    ))]
83    async fn list_tools(
84        &self,
85        request: Option<PaginatedRequestParam>,
86        _context: RequestContext<RoleServer>,
87    ) -> Result<ListToolsResult, ErrorData> {
88        let client = self.client.clone();
89        let guard = client.lock().await;
90
91        // Check if the server has tools capability and forward the request
92        match self.get_info().capabilities.tools {
93            Some(_) => {
94                match guard.list_tools(request).await {
95                    // Forward request to client
96                    Ok(result) => {
97                        // 记录工具列表结果,这些结果会通过 SSE 推送给客户端
98                        info!(
99                            "[list_tools] 工具列表结果 - MCP ID: {}, 工具数量: {}",
100                            self.mcp_id,
101                            result.tools.len()
102                        );
103
104                        debug!(
105                            "Proxying list_tools response with {} tools",
106                            result.tools.len()
107                        );
108                        Ok(result)
109                    }
110                    Err(err) => {
111                        tracing::error!("Error listing tools: {:?}", err);
112                        // Return empty list instead of error
113                        Ok(ListToolsResult::default())
114                    }
115                }
116            }
117            None => {
118                // Server doesn't support tools, return empty list
119                tracing::error!("Server doesn't support tools capability");
120                Ok(ListToolsResult::default())
121            }
122        }
123    }
124
125    #[tracing::instrument(skip(self, request, _context), fields(
126        mcp_id = %self.mcp_id,
127        tool_name = %request.name,
128        tool_arguments = ?request.arguments,
129    ))]
130    async fn call_tool(
131        &self,
132        request: CallToolRequestParam,
133        _context: RequestContext<RoleServer>,
134    ) -> Result<CallToolResult, ErrorData> {
135        let client = self.client.clone();
136        let guard = client.lock().await;
137
138        // Check if the server has tools capability and forward the request
139        match self.get_info().capabilities.tools {
140            Some(_) => {
141                match guard.call_tool(request.clone()).await {
142                    Ok(result) => {
143                        // 记录工具调用结果,这些结果会通过 SSE 推送给客户端
144                        info!(
145                            "[call_tool] 工具调用结果 - MCP ID: {}, 工具: {}",
146                            self.mcp_id, request.name
147                        );
148
149                        debug!("Tool call succeeded");
150                        Ok(result)
151                    }
152                    Err(err) => {
153                        tracing::error!("Error calling tool: {:?}", err);
154                        // Return an error result instead of propagating the error
155                        Ok(CallToolResult::error(vec![Content::text(format!(
156                            "Error: {err}"
157                        ))]))
158                    }
159                }
160            }
161            None => {
162                tracing::error!("Server doesn't support tools capability");
163                Ok(CallToolResult::error(vec![Content::text(
164                    "Server doesn't support tools capability",
165                )]))
166            }
167        }
168    }
169
170    async fn list_resources(
171        &self,
172        request: Option<PaginatedRequestParam>,
173        _context: RequestContext<RoleServer>,
174    ) -> Result<rmcp::model::ListResourcesResult, ErrorData> {
175        // Get a lock on the client
176        let client = self.client.clone();
177        let guard = client.lock().await;
178
179        // Check if the server has resources capability and forward the request
180        match self.get_info().capabilities.resources {
181            Some(_) => {
182                // Forward request to client
183                match guard.list_resources(request).await {
184                    Ok(result) => {
185                        // 记录资源列表结果,这些结果会通过 SSE 推送给客户端
186                        info!(
187                            "[list_resources] 资源列表结果 - MCP ID: {}, 资源数量: {}",
188                            self.mcp_id,
189                            result.resources.len()
190                        );
191
192                        debug!("Proxying list_resources response");
193                        Ok(result)
194                    }
195                    Err(err) => {
196                        tracing::error!("Error listing resources: {:?}", err);
197                        // Return empty list instead of error
198                        Ok(rmcp::model::ListResourcesResult::default())
199                    }
200                }
201            }
202            None => {
203                // Server doesn't support resources, return empty list
204                tracing::error!("Server doesn't support resources capability");
205                Ok(rmcp::model::ListResourcesResult::default())
206            }
207        }
208    }
209
210    async fn read_resource(
211        &self,
212        request: rmcp::model::ReadResourceRequestParam,
213        _context: RequestContext<RoleServer>,
214    ) -> Result<rmcp::model::ReadResourceResult, ErrorData> {
215        // Get a lock on the client
216        let client = self.client.clone();
217        let guard = client.lock().await;
218
219        // Check if the server has resources capability and forward the request
220        match self.get_info().capabilities.resources {
221            Some(_) => {
222                // Forward request to client
223                match guard
224                    .read_resource(rmcp::model::ReadResourceRequestParam {
225                        uri: request.uri.clone(),
226                    })
227                    .await
228                {
229                    Ok(result) => {
230                        // 记录资源读取结果,这些结果会通过 SSE 推送给客户端
231                        info!(
232                            "[read_resource] 资源读取结果 - MCP ID: {}, URI: {}",
233                            self.mcp_id, request.uri
234                        );
235
236                        debug!("Proxying read_resource response for {}", request.uri);
237                        Ok(result)
238                    }
239                    Err(err) => {
240                        tracing::error!("Error reading resource: {:?}", err);
241                        Err(ErrorData::internal_error(
242                            format!("Error reading resource: {err}"),
243                            None,
244                        ))
245                    }
246                }
247            }
248            None => {
249                // Server doesn't support resources, return error
250                tracing::error!("Server doesn't support resources capability");
251                Ok(rmcp::model::ReadResourceResult {
252                    contents: Vec::new(),
253                })
254            }
255        }
256    }
257
258    async fn list_resource_templates(
259        &self,
260        request: Option<PaginatedRequestParam>,
261        _context: RequestContext<RoleServer>,
262    ) -> Result<rmcp::model::ListResourceTemplatesResult, ErrorData> {
263        // Get a lock on the client
264        let client = self.client.clone();
265        let guard = client.lock().await;
266
267        // Check if the server has resources capability and forward the request
268        match self.get_info().capabilities.resources {
269            Some(_) => {
270                // Forward request to client
271                match guard.list_resource_templates(request).await {
272                    Ok(result) => {
273                        debug!("Proxying list_resource_templates response");
274                        Ok(result)
275                    }
276                    Err(err) => {
277                        tracing::error!("Error listing resource templates: {:?}", err);
278                        // Return empty list instead of error
279                        Ok(rmcp::model::ListResourceTemplatesResult::default())
280                    }
281                }
282            }
283            None => {
284                // Server doesn't support resources, return empty list
285                tracing::error!("Server doesn't support resources capability");
286                Ok(rmcp::model::ListResourceTemplatesResult::default())
287            }
288        }
289    }
290
291    async fn list_prompts(
292        &self,
293        request: Option<PaginatedRequestParam>,
294        _context: RequestContext<RoleServer>,
295    ) -> Result<rmcp::model::ListPromptsResult, ErrorData> {
296        // Get a lock on the client
297        let client = self.client.clone();
298        let guard = client.lock().await;
299
300        // Check if the server has prompts capability and forward the request
301        match self.get_info().capabilities.prompts {
302            Some(_) => {
303                // Forward request to client
304                match guard.list_prompts(request).await {
305                    Ok(result) => {
306                        debug!("Proxying list_prompts response");
307                        Ok(result)
308                    }
309                    Err(err) => {
310                        tracing::error!("Error listing prompts: {:?}", err);
311                        // Return empty list instead of error
312                        Ok(rmcp::model::ListPromptsResult::default())
313                    }
314                }
315            }
316            None => {
317                // Server doesn't support prompts, return empty list
318                tracing::warn!("Server doesn't support prompts capability");
319                Ok(rmcp::model::ListPromptsResult::default())
320            }
321        }
322    }
323
324    async fn get_prompt(
325        &self,
326        request: rmcp::model::GetPromptRequestParam,
327        _context: RequestContext<RoleServer>,
328    ) -> Result<rmcp::model::GetPromptResult, ErrorData> {
329        // Get a lock on the client
330        let client = self.client.clone();
331        let guard = client.lock().await;
332
333        // Check if the server has prompts capability and forward the request
334        match self.get_info().capabilities.prompts {
335            Some(_) => {
336                // Forward request to client
337                match guard.get_prompt(request).await {
338                    Ok(result) => {
339                        debug!("Proxying get_prompt response");
340                        Ok(result)
341                    }
342                    Err(err) => {
343                        tracing::error!("Error getting prompt: {:?}", err);
344                        Err(ErrorData::internal_error(
345                            format!("Error getting prompt: {err}"),
346                            None,
347                        ))
348                    }
349                }
350            }
351            None => {
352                // Server doesn't support prompts, return error
353                tracing::warn!("Server doesn't support prompts capability");
354                Ok(rmcp::model::GetPromptResult {
355                    description: None,
356                    messages: Vec::new(),
357                })
358            }
359        }
360    }
361
362    async fn complete(
363        &self,
364        request: rmcp::model::CompleteRequestParam,
365        _context: RequestContext<RoleServer>,
366    ) -> Result<rmcp::model::CompleteResult, ErrorData> {
367        // Get a lock on the client
368        let client = self.client.clone();
369        let guard = client.lock().await;
370
371        // Forward request to client
372        match guard.complete(request).await {
373            Ok(result) => {
374                debug!("Proxying complete response");
375                Ok(result)
376            }
377            Err(err) => {
378                tracing::error!("Error completing: {:?}", err);
379                Err(ErrorData::internal_error(
380                    format!("Error completing: {err}"),
381                    None,
382                ))
383            }
384        }
385    }
386
387    async fn on_progress(
388        &self,
389        notification: rmcp::model::ProgressNotificationParam,
390        _context: NotificationContext<RoleServer>,
391    ) {
392        // Get a lock on the client
393        let client = self.client.clone();
394        let guard = client.lock().await;
395        match guard.notify_progress(notification).await {
396            Ok(_) => {
397                debug!("Proxying progress notification");
398            }
399            Err(err) => {
400                tracing::error!("Error notifying progress: {:?}", err);
401            }
402        }
403    }
404
405    async fn on_cancelled(
406        &self,
407        notification: rmcp::model::CancelledNotificationParam,
408        _context: NotificationContext<RoleServer>,
409    ) {
410        // Get a lock on the client
411        let client = self.client.clone();
412        let guard = client.lock().await;
413        match guard.notify_cancelled(notification).await {
414            Ok(_) => {
415                debug!("Proxying cancelled notification");
416            }
417            Err(err) => {
418                tracing::error!("Error notifying cancelled: {:?}", err);
419            }
420        }
421    }
422}
423
424impl ProxyHandler {
425    pub fn new(client: RunningService<RoleClient, ClientInfo>) -> Self {
426        Self::with_mcp_id(client, "unknown".to_string())
427    }
428
429    pub fn with_mcp_id(client: RunningService<RoleClient, ClientInfo>, mcp_id: String) -> Self {
430        let peer_info = client.peer_info();
431
432        // Create a ServerInfo object that forwards the server's capabilities
433        let cached_info = peer_info.map(|peer_info| ServerInfo {
434            protocol_version: peer_info.protocol_version.clone(),
435            server_info: Implementation {
436                name: peer_info.server_info.name.clone(),
437                version: peer_info.server_info.version.clone(),
438                title: None,
439                website_url: None,
440                icons: None,
441            },
442            instructions: peer_info.instructions.clone(),
443            capabilities: peer_info.capabilities.clone(),
444        });
445
446        Self {
447            client: Arc::new(Mutex::new(client)),
448            cached_info: Arc::new(RwLock::new(cached_info)),
449            mcp_id,
450        }
451    }
452
453    //检查 mcp服务是否正常,尝试调用 list_tools 方法,如果成功返回结果,则认为成功
454    pub async fn is_mcp_server_ready(&self) -> bool {
455        // 使用 try_lock 避免在定时检查时阻塞正常的业务请求
456        // 如果无法获取锁,说明正在处理其他请求,假设服务正常
457        match self.client.try_lock() {
458            Ok(guard) => (guard.list_tools(None).await).is_ok(),
459            Err(_) => {
460                debug!("is_mcp_server_ready: 无法获取锁,假设服务正常");
461                true
462            }
463        }
464    }
465
466    /// 检查子进程是否已经终止
467    pub fn is_terminated(&self) -> bool {
468        // 尝试获取锁,如果无法获取锁,说明子进程可能已经终止
469        match self.client.try_lock() {
470            Ok(_) => {
471                // 能够获取锁,我们假设子进程仍在运行
472                // 注意:我们不再尝试执行异步操作,因为这会导致运行时嵌套问题
473                false
474            }
475            Err(_) => {
476                // 无法获取锁,可能是因为子进程正在被其他线程使用
477                debug!("子进程状态检查: 无法获取锁,假设子进程仍在运行");
478                false // 这种情况下我们假设子进程还在运行
479            }
480        }
481    }
482
483    /// 异步检查子进程是否已经终止
484    pub async fn is_terminated_async(&self) -> bool {
485        // 尝试获取锁,如果无法获取锁,说明子进程可能已经终止
486        match self.client.try_lock() {
487            Ok(guard) => {
488                // 检查客户端是否已经终止
489                // 这里我们通过尝试发送一个轻量级请求来检测连接状态
490                match guard.list_tools(None).await {
491                    Ok(_) => {
492                        debug!("子进程状态检查: 正在运行");
493                        false // 成功获取信息,子进程还在运行
494                    }
495                    Err(e) => {
496                        info!("子进程状态检查: 已终止,原因: {e}");
497                        true // 无法获取信息,子进程可能已终止
498                    }
499                }
500            }
501            Err(_) => {
502                // 无法获取锁,可能是因为子进程正在被其他线程使用
503                debug!("子进程状态检查: 无法获取锁,假设子进程仍在运行");
504                false // 这种情况下我们假设子进程还在运行
505            }
506        }
507    }
508}