1use std::collections::HashMap;
22use std::sync::Arc;
23use tokio::sync::RwLock;
24
25use crate::mcp::config_manager::McpConfigManager;
26use crate::mcp::connection_manager::{ConnectionManager, McpConnectionManager};
27use crate::mcp::error::{McpError, McpResult};
28use crate::mcp::lifecycle_manager::{
29 LifecycleManager, McpLifecycleManager, StartOptions, StopOptions,
30};
31use crate::mcp::tool_manager::{McpTool, McpToolManager, ToolCallResult, ToolManager};
32use crate::mcp::types::{JsonObject, McpServerConfig, McpServerInfo};
33use crate::permission::{PermissionContext, PermissionResult, ToolPermissionManager};
34use crate::tools::{McpToolWrapper, Tool};
35
36pub struct McpIntegration<C: ConnectionManager + 'static> {
43 connection_manager: Arc<C>,
45 lifecycle_manager: Arc<McpLifecycleManager>,
47 config_manager: Arc<McpConfigManager>,
49 tool_manager: Arc<McpToolManager<C>>,
51 permission_manager: Option<Arc<RwLock<ToolPermissionManager>>>,
53 server_extension_map: Arc<RwLock<HashMap<String, String>>>,
55}
56
57impl McpIntegration<McpConnectionManager> {
58 pub fn new() -> Self {
60 let connection_manager = Arc::new(McpConnectionManager::new());
61 let lifecycle_manager = Arc::new(McpLifecycleManager::new());
62 let config_manager = Arc::new(McpConfigManager::new());
63 let tool_manager = Arc::new(McpToolManager::new(connection_manager.clone()));
64
65 Self {
66 connection_manager,
67 lifecycle_manager,
68 config_manager,
69 tool_manager,
70 permission_manager: None,
71 server_extension_map: Arc::new(RwLock::new(HashMap::new())),
72 }
73 }
74
75 pub fn with_components(
77 connection_manager: Arc<McpConnectionManager>,
78 lifecycle_manager: Arc<McpLifecycleManager>,
79 config_manager: Arc<McpConfigManager>,
80 ) -> Self {
81 let tool_manager = Arc::new(McpToolManager::new(connection_manager.clone()));
82
83 Self {
84 connection_manager,
85 lifecycle_manager,
86 config_manager,
87 tool_manager,
88 permission_manager: None,
89 server_extension_map: Arc::new(RwLock::new(HashMap::new())),
90 }
91 }
92}
93
94impl Default for McpIntegration<McpConnectionManager> {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100impl<C: ConnectionManager + 'static> McpIntegration<C> {
101 pub fn set_permission_manager(&mut self, manager: Arc<RwLock<ToolPermissionManager>>) {
105 self.permission_manager = Some(manager);
106 }
107
108 pub fn connection_manager(&self) -> &Arc<C> {
112 &self.connection_manager
113 }
114
115 pub fn lifecycle_manager(&self) -> &Arc<McpLifecycleManager> {
119 &self.lifecycle_manager
120 }
121
122 pub fn config_manager(&self) -> &Arc<McpConfigManager> {
124 &self.config_manager
125 }
126
127 pub fn tool_manager(&self) -> &Arc<McpToolManager<C>> {
131 &self.tool_manager
132 }
133
134 pub async fn enable_extension(
147 &self,
148 extension_name: &str,
149 config: McpServerConfig,
150 ) -> McpResult<()> {
151 let server_name = extension_name.to_string();
152
153 self.lifecycle_manager
155 .register_server(&server_name, config.clone());
156
157 let start_options = StartOptions {
159 wait_for_ready: true,
160 ..Default::default()
161 };
162 self.lifecycle_manager
163 .start(&server_name, Some(start_options))
164 .await?;
165
166 let server_info = McpServerInfo::from_config(&server_name, &config);
168
169 self.connection_manager.connect(server_info).await?;
171
172 {
174 let mut map = self.server_extension_map.write().await;
175 map.insert(server_name, extension_name.to_string());
176 }
177
178 Ok(())
179 }
180
181 pub async fn disable_extension(&self, extension_name: &str) -> McpResult<()> {
190 let server_name = extension_name.to_string();
191
192 if let Some(conn) = self
194 .connection_manager
195 .get_connection_by_server(&server_name)
196 {
197 self.connection_manager.disconnect(&conn.id).await?;
199 }
200
201 let stop_options = StopOptions {
203 reason: Some("Extension disabled".to_string()),
204 ..Default::default()
205 };
206 self.lifecycle_manager
207 .stop(&server_name, Some(stop_options))
208 .await?;
209
210 self.lifecycle_manager
212 .unregister_server(&server_name)
213 .await?;
214
215 {
217 let mut map = self.server_extension_map.write().await;
218 map.remove(&server_name);
219 }
220
221 self.tool_manager.clear_cache(Some(&server_name));
223
224 Ok(())
225 }
226
227 pub async fn is_extension_enabled(&self, extension_name: &str) -> bool {
229 self.lifecycle_manager.is_running(extension_name)
230 }
231
232 pub fn get_enabled_extensions(&self) -> Vec<String> {
234 self.lifecycle_manager.get_running_servers()
235 }
236
237 pub async fn list_tools(&self) -> McpResult<Vec<McpTool>> {
248 self.tool_manager.list_tools(None).await
249 }
250
251 pub async fn list_tools_from_server(&self, server_name: &str) -> McpResult<Vec<McpTool>> {
253 self.tool_manager.list_tools(Some(server_name)).await
254 }
255
256 pub async fn get_tool(&self, server_name: &str, tool_name: &str) -> McpResult<Option<McpTool>> {
258 self.tool_manager.get_tool(server_name, tool_name).await
259 }
260
261 pub async fn get_tool_wrappers(&self) -> McpResult<Vec<McpToolWrapper>> {
268 let tools = self.list_tools().await?;
269 Ok(tools
270 .into_iter()
271 .map(|tool| {
272 McpToolWrapper::new(
273 format!("{}_{}", tool.server_name, tool.name),
274 tool.description.unwrap_or_default(),
275 tool.input_schema,
276 tool.server_name,
277 )
278 })
279 .collect())
280 }
281
282 pub async fn register_tools_with_registry(
289 &self,
290 registry: &mut crate::tools::ToolRegistry,
291 ) -> McpResult<usize> {
292 let wrappers = self.get_tool_wrappers().await?;
293 let count = wrappers.len();
294
295 for wrapper in wrappers {
296 let name = wrapper.name().to_string();
297 registry.register_mcp(name, wrapper);
298 }
299
300 Ok(count)
301 }
302
303 pub fn unregister_tools_from_registry(
308 &self,
309 registry: &mut crate::tools::ToolRegistry,
310 server_name: Option<&str>,
311 ) {
312 let mcp_tool_names: Vec<String> = registry
313 .mcp_tool_names()
314 .iter()
315 .map(|s| s.to_string())
316 .collect();
317
318 for name in mcp_tool_names {
319 if let Some(server) = server_name {
321 if name.starts_with(&format!("{}_", server)) {
322 registry.unregister_mcp(&name);
323 }
324 } else {
325 registry.unregister_mcp(&name);
326 }
327 }
328 }
329
330 pub async fn call_tool(
339 &self,
340 server_name: &str,
341 tool_name: &str,
342 args: JsonObject,
343 context: &PermissionContext,
344 ) -> McpResult<ToolCallResult> {
345 if let Some(ref perm_manager) = self.permission_manager {
347 let full_tool_name = format!("{}_{}", server_name, tool_name);
348 let params_map = args.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
349
350 let perm_result =
351 perm_manager
352 .read()
353 .await
354 .is_allowed(&full_tool_name, ¶ms_map, context);
355
356 if !perm_result.allowed {
357 return Err(McpError::permission_denied(
358 perm_result.reason.unwrap_or_else(|| {
359 format!("Permission denied for tool '{}'", full_tool_name)
360 }),
361 ));
362 }
363 }
364
365 self.tool_manager
367 .call_tool(server_name, tool_name, args)
368 .await
369 }
370
371 pub async fn call_tool_unchecked(
375 &self,
376 server_name: &str,
377 tool_name: &str,
378 args: JsonObject,
379 ) -> McpResult<ToolCallResult> {
380 self.tool_manager
381 .call_tool(server_name, tool_name, args)
382 .await
383 }
384
385 pub async fn check_tool_permission(
396 &self,
397 server_name: &str,
398 tool_name: &str,
399 args: &JsonObject,
400 context: &PermissionContext,
401 ) -> PermissionResult {
402 if let Some(ref perm_manager) = self.permission_manager {
403 let full_tool_name = format!("{}_{}", server_name, tool_name);
405 let params_map = args.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
406
407 perm_manager
408 .read()
409 .await
410 .is_allowed(&full_tool_name, ¶ms_map, context)
411 } else {
412 PermissionResult {
414 allowed: true,
415 reason: None,
416 restricted: false,
417 suggestions: Vec::new(),
418 matched_rule: None,
419 violations: Vec::new(),
420 }
421 }
422 }
423
424 pub async fn check_tools_permissions(
428 &self,
429 tools: &[(String, String, JsonObject)], context: &PermissionContext,
431 ) -> Vec<(String, PermissionResult)> {
432 let mut results = Vec::new();
433
434 for (server_name, tool_name, args) in tools {
435 let full_name = format!("{}_{}", server_name, tool_name);
436 let result = self
437 .check_tool_permission(server_name, tool_name, args, context)
438 .await;
439 results.push((full_name, result));
440 }
441
442 results
443 }
444
445 pub async fn is_tool_allowed(
452 &self,
453 server_name: &str,
454 tool_name: &str,
455 context: &PermissionContext,
456 ) -> bool {
457 let empty_args = serde_json::Map::new();
458 let result = self
459 .check_tool_permission(server_name, tool_name, &empty_args, context)
460 .await;
461 result.allowed
462 }
463
464 pub async fn get_denied_tools(&self, context: &PermissionContext) -> Vec<String> {
471 if let Some(ref perm_manager) = self.permission_manager {
472 let manager = perm_manager.read().await;
473 let permissions = manager.get_permissions(None);
474
475 permissions
476 .iter()
477 .filter(|p| !p.allowed)
478 .filter(|p| {
479 crate::permission::check_conditions(&p.conditions, context)
481 })
482 .map(|p| p.tool.clone())
483 .collect()
484 } else {
485 Vec::new()
486 }
487 }
488
489 pub async fn filter_allowed_tools(
495 &self,
496 tools: Vec<McpTool>,
497 context: &PermissionContext,
498 ) -> Vec<McpTool> {
499 let mut allowed_tools = Vec::new();
500
501 for tool in tools {
502 if self
503 .is_tool_allowed(&tool.server_name, &tool.name, context)
504 .await
505 {
506 allowed_tools.push(tool);
507 }
508 }
509
510 allowed_tools
511 }
512
513 pub async fn list_allowed_tools(&self, context: &PermissionContext) -> McpResult<Vec<McpTool>> {
519 let all_tools = self.list_tools().await?;
520 Ok(self.filter_allowed_tools(all_tools, context).await)
521 }
522}
523
524impl McpServerInfo {
526 pub fn from_config(name: &str, config: &McpServerConfig) -> Self {
528 use crate::mcp::types::ConnectionOptions;
529
530 Self {
531 name: name.to_string(),
532 transport_type: config.transport_type,
533 command: config.command.clone(),
534 args: config.args.clone(),
535 env: config.env.clone(),
536 url: config.url.clone(),
537 headers: config.headers.clone(),
538 options: ConnectionOptions::default(),
539 }
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use crate::mcp::types::TransportType;
547
548 #[test]
549 fn test_mcp_integration_new() {
550 let integration = McpIntegration::new();
551 assert!(integration.permission_manager.is_none());
552 }
553
554 #[test]
555 fn test_mcp_integration_set_permission_manager() {
556 let mut integration = McpIntegration::new();
557 let perm_manager = Arc::new(RwLock::new(ToolPermissionManager::new(None)));
558 integration.set_permission_manager(perm_manager);
559 assert!(integration.permission_manager.is_some());
560 }
561
562 #[test]
563 fn test_server_info_from_config() {
564 let config = McpServerConfig {
565 transport_type: TransportType::Stdio,
566 command: Some("echo".to_string()),
567 args: Some(vec!["hello".to_string()]),
568 enabled: true,
569 ..Default::default()
570 };
571
572 let info = McpServerInfo::from_config("test_server", &config);
573 assert_eq!(info.name, "test_server");
574 assert_eq!(info.transport_type, TransportType::Stdio);
575 assert_eq!(info.command, Some("echo".to_string()));
576 }
577
578 #[tokio::test]
579 async fn test_check_tool_permission_no_manager() {
580 let integration = McpIntegration::new();
581 let context = PermissionContext {
582 working_directory: std::path::PathBuf::from("/tmp"),
583 session_id: "test".to_string(),
584 timestamp: 0,
585 user: None,
586 environment: HashMap::new(),
587 metadata: HashMap::new(),
588 };
589
590 let args = serde_json::Map::new();
591 let result = integration
592 .check_tool_permission("server", "tool", &args, &context)
593 .await;
594
595 assert!(result.allowed);
597 }
598
599 #[tokio::test]
600 async fn test_get_enabled_extensions_empty() {
601 let integration = McpIntegration::new();
602 let extensions = integration.get_enabled_extensions();
603 assert!(extensions.is_empty());
604 }
605
606 #[test]
607 fn test_mcp_tool_wrapper_creation() {
608 use crate::tools::Tool;
609
610 let wrapper = McpToolWrapper::new(
611 "server_tool",
612 "A test tool",
613 serde_json::json!({"type": "object"}),
614 "test_server",
615 );
616
617 assert_eq!(wrapper.name(), "server_tool");
618 assert_eq!(wrapper.description(), "A test tool");
619 assert_eq!(wrapper.server_name(), "test_server");
620 }
621
622 #[tokio::test]
623 async fn test_is_tool_allowed_no_manager() {
624 let integration = McpIntegration::new();
625 let context = PermissionContext {
626 working_directory: std::path::PathBuf::from("/tmp"),
627 session_id: "test".to_string(),
628 timestamp: 0,
629 user: None,
630 environment: HashMap::new(),
631 metadata: HashMap::new(),
632 };
633
634 let allowed = integration
636 .is_tool_allowed("server", "tool", &context)
637 .await;
638 assert!(allowed);
639 }
640
641 #[tokio::test]
642 async fn test_get_denied_tools_no_manager() {
643 let integration = McpIntegration::new();
644 let context = PermissionContext {
645 working_directory: std::path::PathBuf::from("/tmp"),
646 session_id: "test".to_string(),
647 timestamp: 0,
648 user: None,
649 environment: HashMap::new(),
650 metadata: HashMap::new(),
651 };
652
653 let denied = integration.get_denied_tools(&context).await;
655 assert!(denied.is_empty());
656 }
657
658 #[tokio::test]
659 async fn test_filter_allowed_tools_no_manager() {
660 let integration = McpIntegration::new();
661 let context = PermissionContext {
662 working_directory: std::path::PathBuf::from("/tmp"),
663 session_id: "test".to_string(),
664 timestamp: 0,
665 user: None,
666 environment: HashMap::new(),
667 metadata: HashMap::new(),
668 };
669
670 let tools = vec![
671 McpTool::new("tool1", "server1", serde_json::json!({})),
672 McpTool::new("tool2", "server1", serde_json::json!({})),
673 ];
674
675 let allowed = integration
677 .filter_allowed_tools(tools.clone(), &context)
678 .await;
679 assert_eq!(allowed.len(), 2);
680 }
681
682 #[tokio::test]
683 async fn test_check_tools_permissions_multiple() {
684 let integration = McpIntegration::new();
685 let context = PermissionContext {
686 working_directory: std::path::PathBuf::from("/tmp"),
687 session_id: "test".to_string(),
688 timestamp: 0,
689 user: None,
690 environment: HashMap::new(),
691 metadata: HashMap::new(),
692 };
693
694 let tools = vec![
695 (
696 "server1".to_string(),
697 "tool1".to_string(),
698 serde_json::Map::new(),
699 ),
700 (
701 "server2".to_string(),
702 "tool2".to_string(),
703 serde_json::Map::new(),
704 ),
705 ];
706
707 let results = integration.check_tools_permissions(&tools, &context).await;
708 assert_eq!(results.len(), 2);
709
710 for (_, result) in results {
712 assert!(result.allowed);
713 }
714 }
715}