1use log::{debug, info};
2use rmcp::{
6 ErrorData, RoleClient, RoleServer, ServerHandler,
7 model::{
8 CallToolRequestParam, CallToolResult, ClientInfo, Content, Implementation, ListToolsResult,
9 PaginatedRequestParam, ServerInfo,
10 },
11 service::{NotificationContext, RequestContext, RunningService},
12};
13use std::collections::HashSet;
14use std::sync::{Arc, RwLock};
15use tokio::sync::Mutex;
16
17#[derive(Clone, Debug, Default)]
19pub struct ToolFilter {
20 pub allow_tools: Option<HashSet<String>>,
22 pub deny_tools: Option<HashSet<String>>,
24}
25
26impl ToolFilter {
27 pub fn allow(tools: Vec<String>) -> Self {
29 Self {
30 allow_tools: Some(tools.into_iter().collect()),
31 deny_tools: None,
32 }
33 }
34
35 pub fn deny(tools: Vec<String>) -> Self {
37 Self {
38 allow_tools: None,
39 deny_tools: Some(tools.into_iter().collect()),
40 }
41 }
42
43 pub fn is_allowed(&self, tool_name: &str) -> bool {
45 if let Some(ref allow_list) = self.allow_tools {
47 return allow_list.contains(tool_name);
48 }
49 if let Some(ref deny_list) = self.deny_tools {
51 return !deny_list.contains(tool_name);
52 }
53 true
55 }
56
57 pub fn is_enabled(&self) -> bool {
59 self.allow_tools.is_some() || self.deny_tools.is_some()
60 }
61}
62
63#[derive(Clone, Debug)]
65pub struct ProxyHandler {
66 client: Arc<Mutex<RunningService<RoleClient, ClientInfo>>>,
67 cached_info: Arc<RwLock<Option<ServerInfo>>>,
69 mcp_id: String,
71 tool_filter: ToolFilter,
73}
74
75impl ServerHandler for ProxyHandler {
76 fn get_info(&self) -> ServerInfo {
77 if let Ok(cached_read) = self.cached_info.read() {
79 if let Some(ref cached) = *cached_read {
80 return cached.clone();
81 }
82 }
83
84 let client = self.client.clone();
88 if let Ok(guard) = client.try_lock() {
89 if let Some(peer_info) = guard.peer_info() {
90 let server_info = ServerInfo {
91 protocol_version: peer_info.protocol_version.clone(),
92 server_info: Implementation {
93 name: peer_info.server_info.name.clone(),
94 version: peer_info.server_info.version.clone(),
95 title: None,
96 website_url: None,
97 icons: None,
98 },
99 instructions: peer_info.instructions.clone(),
100 capabilities: peer_info.capabilities.clone(),
101 };
102
103 if let Ok(mut cached_write) = self.cached_info.write() {
105 *cached_write = Some(server_info.clone());
106 debug!("Successfully cached server info from peer_info");
107 }
108
109 return server_info;
110 }
111 }
112
113 ServerInfo {
115 protocol_version: Default::default(),
116 server_info: Implementation {
117 name: "MCP Proxy - Service Unavailable".to_string(),
118 version: "0.1.0".to_string(),
119 title: None,
120 website_url: None,
121 icons: None,
122 },
123 instructions: Some("ERROR: MCP service is not available or still initializing. Please try again later.".to_string()),
124 capabilities: Default::default(), }
126 }
127
128 #[tracing::instrument(skip(self, request, _context), fields(
129 mcp_id = %self.mcp_id,
130 request = ?request,
131 ))]
132 async fn list_tools(
133 &self,
134 request: Option<PaginatedRequestParam>,
135 _context: RequestContext<RoleServer>,
136 ) -> Result<ListToolsResult, ErrorData> {
137 let client = self.client.clone();
138 let guard = client.lock().await;
139
140 match self.get_info().capabilities.tools {
142 Some(_) => {
143 match guard.list_tools(request).await {
144 Ok(result) => {
146 let filtered_tools: Vec<_> = if self.tool_filter.is_enabled() {
148 result
149 .tools
150 .into_iter()
151 .filter(|tool| self.tool_filter.is_allowed(&tool.name))
152 .collect()
153 } else {
154 result.tools
155 };
156
157 info!(
159 "[list_tools] 工具列表结果 - MCP ID: {}, 工具数量: {}{}",
160 self.mcp_id,
161 filtered_tools.len(),
162 if self.tool_filter.is_enabled() {
163 " (已过滤)"
164 } else {
165 ""
166 }
167 );
168
169 debug!(
170 "Proxying list_tools response with {} tools",
171 filtered_tools.len()
172 );
173 Ok(ListToolsResult {
174 tools: filtered_tools,
175 next_cursor: result.next_cursor,
176 })
177 }
178 Err(err) => {
179 tracing::error!("Error listing tools: {:?}", err);
180 Ok(ListToolsResult::default())
182 }
183 }
184 }
185 None => {
186 tracing::error!("Server doesn't support tools capability");
188 Ok(ListToolsResult::default())
189 }
190 }
191 }
192
193 #[tracing::instrument(skip(self, request, _context), fields(
194 mcp_id = %self.mcp_id,
195 tool_name = %request.name,
196 tool_arguments = ?request.arguments,
197 ))]
198 async fn call_tool(
199 &self,
200 request: CallToolRequestParam,
201 _context: RequestContext<RoleServer>,
202 ) -> Result<CallToolResult, ErrorData> {
203 if !self.tool_filter.is_allowed(&request.name) {
205 info!(
206 "[call_tool] 工具被过滤 - MCP ID: {}, 工具: {}",
207 self.mcp_id, request.name
208 );
209 return Ok(CallToolResult::error(vec![Content::text(format!(
210 "Tool '{}' is not allowed by filter configuration",
211 request.name
212 ))]));
213 }
214
215 let client = self.client.clone();
216 let guard = client.lock().await;
217
218 match self.get_info().capabilities.tools {
220 Some(_) => {
221 match guard.call_tool(request.clone()).await {
222 Ok(result) => {
223 info!(
225 "[call_tool] 工具调用结果 - MCP ID: {}, 工具: {}",
226 self.mcp_id, request.name
227 );
228
229 debug!("Tool call succeeded");
230 Ok(result)
231 }
232 Err(err) => {
233 tracing::error!("Error calling tool: {:?}", err);
234 Ok(CallToolResult::error(vec![Content::text(format!(
236 "Error: {err}"
237 ))]))
238 }
239 }
240 }
241 None => {
242 tracing::error!("Server doesn't support tools capability");
243 Ok(CallToolResult::error(vec![Content::text(
244 "Server doesn't support tools capability",
245 )]))
246 }
247 }
248 }
249
250 async fn list_resources(
251 &self,
252 request: Option<PaginatedRequestParam>,
253 _context: RequestContext<RoleServer>,
254 ) -> Result<rmcp::model::ListResourcesResult, ErrorData> {
255 let client = self.client.clone();
257 let guard = client.lock().await;
258
259 match self.get_info().capabilities.resources {
261 Some(_) => {
262 match guard.list_resources(request).await {
264 Ok(result) => {
265 info!(
267 "[list_resources] 资源列表结果 - MCP ID: {}, 资源数量: {}",
268 self.mcp_id,
269 result.resources.len()
270 );
271
272 debug!("Proxying list_resources response");
273 Ok(result)
274 }
275 Err(err) => {
276 tracing::error!("Error listing resources: {:?}", err);
277 Ok(rmcp::model::ListResourcesResult::default())
279 }
280 }
281 }
282 None => {
283 tracing::error!("Server doesn't support resources capability");
285 Ok(rmcp::model::ListResourcesResult::default())
286 }
287 }
288 }
289
290 async fn read_resource(
291 &self,
292 request: rmcp::model::ReadResourceRequestParam,
293 _context: RequestContext<RoleServer>,
294 ) -> Result<rmcp::model::ReadResourceResult, ErrorData> {
295 let client = self.client.clone();
297 let guard = client.lock().await;
298
299 match self.get_info().capabilities.resources {
301 Some(_) => {
302 match guard
304 .read_resource(rmcp::model::ReadResourceRequestParam {
305 uri: request.uri.clone(),
306 })
307 .await
308 {
309 Ok(result) => {
310 info!(
312 "[read_resource] 资源读取结果 - MCP ID: {}, URI: {}",
313 self.mcp_id, request.uri
314 );
315
316 debug!("Proxying read_resource response for {}", request.uri);
317 Ok(result)
318 }
319 Err(err) => {
320 tracing::error!("Error reading resource: {:?}", err);
321 Err(ErrorData::internal_error(
322 format!("Error reading resource: {err}"),
323 None,
324 ))
325 }
326 }
327 }
328 None => {
329 tracing::error!("Server doesn't support resources capability");
331 Ok(rmcp::model::ReadResourceResult {
332 contents: Vec::new(),
333 })
334 }
335 }
336 }
337
338 async fn list_resource_templates(
339 &self,
340 request: Option<PaginatedRequestParam>,
341 _context: RequestContext<RoleServer>,
342 ) -> Result<rmcp::model::ListResourceTemplatesResult, ErrorData> {
343 let client = self.client.clone();
345 let guard = client.lock().await;
346
347 match self.get_info().capabilities.resources {
349 Some(_) => {
350 match guard.list_resource_templates(request).await {
352 Ok(result) => {
353 debug!("Proxying list_resource_templates response");
354 Ok(result)
355 }
356 Err(err) => {
357 tracing::error!("Error listing resource templates: {:?}", err);
358 Ok(rmcp::model::ListResourceTemplatesResult::default())
360 }
361 }
362 }
363 None => {
364 tracing::error!("Server doesn't support resources capability");
366 Ok(rmcp::model::ListResourceTemplatesResult::default())
367 }
368 }
369 }
370
371 async fn list_prompts(
372 &self,
373 request: Option<PaginatedRequestParam>,
374 _context: RequestContext<RoleServer>,
375 ) -> Result<rmcp::model::ListPromptsResult, ErrorData> {
376 let client = self.client.clone();
378 let guard = client.lock().await;
379
380 match self.get_info().capabilities.prompts {
382 Some(_) => {
383 match guard.list_prompts(request).await {
385 Ok(result) => {
386 debug!("Proxying list_prompts response");
387 Ok(result)
388 }
389 Err(err) => {
390 tracing::error!("Error listing prompts: {:?}", err);
391 Ok(rmcp::model::ListPromptsResult::default())
393 }
394 }
395 }
396 None => {
397 tracing::warn!("Server doesn't support prompts capability");
399 Ok(rmcp::model::ListPromptsResult::default())
400 }
401 }
402 }
403
404 async fn get_prompt(
405 &self,
406 request: rmcp::model::GetPromptRequestParam,
407 _context: RequestContext<RoleServer>,
408 ) -> Result<rmcp::model::GetPromptResult, ErrorData> {
409 let client = self.client.clone();
411 let guard = client.lock().await;
412
413 match self.get_info().capabilities.prompts {
415 Some(_) => {
416 match guard.get_prompt(request).await {
418 Ok(result) => {
419 debug!("Proxying get_prompt response");
420 Ok(result)
421 }
422 Err(err) => {
423 tracing::error!("Error getting prompt: {:?}", err);
424 Err(ErrorData::internal_error(
425 format!("Error getting prompt: {err}"),
426 None,
427 ))
428 }
429 }
430 }
431 None => {
432 tracing::warn!("Server doesn't support prompts capability");
434 Ok(rmcp::model::GetPromptResult {
435 description: None,
436 messages: Vec::new(),
437 })
438 }
439 }
440 }
441
442 async fn complete(
443 &self,
444 request: rmcp::model::CompleteRequestParam,
445 _context: RequestContext<RoleServer>,
446 ) -> Result<rmcp::model::CompleteResult, ErrorData> {
447 let client = self.client.clone();
449 let guard = client.lock().await;
450
451 match guard.complete(request).await {
453 Ok(result) => {
454 debug!("Proxying complete response");
455 Ok(result)
456 }
457 Err(err) => {
458 tracing::error!("Error completing: {:?}", err);
459 Err(ErrorData::internal_error(
460 format!("Error completing: {err}"),
461 None,
462 ))
463 }
464 }
465 }
466
467 async fn on_progress(
468 &self,
469 notification: rmcp::model::ProgressNotificationParam,
470 _context: NotificationContext<RoleServer>,
471 ) {
472 let client = self.client.clone();
474 let guard = client.lock().await;
475 match guard.notify_progress(notification).await {
476 Ok(_) => {
477 debug!("Proxying progress notification");
478 }
479 Err(err) => {
480 tracing::error!("Error notifying progress: {:?}", err);
481 }
482 }
483 }
484
485 async fn on_cancelled(
486 &self,
487 notification: rmcp::model::CancelledNotificationParam,
488 _context: NotificationContext<RoleServer>,
489 ) {
490 let client = self.client.clone();
492 let guard = client.lock().await;
493 match guard.notify_cancelled(notification).await {
494 Ok(_) => {
495 debug!("Proxying cancelled notification");
496 }
497 Err(err) => {
498 tracing::error!("Error notifying cancelled: {:?}", err);
499 }
500 }
501 }
502}
503
504impl ProxyHandler {
505 pub fn new(client: RunningService<RoleClient, ClientInfo>) -> Self {
506 Self::with_mcp_id(client, "unknown".to_string())
507 }
508
509 pub fn with_mcp_id(client: RunningService<RoleClient, ClientInfo>, mcp_id: String) -> Self {
510 Self::with_tool_filter(client, mcp_id, ToolFilter::default())
511 }
512
513 pub fn with_tool_filter(
515 client: RunningService<RoleClient, ClientInfo>,
516 mcp_id: String,
517 tool_filter: ToolFilter,
518 ) -> Self {
519 let peer_info = client.peer_info();
520
521 let cached_info = peer_info.map(|peer_info| ServerInfo {
523 protocol_version: peer_info.protocol_version.clone(),
524 server_info: Implementation {
525 name: peer_info.server_info.name.clone(),
526 version: peer_info.server_info.version.clone(),
527 title: None,
528 website_url: None,
529 icons: None,
530 },
531 instructions: peer_info.instructions.clone(),
532 capabilities: peer_info.capabilities.clone(),
533 });
534
535 if tool_filter.is_enabled() {
537 if let Some(ref allow_list) = tool_filter.allow_tools {
538 info!(
539 "[ProxyHandler] 工具白名单已启用 - MCP ID: {}, 允许的工具: {:?}",
540 mcp_id, allow_list
541 );
542 }
543 if let Some(ref deny_list) = tool_filter.deny_tools {
544 info!(
545 "[ProxyHandler] 工具黑名单已启用 - MCP ID: {}, 排除的工具: {:?}",
546 mcp_id, deny_list
547 );
548 }
549 }
550
551 Self {
552 client: Arc::new(Mutex::new(client)),
553 cached_info: Arc::new(RwLock::new(cached_info)),
554 mcp_id,
555 tool_filter,
556 }
557 }
558
559 pub async fn is_mcp_server_ready(&self) -> bool {
561 match self.client.try_lock() {
564 Ok(guard) => (guard.list_tools(None).await).is_ok(),
565 Err(_) => {
566 debug!("is_mcp_server_ready: 无法获取锁,假设服务正常");
567 true
568 }
569 }
570 }
571
572 pub fn is_terminated(&self) -> bool {
574 match self.client.try_lock() {
576 Ok(_) => {
577 false
580 }
581 Err(_) => {
582 debug!("子进程状态检查: 无法获取锁,假设子进程仍在运行");
584 false }
586 }
587 }
588
589 pub async fn is_terminated_async(&self) -> bool {
591 match self.client.try_lock() {
593 Ok(guard) => {
594 match guard.list_tools(None).await {
597 Ok(_) => {
598 debug!("子进程状态检查: 正在运行");
599 false }
601 Err(e) => {
602 info!("子进程状态检查: 已终止,原因: {e}");
603 true }
605 }
606 }
607 Err(_) => {
608 debug!("子进程状态检查: 无法获取锁,假设子进程仍在运行");
610 false }
612 }
613 }
614}