1use mcp_core_fishcode2025::protocol::{
2 CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError,
3 JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult,
4 ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
5};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::sync::atomic::{AtomicU64, Ordering};
9use thiserror::Error;
10use tokio::sync::Mutex;
11use tower::Service;
12use tower::ServiceExt; pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
15
16#[derive(Debug, Error)]
18pub enum Error {
19 #[error("Transport error: {0}")]
20 Transport(#[from] super::transport::Error),
21
22 #[error("RPC error: code={code}, message={message}")]
23 RpcError { code: i32, message: String },
24
25 #[error("Serialization error: {0}")]
26 Serialization(#[from] serde_json::Error),
27
28 #[error("Unexpected response from server: {0}")]
29 UnexpectedResponse(String),
30
31 #[error("Not initialized")]
32 NotInitialized,
33
34 #[error("Timeout or service not ready")]
35 NotReady,
36
37 #[error("Request timed out")]
38 Timeout(#[from] tower::timeout::error::Elapsed),
39
40 #[error("Error from mcp-server: {0}")]
41 ServerBoxError(BoxError),
42
43 #[error("Call to '{server}' failed for '{method}'. {source}")]
44 McpServerError {
45 method: String,
46 server: String,
47 #[source]
48 source: BoxError,
49 },
50}
51
52impl From<BoxError> for Error {
54 fn from(err: BoxError) -> Self {
55 Error::ServerBoxError(err)
56 }
57}
58
59#[derive(Serialize, Deserialize)]
60pub struct ClientInfo {
61 pub name: String,
62 pub version: String,
63}
64
65#[derive(Serialize, Deserialize, Default)]
66pub struct ClientCapabilities {
67 }
69
70#[derive(Serialize, Deserialize)]
71pub struct InitializeParams {
72 #[serde(rename = "protocolVersion")]
73 pub protocol_version: String,
74 pub capabilities: ClientCapabilities,
75 #[serde(rename = "clientInfo")]
76 pub client_info: ClientInfo,
77}
78
79#[async_trait::async_trait]
80pub trait McpClientTrait: Send + Sync {
81 async fn initialize(
82 &mut self,
83 info: ClientInfo,
84 capabilities: ClientCapabilities,
85 ) -> Result<InitializeResult, Error>;
86
87 async fn list_resources(
88 &self,
89 next_cursor: Option<String>,
90 ) -> Result<ListResourcesResult, Error>;
91
92 async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, Error>;
93
94 async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error>;
95
96 async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error>;
97
98 async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
99
100 async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
101}
102
103pub struct McpClient<S>
105where
106 S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
107 S::Error: Into<Error>,
108 S::Future: Send,
109{
110 service: Mutex<S>,
111 next_id: AtomicU64,
112 server_capabilities: Option<ServerCapabilities>,
113 server_info: Option<Implementation>,
114}
115
116impl<S> McpClient<S>
117where
118 S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
119 S::Error: Into<Error>,
120 S::Future: Send,
121{
122 pub fn new(service: S) -> Self {
123 Self {
124 service: Mutex::new(service),
125 next_id: AtomicU64::new(1),
126 server_capabilities: None,
127 server_info: None,
128 }
129 }
130
131 async fn send_request<R>(&self, method: &str, params: Value) -> Result<R, Error>
133 where
134 R: for<'de> Deserialize<'de>,
135 {
136 let mut service = self.service.lock().await;
137 service.ready().await.map_err(|_| Error::NotReady)?;
138
139 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
140 let request = JsonRpcMessage::Request(JsonRpcRequest {
141 jsonrpc: "2.0".to_string(),
142 id: Some(id),
143 method: method.to_string(),
144 params: Some(params.clone()),
145 });
146
147 let response_msg = service
148 .call(request)
149 .await
150 .map_err(|e| Error::McpServerError {
151 server: self
152 .server_info
153 .as_ref()
154 .map(|s| s.name.clone())
155 .unwrap_or("".to_string()),
156 method: method.to_string(),
157 source: Box::new(e.into()),
159 })?;
160
161 match response_msg {
162 JsonRpcMessage::Response(JsonRpcResponse {
163 id, result, error, ..
164 }) => {
165 if id != Some(self.next_id.load(Ordering::SeqCst) - 1) {
167 return Err(Error::UnexpectedResponse(
168 "id mismatch for JsonRpcResponse".to_string(),
169 ));
170 }
171 if let Some(err) = error {
172 Err(Error::RpcError {
173 code: err.code,
174 message: err.message,
175 })
176 } else if let Some(r) = result {
177 Ok(serde_json::from_value(r)?)
178 } else {
179 Err(Error::UnexpectedResponse("missing result".to_string()))
180 }
181 }
182 JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => {
183 if id != Some(self.next_id.load(Ordering::SeqCst) - 1) {
184 return Err(Error::UnexpectedResponse(
185 "id mismatch for JsonRpcError".to_string(),
186 ));
187 }
188 Err(Error::RpcError {
189 code: error.code,
190 message: error.message,
191 })
192 }
193 _ => {
194 Err(Error::UnexpectedResponse(
196 "unexpected message type".to_string(),
197 ))
198 }
199 }
200 }
201
202 async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> {
204 let mut service = self.service.lock().await;
205 service.ready().await.map_err(|_| Error::NotReady)?;
206
207 let notification = JsonRpcMessage::Notification(JsonRpcNotification {
208 jsonrpc: "2.0".to_string(),
209 method: method.to_string(),
210 params: Some(params.clone()),
211 });
212
213 service
214 .call(notification)
215 .await
216 .map_err(|e| Error::McpServerError {
217 server: self
218 .server_info
219 .as_ref()
220 .map(|s| s.name.clone())
221 .unwrap_or("".to_string()),
222 method: method.to_string(),
223 source: Box::new(e.into()),
225 })?;
226
227 Ok(())
228 }
229
230 fn completed_initialization(&self) -> bool {
232 self.server_capabilities.is_some()
233 }
234}
235
236#[async_trait::async_trait]
237impl<S> McpClientTrait for McpClient<S>
238where
239 S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
240 S::Error: Into<Error>,
241 S::Future: Send,
242{
243 async fn initialize(
244 &mut self,
245 info: ClientInfo,
246 capabilities: ClientCapabilities,
247 ) -> Result<InitializeResult, Error> {
248 let params = InitializeParams {
249 protocol_version: "1.0.0".into(),
250 client_info: info,
251 capabilities,
252 };
253 let result: InitializeResult = self
254 .send_request("initialize", serde_json::to_value(params)?)
255 .await?;
256
257 self.send_notification("notifications/initialized", serde_json::json!({}))
258 .await?;
259
260 self.server_capabilities = Some(result.capabilities.clone());
261
262 self.server_info = Some(result.server_info.clone());
263
264 Ok(result)
265 }
266
267 async fn list_resources(
268 &self,
269 next_cursor: Option<String>,
270 ) -> Result<ListResourcesResult, Error> {
271 if !self.completed_initialization() {
272 return Err(Error::NotInitialized);
273 }
274 if self
276 .server_capabilities
277 .as_ref()
278 .unwrap()
279 .resources
280 .is_none()
281 {
282 return Ok(ListResourcesResult {
283 resources: vec![],
284 next_cursor: None,
285 });
286 }
287
288 let payload = next_cursor
289 .map(|cursor| serde_json::json!({"cursor": cursor}))
290 .unwrap_or_else(|| serde_json::json!({}));
291
292 self.send_request("resources/list", payload).await
293 }
294
295 async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, Error> {
296 if !self.completed_initialization() {
297 return Err(Error::NotInitialized);
298 }
299 if self
301 .server_capabilities
302 .as_ref()
303 .unwrap()
304 .resources
305 .is_none()
306 {
307 return Err(Error::RpcError {
308 code: METHOD_NOT_FOUND,
309 message: "Server does not support 'resources' capability".to_string(),
310 });
311 }
312
313 let params = serde_json::json!({ "uri": uri });
314 self.send_request("resources/read", params).await
315 }
316
317 async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error> {
318 if !self.completed_initialization() {
319 return Err(Error::NotInitialized);
320 }
321 if self.server_capabilities.as_ref().unwrap().tools.is_none() {
323 return Ok(ListToolsResult {
324 tools: vec![],
325 next_cursor: None,
326 });
327 }
328
329 let payload = next_cursor
330 .map(|cursor| serde_json::json!({"cursor": cursor}))
331 .unwrap_or_else(|| serde_json::json!({}));
332
333 self.send_request("tools/list", payload).await
334 }
335
336 async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error> {
337 if !self.completed_initialization() {
338 return Err(Error::NotInitialized);
339 }
340 if self.server_capabilities.as_ref().unwrap().tools.is_none() {
342 return Err(Error::RpcError {
343 code: METHOD_NOT_FOUND,
344 message: "Server does not support 'tools' capability".to_string(),
345 });
346 }
347
348 let params = serde_json::json!({ "name": name, "arguments": arguments });
349
350 self.send_request("tools/call", params).await
353 }
354
355 async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error> {
356 if !self.completed_initialization() {
357 return Err(Error::NotInitialized);
358 }
359
360 if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
362 return Err(Error::RpcError {
363 code: METHOD_NOT_FOUND,
364 message: "Server does not support 'prompts' capability".to_string(),
365 });
366 }
367
368 let payload = next_cursor
369 .map(|cursor| serde_json::json!({"cursor": cursor}))
370 .unwrap_or_else(|| serde_json::json!({}));
371
372 self.send_request("prompts/list", payload).await
373 }
374
375 async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error> {
376 if !self.completed_initialization() {
377 return Err(Error::NotInitialized);
378 }
379
380 if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
382 return Err(Error::RpcError {
383 code: METHOD_NOT_FOUND,
384 message: "Server does not support 'prompts' capability".to_string(),
385 });
386 }
387
388 let params = serde_json::json!({ "name": name, "arguments": arguments });
389
390 self.send_request("prompts/get", params).await
391 }
392}