1use anyhow::{anyhow, Result};
9use serde_json::{json, Value};
10use std::collections::HashMap;
11use tokio::sync::RwLock;
12
13use super::transport::{create_transport, Transport, TransportConfig};
14use super::types::*;
15
16pub struct McpClient {
22 server_name: String,
24 transport: Box<dyn Transport>,
26 capabilities: RwLock<Option<ServerCapabilities>>,
28 server_info: RwLock<Option<Implementation>>,
30 tools_cache: RwLock<Vec<Tool>>,
32 request_id: RwLock<i64>,
34 initialized: RwLock<bool>,
36}
37
38impl McpClient {
39 pub async fn connect(
41 server_name: impl Into<String>,
42 config: TransportConfig,
43 ) -> Result<Self> {
44 let server_name = server_name.into();
45 let transport = create_transport(&server_name, &config).await?;
46
47 let client = Self {
48 server_name,
49 transport,
50 capabilities: RwLock::new(None),
51 server_info: RwLock::new(None),
52 tools_cache: RwLock::new(Vec::new()),
53 request_id: RwLock::new(0),
54 initialized: RwLock::new(false),
55 };
56
57 client.initialize().await?;
59
60 Ok(client)
61 }
62
63 pub fn server_name(&self) -> &str {
65 &self.server_name
66 }
67
68 pub async fn is_initialized(&self) -> bool {
70 *self.initialized.read().await
71 }
72
73 pub async fn capabilities(&self) -> Option<ServerCapabilities> {
75 self.capabilities.read().await.clone()
76 }
77
78 pub async fn server_info(&self) -> Option<Implementation> {
80 self.server_info.read().await.clone()
81 }
82
83 async fn next_request_id(&self) -> RequestId {
89 let mut id = self.request_id.write().await;
90 *id += 1;
91 RequestId::Number(*id)
92 }
93
94 async fn send_request<T: serde::de::DeserializeOwned>(
96 &self,
97 method: &str,
98 params: Option<Value>,
99 ) -> Result<T> {
100 let id = self.next_request_id().await;
101
102 let request = JsonRpcRequest {
103 jsonrpc: "2.0".to_string(),
104 id: id.clone(),
105 method: method.to_string(),
106 params,
107 };
108
109 let message = serde_json::to_string(&request)?;
110 tracing::debug!("MCP request to '{}': {}", self.server_name, message);
111
112 self.transport.notify(&message).await?;
114
115 loop {
117 let response = self.transport.receive().await?;
118 tracing::debug!("MCP message from '{}': {}", self.server_name, response);
119
120 if let Ok(server_req) = serde_json::from_str::<JsonRpcRequest>(&response) {
122 self.handle_server_request(&server_req).await?;
124 continue;
125 }
126
127 if let Ok(success) = serde_json::from_str::<JsonRpcResponse>(&response) {
129 if success.id != id {
130 continue;
132 }
133 return serde_json::from_value(success.result)
134 .map_err(|e| anyhow!("Failed to parse result: {}", e));
135 }
136
137 if let Ok(error) = serde_json::from_str::<JsonRpcError>(&response) {
139 if error.id != id {
140 continue;
141 }
142 return Err(anyhow!(
143 "MCP error from '{}': [{}] {}",
144 self.server_name,
145 error.error.code,
146 error.error.message
147 ));
148 }
149
150 if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(&response) {
152 tracing::debug!("MCP notification from '{}': {}", self.server_name, notification.method);
153 continue;
154 }
155
156 tracing::warn!("Unexpected MCP message format: {}", response);
158 }
159 }
160
161 async fn handle_server_request(&self, request: &JsonRpcRequest) -> Result<()> {
163 tracing::debug!("MCP server request '{}': {}", self.server_name, request.method);
164
165 match request.method.as_str() {
167 "roots/list" => {
168 let response = JsonRpcResponse {
170 jsonrpc: "2.0".to_string(),
171 id: request.id.clone(),
172 result: json!({ "roots": [] }),
173 };
174 let message = serde_json::to_string(&response)?;
175 self.transport.notify(&message).await?;
176 }
177 "ping" => {
178 let response = JsonRpcResponse {
180 jsonrpc: "2.0".to_string(),
181 id: request.id.clone(),
182 result: json!({}),
183 };
184 let message = serde_json::to_string(&response)?;
185 self.transport.notify(&message).await?;
186 }
187 _ => {
188 tracing::warn!("Unhandled MCP server request: {}", request.method);
189 let error_response = JsonRpcError {
191 jsonrpc: "2.0".to_string(),
192 id: request.id.clone(),
193 error: JsonRpcErrorDetail {
194 code: -32601,
195 message: "Method not found".to_string(),
196 data: None,
197 },
198 };
199 let message = serde_json::to_string(&error_response)?;
200 self.transport.notify(&message).await?;
201 }
202 }
203
204 Ok(())
205 }
206
207 async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<()> {
209 let notification = JsonRpcNotification {
210 jsonrpc: "2.0".to_string(),
211 method: method.to_string(),
212 params,
213 };
214
215 let message = serde_json::to_string(¬ification)?;
216 self.transport.notify(&message).await?;
217 Ok(())
218 }
219
220 async fn initialize(&self) -> Result<()> {
226 tracing::info!("Initializing MCP server '{}'", self.server_name);
227
228 let params = InitializeParams {
230 capabilities: ClientCapabilities {
231 roots: Some(RootsCapability {
232 list_changed: Some(false),
233 }),
234 ..Default::default()
235 },
236 client_info: Implementation::default(),
237 protocol_version: Some("2024-11-05".to_string()),
238 };
239
240 let result: InitializeResult = self.send_request(
241 "initialize",
242 Some(serde_json::to_value(params)?),
243 ).await?;
244
245 let server_name = result.server_info.name.clone();
247 let server_version = result.server_info.version.clone();
248
249 *self.capabilities.write().await = Some(result.capabilities);
250 *self.server_info.write().await = Some(result.server_info);
251
252 tracing::info!(
253 "MCP server '{}' initialized: {} v{}",
254 self.server_name,
255 server_name,
256 server_version
257 );
258
259 self.send_notification("notifications/initialized", None).await?;
261
262 *self.initialized.write().await = true;
263 Ok(())
264 }
265
266 pub async fn list_tools(&self) -> Result<Vec<Tool>> {
272 if !self.is_initialized().await {
273 return Err(anyhow!("MCP client not initialized"));
274 }
275
276 let result: ListToolsResult = self.send_request("tools/list", None).await?;
277
278 *self.tools_cache.write().await = result.tools.clone();
280
281 Ok(result.tools)
282 }
283
284 pub async fn cached_tools(&self) -> Vec<Tool> {
286 self.tools_cache.read().await.clone()
287 }
288
289 pub async fn call_tool(
291 &self,
292 name: &str,
293 arguments: Option<Value>,
294 ) -> Result<CallToolResult> {
295 if !self.is_initialized().await {
296 return Err(anyhow!("MCP client not initialized"));
297 }
298
299 let params = CallToolParams {
300 name: name.to_string(),
301 arguments,
302 };
303
304 self.send_request("tools/call", Some(serde_json::to_value(params)?)).await
305 }
306
307 pub async fn supports_tools(&self) -> bool {
309 self.capabilities.read().await
310 .as_ref()
311 .map(|c| c.tools.is_some())
312 .unwrap_or(false)
313 }
314
315 pub async fn list_resources(&self) -> Result<Vec<Resource>> {
321 if !self.is_initialized().await {
322 return Err(anyhow!("MCP client not initialized"));
323 }
324
325 let result: ListResourcesResult = self.send_request("resources/list", None).await?;
326 Ok(result.resources)
327 }
328
329 pub async fn read_resource(&self, uri: &str) -> Result<Value> {
331 if !self.is_initialized().await {
332 return Err(anyhow!("MCP client not initialized"));
333 }
334
335 self.send_request("resources/read", Some(json!({ "uri": uri }))).await
336 }
337
338 pub async fn supports_resources(&self) -> bool {
340 self.capabilities.read().await
341 .as_ref()
342 .map(|c| c.resources.is_some())
343 .unwrap_or(false)
344 }
345
346 pub async fn list_prompts(&self) -> Result<Vec<Prompt>> {
352 if !self.is_initialized().await {
353 return Err(anyhow!("MCP client not initialized"));
354 }
355
356 let result: ListPromptsResult = self.send_request("prompts/list", None).await?;
357 Ok(result.prompts)
358 }
359
360 pub async fn get_prompt(&self, name: &str, arguments: Option<HashMap<String, String>>) -> Result<Value> {
362 if !self.is_initialized().await {
363 return Err(anyhow!("MCP client not initialized"));
364 }
365
366 let mut params = json!({ "name": name });
367 if let Some(args) = arguments {
368 params["arguments"] = serde_json::to_value(args)?;
369 }
370
371 self.send_request("prompts/get", Some(params)).await
372 }
373
374 pub async fn supports_prompts(&self) -> bool {
376 self.capabilities.read().await
377 .as_ref()
378 .map(|c| c.prompts.is_some())
379 .unwrap_or(false)
380 }
381
382 pub async fn set_logging_level(&self, level: LogLevel) -> Result<()> {
388 if !self.is_initialized().await {
389 return Err(anyhow!("MCP client not initialized"));
390 }
391
392 let params = SetLoggingLevelParams { level };
393 self.send_request("logging/setLevel", Some(serde_json::to_value(params)?)).await
394 }
395
396 pub async fn shutdown(&self) -> Result<()> {
402 tracing::info!("Shutting down MCP server '{}'", self.server_name);
403 self.transport.close().await
404 }
405}
406
407pub struct McpClientBuilder {
413 server_name: String,
414 config: TransportConfig,
415}
416
417impl McpClientBuilder {
418 pub fn new(name: impl Into<String>) -> Self {
420 Self {
421 server_name: name.into(),
422 config: TransportConfig::stdio("", vec![]),
423 }
424 }
425
426 pub fn stdio(mut self, command: impl Into<String>, args: Vec<String>) -> Self {
428 self.config = TransportConfig::stdio(command, args);
429 self
430 }
431
432 pub fn sse(mut self, url: impl Into<String>) -> Self {
434 self.config = TransportConfig::sse(url);
435 self
436 }
437
438 pub async fn connect(self) -> Result<McpClient> {
440 McpClient::connect(self.server_name, self.config).await
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[test]
449 fn test_client_builder() {
450 let builder = McpClientBuilder::new("test")
451 .stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
452
453 assert_eq!(builder.server_name, "test");
454 }
455}