1use log::{debug, info};
2use tokio::time::{timeout, Duration};
3use rmcp::{
7 ErrorData, RoleClient, RoleServer, ServerHandler,
8 model::{
9 CallToolRequestParam, CallToolResult, ClientInfo, Content, Implementation, ListToolsResult,
10 PaginatedRequestParam, ServerInfo,
11 },
12 service::{NotificationContext, RequestContext, RunningService},
13};
14use std::collections::HashSet;
15use std::sync::{Arc, RwLock};
16use tokio::sync::Mutex;
17
18const DEFAULT_TOOL_CALL_TIMEOUT_SECS: u64 = 300;
20
21const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 60;
23
24#[derive(Clone, Debug, Default)]
26pub struct ToolFilter {
27 pub allow_tools: Option<HashSet<String>>,
29 pub deny_tools: Option<HashSet<String>>,
31}
32
33impl ToolFilter {
34 pub fn allow(tools: Vec<String>) -> Self {
36 Self {
37 allow_tools: Some(tools.into_iter().collect()),
38 deny_tools: None,
39 }
40 }
41
42 pub fn deny(tools: Vec<String>) -> Self {
44 Self {
45 allow_tools: None,
46 deny_tools: Some(tools.into_iter().collect()),
47 }
48 }
49
50 pub fn is_allowed(&self, tool_name: &str) -> bool {
52 if let Some(ref allow_list) = self.allow_tools {
54 return allow_list.contains(tool_name);
55 }
56 if let Some(ref deny_list) = self.deny_tools {
58 return !deny_list.contains(tool_name);
59 }
60 true
62 }
63
64 pub fn is_enabled(&self) -> bool {
66 self.allow_tools.is_some() || self.deny_tools.is_some()
67 }
68}
69
70#[derive(Clone, Debug)]
72pub struct ProxyHandler {
73 client: Arc<Mutex<RunningService<RoleClient, ClientInfo>>>,
74 cached_info: Arc<RwLock<Option<ServerInfo>>>,
76 mcp_id: String,
78 tool_filter: ToolFilter,
80}
81
82impl ServerHandler for ProxyHandler {
83 fn get_info(&self) -> ServerInfo {
84 if let Ok(cached_read) = self.cached_info.read() {
86 if let Some(ref cached) = *cached_read {
87 return cached.clone();
88 }
89 }
90
91 let client = self.client.clone();
95 if let Ok(guard) = client.try_lock() {
96 if let Some(peer_info) = guard.peer_info() {
97 let server_info = ServerInfo {
98 protocol_version: peer_info.protocol_version.clone(),
99 server_info: Implementation {
100 name: peer_info.server_info.name.clone(),
101 version: peer_info.server_info.version.clone(),
102 title: None,
103 website_url: None,
104 icons: None,
105 },
106 instructions: peer_info.instructions.clone(),
107 capabilities: peer_info.capabilities.clone(),
108 };
109
110 if let Ok(mut cached_write) = self.cached_info.write() {
112 *cached_write = Some(server_info.clone());
113 debug!("Successfully cached server info from peer_info");
114 }
115
116 return server_info;
117 }
118 }
119
120 ServerInfo {
122 protocol_version: Default::default(),
123 server_info: Implementation {
124 name: "MCP Proxy - Service Unavailable".to_string(),
125 version: "0.1.0".to_string(),
126 title: None,
127 website_url: None,
128 icons: None,
129 },
130 instructions: Some("ERROR: MCP service is not available or still initializing. Please try again later.".to_string()),
131 capabilities: Default::default(), }
133 }
134
135 #[tracing::instrument(skip(self, request, _context), fields(
136 mcp_id = %self.mcp_id,
137 request = ?request,
138 ))]
139 async fn list_tools(
140 &self,
141 request: Option<PaginatedRequestParam>,
142 _context: RequestContext<RoleServer>,
143 ) -> Result<ListToolsResult, ErrorData> {
144 let client = self.client.clone();
145 let guard = client.lock().await;
146
147 match self.get_info().capabilities.tools {
149 Some(_) => {
150 let timeout_duration = Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS);
151 match timeout(timeout_duration, guard.list_tools(request)).await {
152 Ok(Ok(result)) => {
153 let filtered_tools: Vec<_> = if self.tool_filter.is_enabled() {
155 result
156 .tools
157 .into_iter()
158 .filter(|tool| self.tool_filter.is_allowed(&tool.name))
159 .collect()
160 } else {
161 result.tools
162 };
163
164 info!(
166 "[list_tools] 工具列表结果 - MCP ID: {}, 工具数量: {}{}",
167 self.mcp_id,
168 filtered_tools.len(),
169 if self.tool_filter.is_enabled() {
170 " (已过滤)"
171 } else {
172 ""
173 }
174 );
175
176 debug!(
177 "Proxying list_tools response with {} tools",
178 filtered_tools.len()
179 );
180 Ok(ListToolsResult {
181 tools: filtered_tools,
182 next_cursor: result.next_cursor,
183 })
184 }
185 Ok(Err(err)) => {
186 tracing::error!("Error listing tools: {:?}", err);
187 Ok(ListToolsResult::default())
189 }
190 Err(_) => {
191 tracing::error!(
192 "[list_tools] 请求超时 - MCP ID: {}, 超时: {}s",
193 self.mcp_id,
194 DEFAULT_REQUEST_TIMEOUT_SECS
195 );
196 Ok(ListToolsResult::default())
197 }
198 }
199 }
200 None => {
201 tracing::error!("Server doesn't support tools capability");
203 Ok(ListToolsResult::default())
204 }
205 }
206 }
207
208 #[tracing::instrument(skip(self, request, _context), fields(
209 mcp_id = %self.mcp_id,
210 tool_name = %request.name,
211 tool_arguments = ?request.arguments,
212 ))]
213 async fn call_tool(
214 &self,
215 request: CallToolRequestParam,
216 _context: RequestContext<RoleServer>,
217 ) -> Result<CallToolResult, ErrorData> {
218 if !self.tool_filter.is_allowed(&request.name) {
220 info!(
221 "[call_tool] 工具被过滤 - MCP ID: {}, 工具: {}",
222 self.mcp_id, request.name
223 );
224 return Ok(CallToolResult::error(vec![Content::text(format!(
225 "Tool '{}' is not allowed by filter configuration",
226 request.name
227 ))]));
228 }
229
230 let client = self.client.clone();
231 let guard = client.lock().await;
232
233 match self.get_info().capabilities.tools {
235 Some(_) => {
236 let timeout_duration = Duration::from_secs(DEFAULT_TOOL_CALL_TIMEOUT_SECS);
238 match timeout(timeout_duration, guard.call_tool(request.clone())).await {
239 Ok(Ok(result)) => {
240 info!(
242 "[call_tool] 工具调用成功 - MCP ID: {}, 工具: {}",
243 self.mcp_id, request.name
244 );
245
246 debug!("Tool call succeeded");
247 Ok(result)
248 }
249 Ok(Err(err)) => {
250 tracing::error!("Error calling tool: {:?}", err);
251 Ok(CallToolResult::error(vec![Content::text(format!(
253 "Error: {err}"
254 ))]))
255 }
256 Err(_) => {
257 tracing::error!(
259 "[call_tool] 工具调用超时 - MCP ID: {}, 工具: {}, 超时: {}s",
260 self.mcp_id,
261 request.name,
262 DEFAULT_TOOL_CALL_TIMEOUT_SECS
263 );
264 Ok(CallToolResult::error(vec![Content::text(format!(
265 "Tool call timed out after {}s. The underlying MCP service may be unresponsive.",
266 DEFAULT_TOOL_CALL_TIMEOUT_SECS
267 ))]))
268 }
269 }
270 }
271 None => {
272 tracing::error!("Server doesn't support tools capability");
273 Ok(CallToolResult::error(vec![Content::text(
274 "Server doesn't support tools capability",
275 )]))
276 }
277 }
278 }
279
280 async fn list_resources(
281 &self,
282 request: Option<PaginatedRequestParam>,
283 _context: RequestContext<RoleServer>,
284 ) -> Result<rmcp::model::ListResourcesResult, ErrorData> {
285 let client = self.client.clone();
287 let guard = client.lock().await;
288
289 match self.get_info().capabilities.resources {
291 Some(_) => {
292 let timeout_duration = Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS);
293 match timeout(timeout_duration, guard.list_resources(request)).await {
294 Ok(Ok(result)) => {
295 info!(
297 "[list_resources] 资源列表结果 - MCP ID: {}, 资源数量: {}",
298 self.mcp_id,
299 result.resources.len()
300 );
301
302 debug!("Proxying list_resources response");
303 Ok(result)
304 }
305 Ok(Err(err)) => {
306 tracing::error!("Error listing resources: {:?}", err);
307 Ok(rmcp::model::ListResourcesResult::default())
309 }
310 Err(_) => {
311 tracing::error!(
312 "[list_resources] 请求超时 - MCP ID: {}, 超时: {}s",
313 self.mcp_id,
314 DEFAULT_REQUEST_TIMEOUT_SECS
315 );
316 Ok(rmcp::model::ListResourcesResult::default())
317 }
318 }
319 }
320 None => {
321 tracing::error!("Server doesn't support resources capability");
323 Ok(rmcp::model::ListResourcesResult::default())
324 }
325 }
326 }
327
328 async fn read_resource(
329 &self,
330 request: rmcp::model::ReadResourceRequestParam,
331 _context: RequestContext<RoleServer>,
332 ) -> Result<rmcp::model::ReadResourceResult, ErrorData> {
333 let client = self.client.clone();
335 let guard = client.lock().await;
336
337 match self.get_info().capabilities.resources {
339 Some(_) => {
340 let timeout_duration = Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS);
341 let read_future = guard.read_resource(rmcp::model::ReadResourceRequestParam {
342 uri: request.uri.clone(),
343 });
344 match timeout(timeout_duration, read_future).await {
345 Ok(Ok(result)) => {
346 info!(
348 "[read_resource] 资源读取结果 - MCP ID: {}, URI: {}",
349 self.mcp_id, request.uri
350 );
351
352 debug!("Proxying read_resource response for {}", request.uri);
353 Ok(result)
354 }
355 Ok(Err(err)) => {
356 tracing::error!("Error reading resource: {:?}", err);
357 Err(ErrorData::internal_error(
358 format!("Error reading resource: {err}"),
359 None,
360 ))
361 }
362 Err(_) => {
363 tracing::error!(
364 "[read_resource] 请求超时 - MCP ID: {}, URI: {}, 超时: {}s",
365 self.mcp_id,
366 request.uri,
367 DEFAULT_REQUEST_TIMEOUT_SECS
368 );
369 Err(ErrorData::internal_error(
370 format!("Request timed out after {}s", DEFAULT_REQUEST_TIMEOUT_SECS),
371 None,
372 ))
373 }
374 }
375 }
376 None => {
377 tracing::error!("Server doesn't support resources capability");
379 Ok(rmcp::model::ReadResourceResult {
380 contents: Vec::new(),
381 })
382 }
383 }
384 }
385
386 async fn list_resource_templates(
387 &self,
388 request: Option<PaginatedRequestParam>,
389 _context: RequestContext<RoleServer>,
390 ) -> Result<rmcp::model::ListResourceTemplatesResult, ErrorData> {
391 let client = self.client.clone();
393 let guard = client.lock().await;
394
395 match self.get_info().capabilities.resources {
397 Some(_) => {
398 let timeout_duration = Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS);
399 match timeout(timeout_duration, guard.list_resource_templates(request)).await {
400 Ok(Ok(result)) => {
401 debug!("Proxying list_resource_templates response");
402 Ok(result)
403 }
404 Ok(Err(err)) => {
405 tracing::error!("Error listing resource templates: {:?}", err);
406 Ok(rmcp::model::ListResourceTemplatesResult::default())
408 }
409 Err(_) => {
410 tracing::error!(
411 "[list_resource_templates] 请求超时 - MCP ID: {}, 超时: {}s",
412 self.mcp_id,
413 DEFAULT_REQUEST_TIMEOUT_SECS
414 );
415 Ok(rmcp::model::ListResourceTemplatesResult::default())
416 }
417 }
418 }
419 None => {
420 tracing::error!("Server doesn't support resources capability");
422 Ok(rmcp::model::ListResourceTemplatesResult::default())
423 }
424 }
425 }
426
427 async fn list_prompts(
428 &self,
429 request: Option<PaginatedRequestParam>,
430 _context: RequestContext<RoleServer>,
431 ) -> Result<rmcp::model::ListPromptsResult, ErrorData> {
432 let client = self.client.clone();
434 let guard = client.lock().await;
435
436 match self.get_info().capabilities.prompts {
438 Some(_) => {
439 let timeout_duration = Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS);
440 match timeout(timeout_duration, guard.list_prompts(request)).await {
441 Ok(Ok(result)) => {
442 debug!("Proxying list_prompts response");
443 Ok(result)
444 }
445 Ok(Err(err)) => {
446 tracing::error!("Error listing prompts: {:?}", err);
447 Ok(rmcp::model::ListPromptsResult::default())
449 }
450 Err(_) => {
451 tracing::error!(
452 "[list_prompts] 请求超时 - MCP ID: {}, 超时: {}s",
453 self.mcp_id,
454 DEFAULT_REQUEST_TIMEOUT_SECS
455 );
456 Ok(rmcp::model::ListPromptsResult::default())
457 }
458 }
459 }
460 None => {
461 tracing::warn!("Server doesn't support prompts capability");
463 Ok(rmcp::model::ListPromptsResult::default())
464 }
465 }
466 }
467
468 async fn get_prompt(
469 &self,
470 request: rmcp::model::GetPromptRequestParam,
471 _context: RequestContext<RoleServer>,
472 ) -> Result<rmcp::model::GetPromptResult, ErrorData> {
473 let client = self.client.clone();
475 let guard = client.lock().await;
476
477 match self.get_info().capabilities.prompts {
479 Some(_) => {
480 let timeout_duration = Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS);
481 match timeout(timeout_duration, guard.get_prompt(request)).await {
482 Ok(Ok(result)) => {
483 debug!("Proxying get_prompt response");
484 Ok(result)
485 }
486 Ok(Err(err)) => {
487 tracing::error!("Error getting prompt: {:?}", err);
488 Err(ErrorData::internal_error(
489 format!("Error getting prompt: {err}"),
490 None,
491 ))
492 }
493 Err(_) => {
494 tracing::error!(
495 "[get_prompt] 请求超时 - MCP ID: {}, 超时: {}s",
496 self.mcp_id,
497 DEFAULT_REQUEST_TIMEOUT_SECS
498 );
499 Err(ErrorData::internal_error(
500 format!("Request timed out after {}s", DEFAULT_REQUEST_TIMEOUT_SECS),
501 None,
502 ))
503 }
504 }
505 }
506 None => {
507 tracing::warn!("Server doesn't support prompts capability");
509 Ok(rmcp::model::GetPromptResult {
510 description: None,
511 messages: Vec::new(),
512 })
513 }
514 }
515 }
516
517 async fn complete(
518 &self,
519 request: rmcp::model::CompleteRequestParam,
520 _context: RequestContext<RoleServer>,
521 ) -> Result<rmcp::model::CompleteResult, ErrorData> {
522 let client = self.client.clone();
524 let guard = client.lock().await;
525
526 let timeout_duration = Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS);
527 match timeout(timeout_duration, guard.complete(request)).await {
528 Ok(Ok(result)) => {
529 debug!("Proxying complete response");
530 Ok(result)
531 }
532 Ok(Err(err)) => {
533 tracing::error!("Error completing: {:?}", err);
534 Err(ErrorData::internal_error(
535 format!("Error completing: {err}"),
536 None,
537 ))
538 }
539 Err(_) => {
540 tracing::error!(
541 "[complete] 请求超时 - MCP ID: {}, 超时: {}s",
542 self.mcp_id,
543 DEFAULT_REQUEST_TIMEOUT_SECS
544 );
545 Err(ErrorData::internal_error(
546 format!("Request timed out after {}s", DEFAULT_REQUEST_TIMEOUT_SECS),
547 None,
548 ))
549 }
550 }
551 }
552
553 async fn on_progress(
554 &self,
555 notification: rmcp::model::ProgressNotificationParam,
556 _context: NotificationContext<RoleServer>,
557 ) {
558 let client = self.client.clone();
560 let guard = client.lock().await;
561 match guard.notify_progress(notification).await {
562 Ok(_) => {
563 debug!("Proxying progress notification");
564 }
565 Err(err) => {
566 tracing::error!("Error notifying progress: {:?}", err);
567 }
568 }
569 }
570
571 async fn on_cancelled(
572 &self,
573 notification: rmcp::model::CancelledNotificationParam,
574 _context: NotificationContext<RoleServer>,
575 ) {
576 let client = self.client.clone();
578 let guard = client.lock().await;
579 match guard.notify_cancelled(notification).await {
580 Ok(_) => {
581 debug!("Proxying cancelled notification");
582 }
583 Err(err) => {
584 tracing::error!("Error notifying cancelled: {:?}", err);
585 }
586 }
587 }
588}
589
590impl ProxyHandler {
591 pub fn new(client: RunningService<RoleClient, ClientInfo>) -> Self {
592 Self::with_mcp_id(client, "unknown".to_string())
593 }
594
595 pub fn with_mcp_id(client: RunningService<RoleClient, ClientInfo>, mcp_id: String) -> Self {
596 Self::with_tool_filter(client, mcp_id, ToolFilter::default())
597 }
598
599 pub fn with_tool_filter(
601 client: RunningService<RoleClient, ClientInfo>,
602 mcp_id: String,
603 tool_filter: ToolFilter,
604 ) -> Self {
605 let peer_info = client.peer_info();
606
607 let cached_info = peer_info.map(|peer_info| ServerInfo {
609 protocol_version: peer_info.protocol_version.clone(),
610 server_info: Implementation {
611 name: peer_info.server_info.name.clone(),
612 version: peer_info.server_info.version.clone(),
613 title: None,
614 website_url: None,
615 icons: None,
616 },
617 instructions: peer_info.instructions.clone(),
618 capabilities: peer_info.capabilities.clone(),
619 });
620
621 if tool_filter.is_enabled() {
623 if let Some(ref allow_list) = tool_filter.allow_tools {
624 info!(
625 "[ProxyHandler] 工具白名单已启用 - MCP ID: {}, 允许的工具: {:?}",
626 mcp_id, allow_list
627 );
628 }
629 if let Some(ref deny_list) = tool_filter.deny_tools {
630 info!(
631 "[ProxyHandler] 工具黑名单已启用 - MCP ID: {}, 排除的工具: {:?}",
632 mcp_id, deny_list
633 );
634 }
635 }
636
637 Self {
638 client: Arc::new(Mutex::new(client)),
639 cached_info: Arc::new(RwLock::new(cached_info)),
640 mcp_id,
641 tool_filter,
642 }
643 }
644
645 pub async fn is_mcp_server_ready(&self) -> bool {
647 match self.client.try_lock() {
650 Ok(guard) => (guard.list_tools(None).await).is_ok(),
651 Err(_) => {
652 debug!("is_mcp_server_ready: 无法获取锁,假设服务正常");
653 true
654 }
655 }
656 }
657
658 pub fn is_terminated(&self) -> bool {
660 match self.client.try_lock() {
662 Ok(_) => {
663 false
666 }
667 Err(_) => {
668 debug!("子进程状态检查: 无法获取锁,假设子进程仍在运行");
670 false }
672 }
673 }
674
675 pub async fn is_terminated_async(&self) -> bool {
677 match self.client.try_lock() {
679 Ok(guard) => {
680 match guard.list_tools(None).await {
683 Ok(_) => {
684 debug!("子进程状态检查: 正在运行");
685 false }
687 Err(e) => {
688 info!("子进程状态检查: 已终止,原因: {e}");
689 true }
691 }
692 }
693 Err(_) => {
694 debug!("子进程状态检查: 无法获取锁,假设子进程仍在运行");
696 false }
698 }
699 }
700}