Skip to main content

a3s_code_core/mcp/
client.rs

1//! MCP Client
2//!
3//! Provides a high-level client for interacting with MCP servers.
4
5use crate::mcp::protocol::{
6    CallToolParams, CallToolResult, ClientCapabilities, ClientInfo, InitializeParams,
7    InitializeResult, JsonRpcNotification, JsonRpcRequest, ListResourcesResult, ListToolsResult,
8    McpNotification, McpResource, McpTool, ReadResourceParams, ReadResourceResult,
9    ServerCapabilities, PROTOCOL_VERSION,
10};
11use crate::mcp::transport::McpTransport;
12use anyhow::{anyhow, Result};
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17/// MCP client for communicating with MCP servers
18pub struct McpClient {
19    /// Server name
20    pub name: String,
21    /// Transport layer
22    transport: Arc<dyn McpTransport>,
23    /// Server capabilities (after initialization)
24    capabilities: RwLock<ServerCapabilities>,
25    /// Cached tools
26    tools: RwLock<Vec<McpTool>>,
27    /// Cached resources
28    resources: RwLock<Vec<McpResource>>,
29    /// Request ID counter
30    request_id: AtomicU64,
31    /// Initialized flag
32    initialized: RwLock<bool>,
33}
34
35impl McpClient {
36    /// Create a new MCP client with the given transport
37    pub fn new(name: String, transport: Arc<dyn McpTransport>) -> Self {
38        Self {
39            name,
40            transport,
41            capabilities: RwLock::new(ServerCapabilities::default()),
42            tools: RwLock::new(Vec::new()),
43            resources: RwLock::new(Vec::new()),
44            request_id: AtomicU64::new(1),
45            initialized: RwLock::new(false),
46        }
47    }
48
49    /// Get next request ID
50    fn next_id(&self) -> u64 {
51        self.request_id.fetch_add(1, Ordering::SeqCst)
52    }
53
54    /// Initialize the MCP connection
55    pub async fn initialize(&self) -> Result<InitializeResult> {
56        let params = InitializeParams {
57            protocol_version: PROTOCOL_VERSION.to_string(),
58            capabilities: ClientCapabilities::default(),
59            client_info: ClientInfo {
60                name: "a3s-code".to_string(),
61                version: env!("CARGO_PKG_VERSION").to_string(),
62            },
63        };
64
65        let request = JsonRpcRequest::new(
66            self.next_id(),
67            "initialize",
68            Some(serde_json::to_value(&params)?),
69        );
70
71        let response = self.transport.request(request).await?;
72
73        if let Some(error) = response.error {
74            return Err(anyhow!(
75                "MCP initialize error: {} ({})",
76                error.message,
77                error.code
78            ));
79        }
80
81        let result: InitializeResult = serde_json::from_value(
82            response
83                .result
84                .ok_or_else(|| anyhow!("No result in response"))?,
85        )?;
86
87        // Store capabilities
88        {
89            let mut caps = self.capabilities.write().await;
90            *caps = result.capabilities.clone();
91        }
92
93        // Send initialized notification
94        let notification = JsonRpcNotification::new("notifications/initialized", None);
95        self.transport.notify(notification).await?;
96
97        // Mark as initialized
98        {
99            let mut init = self.initialized.write().await;
100            *init = true;
101        }
102
103        tracing::info!(
104            "MCP client '{}' initialized with server '{}' v{}",
105            self.name,
106            result.server_info.name,
107            result.server_info.version
108        );
109
110        Ok(result)
111    }
112
113    /// Check if client is initialized
114    pub async fn is_initialized(&self) -> bool {
115        *self.initialized.read().await
116    }
117
118    /// Get server capabilities
119    pub async fn capabilities(&self) -> ServerCapabilities {
120        self.capabilities.read().await.clone()
121    }
122
123    /// List available tools
124    pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
125        let request = JsonRpcRequest::new(self.next_id(), "tools/list", None);
126        let response = self.transport.request(request).await?;
127
128        if let Some(error) = response.error {
129            return Err(anyhow!(
130                "MCP list_tools error: {} ({})",
131                error.message,
132                error.code
133            ));
134        }
135
136        let result: ListToolsResult =
137            serde_json::from_value(response.result.ok_or_else(|| anyhow!("No result"))?)?;
138
139        // Cache tools
140        {
141            let mut tools = self.tools.write().await;
142            *tools = result.tools.clone();
143        }
144
145        Ok(result.tools)
146    }
147
148    /// Get cached tools
149    pub async fn get_cached_tools(&self) -> Vec<McpTool> {
150        self.tools.read().await.clone()
151    }
152
153    /// Call a tool
154    pub async fn call_tool(
155        &self,
156        name: &str,
157        arguments: Option<serde_json::Value>,
158    ) -> Result<CallToolResult> {
159        let params = CallToolParams {
160            name: name.to_string(),
161            arguments,
162        };
163
164        let request = JsonRpcRequest::new(
165            self.next_id(),
166            "tools/call",
167            Some(serde_json::to_value(&params)?),
168        );
169
170        let response = self.transport.request(request).await?;
171
172        if let Some(error) = response.error {
173            return Err(anyhow!(
174                "MCP call_tool error: {} ({})",
175                error.message,
176                error.code
177            ));
178        }
179
180        let result: CallToolResult =
181            serde_json::from_value(response.result.ok_or_else(|| anyhow!("No result"))?)?;
182
183        Ok(result)
184    }
185
186    /// List available resources
187    pub async fn list_resources(&self) -> Result<Vec<McpResource>> {
188        let request = JsonRpcRequest::new(self.next_id(), "resources/list", None);
189        let response = self.transport.request(request).await?;
190
191        if let Some(error) = response.error {
192            return Err(anyhow!(
193                "MCP list_resources error: {} ({})",
194                error.message,
195                error.code
196            ));
197        }
198
199        let result: ListResourcesResult =
200            serde_json::from_value(response.result.ok_or_else(|| anyhow!("No result"))?)?;
201
202        // Cache resources
203        {
204            let mut resources = self.resources.write().await;
205            *resources = result.resources.clone();
206        }
207
208        Ok(result.resources)
209    }
210
211    /// Read a resource
212    pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult> {
213        let params = ReadResourceParams {
214            uri: uri.to_string(),
215        };
216
217        let request = JsonRpcRequest::new(
218            self.next_id(),
219            "resources/read",
220            Some(serde_json::to_value(&params)?),
221        );
222
223        let response = self.transport.request(request).await?;
224
225        if let Some(error) = response.error {
226            return Err(anyhow!(
227                "MCP read_resource error: {} ({})",
228                error.message,
229                error.code
230            ));
231        }
232
233        let result: ReadResourceResult =
234            serde_json::from_value(response.result.ok_or_else(|| anyhow!("No result"))?)?;
235
236        Ok(result)
237    }
238
239    /// Get notification receiver
240    pub fn notifications(&self) -> tokio::sync::mpsc::Receiver<McpNotification> {
241        self.transport.notifications()
242    }
243
244    /// Close the client
245    pub async fn close(&self) -> Result<()> {
246        self.transport.close().await
247    }
248
249    /// Check if connected
250    pub fn is_connected(&self) -> bool {
251        self.transport.is_connected()
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_client_info() {
261        let info = ClientInfo {
262            name: "test".to_string(),
263            version: "1.0.0".to_string(),
264        };
265        let json = serde_json::to_string(&info).unwrap();
266        assert!(json.contains("test"));
267    }
268
269    #[test]
270    fn test_initialize_params() {
271        let params = InitializeParams {
272            protocol_version: PROTOCOL_VERSION.to_string(),
273            capabilities: ClientCapabilities::default(),
274            client_info: ClientInfo {
275                name: "a3s-code".to_string(),
276                version: "0.1.0".to_string(),
277            },
278        };
279        let json = serde_json::to_string(&params).unwrap();
280        assert!(json.contains("protocolVersion"));
281        assert!(json.contains("clientInfo"));
282    }
283
284    #[test]
285    fn test_client_info_serialize() {
286        let info = ClientInfo {
287            name: "test-client".to_string(),
288            version: "2.0.0".to_string(),
289        };
290        let json = serde_json::to_string(&info).unwrap();
291        assert!(json.contains("test-client"));
292        assert!(json.contains("2.0.0"));
293    }
294
295    #[test]
296    fn test_client_info_deserialize() {
297        let json = r#"{"name":"my-client","version":"1.2.3"}"#;
298        let info: ClientInfo = serde_json::from_str(json).unwrap();
299        assert_eq!(info.name, "my-client");
300        assert_eq!(info.version, "1.2.3");
301    }
302
303    #[test]
304    fn test_initialize_params_serialize() {
305        let params = InitializeParams {
306            protocol_version: "2024-11-05".to_string(),
307            capabilities: ClientCapabilities::default(),
308            client_info: ClientInfo {
309                name: "test".to_string(),
310                version: "1.0.0".to_string(),
311            },
312        };
313        let json = serde_json::to_string(&params).unwrap();
314        assert!(json.contains("2024-11-05"));
315        assert!(json.contains("capabilities"));
316    }
317
318    #[test]
319    fn test_call_tool_params_serialize() {
320        let params = CallToolParams {
321            name: "test_tool".to_string(),
322            arguments: Some(serde_json::json!({"key": "value"})),
323        };
324        let json = serde_json::to_string(&params).unwrap();
325        assert!(json.contains("test_tool"));
326        assert!(json.contains("key"));
327    }
328
329    #[test]
330    fn test_call_tool_params_no_arguments() {
331        let params = CallToolParams {
332            name: "simple_tool".to_string(),
333            arguments: None,
334        };
335        let json = serde_json::to_string(&params).unwrap();
336        assert!(json.contains("simple_tool"));
337    }
338
339    #[test]
340    fn test_read_resource_params_serialize() {
341        let params = ReadResourceParams {
342            uri: "file:///test.txt".to_string(),
343        };
344        let json = serde_json::to_string(&params).unwrap();
345        assert!(json.contains("file:///test.txt"));
346    }
347
348    #[test]
349    fn test_read_resource_params_deserialize() {
350        let json = r#"{"uri":"http://example.com/resource"}"#;
351        let params: ReadResourceParams = serde_json::from_str(json).unwrap();
352        assert_eq!(params.uri, "http://example.com/resource");
353    }
354
355    #[test]
356    fn test_server_capabilities_default() {
357        let caps = ServerCapabilities::default();
358        let json = serde_json::to_string(&caps).unwrap();
359        assert!(!json.is_empty());
360    }
361
362    #[test]
363    fn test_client_capabilities_default() {
364        let caps = ClientCapabilities::default();
365        let json = serde_json::to_string(&caps).unwrap();
366        assert!(!json.is_empty());
367    }
368
369    #[test]
370    fn test_protocol_version_constant() {
371        assert!(!PROTOCOL_VERSION.is_empty());
372        assert!(PROTOCOL_VERSION.contains("-"));
373    }
374}