mcp_client_fishcode2025/
client.rs

1use mcp_core_fishcode2025::protocol::{
2    CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError,
3    JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult,
4    ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
5};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::sync::atomic::{AtomicU64, Ordering};
9use thiserror::Error;
10use tokio::sync::Mutex;
11use tower::Service;
12use tower::ServiceExt; // for Service::ready()
13
14pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
15
16/// Error type for MCP client operations.
17#[derive(Debug, Error)]
18pub enum Error {
19    #[error("Transport error: {0}")]
20    Transport(#[from] super::transport::Error),
21
22    #[error("RPC error: code={code}, message={message}")]
23    RpcError { code: i32, message: String },
24
25    #[error("Serialization error: {0}")]
26    Serialization(#[from] serde_json::Error),
27
28    #[error("Unexpected response from server: {0}")]
29    UnexpectedResponse(String),
30
31    #[error("Not initialized")]
32    NotInitialized,
33
34    #[error("Timeout or service not ready")]
35    NotReady,
36
37    #[error("Request timed out")]
38    Timeout(#[from] tower::timeout::error::Elapsed),
39
40    #[error("Error from mcp-server: {0}")]
41    ServerBoxError(BoxError),
42
43    #[error("Call to '{server}' failed for '{method}'. {source}")]
44    McpServerError {
45        method: String,
46        server: String,
47        #[source]
48        source: BoxError,
49    },
50}
51
52// BoxError from mcp-server gets converted to our Error type
53impl From<BoxError> for Error {
54    fn from(err: BoxError) -> Self {
55        Error::ServerBoxError(err)
56    }
57}
58
59#[derive(Serialize, Deserialize)]
60pub struct ClientInfo {
61    pub name: String,
62    pub version: String,
63}
64
65#[derive(Serialize, Deserialize, Default)]
66pub struct ClientCapabilities {
67    // Add fields as needed. For now, empty capabilities are fine.
68}
69
70#[derive(Serialize, Deserialize)]
71pub struct InitializeParams {
72    #[serde(rename = "protocolVersion")]
73    pub protocol_version: String,
74    pub capabilities: ClientCapabilities,
75    #[serde(rename = "clientInfo")]
76    pub client_info: ClientInfo,
77}
78
79#[async_trait::async_trait]
80pub trait McpClientTrait: Send + Sync {
81    async fn initialize(
82        &mut self,
83        info: ClientInfo,
84        capabilities: ClientCapabilities,
85    ) -> Result<InitializeResult, Error>;
86
87    async fn list_resources(
88        &self,
89        next_cursor: Option<String>,
90    ) -> Result<ListResourcesResult, Error>;
91
92    async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, Error>;
93
94    async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error>;
95
96    async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error>;
97
98    async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
99
100    async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
101}
102
103/// The MCP client is the interface for MCP operations.
104pub struct McpClient<S>
105where
106    S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
107    S::Error: Into<Error>,
108    S::Future: Send,
109{
110    service: Mutex<S>,
111    next_id: AtomicU64,
112    server_capabilities: Option<ServerCapabilities>,
113    server_info: Option<Implementation>,
114}
115
116impl<S> McpClient<S>
117where
118    S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
119    S::Error: Into<Error>,
120    S::Future: Send,
121{
122    pub fn new(service: S) -> Self {
123        Self {
124            service: Mutex::new(service),
125            next_id: AtomicU64::new(1),
126            server_capabilities: None,
127            server_info: None,
128        }
129    }
130
131    /// Send a JSON-RPC request and check we don't get an error response.
132    async fn send_request<R>(&self, method: &str, params: Value) -> Result<R, Error>
133    where
134        R: for<'de> Deserialize<'de>,
135    {
136        let mut service = self.service.lock().await;
137        service.ready().await.map_err(|_| Error::NotReady)?;
138
139        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
140        let request = JsonRpcMessage::Request(JsonRpcRequest {
141            jsonrpc: "2.0".to_string(),
142            id: Some(id),
143            method: method.to_string(),
144            params: Some(params.clone()),
145        });
146
147        let response_msg = service
148            .call(request)
149            .await
150            .map_err(|e| Error::McpServerError {
151                server: self
152                    .server_info
153                    .as_ref()
154                    .map(|s| s.name.clone())
155                    .unwrap_or("".to_string()),
156                method: method.to_string(),
157                // we don't need include params because it can be really large
158                source: Box::new(e.into()),
159            })?;
160
161        match response_msg {
162            JsonRpcMessage::Response(JsonRpcResponse {
163                id, result, error, ..
164            }) => {
165                // Verify id matches
166                if id != Some(self.next_id.load(Ordering::SeqCst) - 1) {
167                    return Err(Error::UnexpectedResponse(
168                        "id mismatch for JsonRpcResponse".to_string(),
169                    ));
170                }
171                if let Some(err) = error {
172                    Err(Error::RpcError {
173                        code: err.code,
174                        message: err.message,
175                    })
176                } else if let Some(r) = result {
177                    Ok(serde_json::from_value(r)?)
178                } else {
179                    Err(Error::UnexpectedResponse("missing result".to_string()))
180                }
181            }
182            JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => {
183                if id != Some(self.next_id.load(Ordering::SeqCst) - 1) {
184                    return Err(Error::UnexpectedResponse(
185                        "id mismatch for JsonRpcError".to_string(),
186                    ));
187                }
188                Err(Error::RpcError {
189                    code: error.code,
190                    message: error.message,
191                })
192            }
193            _ => {
194                // Requests/notifications not expected as a response
195                Err(Error::UnexpectedResponse(
196                    "unexpected message type".to_string(),
197                ))
198            }
199        }
200    }
201
202    /// Send a JSON-RPC notification.
203    async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> {
204        let mut service = self.service.lock().await;
205        service.ready().await.map_err(|_| Error::NotReady)?;
206
207        let notification = JsonRpcMessage::Notification(JsonRpcNotification {
208            jsonrpc: "2.0".to_string(),
209            method: method.to_string(),
210            params: Some(params.clone()),
211        });
212
213        service
214            .call(notification)
215            .await
216            .map_err(|e| Error::McpServerError {
217                server: self
218                    .server_info
219                    .as_ref()
220                    .map(|s| s.name.clone())
221                    .unwrap_or("".to_string()),
222                method: method.to_string(),
223                // we don't need include params because it can be really large
224                source: Box::new(e.into()),
225            })?;
226
227        Ok(())
228    }
229
230    // Check if the client has completed initialization
231    fn completed_initialization(&self) -> bool {
232        self.server_capabilities.is_some()
233    }
234}
235
236#[async_trait::async_trait]
237impl<S> McpClientTrait for McpClient<S>
238where
239    S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
240    S::Error: Into<Error>,
241    S::Future: Send,
242{
243    async fn initialize(
244        &mut self,
245        info: ClientInfo,
246        capabilities: ClientCapabilities,
247    ) -> Result<InitializeResult, Error> {
248        let params = InitializeParams {
249            protocol_version: "1.0.0".into(),
250            client_info: info,
251            capabilities,
252        };
253        let result: InitializeResult = self
254            .send_request("initialize", serde_json::to_value(params)?)
255            .await?;
256
257        self.send_notification("notifications/initialized", serde_json::json!({}))
258            .await?;
259
260        self.server_capabilities = Some(result.capabilities.clone());
261
262        self.server_info = Some(result.server_info.clone());
263
264        Ok(result)
265    }
266
267    async fn list_resources(
268        &self,
269        next_cursor: Option<String>,
270    ) -> Result<ListResourcesResult, Error> {
271        if !self.completed_initialization() {
272            return Err(Error::NotInitialized);
273        }
274        // If resources is not supported, return an empty list
275        if self
276            .server_capabilities
277            .as_ref()
278            .unwrap()
279            .resources
280            .is_none()
281        {
282            return Ok(ListResourcesResult {
283                resources: vec![],
284                next_cursor: None,
285            });
286        }
287
288        let payload = next_cursor
289            .map(|cursor| serde_json::json!({"cursor": cursor}))
290            .unwrap_or_else(|| serde_json::json!({}));
291
292        self.send_request("resources/list", payload).await
293    }
294
295    async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, Error> {
296        if !self.completed_initialization() {
297            return Err(Error::NotInitialized);
298        }
299        // If resources is not supported, return an error
300        if self
301            .server_capabilities
302            .as_ref()
303            .unwrap()
304            .resources
305            .is_none()
306        {
307            return Err(Error::RpcError {
308                code: METHOD_NOT_FOUND,
309                message: "Server does not support 'resources' capability".to_string(),
310            });
311        }
312
313        let params = serde_json::json!({ "uri": uri });
314        self.send_request("resources/read", params).await
315    }
316
317    async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error> {
318        if !self.completed_initialization() {
319            return Err(Error::NotInitialized);
320        }
321        // If tools is not supported, return an empty list
322        if self.server_capabilities.as_ref().unwrap().tools.is_none() {
323            return Ok(ListToolsResult {
324                tools: vec![],
325                next_cursor: None,
326            });
327        }
328
329        let payload = next_cursor
330            .map(|cursor| serde_json::json!({"cursor": cursor}))
331            .unwrap_or_else(|| serde_json::json!({}));
332
333        self.send_request("tools/list", payload).await
334    }
335
336    async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error> {
337        if !self.completed_initialization() {
338            return Err(Error::NotInitialized);
339        }
340        // If tools is not supported, return an error
341        if self.server_capabilities.as_ref().unwrap().tools.is_none() {
342            return Err(Error::RpcError {
343                code: METHOD_NOT_FOUND,
344                message: "Server does not support 'tools' capability".to_string(),
345            });
346        }
347
348        let params = serde_json::json!({ "name": name, "arguments": arguments });
349
350        // TODO ERROR: check that if there is an error, we send back is_error: true with msg
351        // https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2
352        self.send_request("tools/call", params).await
353    }
354
355    async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error> {
356        if !self.completed_initialization() {
357            return Err(Error::NotInitialized);
358        }
359
360        // If prompts is not supported, return an error
361        if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
362            return Err(Error::RpcError {
363                code: METHOD_NOT_FOUND,
364                message: "Server does not support 'prompts' capability".to_string(),
365            });
366        }
367
368        let payload = next_cursor
369            .map(|cursor| serde_json::json!({"cursor": cursor}))
370            .unwrap_or_else(|| serde_json::json!({}));
371
372        self.send_request("prompts/list", payload).await
373    }
374
375    async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error> {
376        if !self.completed_initialization() {
377            return Err(Error::NotInitialized);
378        }
379
380        // If prompts is not supported, return an error
381        if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
382            return Err(Error::RpcError {
383                code: METHOD_NOT_FOUND,
384                message: "Server does not support 'prompts' capability".to_string(),
385            });
386        }
387
388        let params = serde_json::json!({ "name": name, "arguments": arguments });
389
390        self.send_request("prompts/get", params).await
391    }
392}