mcp_stdio_proxy/proxy/
proxy_handler.rs

1use log::{debug, info};
2/**
3 * Create a local SSE server that proxies requests to a stdio MCP server.
4 */
5use rmcp::{
6    ErrorData, RoleClient, RoleServer, ServerHandler,
7    model::{
8        CallToolRequestParam, CallToolResult, ClientInfo, Content, Implementation, ListToolsResult,
9        PaginatedRequestParam, ServerInfo,
10    },
11    service::{NotificationContext, RequestContext, RunningService},
12};
13use std::collections::HashSet;
14use std::sync::{Arc, RwLock};
15use tokio::sync::Mutex;
16
17/// 工具过滤配置
18#[derive(Clone, Debug, Default)]
19pub struct ToolFilter {
20    /// 白名单(只允许这些工具)
21    pub allow_tools: Option<HashSet<String>>,
22    /// 黑名单(排除这些工具)
23    pub deny_tools: Option<HashSet<String>>,
24}
25
26impl ToolFilter {
27    /// 创建白名单过滤器
28    pub fn allow(tools: Vec<String>) -> Self {
29        Self {
30            allow_tools: Some(tools.into_iter().collect()),
31            deny_tools: None,
32        }
33    }
34
35    /// 创建黑名单过滤器
36    pub fn deny(tools: Vec<String>) -> Self {
37        Self {
38            allow_tools: None,
39            deny_tools: Some(tools.into_iter().collect()),
40        }
41    }
42
43    /// 检查工具是否被允许
44    pub fn is_allowed(&self, tool_name: &str) -> bool {
45        // 白名单模式:只有在白名单中的工具才被允许
46        if let Some(ref allow_list) = self.allow_tools {
47            return allow_list.contains(tool_name);
48        }
49        // 黑名单模式:不在黑名单中的工具都被允许
50        if let Some(ref deny_list) = self.deny_tools {
51            return !deny_list.contains(tool_name);
52        }
53        // 无过滤:全部允许
54        true
55    }
56
57    /// 检查是否启用了过滤
58    pub fn is_enabled(&self) -> bool {
59        self.allow_tools.is_some() || self.deny_tools.is_some()
60    }
61}
62
63/// A proxy handler that forwards requests to a client based on the server's capabilities
64#[derive(Clone, Debug)]
65pub struct ProxyHandler {
66    client: Arc<Mutex<RunningService<RoleClient, ClientInfo>>>,
67    // Store the server's capabilities to avoid locking the client on every get_info call
68    cached_info: Arc<RwLock<Option<ServerInfo>>>,
69    // MCP ID 用于日志记录
70    mcp_id: String,
71    // 工具过滤配置
72    tool_filter: ToolFilter,
73}
74
75impl ServerHandler for ProxyHandler {
76    fn get_info(&self) -> ServerInfo {
77        // 首先检查缓存的信息
78        if let Ok(cached_read) = self.cached_info.read() {
79            if let Some(ref cached) = *cached_read {
80                return cached.clone();
81            }
82        }
83
84        // 如果缓存为空,尝试动态获取
85        // 使用 try_lock 而不是 lock,避免阻塞
86        // peer_info() 是同步方法,可以安全调用
87        let client = self.client.clone();
88        if let Ok(guard) = client.try_lock() {
89            if let Some(peer_info) = guard.peer_info() {
90                let server_info = ServerInfo {
91                    protocol_version: peer_info.protocol_version.clone(),
92                    server_info: Implementation {
93                        name: peer_info.server_info.name.clone(),
94                        version: peer_info.server_info.version.clone(),
95                        title: None,
96                        website_url: None,
97                        icons: None,
98                    },
99                    instructions: peer_info.instructions.clone(),
100                    capabilities: peer_info.capabilities.clone(),
101                };
102
103                // 将动态获取的信息缓存起来
104                if let Ok(mut cached_write) = self.cached_info.write() {
105                    *cached_write = Some(server_info.clone());
106                    debug!("Successfully cached server info from peer_info");
107                }
108
109                return server_info;
110            }
111        }
112
113        // 如果都获取不到,返回错误状态信息
114        ServerInfo {
115            protocol_version: Default::default(),
116            server_info: Implementation {
117                name: "MCP Proxy - Service Unavailable".to_string(),
118                version: "0.1.0".to_string(),
119                title: None,
120                website_url: None,
121                icons: None,
122            },
123            instructions: Some("ERROR: MCP service is not available or still initializing. Please try again later.".to_string()),
124            capabilities: Default::default(), // 空的能力列表,表示服务不可用
125        }
126    }
127
128    #[tracing::instrument(skip(self, request, _context), fields(
129        mcp_id = %self.mcp_id,
130        request = ?request,
131    ))]
132    async fn list_tools(
133        &self,
134        request: Option<PaginatedRequestParam>,
135        _context: RequestContext<RoleServer>,
136    ) -> Result<ListToolsResult, ErrorData> {
137        let client = self.client.clone();
138        let guard = client.lock().await;
139
140        // Check if the server has tools capability and forward the request
141        match self.get_info().capabilities.tools {
142            Some(_) => {
143                match guard.list_tools(request).await {
144                    // Forward request to client
145                    Ok(result) => {
146                        // 根据过滤配置过滤工具列表
147                        let filtered_tools: Vec<_> = if self.tool_filter.is_enabled() {
148                            result
149                                .tools
150                                .into_iter()
151                                .filter(|tool| self.tool_filter.is_allowed(&tool.name))
152                                .collect()
153                        } else {
154                            result.tools
155                        };
156
157                        // 记录工具列表结果,这些结果会通过 SSE 推送给客户端
158                        info!(
159                            "[list_tools] 工具列表结果 - MCP ID: {}, 工具数量: {}{}",
160                            self.mcp_id,
161                            filtered_tools.len(),
162                            if self.tool_filter.is_enabled() {
163                                " (已过滤)"
164                            } else {
165                                ""
166                            }
167                        );
168
169                        debug!(
170                            "Proxying list_tools response with {} tools",
171                            filtered_tools.len()
172                        );
173                        Ok(ListToolsResult {
174                            tools: filtered_tools,
175                            next_cursor: result.next_cursor,
176                        })
177                    }
178                    Err(err) => {
179                        tracing::error!("Error listing tools: {:?}", err);
180                        // Return empty list instead of error
181                        Ok(ListToolsResult::default())
182                    }
183                }
184            }
185            None => {
186                // Server doesn't support tools, return empty list
187                tracing::error!("Server doesn't support tools capability");
188                Ok(ListToolsResult::default())
189            }
190        }
191    }
192
193    #[tracing::instrument(skip(self, request, _context), fields(
194        mcp_id = %self.mcp_id,
195        tool_name = %request.name,
196        tool_arguments = ?request.arguments,
197    ))]
198    async fn call_tool(
199        &self,
200        request: CallToolRequestParam,
201        _context: RequestContext<RoleServer>,
202    ) -> Result<CallToolResult, ErrorData> {
203        // 首先检查工具是否被过滤
204        if !self.tool_filter.is_allowed(&request.name) {
205            info!(
206                "[call_tool] 工具被过滤 - MCP ID: {}, 工具: {}",
207                self.mcp_id, request.name
208            );
209            return Ok(CallToolResult::error(vec![Content::text(format!(
210                "Tool '{}' is not allowed by filter configuration",
211                request.name
212            ))]));
213        }
214
215        let client = self.client.clone();
216        let guard = client.lock().await;
217
218        // Check if the server has tools capability and forward the request
219        match self.get_info().capabilities.tools {
220            Some(_) => {
221                match guard.call_tool(request.clone()).await {
222                    Ok(result) => {
223                        // 记录工具调用结果,这些结果会通过 SSE 推送给客户端
224                        info!(
225                            "[call_tool] 工具调用结果 - MCP ID: {}, 工具: {}",
226                            self.mcp_id, request.name
227                        );
228
229                        debug!("Tool call succeeded");
230                        Ok(result)
231                    }
232                    Err(err) => {
233                        tracing::error!("Error calling tool: {:?}", err);
234                        // Return an error result instead of propagating the error
235                        Ok(CallToolResult::error(vec![Content::text(format!(
236                            "Error: {err}"
237                        ))]))
238                    }
239                }
240            }
241            None => {
242                tracing::error!("Server doesn't support tools capability");
243                Ok(CallToolResult::error(vec![Content::text(
244                    "Server doesn't support tools capability",
245                )]))
246            }
247        }
248    }
249
250    async fn list_resources(
251        &self,
252        request: Option<PaginatedRequestParam>,
253        _context: RequestContext<RoleServer>,
254    ) -> Result<rmcp::model::ListResourcesResult, ErrorData> {
255        // Get a lock on the client
256        let client = self.client.clone();
257        let guard = client.lock().await;
258
259        // Check if the server has resources capability and forward the request
260        match self.get_info().capabilities.resources {
261            Some(_) => {
262                // Forward request to client
263                match guard.list_resources(request).await {
264                    Ok(result) => {
265                        // 记录资源列表结果,这些结果会通过 SSE 推送给客户端
266                        info!(
267                            "[list_resources] 资源列表结果 - MCP ID: {}, 资源数量: {}",
268                            self.mcp_id,
269                            result.resources.len()
270                        );
271
272                        debug!("Proxying list_resources response");
273                        Ok(result)
274                    }
275                    Err(err) => {
276                        tracing::error!("Error listing resources: {:?}", err);
277                        // Return empty list instead of error
278                        Ok(rmcp::model::ListResourcesResult::default())
279                    }
280                }
281            }
282            None => {
283                // Server doesn't support resources, return empty list
284                tracing::error!("Server doesn't support resources capability");
285                Ok(rmcp::model::ListResourcesResult::default())
286            }
287        }
288    }
289
290    async fn read_resource(
291        &self,
292        request: rmcp::model::ReadResourceRequestParam,
293        _context: RequestContext<RoleServer>,
294    ) -> Result<rmcp::model::ReadResourceResult, ErrorData> {
295        // Get a lock on the client
296        let client = self.client.clone();
297        let guard = client.lock().await;
298
299        // Check if the server has resources capability and forward the request
300        match self.get_info().capabilities.resources {
301            Some(_) => {
302                // Forward request to client
303                match guard
304                    .read_resource(rmcp::model::ReadResourceRequestParam {
305                        uri: request.uri.clone(),
306                    })
307                    .await
308                {
309                    Ok(result) => {
310                        // 记录资源读取结果,这些结果会通过 SSE 推送给客户端
311                        info!(
312                            "[read_resource] 资源读取结果 - MCP ID: {}, URI: {}",
313                            self.mcp_id, request.uri
314                        );
315
316                        debug!("Proxying read_resource response for {}", request.uri);
317                        Ok(result)
318                    }
319                    Err(err) => {
320                        tracing::error!("Error reading resource: {:?}", err);
321                        Err(ErrorData::internal_error(
322                            format!("Error reading resource: {err}"),
323                            None,
324                        ))
325                    }
326                }
327            }
328            None => {
329                // Server doesn't support resources, return error
330                tracing::error!("Server doesn't support resources capability");
331                Ok(rmcp::model::ReadResourceResult {
332                    contents: Vec::new(),
333                })
334            }
335        }
336    }
337
338    async fn list_resource_templates(
339        &self,
340        request: Option<PaginatedRequestParam>,
341        _context: RequestContext<RoleServer>,
342    ) -> Result<rmcp::model::ListResourceTemplatesResult, ErrorData> {
343        // Get a lock on the client
344        let client = self.client.clone();
345        let guard = client.lock().await;
346
347        // Check if the server has resources capability and forward the request
348        match self.get_info().capabilities.resources {
349            Some(_) => {
350                // Forward request to client
351                match guard.list_resource_templates(request).await {
352                    Ok(result) => {
353                        debug!("Proxying list_resource_templates response");
354                        Ok(result)
355                    }
356                    Err(err) => {
357                        tracing::error!("Error listing resource templates: {:?}", err);
358                        // Return empty list instead of error
359                        Ok(rmcp::model::ListResourceTemplatesResult::default())
360                    }
361                }
362            }
363            None => {
364                // Server doesn't support resources, return empty list
365                tracing::error!("Server doesn't support resources capability");
366                Ok(rmcp::model::ListResourceTemplatesResult::default())
367            }
368        }
369    }
370
371    async fn list_prompts(
372        &self,
373        request: Option<PaginatedRequestParam>,
374        _context: RequestContext<RoleServer>,
375    ) -> Result<rmcp::model::ListPromptsResult, ErrorData> {
376        // Get a lock on the client
377        let client = self.client.clone();
378        let guard = client.lock().await;
379
380        // Check if the server has prompts capability and forward the request
381        match self.get_info().capabilities.prompts {
382            Some(_) => {
383                // Forward request to client
384                match guard.list_prompts(request).await {
385                    Ok(result) => {
386                        debug!("Proxying list_prompts response");
387                        Ok(result)
388                    }
389                    Err(err) => {
390                        tracing::error!("Error listing prompts: {:?}", err);
391                        // Return empty list instead of error
392                        Ok(rmcp::model::ListPromptsResult::default())
393                    }
394                }
395            }
396            None => {
397                // Server doesn't support prompts, return empty list
398                tracing::warn!("Server doesn't support prompts capability");
399                Ok(rmcp::model::ListPromptsResult::default())
400            }
401        }
402    }
403
404    async fn get_prompt(
405        &self,
406        request: rmcp::model::GetPromptRequestParam,
407        _context: RequestContext<RoleServer>,
408    ) -> Result<rmcp::model::GetPromptResult, ErrorData> {
409        // Get a lock on the client
410        let client = self.client.clone();
411        let guard = client.lock().await;
412
413        // Check if the server has prompts capability and forward the request
414        match self.get_info().capabilities.prompts {
415            Some(_) => {
416                // Forward request to client
417                match guard.get_prompt(request).await {
418                    Ok(result) => {
419                        debug!("Proxying get_prompt response");
420                        Ok(result)
421                    }
422                    Err(err) => {
423                        tracing::error!("Error getting prompt: {:?}", err);
424                        Err(ErrorData::internal_error(
425                            format!("Error getting prompt: {err}"),
426                            None,
427                        ))
428                    }
429                }
430            }
431            None => {
432                // Server doesn't support prompts, return error
433                tracing::warn!("Server doesn't support prompts capability");
434                Ok(rmcp::model::GetPromptResult {
435                    description: None,
436                    messages: Vec::new(),
437                })
438            }
439        }
440    }
441
442    async fn complete(
443        &self,
444        request: rmcp::model::CompleteRequestParam,
445        _context: RequestContext<RoleServer>,
446    ) -> Result<rmcp::model::CompleteResult, ErrorData> {
447        // Get a lock on the client
448        let client = self.client.clone();
449        let guard = client.lock().await;
450
451        // Forward request to client
452        match guard.complete(request).await {
453            Ok(result) => {
454                debug!("Proxying complete response");
455                Ok(result)
456            }
457            Err(err) => {
458                tracing::error!("Error completing: {:?}", err);
459                Err(ErrorData::internal_error(
460                    format!("Error completing: {err}"),
461                    None,
462                ))
463            }
464        }
465    }
466
467    async fn on_progress(
468        &self,
469        notification: rmcp::model::ProgressNotificationParam,
470        _context: NotificationContext<RoleServer>,
471    ) {
472        // Get a lock on the client
473        let client = self.client.clone();
474        let guard = client.lock().await;
475        match guard.notify_progress(notification).await {
476            Ok(_) => {
477                debug!("Proxying progress notification");
478            }
479            Err(err) => {
480                tracing::error!("Error notifying progress: {:?}", err);
481            }
482        }
483    }
484
485    async fn on_cancelled(
486        &self,
487        notification: rmcp::model::CancelledNotificationParam,
488        _context: NotificationContext<RoleServer>,
489    ) {
490        // Get a lock on the client
491        let client = self.client.clone();
492        let guard = client.lock().await;
493        match guard.notify_cancelled(notification).await {
494            Ok(_) => {
495                debug!("Proxying cancelled notification");
496            }
497            Err(err) => {
498                tracing::error!("Error notifying cancelled: {:?}", err);
499            }
500        }
501    }
502}
503
504impl ProxyHandler {
505    pub fn new(client: RunningService<RoleClient, ClientInfo>) -> Self {
506        Self::with_mcp_id(client, "unknown".to_string())
507    }
508
509    pub fn with_mcp_id(client: RunningService<RoleClient, ClientInfo>, mcp_id: String) -> Self {
510        Self::with_tool_filter(client, mcp_id, ToolFilter::default())
511    }
512
513    /// 创建带工具过滤器的 ProxyHandler
514    pub fn with_tool_filter(
515        client: RunningService<RoleClient, ClientInfo>,
516        mcp_id: String,
517        tool_filter: ToolFilter,
518    ) -> Self {
519        let peer_info = client.peer_info();
520
521        // Create a ServerInfo object that forwards the server's capabilities
522        let cached_info = peer_info.map(|peer_info| ServerInfo {
523            protocol_version: peer_info.protocol_version.clone(),
524            server_info: Implementation {
525                name: peer_info.server_info.name.clone(),
526                version: peer_info.server_info.version.clone(),
527                title: None,
528                website_url: None,
529                icons: None,
530            },
531            instructions: peer_info.instructions.clone(),
532            capabilities: peer_info.capabilities.clone(),
533        });
534
535        // 记录过滤器配置
536        if tool_filter.is_enabled() {
537            if let Some(ref allow_list) = tool_filter.allow_tools {
538                info!(
539                    "[ProxyHandler] 工具白名单已启用 - MCP ID: {}, 允许的工具: {:?}",
540                    mcp_id, allow_list
541                );
542            }
543            if let Some(ref deny_list) = tool_filter.deny_tools {
544                info!(
545                    "[ProxyHandler] 工具黑名单已启用 - MCP ID: {}, 排除的工具: {:?}",
546                    mcp_id, deny_list
547                );
548            }
549        }
550
551        Self {
552            client: Arc::new(Mutex::new(client)),
553            cached_info: Arc::new(RwLock::new(cached_info)),
554            mcp_id,
555            tool_filter,
556        }
557    }
558
559    //检查 mcp服务是否正常,尝试调用 list_tools 方法,如果成功返回结果,则认为成功
560    pub async fn is_mcp_server_ready(&self) -> bool {
561        // 使用 try_lock 避免在定时检查时阻塞正常的业务请求
562        // 如果无法获取锁,说明正在处理其他请求,假设服务正常
563        match self.client.try_lock() {
564            Ok(guard) => (guard.list_tools(None).await).is_ok(),
565            Err(_) => {
566                debug!("is_mcp_server_ready: 无法获取锁,假设服务正常");
567                true
568            }
569        }
570    }
571
572    /// 检查子进程是否已经终止
573    pub fn is_terminated(&self) -> bool {
574        // 尝试获取锁,如果无法获取锁,说明子进程可能已经终止
575        match self.client.try_lock() {
576            Ok(_) => {
577                // 能够获取锁,我们假设子进程仍在运行
578                // 注意:我们不再尝试执行异步操作,因为这会导致运行时嵌套问题
579                false
580            }
581            Err(_) => {
582                // 无法获取锁,可能是因为子进程正在被其他线程使用
583                debug!("子进程状态检查: 无法获取锁,假设子进程仍在运行");
584                false // 这种情况下我们假设子进程还在运行
585            }
586        }
587    }
588
589    /// 异步检查子进程是否已经终止
590    pub async fn is_terminated_async(&self) -> bool {
591        // 尝试获取锁,如果无法获取锁,说明子进程可能已经终止
592        match self.client.try_lock() {
593            Ok(guard) => {
594                // 检查客户端是否已经终止
595                // 这里我们通过尝试发送一个轻量级请求来检测连接状态
596                match guard.list_tools(None).await {
597                    Ok(_) => {
598                        debug!("子进程状态检查: 正在运行");
599                        false // 成功获取信息,子进程还在运行
600                    }
601                    Err(e) => {
602                        info!("子进程状态检查: 已终止,原因: {e}");
603                        true // 无法获取信息,子进程可能已终止
604                    }
605                }
606            }
607            Err(_) => {
608                // 无法获取锁,可能是因为子进程正在被其他线程使用
609                debug!("子进程状态检查: 无法获取锁,假设子进程仍在运行");
610                false // 这种情况下我们假设子进程还在运行
611            }
612        }
613    }
614}