1use mcp_spec::protocol::{
2    CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError,
3    JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult,
4    ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
5};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::sync::atomic::{AtomicU64, Ordering};
9use thiserror::Error;
10use tokio::sync::Mutex;
11use tower::{Service, ServiceExt}; pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
14
15#[derive(Debug, Error)]
17pub enum Error {
18    #[error("Transport error: {0}")]
19    Transport(#[from] super::transport::Error),
20
21    #[error("RPC error: code={code}, message={message}")]
22    RpcError { code: i32, message: String },
23
24    #[error("Serialization error: {0}")]
25    Serialization(#[from] serde_json::Error),
26
27    #[error("Unexpected response from server: {0}")]
28    UnexpectedResponse(String),
29
30    #[error("Not initialized")]
31    NotInitialized,
32
33    #[error("Timeout or service not ready")]
34    NotReady,
35
36    #[error("Request timed out")]
37    Timeout(#[from] tower::timeout::error::Elapsed),
38
39    #[error("Error from mcp-server: {0}")]
40    ServerBoxError(BoxError),
41
42    #[error("Call to '{server}' failed for '{method}'. {source}")]
43    McpServerError {
44        method: String,
45        server: String,
46        #[source]
47        source: BoxError,
48    },
49}
50
51impl From<BoxError> for Error {
53    fn from(err: BoxError) -> Self {
54        Error::ServerBoxError(err)
55    }
56}
57
58#[derive(Serialize, Deserialize)]
59pub struct ClientInfo {
60    pub name: String,
61    pub version: String,
62}
63
64#[derive(Serialize, Deserialize, Default)]
65pub struct ClientCapabilities {
66    }
68
69#[derive(Serialize, Deserialize)]
70pub struct InitializeParams {
71    #[serde(rename = "protocolVersion")]
72    pub protocol_version: String,
73    pub capabilities: ClientCapabilities,
74    #[serde(rename = "clientInfo")]
75    pub client_info: ClientInfo,
76}
77
78#[async_trait::async_trait]
79pub trait McpClientTrait: Send + Sync {
80    async fn initialize(
81        &mut self,
82        info: ClientInfo,
83        capabilities: ClientCapabilities,
84    ) -> Result<InitializeResult, Error>;
85
86    async fn list_resources(
87        &self,
88        next_cursor: Option<String>,
89    ) -> Result<ListResourcesResult, Error>;
90
91    async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, Error>;
92
93    async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error>;
94
95    async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error>;
96
97    async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
98
99    async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
100}
101
102pub struct McpClient<S>
104where
105    S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
106    S::Error: Into<Error>,
107    S::Future: Send,
108{
109    service: Mutex<S>,
110    next_id: AtomicU64,
111    server_capabilities: Option<ServerCapabilities>,
112    server_info: Option<Implementation>,
113}
114
115impl<S> McpClient<S>
116where
117    S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
118    S::Error: Into<Error>,
119    S::Future: Send,
120{
121    pub fn new(service: S) -> Self {
122        Self {
123            service: Mutex::new(service),
124            next_id: AtomicU64::new(1),
125            server_capabilities: None,
126            server_info: None,
127        }
128    }
129
130    async fn send_request<R>(&self, method: &str, params: Value) -> Result<R, Error>
132    where
133        R: for<'de> Deserialize<'de>,
134    {
135        let mut service = self.service.lock().await;
136        service.ready().await.map_err(|_| Error::NotReady)?;
137
138        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
139        let request = JsonRpcMessage::Request(JsonRpcRequest {
140            jsonrpc: "2.0".to_string(),
141            id: Some(id),
142            method: method.to_string(),
143            params: Some(params.clone()),
144        });
145
146        let response_msg = service
147            .call(request)
148            .await
149            .map_err(|e| Error::McpServerError {
150                server: self
151                    .server_info
152                    .as_ref()
153                    .map(|s| s.name.clone())
154                    .unwrap_or("".to_string()),
155                method: method.to_string(),
156                source: Box::new(e.into()),
158            })?;
159
160        match response_msg {
161            JsonRpcMessage::Response(JsonRpcResponse {
162                id, result, error, ..
163            }) => {
164                if id != Some(self.next_id.load(Ordering::SeqCst) - 1) {
166                    return Err(Error::UnexpectedResponse(
167                        "id mismatch for JsonRpcResponse".to_string(),
168                    ));
169                }
170                if let Some(err) = error {
171                    Err(Error::RpcError {
172                        code: err.code,
173                        message: err.message,
174                    })
175                } else if let Some(r) = result {
176                    Ok(serde_json::from_value(r)?)
177                } else {
178                    Err(Error::UnexpectedResponse("missing result".to_string()))
179                }
180            }
181            JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => {
182                if id != Some(self.next_id.load(Ordering::SeqCst) - 1) {
183                    return Err(Error::UnexpectedResponse(
184                        "id mismatch for JsonRpcError".to_string(),
185                    ));
186                }
187                Err(Error::RpcError {
188                    code: error.code,
189                    message: error.message,
190                })
191            }
192            _ => {
193                Err(Error::UnexpectedResponse(
195                    "unexpected message type".to_string(),
196                ))
197            }
198        }
199    }
200
201    async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> {
203        let mut service = self.service.lock().await;
204        service.ready().await.map_err(|_| Error::NotReady)?;
205
206        let notification = JsonRpcMessage::Notification(JsonRpcNotification {
207            jsonrpc: "2.0".to_string(),
208            method: method.to_string(),
209            params: Some(params.clone()),
210        });
211
212        service
213            .call(notification)
214            .await
215            .map_err(|e| Error::McpServerError {
216                server: self
217                    .server_info
218                    .as_ref()
219                    .map(|s| s.name.clone())
220                    .unwrap_or("".to_string()),
221                method: method.to_string(),
222                source: Box::new(e.into()),
224            })?;
225
226        Ok(())
227    }
228
229    fn completed_initialization(&self) -> bool {
231        self.server_capabilities.is_some()
232    }
233}
234
235#[async_trait::async_trait]
236impl<S> McpClientTrait for McpClient<S>
237where
238    S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
239    S::Error: Into<Error>,
240    S::Future: Send,
241{
242    async fn initialize(
243        &mut self,
244        info: ClientInfo,
245        capabilities: ClientCapabilities,
246    ) -> Result<InitializeResult, Error> {
247        let params = InitializeParams {
248            protocol_version: "1.0.0".into(),
249            client_info: info,
250            capabilities,
251        };
252        let result: InitializeResult = self
253            .send_request("initialize", serde_json::to_value(params)?)
254            .await?;
255
256        self.send_notification("notifications/initialized", serde_json::json!({}))
257            .await?;
258
259        self.server_capabilities = Some(result.capabilities.clone());
260
261        self.server_info = Some(result.server_info.clone());
262
263        Ok(result)
264    }
265
266    async fn list_resources(
267        &self,
268        next_cursor: Option<String>,
269    ) -> Result<ListResourcesResult, Error> {
270        if !self.completed_initialization() {
271            return Err(Error::NotInitialized);
272        }
273        if self
275            .server_capabilities
276            .as_ref()
277            .unwrap()
278            .resources
279            .is_none()
280        {
281            return Ok(ListResourcesResult {
282                resources: vec![],
283                next_cursor: None,
284            });
285        }
286
287        let payload = next_cursor
288            .map(|cursor| serde_json::json!({"cursor": cursor}))
289            .unwrap_or_else(|| serde_json::json!({}));
290
291        self.send_request("resources/list", payload).await
292    }
293
294    async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, Error> {
295        if !self.completed_initialization() {
296            return Err(Error::NotInitialized);
297        }
298        if self
300            .server_capabilities
301            .as_ref()
302            .unwrap()
303            .resources
304            .is_none()
305        {
306            return Err(Error::RpcError {
307                code: METHOD_NOT_FOUND,
308                message: "Server does not support 'resources' capability".to_string(),
309            });
310        }
311
312        let params = serde_json::json!({ "uri": uri });
313        self.send_request("resources/read", params).await
314    }
315
316    async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error> {
317        if !self.completed_initialization() {
318            return Err(Error::NotInitialized);
319        }
320        if self.server_capabilities.as_ref().unwrap().tools.is_none() {
322            return Ok(ListToolsResult {
323                tools: vec![],
324                next_cursor: None,
325            });
326        }
327
328        let payload = next_cursor
329            .map(|cursor| serde_json::json!({"cursor": cursor}))
330            .unwrap_or_else(|| serde_json::json!({}));
331
332        self.send_request("tools/list", payload).await
333    }
334
335    async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error> {
336        if !self.completed_initialization() {
337            return Err(Error::NotInitialized);
338        }
339        if self.server_capabilities.as_ref().unwrap().tools.is_none() {
341            return Err(Error::RpcError {
342                code: METHOD_NOT_FOUND,
343                message: "Server does not support 'tools' capability".to_string(),
344            });
345        }
346
347        let params = serde_json::json!({ "name": name, "arguments": arguments });
348
349        self.send_request("tools/call", params).await
352    }
353
354    async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error> {
355        if !self.completed_initialization() {
356            return Err(Error::NotInitialized);
357        }
358
359        if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
361            return Err(Error::RpcError {
362                code: METHOD_NOT_FOUND,
363                message: "Server does not support 'prompts' capability".to_string(),
364            });
365        }
366
367        let payload = next_cursor
368            .map(|cursor| serde_json::json!({"cursor": cursor}))
369            .unwrap_or_else(|| serde_json::json!({}));
370
371        self.send_request("prompts/list", payload).await
372    }
373
374    async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error> {
375        if !self.completed_initialization() {
376            return Err(Error::NotInitialized);
377        }
378
379        if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
381            return Err(Error::RpcError {
382                code: METHOD_NOT_FOUND,
383                message: "Server does not support 'prompts' capability".to_string(),
384            });
385        }
386
387        let params = serde_json::json!({ "name": name, "arguments": arguments });
388
389        self.send_request("prompts/get", params).await
390    }
391}