1use crate::mcp::protocol::{
6 CallToolParams, CallToolResult, ClientCapabilities, ClientInfo, InitializeParams,
7 InitializeResult, JsonRpcNotification, JsonRpcRequest, ListResourcesResult, ListToolsResult,
8 McpNotification, McpResource, McpTool, ReadResourceParams, ReadResourceResult,
9 ServerCapabilities, PROTOCOL_VERSION,
10};
11use crate::mcp::transport::McpTransport;
12use anyhow::{anyhow, Result};
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17pub struct McpClient {
19 pub name: String,
21 transport: Arc<dyn McpTransport>,
23 capabilities: RwLock<ServerCapabilities>,
25 tools: RwLock<Vec<McpTool>>,
27 resources: RwLock<Vec<McpResource>>,
29 request_id: AtomicU64,
31 initialized: RwLock<bool>,
33}
34
35impl McpClient {
36 pub fn new(name: String, transport: Arc<dyn McpTransport>) -> Self {
38 Self {
39 name,
40 transport,
41 capabilities: RwLock::new(ServerCapabilities::default()),
42 tools: RwLock::new(Vec::new()),
43 resources: RwLock::new(Vec::new()),
44 request_id: AtomicU64::new(1),
45 initialized: RwLock::new(false),
46 }
47 }
48
49 fn next_id(&self) -> u64 {
51 self.request_id.fetch_add(1, Ordering::SeqCst)
52 }
53
54 pub async fn initialize(&self) -> Result<InitializeResult> {
56 let params = InitializeParams {
57 protocol_version: PROTOCOL_VERSION.to_string(),
58 capabilities: ClientCapabilities::default(),
59 client_info: ClientInfo {
60 name: "a3s-code".to_string(),
61 version: env!("CARGO_PKG_VERSION").to_string(),
62 },
63 };
64
65 let request = JsonRpcRequest::new(
66 self.next_id(),
67 "initialize",
68 Some(serde_json::to_value(¶ms)?),
69 );
70
71 let response = self.transport.request(request).await?;
72
73 if let Some(error) = response.error {
74 return Err(anyhow!(
75 "MCP initialize error: {} ({})",
76 error.message,
77 error.code
78 ));
79 }
80
81 let result: InitializeResult = serde_json::from_value(
82 response
83 .result
84 .ok_or_else(|| anyhow!("No result in response"))?,
85 )?;
86
87 {
89 let mut caps = self.capabilities.write().await;
90 *caps = result.capabilities.clone();
91 }
92
93 let notification = JsonRpcNotification::new("notifications/initialized", None);
95 self.transport.notify(notification).await?;
96
97 {
99 let mut init = self.initialized.write().await;
100 *init = true;
101 }
102
103 tracing::info!(
104 "MCP client '{}' initialized with server '{}' v{}",
105 self.name,
106 result.server_info.name,
107 result.server_info.version
108 );
109
110 Ok(result)
111 }
112
113 pub async fn is_initialized(&self) -> bool {
115 *self.initialized.read().await
116 }
117
118 pub async fn capabilities(&self) -> ServerCapabilities {
120 self.capabilities.read().await.clone()
121 }
122
123 pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
125 let request = JsonRpcRequest::new(self.next_id(), "tools/list", None);
126 let response = self.transport.request(request).await?;
127
128 if let Some(error) = response.error {
129 return Err(anyhow!(
130 "MCP list_tools error: {} ({})",
131 error.message,
132 error.code
133 ));
134 }
135
136 let result: ListToolsResult =
137 serde_json::from_value(response.result.ok_or_else(|| anyhow!("No result"))?)?;
138
139 {
141 let mut tools = self.tools.write().await;
142 *tools = result.tools.clone();
143 }
144
145 Ok(result.tools)
146 }
147
148 pub async fn get_cached_tools(&self) -> Vec<McpTool> {
150 self.tools.read().await.clone()
151 }
152
153 pub async fn call_tool(
155 &self,
156 name: &str,
157 arguments: Option<serde_json::Value>,
158 ) -> Result<CallToolResult> {
159 let params = CallToolParams {
160 name: name.to_string(),
161 arguments,
162 };
163
164 let request = JsonRpcRequest::new(
165 self.next_id(),
166 "tools/call",
167 Some(serde_json::to_value(¶ms)?),
168 );
169
170 let response = self.transport.request(request).await?;
171
172 if let Some(error) = response.error {
173 return Err(anyhow!(
174 "MCP call_tool error: {} ({})",
175 error.message,
176 error.code
177 ));
178 }
179
180 let result: CallToolResult =
181 serde_json::from_value(response.result.ok_or_else(|| anyhow!("No result"))?)?;
182
183 Ok(result)
184 }
185
186 pub async fn list_resources(&self) -> Result<Vec<McpResource>> {
188 let request = JsonRpcRequest::new(self.next_id(), "resources/list", None);
189 let response = self.transport.request(request).await?;
190
191 if let Some(error) = response.error {
192 return Err(anyhow!(
193 "MCP list_resources error: {} ({})",
194 error.message,
195 error.code
196 ));
197 }
198
199 let result: ListResourcesResult =
200 serde_json::from_value(response.result.ok_or_else(|| anyhow!("No result"))?)?;
201
202 {
204 let mut resources = self.resources.write().await;
205 *resources = result.resources.clone();
206 }
207
208 Ok(result.resources)
209 }
210
211 pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult> {
213 let params = ReadResourceParams {
214 uri: uri.to_string(),
215 };
216
217 let request = JsonRpcRequest::new(
218 self.next_id(),
219 "resources/read",
220 Some(serde_json::to_value(¶ms)?),
221 );
222
223 let response = self.transport.request(request).await?;
224
225 if let Some(error) = response.error {
226 return Err(anyhow!(
227 "MCP read_resource error: {} ({})",
228 error.message,
229 error.code
230 ));
231 }
232
233 let result: ReadResourceResult =
234 serde_json::from_value(response.result.ok_or_else(|| anyhow!("No result"))?)?;
235
236 Ok(result)
237 }
238
239 pub fn notifications(&self) -> tokio::sync::mpsc::Receiver<McpNotification> {
241 self.transport.notifications()
242 }
243
244 pub async fn close(&self) -> Result<()> {
246 self.transport.close().await
247 }
248
249 pub fn is_connected(&self) -> bool {
251 self.transport.is_connected()
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_client_info() {
261 let info = ClientInfo {
262 name: "test".to_string(),
263 version: "1.0.0".to_string(),
264 };
265 let json = serde_json::to_string(&info).unwrap();
266 assert!(json.contains("test"));
267 }
268
269 #[test]
270 fn test_initialize_params() {
271 let params = InitializeParams {
272 protocol_version: PROTOCOL_VERSION.to_string(),
273 capabilities: ClientCapabilities::default(),
274 client_info: ClientInfo {
275 name: "a3s-code".to_string(),
276 version: "0.1.0".to_string(),
277 },
278 };
279 let json = serde_json::to_string(¶ms).unwrap();
280 assert!(json.contains("protocolVersion"));
281 assert!(json.contains("clientInfo"));
282 }
283
284 #[test]
285 fn test_client_info_serialize() {
286 let info = ClientInfo {
287 name: "test-client".to_string(),
288 version: "2.0.0".to_string(),
289 };
290 let json = serde_json::to_string(&info).unwrap();
291 assert!(json.contains("test-client"));
292 assert!(json.contains("2.0.0"));
293 }
294
295 #[test]
296 fn test_client_info_deserialize() {
297 let json = r#"{"name":"my-client","version":"1.2.3"}"#;
298 let info: ClientInfo = serde_json::from_str(json).unwrap();
299 assert_eq!(info.name, "my-client");
300 assert_eq!(info.version, "1.2.3");
301 }
302
303 #[test]
304 fn test_initialize_params_serialize() {
305 let params = InitializeParams {
306 protocol_version: "2024-11-05".to_string(),
307 capabilities: ClientCapabilities::default(),
308 client_info: ClientInfo {
309 name: "test".to_string(),
310 version: "1.0.0".to_string(),
311 },
312 };
313 let json = serde_json::to_string(¶ms).unwrap();
314 assert!(json.contains("2024-11-05"));
315 assert!(json.contains("capabilities"));
316 }
317
318 #[test]
319 fn test_call_tool_params_serialize() {
320 let params = CallToolParams {
321 name: "test_tool".to_string(),
322 arguments: Some(serde_json::json!({"key": "value"})),
323 };
324 let json = serde_json::to_string(¶ms).unwrap();
325 assert!(json.contains("test_tool"));
326 assert!(json.contains("key"));
327 }
328
329 #[test]
330 fn test_call_tool_params_no_arguments() {
331 let params = CallToolParams {
332 name: "simple_tool".to_string(),
333 arguments: None,
334 };
335 let json = serde_json::to_string(¶ms).unwrap();
336 assert!(json.contains("simple_tool"));
337 }
338
339 #[test]
340 fn test_read_resource_params_serialize() {
341 let params = ReadResourceParams {
342 uri: "file:///test.txt".to_string(),
343 };
344 let json = serde_json::to_string(¶ms).unwrap();
345 assert!(json.contains("file:///test.txt"));
346 }
347
348 #[test]
349 fn test_read_resource_params_deserialize() {
350 let json = r#"{"uri":"http://example.com/resource"}"#;
351 let params: ReadResourceParams = serde_json::from_str(json).unwrap();
352 assert_eq!(params.uri, "http://example.com/resource");
353 }
354
355 #[test]
356 fn test_server_capabilities_default() {
357 let caps = ServerCapabilities::default();
358 let json = serde_json::to_string(&caps).unwrap();
359 assert!(!json.is_empty());
360 }
361
362 #[test]
363 fn test_client_capabilities_default() {
364 let caps = ClientCapabilities::default();
365 let json = serde_json::to_string(&caps).unwrap();
366 assert!(!json.is_empty());
367 }
368
369 #[test]
370 fn test_protocol_version_constant() {
371 assert!(!PROTOCOL_VERSION.is_empty());
372 assert!(PROTOCOL_VERSION.contains("-"));
373 }
374}