1use crate::{CallToolResult, ListToolsResult, Tool, ToolProvider};
9use async_trait::async_trait;
10use protocol_transport_core::{
11 ProtocolError, SseTransport, Transport, TransportFactory, UniversalRequest,
12};
13use serde_json::json;
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
18pub struct McpProxyConfig {
19 pub servers: Vec<McpProxyTarget>,
21 pub proxy_auth: Option<String>,
23 pub timeout_seconds: u64,
25}
26
27#[derive(Debug, Clone)]
29pub struct McpProxyTarget {
30 pub name: String,
32 pub sse_endpoint: String,
34 pub auth_token: Option<String>,
36 pub description: Option<String>,
38}
39
40pub struct McpProxy {
42 config: McpProxyConfig,
44 sse_transports: HashMap<String, SseTransport>,
46}
47
48impl McpProxy {
49 pub fn new(config: McpProxyConfig) -> Self {
51 let mut sse_transports = HashMap::new();
53
54 for server in &config.servers {
55 let transport = match &server.auth_token {
56 Some(token) => TransportFactory::mcp_sse_auth(&server.sse_endpoint, token),
57 None => TransportFactory::mcp_sse(&server.sse_endpoint),
58 };
59 sse_transports.insert(server.name.clone(), transport);
60 }
61
62 Self {
63 config,
64 sse_transports,
65 }
66 }
67
68 async fn send_to_server(
70 &self,
71 server_name: &str,
72 method: &str,
73 params: serde_json::Value,
74 ) -> Result<serde_json::Value, ProtocolError> {
75 let transport = self.sse_transports.get(server_name).ok_or_else(|| {
76 ProtocolError::internal_error(&format!("Unknown server: {}", server_name))
77 })?;
78
79 let request = UniversalRequest {
81 method: method.to_string(),
82 uri: "/".to_string(),
83 headers: HashMap::new(),
84 body: json!({
85 "jsonrpc": "2.0",
86 "method": method,
87 "params": params,
88 "id": 1
89 })
90 .to_string()
91 .into_bytes(),
92 protocol: "MCP".to_string(),
93 correlation_id: format!("{}-{}", method.replace("/", "-"), server_name),
94 };
95
96 let response = transport
98 .send(request)
99 .await
100 .map_err(|e| ProtocolError::internal_error(&format!("Transport error: {:?}", e)))?;
101
102 let response_body = String::from_utf8(response.body)
104 .map_err(|e| ProtocolError::Parsing(format!("Invalid UTF-8 response: {}", e)))?;
105
106 let response_json: serde_json::Value = serde_json::from_str(&response_body)
107 .map_err(|e| ProtocolError::Parsing(format!("Invalid JSON response: {}", e)))?;
108
109 response_json
111 .get("result")
112 .ok_or_else(|| ProtocolError::Parsing("Missing 'result' field".to_string()))
113 .map(|v| v.clone())
114 }
115
116 pub async fn list_tools_async(&self) -> Result<Vec<Tool>, ProtocolError> {
118 let mut all_tools = Vec::new();
119
120 for server in &self.config.servers {
121 match self
122 .send_to_server(&server.name, "tools/list", json!({}))
123 .await
124 {
125 Ok(result) => {
126 let list_result: ListToolsResult =
127 serde_json::from_value(result).map_err(|e| {
128 ProtocolError::Parsing(format!("Invalid tools list format: {}", e))
129 })?;
130
131 let mut tools = list_result.tools;
133 for tool in &mut tools {
134 tool.name = format!("{}:{}", server.name, tool.name);
135 }
136 all_tools.extend(tools);
137 }
138 Err(e) => {
139 log::warn!(
140 "Failed to list tools from proxy target '{}': {:?}",
141 server.name,
142 e
143 );
144 }
145 }
146 }
147
148 Ok(all_tools)
149 }
150
151 pub async fn call_tool_async(
153 &self,
154 name: &str,
155 arguments: Option<serde_json::Value>,
156 ) -> Result<CallToolResult, ProtocolError> {
157 let parts: Vec<&str> = name.splitn(2, ':').collect();
159 if parts.len() != 2 {
160 return Err(ProtocolError::internal_error(
161 "Tool name must be in format 'server:tool'",
162 ));
163 }
164
165 let server_name = parts[0];
166 let tool_name = parts[1];
167
168 let params = json!({
169 "name": tool_name,
170 "arguments": arguments
171 });
172
173 let result = self
174 .send_to_server(server_name, "tools/call", params)
175 .await?;
176
177 let call_result: CallToolResult = serde_json::from_value(result).map_err(|e| {
178 ProtocolError::Parsing(format!("Invalid tool call result format: {}", e))
179 })?;
180
181 Ok(call_result)
182 }
183
184 pub async fn health_check_all(&self) -> HashMap<String, bool> {
186 let mut health_status = HashMap::new();
187
188 for server in &self.config.servers {
189 if let Some(transport) = self.sse_transports.get(&server.name) {
190 let is_healthy = transport.health_check().await.is_ok();
191 health_status.insert(server.name.clone(), is_healthy);
192 } else {
193 health_status.insert(server.name.clone(), false);
194 }
195 }
196
197 health_status
198 }
199}
200
201#[async_trait]
202impl ToolProvider for McpProxy {
203 fn list_tools(&self) -> Result<Vec<Tool>, ProtocolError> {
204 Err(ProtocolError::internal_error(
208 "Async tool listing not supported in sync context. Use async proxy methods.",
209 ))
210 }
211
212 async fn call_tool(
213 &self,
214 name: &str,
215 _arguments: Option<serde_json::Value>,
216 ) -> Result<CallToolResult, ProtocolError> {
217 let parts: Vec<&str> = name.splitn(2, ':').collect();
218 if parts.len() != 2 {
219 return Err(ProtocolError::internal_error(
220 "Tool name must be in format 'server:tool'",
221 ));
222 }
223
224 Err(ProtocolError::internal_error(
225 "Async tool calls not supported in sync context. Use async proxy methods.",
226 ))
227 }
228}
229
230pub struct McpProxyBuilder {
232 servers: Vec<McpProxyTarget>,
233 proxy_auth: Option<String>,
234 timeout_seconds: u64,
235}
236
237impl McpProxyBuilder {
238 pub fn new() -> Self {
240 Self {
241 servers: Vec::new(),
242 proxy_auth: None,
243 timeout_seconds: 30,
244 }
245 }
246
247 pub fn add_server(mut self, name: &str, sse_endpoint: &str) -> Self {
249 self.servers.push(McpProxyTarget {
250 name: name.to_string(),
251 sse_endpoint: sse_endpoint.to_string(),
252 auth_token: None,
253 description: None,
254 });
255 self
256 }
257
258 pub fn add_server_with_auth(
260 mut self,
261 name: &str,
262 sse_endpoint: &str,
263 auth_token: &str,
264 ) -> Self {
265 self.servers.push(McpProxyTarget {
266 name: name.to_string(),
267 sse_endpoint: sse_endpoint.to_string(),
268 auth_token: Some(auth_token.to_string()),
269 description: None,
270 });
271 self
272 }
273
274 pub fn with_proxy_auth(mut self, auth_token: &str) -> Self {
276 self.proxy_auth = Some(auth_token.to_string());
277 self
278 }
279
280 pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
282 self.timeout_seconds = timeout_seconds;
283 self
284 }
285
286 pub fn build(self) -> McpProxy {
288 let config = McpProxyConfig {
289 servers: self.servers,
290 proxy_auth: self.proxy_auth,
291 timeout_seconds: self.timeout_seconds,
292 };
293
294 McpProxy::new(config)
295 }
296}
297
298impl Default for McpProxyBuilder {
299 fn default() -> Self {
300 Self::new()
301 }
302}