1use mcp_spec::protocol::{
2 CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError,
3 JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult,
4 ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
5};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::sync::atomic::{AtomicU64, Ordering};
9use thiserror::Error;
10use tokio::sync::Mutex;
11use tower::{Service, ServiceExt}; pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
14
15#[derive(Debug, Error)]
17pub enum Error {
18 #[error("Transport error: {0}")]
19 Transport(#[from] super::transport::Error),
20
21 #[error("RPC error: code={code}, message={message}")]
22 RpcError { code: i32, message: String },
23
24 #[error("Serialization error: {0}")]
25 Serialization(#[from] serde_json::Error),
26
27 #[error("Unexpected response from server: {0}")]
28 UnexpectedResponse(String),
29
30 #[error("Not initialized")]
31 NotInitialized,
32
33 #[error("Timeout or service not ready")]
34 NotReady,
35
36 #[error("Request timed out")]
37 Timeout(#[from] tower::timeout::error::Elapsed),
38
39 #[error("Error from mcp-server: {0}")]
40 ServerBoxError(BoxError),
41
42 #[error("Call to '{server}' failed for '{method}'. {source}")]
43 McpServerError {
44 method: String,
45 server: String,
46 #[source]
47 source: BoxError,
48 },
49}
50
51impl From<BoxError> for Error {
53 fn from(err: BoxError) -> Self {
54 Error::ServerBoxError(err)
55 }
56}
57
58#[derive(Serialize, Deserialize)]
59pub struct ClientInfo {
60 pub name: String,
61 pub version: String,
62}
63
64#[derive(Serialize, Deserialize, Default)]
65pub struct ClientCapabilities {
66 }
68
69#[derive(Serialize, Deserialize)]
70pub struct InitializeParams {
71 #[serde(rename = "protocolVersion")]
72 pub protocol_version: String,
73 pub capabilities: ClientCapabilities,
74 #[serde(rename = "clientInfo")]
75 pub client_info: ClientInfo,
76}
77
78#[async_trait::async_trait]
79pub trait McpClientTrait: Send + Sync {
80 async fn initialize(
81 &mut self,
82 info: ClientInfo,
83 capabilities: ClientCapabilities,
84 ) -> Result<InitializeResult, Error>;
85
86 async fn list_resources(
87 &self,
88 next_cursor: Option<String>,
89 ) -> Result<ListResourcesResult, Error>;
90
91 async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, Error>;
92
93 async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error>;
94
95 async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error>;
96
97 async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
98
99 async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
100}
101
102pub struct McpClient<S>
104where
105 S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
106 S::Error: Into<Error>,
107 S::Future: Send,
108{
109 service: Mutex<S>,
110 next_id: AtomicU64,
111 server_capabilities: Option<ServerCapabilities>,
112 server_info: Option<Implementation>,
113}
114
115impl<S> McpClient<S>
116where
117 S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
118 S::Error: Into<Error>,
119 S::Future: Send,
120{
121 pub fn new(service: S) -> Self {
122 Self {
123 service: Mutex::new(service),
124 next_id: AtomicU64::new(1),
125 server_capabilities: None,
126 server_info: None,
127 }
128 }
129
130 async fn send_request<R>(&self, method: &str, params: Value) -> Result<R, Error>
132 where
133 R: for<'de> Deserialize<'de>,
134 {
135 let mut service = self.service.lock().await;
136 service.ready().await.map_err(|_| Error::NotReady)?;
137
138 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
139 let request = JsonRpcMessage::Request(JsonRpcRequest {
140 jsonrpc: "2.0".to_string(),
141 id: Some(id),
142 method: method.to_string(),
143 params: Some(params.clone()),
144 });
145
146 let response_msg = service
147 .call(request)
148 .await
149 .map_err(|e| Error::McpServerError {
150 server: self
151 .server_info
152 .as_ref()
153 .map(|s| s.name.clone())
154 .unwrap_or("".to_string()),
155 method: method.to_string(),
156 source: Box::new(e.into()),
158 })?;
159
160 match response_msg {
161 JsonRpcMessage::Response(JsonRpcResponse {
162 id, result, error, ..
163 }) => {
164 if id != Some(self.next_id.load(Ordering::SeqCst) - 1) {
166 return Err(Error::UnexpectedResponse(
167 "id mismatch for JsonRpcResponse".to_string(),
168 ));
169 }
170 if let Some(err) = error {
171 Err(Error::RpcError {
172 code: err.code,
173 message: err.message,
174 })
175 } else if let Some(r) = result {
176 Ok(serde_json::from_value(r)?)
177 } else {
178 Err(Error::UnexpectedResponse("missing result".to_string()))
179 }
180 }
181 JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => {
182 if id != Some(self.next_id.load(Ordering::SeqCst) - 1) {
183 return Err(Error::UnexpectedResponse(
184 "id mismatch for JsonRpcError".to_string(),
185 ));
186 }
187 Err(Error::RpcError {
188 code: error.code,
189 message: error.message,
190 })
191 }
192 _ => {
193 Err(Error::UnexpectedResponse(
195 "unexpected message type".to_string(),
196 ))
197 }
198 }
199 }
200
201 async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> {
203 let mut service = self.service.lock().await;
204 service.ready().await.map_err(|_| Error::NotReady)?;
205
206 let notification = JsonRpcMessage::Notification(JsonRpcNotification {
207 jsonrpc: "2.0".to_string(),
208 method: method.to_string(),
209 params: Some(params.clone()),
210 });
211
212 service
213 .call(notification)
214 .await
215 .map_err(|e| Error::McpServerError {
216 server: self
217 .server_info
218 .as_ref()
219 .map(|s| s.name.clone())
220 .unwrap_or("".to_string()),
221 method: method.to_string(),
222 source: Box::new(e.into()),
224 })?;
225
226 Ok(())
227 }
228
229 fn completed_initialization(&self) -> bool {
231 self.server_capabilities.is_some()
232 }
233}
234
235#[async_trait::async_trait]
236impl<S> McpClientTrait for McpClient<S>
237where
238 S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
239 S::Error: Into<Error>,
240 S::Future: Send,
241{
242 async fn initialize(
243 &mut self,
244 info: ClientInfo,
245 capabilities: ClientCapabilities,
246 ) -> Result<InitializeResult, Error> {
247 let params = InitializeParams {
248 protocol_version: "1.0.0".into(),
249 client_info: info,
250 capabilities,
251 };
252 let result: InitializeResult = self
253 .send_request("initialize", serde_json::to_value(params)?)
254 .await?;
255
256 self.send_notification("notifications/initialized", serde_json::json!({}))
257 .await?;
258
259 self.server_capabilities = Some(result.capabilities.clone());
260
261 self.server_info = Some(result.server_info.clone());
262
263 Ok(result)
264 }
265
266 async fn list_resources(
267 &self,
268 next_cursor: Option<String>,
269 ) -> Result<ListResourcesResult, Error> {
270 if !self.completed_initialization() {
271 return Err(Error::NotInitialized);
272 }
273 if self
275 .server_capabilities
276 .as_ref()
277 .unwrap()
278 .resources
279 .is_none()
280 {
281 return Ok(ListResourcesResult {
282 resources: vec![],
283 next_cursor: None,
284 });
285 }
286
287 let payload = next_cursor
288 .map(|cursor| serde_json::json!({"cursor": cursor}))
289 .unwrap_or_else(|| serde_json::json!({}));
290
291 self.send_request("resources/list", payload).await
292 }
293
294 async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, Error> {
295 if !self.completed_initialization() {
296 return Err(Error::NotInitialized);
297 }
298 if self
300 .server_capabilities
301 .as_ref()
302 .unwrap()
303 .resources
304 .is_none()
305 {
306 return Err(Error::RpcError {
307 code: METHOD_NOT_FOUND,
308 message: "Server does not support 'resources' capability".to_string(),
309 });
310 }
311
312 let params = serde_json::json!({ "uri": uri });
313 self.send_request("resources/read", params).await
314 }
315
316 async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error> {
317 if !self.completed_initialization() {
318 return Err(Error::NotInitialized);
319 }
320 if self.server_capabilities.as_ref().unwrap().tools.is_none() {
322 return Ok(ListToolsResult {
323 tools: vec![],
324 next_cursor: None,
325 });
326 }
327
328 let payload = next_cursor
329 .map(|cursor| serde_json::json!({"cursor": cursor}))
330 .unwrap_or_else(|| serde_json::json!({}));
331
332 self.send_request("tools/list", payload).await
333 }
334
335 async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error> {
336 if !self.completed_initialization() {
337 return Err(Error::NotInitialized);
338 }
339 if self.server_capabilities.as_ref().unwrap().tools.is_none() {
341 return Err(Error::RpcError {
342 code: METHOD_NOT_FOUND,
343 message: "Server does not support 'tools' capability".to_string(),
344 });
345 }
346
347 let params = serde_json::json!({ "name": name, "arguments": arguments });
348
349 self.send_request("tools/call", params).await
352 }
353
354 async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error> {
355 if !self.completed_initialization() {
356 return Err(Error::NotInitialized);
357 }
358
359 if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
361 return Err(Error::RpcError {
362 code: METHOD_NOT_FOUND,
363 message: "Server does not support 'prompts' capability".to_string(),
364 });
365 }
366
367 let payload = next_cursor
368 .map(|cursor| serde_json::json!({"cursor": cursor}))
369 .unwrap_or_else(|| serde_json::json!({}));
370
371 self.send_request("prompts/list", payload).await
372 }
373
374 async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error> {
375 if !self.completed_initialization() {
376 return Err(Error::NotInitialized);
377 }
378
379 if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
381 return Err(Error::RpcError {
382 code: METHOD_NOT_FOUND,
383 message: "Server does not support 'prompts' capability".to_string(),
384 });
385 }
386
387 let params = serde_json::json!({ "name": name, "arguments": arguments });
388
389 self.send_request("prompts/get", params).await
390 }
391}