1use std::sync::Arc;
4use tracing::{debug, info};
5
6use crate::error::McpError;
7use crate::protocol::*;
8use crate::transport::McpTransport;
9
10pub struct McpClient<T: McpTransport> {
12 transport: Arc<T>,
13 server_info: Option<Implementation>,
14 server_capabilities: Option<ServerCapabilities>,
15 tools_cache: Option<Vec<McpTool>>,
16}
17
18impl<T: McpTransport> McpClient<T> {
19 pub async fn new(transport: T) -> Result<Self, McpError> {
21 let mut client = Self {
22 transport: Arc::new(transport),
23 server_info: None,
24 server_capabilities: None,
25 tools_cache: None,
26 };
27
28 client.initialize().await?;
29 Ok(client)
30 }
31
32 pub fn new_uninit(transport: T) -> Self {
34 Self {
35 transport: Arc::new(transport),
36 server_info: None,
37 server_capabilities: None,
38 tools_cache: None,
39 }
40 }
41
42 async fn initialize(&mut self) -> Result<(), McpError> {
44 let params = InitializeParams {
45 protocol_version: PROTOCOL_VERSION.to_string(),
46 capabilities: ClientCapabilities {
47 roots: Some(RootsCapability::default()),
48 sampling: None,
49 experimental: None,
50 },
51 client_info: Implementation {
52 name: "cortex".to_string(),
53 version: env!("CARGO_PKG_VERSION").to_string(),
54 },
55 };
56
57 let request =
58 JsonRpcRequest::new(1i64, "initialize").with_params(serde_json::to_value(¶ms)?);
59
60 let response = self.transport.request(request).await?;
61
62 if let Some(error) = response.error {
63 return Err(McpError::JsonRpc {
64 code: error.code,
65 message: error.message,
66 });
67 }
68
69 let result: InitializeResult =
70 serde_json::from_value(response.result.ok_or_else(|| {
71 McpError::InvalidResponse("No result in initialize response".to_string())
72 })?)?;
73
74 if result.protocol_version != PROTOCOL_VERSION {
76 debug!(
78 "Protocol version mismatch: client={}, server={}",
79 PROTOCOL_VERSION, result.protocol_version
80 );
81 }
82
83 self.server_info = Some(result.server_info.clone());
84 self.server_capabilities = Some(result.capabilities.clone());
85
86 info!(
87 "MCP initialized: server={} v{}, tools={}, resources={}, prompts={}",
88 result.server_info.name,
89 result.server_info.version,
90 result.capabilities.tools.is_some(),
91 result.capabilities.resources.is_some(),
92 result.capabilities.prompts.is_some(),
93 );
94
95 self.transport
97 .notify("notifications/initialized", None)
98 .await?;
99
100 Ok(())
101 }
102
103 pub fn server_info(&self) -> Option<&Implementation> {
105 self.server_info.as_ref()
106 }
107
108 pub fn capabilities(&self) -> Option<&ServerCapabilities> {
110 self.server_capabilities.as_ref()
111 }
112
113 pub fn supports_tools(&self) -> bool {
115 self.server_capabilities
116 .as_ref()
117 .map(|c| c.tools.is_some())
118 .unwrap_or(false)
119 }
120
121 pub fn supports_resources(&self) -> bool {
123 self.server_capabilities
124 .as_ref()
125 .map(|c| c.resources.is_some())
126 .unwrap_or(false)
127 }
128
129 pub fn supports_prompts(&self) -> bool {
131 self.server_capabilities
132 .as_ref()
133 .map(|c| c.prompts.is_some())
134 .unwrap_or(false)
135 }
136
137 pub async fn list_tools(&mut self) -> Result<Vec<McpTool>, McpError> {
143 if !self.supports_tools() {
144 return Err(McpError::CapabilityNotSupported("tools".to_string()));
145 }
146
147 let mut all_tools = Vec::new();
148 let mut cursor: Option<String> = None;
149
150 loop {
151 let params = ListToolsParams {
152 cursor: cursor.clone(),
153 };
154 let request =
155 JsonRpcRequest::new(0i64, "tools/list").with_params(serde_json::to_value(¶ms)?);
156
157 let response = self.transport.request(request).await?;
158
159 if let Some(error) = response.error {
160 return Err(McpError::JsonRpc {
161 code: error.code,
162 message: error.message,
163 });
164 }
165
166 let result: ListToolsResult =
167 serde_json::from_value(response.result.ok_or_else(|| {
168 McpError::InvalidResponse("No result in list_tools response".to_string())
169 })?)?;
170
171 all_tools.extend(result.tools);
172
173 match result.next_cursor {
174 Some(next) => cursor = Some(next),
175 None => break,
176 }
177 }
178
179 self.tools_cache = Some(all_tools.clone());
180 Ok(all_tools)
181 }
182
183 pub async fn get_tools(&mut self) -> Result<&[McpTool], McpError> {
185 if self.tools_cache.is_none() {
186 self.list_tools().await?;
187 }
188 Ok(self.tools_cache.as_ref().unwrap())
189 }
190
191 pub async fn call_tool(
193 &self,
194 name: &str,
195 arguments: serde_json::Value,
196 ) -> Result<CallToolResult, McpError> {
197 let params = CallToolParams {
198 name: name.to_string(),
199 arguments: Some(arguments),
200 };
201
202 let request =
203 JsonRpcRequest::new(0i64, "tools/call").with_params(serde_json::to_value(¶ms)?);
204
205 let response = self.transport.request(request).await?;
206
207 if let Some(error) = response.error {
208 return Err(McpError::JsonRpc {
209 code: error.code,
210 message: error.message,
211 });
212 }
213
214 let result: CallToolResult = serde_json::from_value(response.result.ok_or_else(|| {
215 McpError::InvalidResponse("No result in call_tool response".to_string())
216 })?)?;
217
218 Ok(result)
219 }
220
221 pub async fn list_resources(&self) -> Result<Vec<McpResource>, McpError> {
227 if !self.supports_resources() {
228 return Err(McpError::CapabilityNotSupported("resources".to_string()));
229 }
230
231 let mut all_resources = Vec::new();
232 let mut cursor: Option<String> = None;
233
234 loop {
235 let params = serde_json::json!({ "cursor": cursor });
236 let request = JsonRpcRequest::new(0i64, "resources/list").with_params(params);
237
238 let response = self.transport.request(request).await?;
239
240 if let Some(error) = response.error {
241 return Err(McpError::JsonRpc {
242 code: error.code,
243 message: error.message,
244 });
245 }
246
247 let result: ListResourcesResult =
248 serde_json::from_value(response.result.ok_or_else(|| {
249 McpError::InvalidResponse("No result in list_resources response".to_string())
250 })?)?;
251
252 all_resources.extend(result.resources);
253
254 match result.next_cursor {
255 Some(next) => cursor = Some(next),
256 None => break,
257 }
258 }
259
260 Ok(all_resources)
261 }
262
263 pub async fn read_resource(&self, uri: &str) -> Result<ResourceContent, McpError> {
265 let params = serde_json::json!({ "uri": uri });
266 let request = JsonRpcRequest::new(0i64, "resources/read").with_params(params);
267
268 let response = self.transport.request(request).await?;
269
270 if let Some(error) = response.error {
271 return Err(McpError::JsonRpc {
272 code: error.code,
273 message: error.message,
274 });
275 }
276
277 #[derive(serde::Deserialize)]
278 struct ReadResult {
279 contents: Vec<ResourceContent>,
280 }
281
282 let result: ReadResult = serde_json::from_value(response.result.ok_or_else(|| {
283 McpError::InvalidResponse("No result in read_resource response".to_string())
284 })?)?;
285
286 result.contents.into_iter().next().ok_or_else(|| {
287 McpError::InvalidResponse("Empty contents in read_resource response".to_string())
288 })
289 }
290
291 pub async fn list_prompts(&self) -> Result<Vec<McpPrompt>, McpError> {
297 if !self.supports_prompts() {
298 return Err(McpError::CapabilityNotSupported("prompts".to_string()));
299 }
300
301 let mut all_prompts = Vec::new();
302 let mut cursor: Option<String> = None;
303
304 loop {
305 let params = serde_json::json!({ "cursor": cursor });
306 let request = JsonRpcRequest::new(0i64, "prompts/list").with_params(params);
307
308 let response = self.transport.request(request).await?;
309
310 if let Some(error) = response.error {
311 return Err(McpError::JsonRpc {
312 code: error.code,
313 message: error.message,
314 });
315 }
316
317 let result: ListPromptsResult =
318 serde_json::from_value(response.result.ok_or_else(|| {
319 McpError::InvalidResponse("No result in list_prompts response".to_string())
320 })?)?;
321
322 all_prompts.extend(result.prompts);
323
324 match result.next_cursor {
325 Some(next) => cursor = Some(next),
326 None => break,
327 }
328 }
329
330 Ok(all_prompts)
331 }
332
333 pub async fn close(self) -> Result<(), McpError> {
339 self.transport.close().await
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_client_capabilities() {
349 let caps = ClientCapabilities::default();
350 assert!(caps.roots.is_none());
351 assert!(caps.sampling.is_none());
352 }
353}