Skip to main content

cortexai_mcp/
client.rs

1//! MCP Client implementation
2
3use std::sync::Arc;
4use tracing::{debug, info};
5
6use crate::error::McpError;
7use crate::protocol::*;
8use crate::transport::McpTransport;
9
10/// MCP Client for communicating with MCP servers
11pub struct McpClient<T: McpTransport> {
12    transport: Arc<T>,
13    server_info: Option<Implementation>,
14    server_capabilities: Option<ServerCapabilities>,
15    tools_cache: Option<Vec<McpTool>>,
16}
17
18impl<T: McpTransport> McpClient<T> {
19    /// Create a new MCP client and initialize the connection
20    pub async fn new(transport: T) -> Result<Self, McpError> {
21        let mut client = Self {
22            transport: Arc::new(transport),
23            server_info: None,
24            server_capabilities: None,
25            tools_cache: None,
26        };
27
28        client.initialize().await?;
29        Ok(client)
30    }
31
32    /// Create without auto-initialization (for testing)
33    pub fn new_uninit(transport: T) -> Self {
34        Self {
35            transport: Arc::new(transport),
36            server_info: None,
37            server_capabilities: None,
38            tools_cache: None,
39        }
40    }
41
42    /// Initialize the MCP connection
43    async fn initialize(&mut self) -> Result<(), McpError> {
44        let params = InitializeParams {
45            protocol_version: PROTOCOL_VERSION.to_string(),
46            capabilities: ClientCapabilities {
47                roots: Some(RootsCapability::default()),
48                sampling: None,
49                experimental: None,
50            },
51            client_info: Implementation {
52                name: "cortex".to_string(),
53                version: env!("CARGO_PKG_VERSION").to_string(),
54            },
55        };
56
57        let request =
58            JsonRpcRequest::new(1i64, "initialize").with_params(serde_json::to_value(&params)?);
59
60        let response = self.transport.request(request).await?;
61
62        if let Some(error) = response.error {
63            return Err(McpError::JsonRpc {
64                code: error.code,
65                message: error.message,
66            });
67        }
68
69        let result: InitializeResult =
70            serde_json::from_value(response.result.ok_or_else(|| {
71                McpError::InvalidResponse("No result in initialize response".to_string())
72            })?)?;
73
74        // Check version compatibility
75        if result.protocol_version != PROTOCOL_VERSION {
76            // Log warning but continue - servers may support multiple versions
77            debug!(
78                "Protocol version mismatch: client={}, server={}",
79                PROTOCOL_VERSION, result.protocol_version
80            );
81        }
82
83        self.server_info = Some(result.server_info.clone());
84        self.server_capabilities = Some(result.capabilities.clone());
85
86        info!(
87            "MCP initialized: server={} v{}, tools={}, resources={}, prompts={}",
88            result.server_info.name,
89            result.server_info.version,
90            result.capabilities.tools.is_some(),
91            result.capabilities.resources.is_some(),
92            result.capabilities.prompts.is_some(),
93        );
94
95        // Send initialized notification
96        self.transport
97            .notify("notifications/initialized", None)
98            .await?;
99
100        Ok(())
101    }
102
103    /// Get server info
104    pub fn server_info(&self) -> Option<&Implementation> {
105        self.server_info.as_ref()
106    }
107
108    /// Get server capabilities
109    pub fn capabilities(&self) -> Option<&ServerCapabilities> {
110        self.server_capabilities.as_ref()
111    }
112
113    /// Check if server supports tools
114    pub fn supports_tools(&self) -> bool {
115        self.server_capabilities
116            .as_ref()
117            .map(|c| c.tools.is_some())
118            .unwrap_or(false)
119    }
120
121    /// Check if server supports resources
122    pub fn supports_resources(&self) -> bool {
123        self.server_capabilities
124            .as_ref()
125            .map(|c| c.resources.is_some())
126            .unwrap_or(false)
127    }
128
129    /// Check if server supports prompts
130    pub fn supports_prompts(&self) -> bool {
131        self.server_capabilities
132            .as_ref()
133            .map(|c| c.prompts.is_some())
134            .unwrap_or(false)
135    }
136
137    // =========================================================================
138    // Tool Operations
139    // =========================================================================
140
141    /// List available tools
142    pub async fn list_tools(&mut self) -> Result<Vec<McpTool>, McpError> {
143        if !self.supports_tools() {
144            return Err(McpError::CapabilityNotSupported("tools".to_string()));
145        }
146
147        let mut all_tools = Vec::new();
148        let mut cursor: Option<String> = None;
149
150        loop {
151            let params = ListToolsParams {
152                cursor: cursor.clone(),
153            };
154            let request =
155                JsonRpcRequest::new(0i64, "tools/list").with_params(serde_json::to_value(&params)?);
156
157            let response = self.transport.request(request).await?;
158
159            if let Some(error) = response.error {
160                return Err(McpError::JsonRpc {
161                    code: error.code,
162                    message: error.message,
163                });
164            }
165
166            let result: ListToolsResult =
167                serde_json::from_value(response.result.ok_or_else(|| {
168                    McpError::InvalidResponse("No result in list_tools response".to_string())
169                })?)?;
170
171            all_tools.extend(result.tools);
172
173            match result.next_cursor {
174                Some(next) => cursor = Some(next),
175                None => break,
176            }
177        }
178
179        self.tools_cache = Some(all_tools.clone());
180        Ok(all_tools)
181    }
182
183    /// Get cached tools (or fetch if not cached)
184    pub async fn get_tools(&mut self) -> Result<&[McpTool], McpError> {
185        if self.tools_cache.is_none() {
186            self.list_tools().await?;
187        }
188        Ok(self.tools_cache.as_ref().unwrap())
189    }
190
191    /// Call a tool
192    pub async fn call_tool(
193        &self,
194        name: &str,
195        arguments: serde_json::Value,
196    ) -> Result<CallToolResult, McpError> {
197        let params = CallToolParams {
198            name: name.to_string(),
199            arguments: Some(arguments),
200        };
201
202        let request =
203            JsonRpcRequest::new(0i64, "tools/call").with_params(serde_json::to_value(&params)?);
204
205        let response = self.transport.request(request).await?;
206
207        if let Some(error) = response.error {
208            return Err(McpError::JsonRpc {
209                code: error.code,
210                message: error.message,
211            });
212        }
213
214        let result: CallToolResult = serde_json::from_value(response.result.ok_or_else(|| {
215            McpError::InvalidResponse("No result in call_tool response".to_string())
216        })?)?;
217
218        Ok(result)
219    }
220
221    // =========================================================================
222    // Resource Operations
223    // =========================================================================
224
225    /// List available resources
226    pub async fn list_resources(&self) -> Result<Vec<McpResource>, McpError> {
227        if !self.supports_resources() {
228            return Err(McpError::CapabilityNotSupported("resources".to_string()));
229        }
230
231        let mut all_resources = Vec::new();
232        let mut cursor: Option<String> = None;
233
234        loop {
235            let params = serde_json::json!({ "cursor": cursor });
236            let request = JsonRpcRequest::new(0i64, "resources/list").with_params(params);
237
238            let response = self.transport.request(request).await?;
239
240            if let Some(error) = response.error {
241                return Err(McpError::JsonRpc {
242                    code: error.code,
243                    message: error.message,
244                });
245            }
246
247            let result: ListResourcesResult =
248                serde_json::from_value(response.result.ok_or_else(|| {
249                    McpError::InvalidResponse("No result in list_resources response".to_string())
250                })?)?;
251
252            all_resources.extend(result.resources);
253
254            match result.next_cursor {
255                Some(next) => cursor = Some(next),
256                None => break,
257            }
258        }
259
260        Ok(all_resources)
261    }
262
263    /// Read a resource
264    pub async fn read_resource(&self, uri: &str) -> Result<ResourceContent, McpError> {
265        let params = serde_json::json!({ "uri": uri });
266        let request = JsonRpcRequest::new(0i64, "resources/read").with_params(params);
267
268        let response = self.transport.request(request).await?;
269
270        if let Some(error) = response.error {
271            return Err(McpError::JsonRpc {
272                code: error.code,
273                message: error.message,
274            });
275        }
276
277        #[derive(serde::Deserialize)]
278        struct ReadResult {
279            contents: Vec<ResourceContent>,
280        }
281
282        let result: ReadResult = serde_json::from_value(response.result.ok_or_else(|| {
283            McpError::InvalidResponse("No result in read_resource response".to_string())
284        })?)?;
285
286        result.contents.into_iter().next().ok_or_else(|| {
287            McpError::InvalidResponse("Empty contents in read_resource response".to_string())
288        })
289    }
290
291    // =========================================================================
292    // Prompt Operations
293    // =========================================================================
294
295    /// List available prompts
296    pub async fn list_prompts(&self) -> Result<Vec<McpPrompt>, McpError> {
297        if !self.supports_prompts() {
298            return Err(McpError::CapabilityNotSupported("prompts".to_string()));
299        }
300
301        let mut all_prompts = Vec::new();
302        let mut cursor: Option<String> = None;
303
304        loop {
305            let params = serde_json::json!({ "cursor": cursor });
306            let request = JsonRpcRequest::new(0i64, "prompts/list").with_params(params);
307
308            let response = self.transport.request(request).await?;
309
310            if let Some(error) = response.error {
311                return Err(McpError::JsonRpc {
312                    code: error.code,
313                    message: error.message,
314                });
315            }
316
317            let result: ListPromptsResult =
318                serde_json::from_value(response.result.ok_or_else(|| {
319                    McpError::InvalidResponse("No result in list_prompts response".to_string())
320                })?)?;
321
322            all_prompts.extend(result.prompts);
323
324            match result.next_cursor {
325                Some(next) => cursor = Some(next),
326                None => break,
327            }
328        }
329
330        Ok(all_prompts)
331    }
332
333    // =========================================================================
334    // Lifecycle
335    // =========================================================================
336
337    /// Close the connection
338    pub async fn close(self) -> Result<(), McpError> {
339        self.transport.close().await
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_client_capabilities() {
349        let caps = ClientCapabilities::default();
350        assert!(caps.roots.is_none());
351        assert!(caps.sampling.is_none());
352    }
353}