rmcp_proxy/
proxy_handler.rs

1/**
2 * Create a local SSE server that proxies requests to a stdio MCP server.
3 */
4use rmcp::{
5    model::{
6        CallToolRequestParam, CallToolResult, ClientInfo, Content, Implementation, ListToolsResult,
7        PaginatedRequestParam, ServerInfo,
8    },
9    service::{RequestContext, RunningService},
10    Error, RoleClient, RoleServer, ServerHandler,
11};
12use std::sync::Arc;
13use tokio::sync::Mutex;
14use tracing::debug;
15
16/// A proxy handler that forwards requests to a client based on the server's capabilities
17#[derive(Clone)]
18pub struct ProxyHandler {
19    client: Arc<Mutex<RunningService<RoleClient, ClientInfo>>>,
20    // Store the server's capabilities to avoid locking the client on every get_info call
21    cached_info: Arc<ServerInfo>,
22}
23
24impl ServerHandler for ProxyHandler {
25    fn get_info(&self) -> ServerInfo {
26        // Return the cached server info with capabilities
27        self.cached_info.as_ref().clone()
28    }
29
30    async fn list_tools(
31        &self,
32        request: PaginatedRequestParam,
33        _context: RequestContext<RoleServer>,
34    ) -> Result<ListToolsResult, Error> {
35        let client = self.client.clone();
36        let guard = client.lock().await;
37
38        // Check if the server has tools capability and forward the request
39        match self.cached_info.capabilities.tools {
40            Some(_) => {
41                match guard.list_tools(request).await {
42                    // Forward request to client
43                    Ok(result) => {
44                        debug!(
45                            "Proxying list_tools response with {} tools",
46                            result.tools.len()
47                        );
48                        Ok(result)
49                    }
50                    Err(err) => {
51                        tracing::error!("Error listing tools: {:?}", err);
52                        // Return empty list instead of error
53                        Ok(ListToolsResult::default())
54                    }
55                }
56            }
57            None => {
58                // Server doesn't support tools, return empty list
59                tracing::error!("Server doesn't support tools capability");
60                Ok(ListToolsResult::default())
61            }
62        }
63    }
64
65    async fn call_tool(
66        &self,
67        request: CallToolRequestParam,
68        _context: RequestContext<RoleServer>,
69    ) -> Result<CallToolResult, Error> {
70        let client = self.client.clone();
71        let guard = client.lock().await;
72
73        // Check if the server has tools capability and forward the request
74        match self.cached_info.capabilities.tools {
75            Some(_) => {
76                match guard.call_tool(request.clone()).await {
77                    Ok(result) => {
78                        debug!("Tool call succeeded");
79                        Ok(result)
80                    }
81                    Err(err) => {
82                        tracing::error!("Error calling tool: {:?}", err);
83                        // Return an error result instead of propagating the error
84                        Ok(CallToolResult::error(vec![Content::text(format!(
85                            "Error: {}",
86                            err
87                        ))]))
88                    }
89                }
90            }
91            None => {
92                tracing::error!("Server doesn't support tools capability");
93                Ok(CallToolResult::error(vec![Content::text(
94                    "Server doesn't support tools capability",
95                )]))
96            }
97        }
98    }
99
100    async fn list_resources(
101        &self,
102        request: PaginatedRequestParam,
103        _context: RequestContext<RoleServer>,
104    ) -> Result<rmcp::model::ListResourcesResult, Error> {
105        // Get a lock on the client
106        let client = self.client.clone();
107        let guard = client.lock().await;
108
109        // Check if the server has resources capability and forward the request
110        match self.cached_info.capabilities.resources {
111            Some(_) => {
112                // Forward request to client
113                match guard.list_resources(request).await {
114                    Ok(result) => {
115                        debug!("Proxying list_resources response");
116                        Ok(result)
117                    }
118                    Err(err) => {
119                        tracing::error!("Error listing resources: {:?}", err);
120                        // Return empty list instead of error
121                        Ok(rmcp::model::ListResourcesResult::default())
122                    }
123                }
124            }
125            None => {
126                // Server doesn't support resources, return empty list
127                tracing::error!("Server doesn't support resources capability");
128                Ok(rmcp::model::ListResourcesResult::default())
129            }
130        }
131    }
132
133    async fn read_resource(
134        &self,
135        request: rmcp::model::ReadResourceRequestParam,
136        _context: RequestContext<RoleServer>,
137    ) -> Result<rmcp::model::ReadResourceResult, Error> {
138        // Get a lock on the client
139        let client = self.client.clone();
140        let guard = client.lock().await;
141
142        // Check if the server has resources capability and forward the request
143        match self.cached_info.capabilities.resources {
144            Some(_) => {
145                // Forward request to client
146                match guard
147                    .read_resource(rmcp::model::ReadResourceRequestParam {
148                        uri: request.uri.clone(),
149                    })
150                    .await
151                {
152                    Ok(result) => {
153                        debug!("Proxying read_resource response for {}", request.uri);
154                        Ok(result)
155                    }
156                    Err(err) => {
157                        tracing::error!("Error reading resource: {:?}", err);
158                        Err(Error::internal_error(
159                            format!("Error reading resource: {}", err),
160                            None,
161                        ))
162                    }
163                }
164            }
165            None => {
166                // Server doesn't support resources, return error
167                tracing::error!("Server doesn't support resources capability");
168                Err(Error::internal_error(
169                    "Server doesn't support resources capability".to_string(),
170                    None,
171                ))
172            }
173        }
174    }
175
176    async fn list_resource_templates(
177        &self,
178        request: PaginatedRequestParam,
179        _context: RequestContext<RoleServer>,
180    ) -> Result<rmcp::model::ListResourceTemplatesResult, Error> {
181        // Get a lock on the client
182        let client = self.client.clone();
183        let guard = client.lock().await;
184
185        // Check if the server has resources capability and forward the request
186        match self.cached_info.capabilities.resources {
187            Some(_) => {
188                // Forward request to client
189                match guard.list_resource_templates(request).await {
190                    Ok(result) => {
191                        debug!("Proxying list_resource_templates response");
192                        Ok(result)
193                    }
194                    Err(err) => {
195                        tracing::error!("Error listing resource templates: {:?}", err);
196                        // Return empty list instead of error
197                        Ok(rmcp::model::ListResourceTemplatesResult::default())
198                    }
199                }
200            }
201            None => {
202                // Server doesn't support resources, return empty list
203                tracing::error!("Server doesn't support resources capability");
204                Ok(rmcp::model::ListResourceTemplatesResult::default())
205            }
206        }
207    }
208
209    async fn list_prompts(
210        &self,
211        request: PaginatedRequestParam,
212        _context: RequestContext<RoleServer>,
213    ) -> Result<rmcp::model::ListPromptsResult, Error> {
214        // Get a lock on the client
215        let client = self.client.clone();
216        let guard = client.lock().await;
217
218        // Check if the server has prompts capability and forward the request
219        match self.cached_info.capabilities.prompts {
220            Some(_) => {
221                // Forward request to client
222                match guard.list_prompts(request).await {
223                    Ok(result) => {
224                        debug!("Proxying list_prompts response");
225                        Ok(result)
226                    }
227                    Err(err) => {
228                        tracing::error!("Error listing prompts: {:?}", err);
229                        // Return empty list instead of error
230                        Ok(rmcp::model::ListPromptsResult::default())
231                    }
232                }
233            }
234            None => {
235                // Server doesn't support prompts, return empty list
236                tracing::error!("Server doesn't support prompts capability");
237                Ok(rmcp::model::ListPromptsResult::default())
238            }
239        }
240    }
241
242    async fn get_prompt(
243        &self,
244        request: rmcp::model::GetPromptRequestParam,
245        _context: RequestContext<RoleServer>,
246    ) -> Result<rmcp::model::GetPromptResult, Error> {
247        // Get a lock on the client
248        let client = self.client.clone();
249        let guard = client.lock().await;
250
251        // Check if the server has prompts capability and forward the request
252        match self.cached_info.capabilities.prompts {
253            Some(_) => {
254                // Forward request to client
255                match guard.get_prompt(request).await {
256                    Ok(result) => {
257                        debug!("Proxying get_prompt response");
258                        Ok(result)
259                    }
260                    Err(err) => {
261                        tracing::error!("Error getting prompt: {:?}", err);
262                        Err(Error::internal_error(
263                            format!("Error getting prompt: {}", err),
264                            None,
265                        ))
266                    }
267                }
268            }
269            None => {
270                // Server doesn't support prompts, return error
271                tracing::error!("Server doesn't support prompts capability");
272                Err(Error::internal_error(
273                    "Server doesn't support prompts capability".to_string(),
274                    None,
275                ))
276            }
277        }
278    }
279
280    async fn complete(
281        &self,
282        request: rmcp::model::CompleteRequestParam,
283        _context: RequestContext<RoleServer>,
284    ) -> Result<rmcp::model::CompleteResult, Error> {
285        // Get a lock on the client
286        let client = self.client.clone();
287        let guard = client.lock().await;
288
289        // Forward request to client
290        match guard.complete(request).await {
291            Ok(result) => {
292                debug!("Proxying complete response");
293                Ok(result)
294            }
295            Err(err) => {
296                tracing::error!("Error completing: {:?}", err);
297                Err(Error::internal_error(
298                    format!("Error completing: {}", err),
299                    None,
300                ))
301            }
302        }
303    }
304
305    async fn on_progress(&self, notification: rmcp::model::ProgressNotificationParam) {
306        // Get a lock on the client
307        let client = self.client.clone();
308        let guard = client.lock().await;
309        match guard.notify_progress(notification).await {
310            Ok(_) => {
311                debug!("Proxying progress notification");
312            }
313            Err(err) => {
314                tracing::error!("Error notifying progress: {:?}", err);
315            }
316        }
317    }
318
319    async fn on_cancelled(&self, notification: rmcp::model::CancelledNotificationParam) {
320        // Get a lock on the client
321        let client = self.client.clone();
322        let guard = client.lock().await;
323        match guard.notify_cancelled(notification).await {
324            Ok(_) => {
325                debug!("Proxying cancelled notification");
326            }
327            Err(err) => {
328                tracing::error!("Error notifying cancelled: {:?}", err);
329            }
330        }
331    }
332}
333
334impl ProxyHandler {
335    pub fn new(client: RunningService<RoleClient, ClientInfo>) -> Self {
336        let peer_info = client.peer_info();
337
338        // Create a ServerInfo object that forwards the server's capabilities
339        let cached_info = ServerInfo {
340            protocol_version: peer_info.protocol_version.clone(),
341            server_info: Implementation {
342                name: peer_info.server_info.name.clone(),
343                version: peer_info.server_info.version.clone(),
344            },
345            instructions: peer_info.instructions.clone(),
346            capabilities: peer_info.capabilities.clone(),
347        };
348
349        Self {
350            client: Arc::new(Mutex::new(client)),
351            cached_info: Arc::new(cached_info),
352        }
353    }
354}