Skip to main content

mcp_protocol/
proxy.rs

1//! # MCP Proxy
2//!
3//! **Transport-aware proxy** for connecting to MCP servers
4//! - **Client ↔ Proxy**: HTTP/JSON-RPC (internal)
5//! - **Proxy ↔ MCP Server**: SSE (web-standard)
6//! - **Multi-server routing** with authentication
7
8use crate::{CallToolResult, ListToolsResult, Tool, ToolProvider};
9use async_trait::async_trait;
10use protocol_transport_core::{
11    ProtocolError, SseTransport, Transport, TransportFactory, UniversalRequest,
12};
13use serde_json::json;
14use std::collections::HashMap;
15
16/// **MCP Proxy Configuration**
17#[derive(Debug, Clone)]
18pub struct McpProxyConfig {
19    /// Target MCP servers the proxy routes to
20    pub servers: Vec<McpProxyTarget>,
21    /// Proxy authentication (for incoming requests)
22    pub proxy_auth: Option<String>,
23    /// Default timeout for requests (seconds)
24    pub timeout_seconds: u64,
25}
26
27/// **MCP Proxy Target** - An external MCP server the proxy connects to
28#[derive(Debug, Clone)]
29pub struct McpProxyTarget {
30    /// Server identifier
31    pub name: String,
32    /// SSE endpoint URL (e.g., "https://api.example.com/sse")
33    pub sse_endpoint: String,
34    /// Authentication token for this server
35    pub auth_token: Option<String>,
36    /// Server description
37    pub description: Option<String>,
38}
39
40/// **MCP Proxy** - Routes between internal HTTP and external SSE
41pub struct McpProxy {
42    /// Configuration
43    config: McpProxyConfig,
44    /// SSE transports for external servers
45    sse_transports: HashMap<String, SseTransport>,
46}
47
48impl McpProxy {
49    /// Create new MCP proxy
50    pub fn new(config: McpProxyConfig) -> Self {
51        // Create SSE transports for each server
52        let mut sse_transports = HashMap::new();
53
54        for server in &config.servers {
55            let transport = match &server.auth_token {
56                Some(token) => TransportFactory::mcp_sse_auth(&server.sse_endpoint, token),
57                None => TransportFactory::mcp_sse(&server.sse_endpoint),
58            };
59            sse_transports.insert(server.name.clone(), transport);
60        }
61
62        Self {
63            config,
64            sse_transports,
65        }
66    }
67
68    /// Send JSON-RPC request to specific server
69    async fn send_to_server(
70        &self,
71        server_name: &str,
72        method: &str,
73        params: serde_json::Value,
74    ) -> Result<serde_json::Value, ProtocolError> {
75        let transport = self.sse_transports.get(server_name).ok_or_else(|| {
76            ProtocolError::internal_error(&format!("Unknown server: {}", server_name))
77        })?;
78
79        // Build JSON-RPC request
80        let request = UniversalRequest {
81            method: method.to_string(),
82            uri: "/".to_string(),
83            headers: HashMap::new(),
84            body: json!({
85                "jsonrpc": "2.0",
86                "method": method,
87                "params": params,
88                "id": 1
89            })
90            .to_string()
91            .into_bytes(),
92            protocol: "MCP".to_string(),
93            correlation_id: format!("{}-{}", method.replace("/", "-"), server_name),
94        };
95
96        // Send via SSE transport
97        let response = transport
98            .send(request)
99            .await
100            .map_err(|e| ProtocolError::internal_error(&format!("Transport error: {:?}", e)))?;
101
102        // Parse JSON-RPC response
103        let response_body = String::from_utf8(response.body)
104            .map_err(|e| ProtocolError::Parsing(format!("Invalid UTF-8 response: {}", e)))?;
105
106        let response_json: serde_json::Value = serde_json::from_str(&response_body)
107            .map_err(|e| ProtocolError::Parsing(format!("Invalid JSON response: {}", e)))?;
108
109        // Extract result from JSON-RPC
110        response_json
111            .get("result")
112            .ok_or_else(|| ProtocolError::Parsing("Missing 'result' field".to_string()))
113            .map(|v| v.clone())
114    }
115
116    /// List tools from all servers (async version)
117    pub async fn list_tools_async(&self) -> Result<Vec<Tool>, ProtocolError> {
118        let mut all_tools = Vec::new();
119
120        for server in &self.config.servers {
121            match self
122                .send_to_server(&server.name, "tools/list", json!({}))
123                .await
124            {
125                Ok(result) => {
126                    let list_result: ListToolsResult =
127                        serde_json::from_value(result).map_err(|e| {
128                            ProtocolError::Parsing(format!("Invalid tools list format: {}", e))
129                        })?;
130
131                    // Prefix tool names with server name for uniqueness
132                    let mut tools = list_result.tools;
133                    for tool in &mut tools {
134                        tool.name = format!("{}:{}", server.name, tool.name);
135                    }
136                    all_tools.extend(tools);
137                }
138                Err(e) => {
139                    log::warn!(
140                        "Failed to list tools from proxy target '{}': {:?}",
141                        server.name,
142                        e
143                    );
144                }
145            }
146        }
147
148        Ok(all_tools)
149    }
150
151    /// Call tool (async version)
152    pub async fn call_tool_async(
153        &self,
154        name: &str,
155        arguments: Option<serde_json::Value>,
156    ) -> Result<CallToolResult, ProtocolError> {
157        // Parse tool name: "server:tool" format
158        let parts: Vec<&str> = name.splitn(2, ':').collect();
159        if parts.len() != 2 {
160            return Err(ProtocolError::internal_error(
161                "Tool name must be in format 'server:tool'",
162            ));
163        }
164
165        let server_name = parts[0];
166        let tool_name = parts[1];
167
168        let params = json!({
169            "name": tool_name,
170            "arguments": arguments
171        });
172
173        let result = self
174            .send_to_server(server_name, "tools/call", params)
175            .await?;
176
177        let call_result: CallToolResult = serde_json::from_value(result).map_err(|e| {
178            ProtocolError::Parsing(format!("Invalid tool call result format: {}", e))
179        })?;
180
181        Ok(call_result)
182    }
183
184    /// Health check all servers
185    pub async fn health_check_all(&self) -> HashMap<String, bool> {
186        let mut health_status = HashMap::new();
187
188        for server in &self.config.servers {
189            if let Some(transport) = self.sse_transports.get(&server.name) {
190                let is_healthy = transport.health_check().await.is_ok();
191                health_status.insert(server.name.clone(), is_healthy);
192            } else {
193                health_status.insert(server.name.clone(), false);
194            }
195        }
196
197        health_status
198    }
199}
200
201#[async_trait]
202impl ToolProvider for McpProxy {
203    fn list_tools(&self) -> Result<Vec<Tool>, ProtocolError> {
204        // Note: Since ToolProvider is sync but SSE calls are async,
205        // in practice this would use a cached tool list or async context
206        // For now, return placeholder - actual implementation needs async support
207        Err(ProtocolError::internal_error(
208            "Async tool listing not supported in sync context. Use async proxy methods.",
209        ))
210    }
211
212    async fn call_tool(
213        &self,
214        name: &str,
215        _arguments: Option<serde_json::Value>,
216    ) -> Result<CallToolResult, ProtocolError> {
217        let parts: Vec<&str> = name.splitn(2, ':').collect();
218        if parts.len() != 2 {
219            return Err(ProtocolError::internal_error(
220                "Tool name must be in format 'server:tool'",
221            ));
222        }
223
224        Err(ProtocolError::internal_error(
225            "Async tool calls not supported in sync context. Use async proxy methods.",
226        ))
227    }
228}
229
230/// **MCP Proxy Builder** - Convenient proxy configuration
231pub struct McpProxyBuilder {
232    servers: Vec<McpProxyTarget>,
233    proxy_auth: Option<String>,
234    timeout_seconds: u64,
235}
236
237impl McpProxyBuilder {
238    /// Create new proxy builder
239    pub fn new() -> Self {
240        Self {
241            servers: Vec::new(),
242            proxy_auth: None,
243            timeout_seconds: 30,
244        }
245    }
246
247    /// Add MCP server
248    pub fn add_server(mut self, name: &str, sse_endpoint: &str) -> Self {
249        self.servers.push(McpProxyTarget {
250            name: name.to_string(),
251            sse_endpoint: sse_endpoint.to_string(),
252            auth_token: None,
253            description: None,
254        });
255        self
256    }
257
258    /// Add MCP server with authentication
259    pub fn add_server_with_auth(
260        mut self,
261        name: &str,
262        sse_endpoint: &str,
263        auth_token: &str,
264    ) -> Self {
265        self.servers.push(McpProxyTarget {
266            name: name.to_string(),
267            sse_endpoint: sse_endpoint.to_string(),
268            auth_token: Some(auth_token.to_string()),
269            description: None,
270        });
271        self
272    }
273
274    /// Set proxy authentication token
275    pub fn with_proxy_auth(mut self, auth_token: &str) -> Self {
276        self.proxy_auth = Some(auth_token.to_string());
277        self
278    }
279
280    /// Set timeout for external requests
281    pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
282        self.timeout_seconds = timeout_seconds;
283        self
284    }
285
286    /// Build the MCP proxy
287    pub fn build(self) -> McpProxy {
288        let config = McpProxyConfig {
289            servers: self.servers,
290            proxy_auth: self.proxy_auth,
291            timeout_seconds: self.timeout_seconds,
292        };
293
294        McpProxy::new(config)
295    }
296}
297
298impl Default for McpProxyBuilder {
299    fn default() -> Self {
300        Self::new()
301    }
302}