1use anyhow::{Result, anyhow};
9use serde_json::{Value, json};
10use std::collections::HashMap;
11use tokio::sync::RwLock;
12
13use super::transport::{Transport, TransportConfig, create_transport};
14use super::types::*;
15
16pub struct McpClient {
22 server_name: String,
24 transport: Box<dyn Transport>,
26 capabilities: RwLock<Option<ServerCapabilities>>,
28 server_info: RwLock<Option<Implementation>>,
30 tools_cache: RwLock<Vec<Tool>>,
32 request_id: RwLock<i64>,
34 initialized: RwLock<bool>,
36}
37
38impl McpClient {
39 pub async fn connect(server_name: impl Into<String>, config: TransportConfig) -> Result<Self> {
41 let server_name = server_name.into();
42 let transport = create_transport(&server_name, &config).await?;
43
44 let client = Self {
45 server_name,
46 transport,
47 capabilities: RwLock::new(None),
48 server_info: RwLock::new(None),
49 tools_cache: RwLock::new(Vec::new()),
50 request_id: RwLock::new(0),
51 initialized: RwLock::new(false),
52 };
53
54 client.initialize().await?;
56
57 Ok(client)
58 }
59
60 pub fn server_name(&self) -> &str {
62 &self.server_name
63 }
64
65 pub async fn is_initialized(&self) -> bool {
67 *self.initialized.read().await
68 }
69
70 pub async fn capabilities(&self) -> Option<ServerCapabilities> {
72 self.capabilities.read().await.clone()
73 }
74
75 pub async fn server_info(&self) -> Option<Implementation> {
77 self.server_info.read().await.clone()
78 }
79
80 async fn next_request_id(&self) -> RequestId {
86 let mut id = self.request_id.write().await;
87 *id += 1;
88 RequestId::Number(*id)
89 }
90
91 async fn send_request<T: serde::de::DeserializeOwned>(
93 &self,
94 method: &str,
95 params: Option<Value>,
96 ) -> Result<T> {
97 let id = self.next_request_id().await;
98
99 let request = JsonRpcRequest {
100 jsonrpc: "2.0".to_string(),
101 id: id.clone(),
102 method: method.to_string(),
103 params,
104 };
105
106 let message = serde_json::to_string(&request)?;
107 tracing::debug!("MCP request to '{}': {}", self.server_name, message);
108
109 self.transport.notify(&message).await?;
111
112 loop {
114 let response = self.transport.receive().await?;
115 tracing::debug!("MCP message from '{}': {}", self.server_name, response);
116
117 if let Ok(server_req) = serde_json::from_str::<JsonRpcRequest>(&response) {
119 self.handle_server_request(&server_req).await?;
121 continue;
122 }
123
124 if let Ok(success) = serde_json::from_str::<JsonRpcResponse>(&response) {
126 if success.id != id {
127 continue;
129 }
130 return serde_json::from_value(success.result)
131 .map_err(|e| anyhow!("Failed to parse result: {}", e));
132 }
133
134 if let Ok(error) = serde_json::from_str::<JsonRpcError>(&response) {
136 if error.id != id {
137 continue;
138 }
139 return Err(anyhow!(
140 "MCP error from '{}': [{}] {}",
141 self.server_name,
142 error.error.code,
143 error.error.message
144 ));
145 }
146
147 if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(&response) {
149 tracing::debug!(
150 "MCP notification from '{}': {}",
151 self.server_name,
152 notification.method
153 );
154 continue;
155 }
156
157 tracing::warn!("Unexpected MCP message format: {}", response);
159 }
160 }
161
162 async fn handle_server_request(&self, request: &JsonRpcRequest) -> Result<()> {
164 tracing::debug!(
165 "MCP server request '{}': {}",
166 self.server_name,
167 request.method
168 );
169
170 match request.method.as_str() {
172 "roots/list" => {
173 let response = JsonRpcResponse {
175 jsonrpc: "2.0".to_string(),
176 id: request.id.clone(),
177 result: json!({ "roots": [] }),
178 };
179 let message = serde_json::to_string(&response)?;
180 self.transport.notify(&message).await?;
181 }
182 "ping" => {
183 let response = JsonRpcResponse {
185 jsonrpc: "2.0".to_string(),
186 id: request.id.clone(),
187 result: json!({}),
188 };
189 let message = serde_json::to_string(&response)?;
190 self.transport.notify(&message).await?;
191 }
192 _ => {
193 tracing::warn!("Unhandled MCP server request: {}", request.method);
194 let error_response = JsonRpcError {
196 jsonrpc: "2.0".to_string(),
197 id: request.id.clone(),
198 error: JsonRpcErrorDetail {
199 code: -32601,
200 message: "Method not found".to_string(),
201 data: None,
202 },
203 };
204 let message = serde_json::to_string(&error_response)?;
205 self.transport.notify(&message).await?;
206 }
207 }
208
209 Ok(())
210 }
211
212 async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<()> {
214 let notification = JsonRpcNotification {
215 jsonrpc: "2.0".to_string(),
216 method: method.to_string(),
217 params,
218 };
219
220 let message = serde_json::to_string(¬ification)?;
221 self.transport.notify(&message).await?;
222 Ok(())
223 }
224
225 async fn initialize(&self) -> Result<()> {
231 tracing::info!("Initializing MCP server '{}'", self.server_name);
232
233 let params = InitializeParams {
235 capabilities: ClientCapabilities {
236 roots: Some(RootsCapability {
237 list_changed: Some(false),
238 }),
239 ..Default::default()
240 },
241 client_info: Implementation::default(),
242 protocol_version: Some("2024-11-05".to_string()),
243 };
244
245 let result: InitializeResult = self
246 .send_request("initialize", Some(serde_json::to_value(params)?))
247 .await?;
248
249 let server_name = result.server_info.name.clone();
251 let server_version = result.server_info.version.clone();
252
253 *self.capabilities.write().await = Some(result.capabilities);
254 *self.server_info.write().await = Some(result.server_info);
255
256 tracing::info!(
257 "MCP server '{}' initialized: {} v{}",
258 self.server_name,
259 server_name,
260 server_version
261 );
262
263 self.send_notification("notifications/initialized", None)
265 .await?;
266
267 *self.initialized.write().await = true;
268 Ok(())
269 }
270
271 pub async fn list_tools(&self) -> Result<Vec<Tool>> {
277 if !self.is_initialized().await {
278 return Err(anyhow!("MCP client not initialized"));
279 }
280
281 let result: ListToolsResult = self.send_request("tools/list", None).await?;
282
283 *self.tools_cache.write().await = result.tools.clone();
285
286 Ok(result.tools)
287 }
288
289 pub async fn cached_tools(&self) -> Vec<Tool> {
291 self.tools_cache.read().await.clone()
292 }
293
294 pub async fn call_tool(&self, name: &str, arguments: Option<Value>) -> Result<CallToolResult> {
296 if !self.is_initialized().await {
297 return Err(anyhow!("MCP client not initialized"));
298 }
299
300 let params = CallToolParams {
301 name: name.to_string(),
302 arguments,
303 };
304
305 self.send_request("tools/call", Some(serde_json::to_value(params)?))
306 .await
307 }
308
309 pub async fn supports_tools(&self) -> bool {
311 self.capabilities
312 .read()
313 .await
314 .as_ref()
315 .map(|c| c.tools.is_some())
316 .unwrap_or(false)
317 }
318
319 pub async fn list_resources(&self) -> Result<Vec<Resource>> {
325 if !self.is_initialized().await {
326 return Err(anyhow!("MCP client not initialized"));
327 }
328
329 let result: ListResourcesResult = self.send_request("resources/list", None).await?;
330 Ok(result.resources)
331 }
332
333 pub async fn read_resource(&self, uri: &str) -> Result<Value> {
335 if !self.is_initialized().await {
336 return Err(anyhow!("MCP client not initialized"));
337 }
338
339 self.send_request("resources/read", Some(json!({ "uri": uri })))
340 .await
341 }
342
343 pub async fn supports_resources(&self) -> bool {
345 self.capabilities
346 .read()
347 .await
348 .as_ref()
349 .map(|c| c.resources.is_some())
350 .unwrap_or(false)
351 }
352
353 pub async fn list_prompts(&self) -> Result<Vec<Prompt>> {
359 if !self.is_initialized().await {
360 return Err(anyhow!("MCP client not initialized"));
361 }
362
363 let result: ListPromptsResult = self.send_request("prompts/list", None).await?;
364 Ok(result.prompts)
365 }
366
367 pub async fn get_prompt(
369 &self,
370 name: &str,
371 arguments: Option<HashMap<String, String>>,
372 ) -> Result<Value> {
373 if !self.is_initialized().await {
374 return Err(anyhow!("MCP client not initialized"));
375 }
376
377 let mut params = json!({ "name": name });
378 if let Some(args) = arguments {
379 params["arguments"] = serde_json::to_value(args)?;
380 }
381
382 self.send_request("prompts/get", Some(params)).await
383 }
384
385 pub async fn supports_prompts(&self) -> bool {
387 self.capabilities
388 .read()
389 .await
390 .as_ref()
391 .map(|c| c.prompts.is_some())
392 .unwrap_or(false)
393 }
394
395 pub async fn set_logging_level(&self, level: LogLevel) -> Result<()> {
401 if !self.is_initialized().await {
402 return Err(anyhow!("MCP client not initialized"));
403 }
404
405 let params = SetLoggingLevelParams { level };
406 self.send_request("logging/setLevel", Some(serde_json::to_value(params)?))
407 .await
408 }
409
410 pub async fn shutdown(&self) -> Result<()> {
416 tracing::info!("Shutting down MCP server '{}'", self.server_name);
417 self.transport.close().await
418 }
419}
420
421pub struct McpClientBuilder {
427 server_name: String,
428 config: TransportConfig,
429}
430
431impl McpClientBuilder {
432 pub fn new(name: impl Into<String>) -> Self {
434 Self {
435 server_name: name.into(),
436 config: TransportConfig::stdio("", vec![]),
437 }
438 }
439
440 pub fn stdio(mut self, command: impl Into<String>, args: Vec<String>) -> Self {
442 self.config = TransportConfig::stdio(command, args);
443 self
444 }
445
446 pub fn sse(mut self, url: impl Into<String>) -> Self {
448 self.config = TransportConfig::sse(url);
449 self
450 }
451
452 pub async fn connect(self) -> Result<McpClient> {
454 McpClient::connect(self.server_name, self.config).await
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_client_builder() {
464 let builder =
465 McpClientBuilder::new("test").stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
466
467 assert_eq!(builder.server_name, "test");
468 }
469}