pulseengine_mcp_client/
client.rs

1//! MCP Client implementation
2//!
3//! The main client struct for interacting with MCP servers.
4
5use crate::error::{ClientError, ClientResult};
6use crate::transport::{ClientTransport, JsonRpcMessage, next_request_id};
7use pulseengine_mcp_protocol::{
8    CallToolRequestParam, CallToolResult, CompleteRequestParam, CompleteResult,
9    GetPromptRequestParam, GetPromptResult, Implementation, InitializeRequestParam,
10    InitializeResult, ListPromptsResult, ListResourceTemplatesResult, ListResourcesResult,
11    ListToolsResult, NumberOrString, PaginatedRequestParam, ReadResourceRequestParam,
12    ReadResourceResult, Request, Response,
13};
14use serde_json::json;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Duration;
18use tokio::sync::{Mutex, oneshot};
19use tracing::{debug, info, warn};
20
21/// Default timeout for requests
22const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
23
24/// MCP Client for connecting to MCP servers
25///
26/// Provides a high-level API for interacting with MCP servers,
27/// handling request/response correlation and protocol details.
28pub struct McpClient<T: ClientTransport> {
29    transport: Arc<T>,
30    /// Pending requests waiting for responses
31    pending: Arc<Mutex<HashMap<String, oneshot::Sender<Response>>>>,
32    /// Server info after initialization
33    server_info: Option<InitializeResult>,
34    /// Default request timeout
35    timeout: Duration,
36    /// Client info sent during initialization
37    client_info: Implementation,
38}
39
40impl<T: ClientTransport + 'static> McpClient<T> {
41    /// Create a new MCP client with the given transport
42    pub fn new(transport: T) -> Self {
43        Self {
44            transport: Arc::new(transport),
45            pending: Arc::new(Mutex::new(HashMap::new())),
46            server_info: None,
47            timeout: DEFAULT_TIMEOUT,
48            client_info: Implementation::new("pulseengine-mcp-client", env!("CARGO_PKG_VERSION")),
49        }
50    }
51
52    /// Set the default request timeout
53    pub fn with_timeout(mut self, timeout: Duration) -> Self {
54        self.timeout = timeout;
55        self
56    }
57
58    /// Set the client info for initialization
59    pub fn with_client_info(mut self, name: &str, version: &str) -> Self {
60        self.client_info = Implementation::new(name, version);
61        self
62    }
63
64    /// Get the server info (available after initialization)
65    pub fn server_info(&self) -> Option<&InitializeResult> {
66        self.server_info.as_ref()
67    }
68
69    /// Check if the client has been initialized
70    pub fn is_initialized(&self) -> bool {
71        self.server_info.is_some()
72    }
73
74    /// Initialize the connection with the server
75    ///
76    /// This must be called before any other methods.
77    pub async fn initialize(
78        &mut self,
79        client_name: &str,
80        client_version: &str,
81    ) -> ClientResult<InitializeResult> {
82        self.client_info = Implementation::new(client_name, client_version);
83
84        let params = InitializeRequestParam {
85            protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(),
86            capabilities: json!({}), // Empty capabilities - server will respond with its capabilities
87            client_info: self.client_info.clone(),
88        };
89
90        let result: InitializeResult = self.request("initialize", params).await?;
91
92        info!(
93            "Initialized with server: {} v{}",
94            result.server_info.name, result.server_info.version
95        );
96
97        self.server_info = Some(result.clone());
98
99        // Send initialized notification
100        self.notify("notifications/initialized", json!({})).await?;
101
102        Ok(result)
103    }
104
105    // =========================================================================
106    // Tools API
107    // =========================================================================
108
109    /// List available tools from the server
110    pub async fn list_tools(&self) -> ClientResult<ListToolsResult> {
111        self.ensure_initialized()?;
112        self.request("tools/list", PaginatedRequestParam { cursor: None })
113            .await
114    }
115
116    /// List all tools, automatically handling pagination
117    pub async fn list_all_tools(&self) -> ClientResult<Vec<pulseengine_mcp_protocol::Tool>> {
118        self.ensure_initialized()?;
119        let mut all_tools = Vec::new();
120        let mut cursor = None;
121
122        loop {
123            let result: ListToolsResult = self
124                .request("tools/list", PaginatedRequestParam { cursor })
125                .await?;
126
127            all_tools.extend(result.tools);
128
129            match result.next_cursor {
130                Some(next) => cursor = Some(next),
131                None => break,
132            }
133        }
134
135        Ok(all_tools)
136    }
137
138    /// Call a tool on the server
139    pub async fn call_tool(
140        &self,
141        name: &str,
142        arguments: serde_json::Value,
143    ) -> ClientResult<CallToolResult> {
144        self.ensure_initialized()?;
145        self.request(
146            "tools/call",
147            CallToolRequestParam {
148                name: name.to_string(),
149                arguments: Some(arguments),
150            },
151        )
152        .await
153    }
154
155    // =========================================================================
156    // Resources API
157    // =========================================================================
158
159    /// List available resources from the server
160    pub async fn list_resources(&self) -> ClientResult<ListResourcesResult> {
161        self.ensure_initialized()?;
162        self.request("resources/list", PaginatedRequestParam { cursor: None })
163            .await
164    }
165
166    /// List all resources, automatically handling pagination
167    pub async fn list_all_resources(
168        &self,
169    ) -> ClientResult<Vec<pulseengine_mcp_protocol::Resource>> {
170        self.ensure_initialized()?;
171        let mut all_resources = Vec::new();
172        let mut cursor = None;
173
174        loop {
175            let result: ListResourcesResult = self
176                .request("resources/list", PaginatedRequestParam { cursor })
177                .await?;
178
179            all_resources.extend(result.resources);
180
181            match result.next_cursor {
182                Some(next) => cursor = Some(next),
183                None => break,
184            }
185        }
186
187        Ok(all_resources)
188    }
189
190    /// Read a resource from the server
191    pub async fn read_resource(&self, uri: &str) -> ClientResult<ReadResourceResult> {
192        self.ensure_initialized()?;
193        self.request(
194            "resources/read",
195            ReadResourceRequestParam {
196                uri: uri.to_string(),
197            },
198        )
199        .await
200    }
201
202    /// List resource templates from the server
203    pub async fn list_resource_templates(&self) -> ClientResult<ListResourceTemplatesResult> {
204        self.ensure_initialized()?;
205        self.request(
206            "resources/templates/list",
207            PaginatedRequestParam { cursor: None },
208        )
209        .await
210    }
211
212    // =========================================================================
213    // Prompts API
214    // =========================================================================
215
216    /// List available prompts from the server
217    pub async fn list_prompts(&self) -> ClientResult<ListPromptsResult> {
218        self.ensure_initialized()?;
219        self.request("prompts/list", PaginatedRequestParam { cursor: None })
220            .await
221    }
222
223    /// List all prompts, automatically handling pagination
224    pub async fn list_all_prompts(&self) -> ClientResult<Vec<pulseengine_mcp_protocol::Prompt>> {
225        self.ensure_initialized()?;
226        let mut all_prompts = Vec::new();
227        let mut cursor = None;
228
229        loop {
230            let result: ListPromptsResult = self
231                .request("prompts/list", PaginatedRequestParam { cursor })
232                .await?;
233
234            all_prompts.extend(result.prompts);
235
236            match result.next_cursor {
237                Some(next) => cursor = Some(next),
238                None => break,
239            }
240        }
241
242        Ok(all_prompts)
243    }
244
245    /// Get a prompt by name
246    pub async fn get_prompt(
247        &self,
248        name: &str,
249        arguments: Option<HashMap<String, String>>,
250    ) -> ClientResult<GetPromptResult> {
251        self.ensure_initialized()?;
252        self.request(
253            "prompts/get",
254            GetPromptRequestParam {
255                name: name.to_string(),
256                arguments,
257            },
258        )
259        .await
260    }
261
262    // =========================================================================
263    // Completion API
264    // =========================================================================
265
266    /// Request completion suggestions
267    pub async fn complete(&self, params: CompleteRequestParam) -> ClientResult<CompleteResult> {
268        self.ensure_initialized()?;
269        self.request("completion/complete", params).await
270    }
271
272    // =========================================================================
273    // Utility Methods
274    // =========================================================================
275
276    /// Send a ping to the server
277    pub async fn ping(&self) -> ClientResult<()> {
278        self.ensure_initialized()?;
279        let _: serde_json::Value = self.request("ping", json!({})).await?;
280        Ok(())
281    }
282
283    /// Close the client connection
284    pub async fn close(&self) -> ClientResult<()> {
285        self.transport.close().await
286    }
287
288    // =========================================================================
289    // Notification Methods
290    // =========================================================================
291
292    /// Send a progress notification
293    pub async fn notify_progress(
294        &self,
295        progress_token: &str,
296        progress: f64,
297        total: Option<f64>,
298    ) -> ClientResult<()> {
299        self.notify(
300            "notifications/progress",
301            json!({
302                "progressToken": progress_token,
303                "progress": progress,
304                "total": total,
305            }),
306        )
307        .await
308    }
309
310    /// Send a cancellation notification
311    pub async fn notify_cancelled(
312        &self,
313        request_id: &str,
314        reason: Option<&str>,
315    ) -> ClientResult<()> {
316        self.notify(
317            "notifications/cancelled",
318            json!({
319                "requestId": request_id,
320                "reason": reason,
321            }),
322        )
323        .await
324    }
325
326    /// Send a roots list changed notification
327    pub async fn notify_roots_list_changed(&self) -> ClientResult<()> {
328        self.notify("notifications/roots/list_changed", json!({}))
329            .await
330    }
331
332    // =========================================================================
333    // Internal Methods
334    // =========================================================================
335
336    /// Ensure the client has been initialized
337    fn ensure_initialized(&self) -> ClientResult<()> {
338        if self.server_info.is_none() {
339            return Err(ClientError::NotInitialized);
340        }
341        Ok(())
342    }
343
344    /// Send a request and wait for the response
345    async fn request<P, R>(&self, method: &str, params: P) -> ClientResult<R>
346    where
347        P: serde::Serialize,
348        R: serde::de::DeserializeOwned,
349    {
350        let id = next_request_id();
351        let id_str = match &id {
352            NumberOrString::Number(n) => n.to_string(),
353            NumberOrString::String(s) => s.to_string(),
354        };
355
356        let request = Request {
357            jsonrpc: "2.0".to_string(),
358            method: method.to_string(),
359            params: serde_json::to_value(params)?,
360            id: Some(id),
361        };
362
363        // Create channel for response
364        let (tx, rx) = oneshot::channel();
365
366        // Register pending request
367        {
368            let mut pending = self.pending.lock().await;
369            pending.insert(id_str.clone(), tx);
370        }
371
372        // Send request
373        self.transport.send(&request).await?;
374
375        debug!("Sent request: method={}, id={}", method, id_str);
376
377        // Wait for response with timeout
378        let response = tokio::select! {
379            result = self.wait_for_response(rx) => result?,
380            _ = tokio::time::sleep(self.timeout) => {
381                // Remove from pending on timeout
382                let mut pending = self.pending.lock().await;
383                pending.remove(&id_str);
384                return Err(ClientError::Timeout(self.timeout));
385            }
386        };
387
388        // Check for error response
389        if let Some(error) = response.error {
390            return Err(ClientError::from_protocol_error(error));
391        }
392
393        // Parse result
394        let result = response
395            .result
396            .ok_or_else(|| ClientError::protocol("Response has no result or error"))?;
397
398        serde_json::from_value(result).map_err(ClientError::from)
399    }
400
401    /// Wait for a response and handle incoming messages
402    async fn wait_for_response(
403        &self,
404        mut rx: oneshot::Receiver<Response>,
405    ) -> ClientResult<Response> {
406        // In a simple implementation, we just read messages until we get our response
407        // A more sophisticated implementation would use a background task
408        loop {
409            tokio::select! {
410                biased;
411
412                // Check if response arrived via channel (priority)
413                result = &mut rx => {
414                    return result.map_err(|_| ClientError::ChannelClosed("Response channel closed".into()));
415                }
416                // Read next message from transport
417                msg = self.transport.recv() => {
418                    match msg? {
419                        JsonRpcMessage::Response(response) => {
420                            // Route response to waiting request
421                            let id_str = response.id.as_ref().map(|id| match id {
422                                NumberOrString::Number(n) => n.to_string(),
423                                NumberOrString::String(s) => s.to_string(),
424                            });
425
426                            if let Some(id) = id_str {
427                                let mut pending = self.pending.lock().await;
428                                if let Some(tx) = pending.remove(&id) {
429                                    let _ = tx.send(response);
430                                } else {
431                                    warn!("Received response for unknown request: {}", id);
432                                }
433                            }
434                        }
435                        JsonRpcMessage::Request(request) => {
436                            // Handle server-initiated request (sampling, etc.)
437                            // For now, log and continue - could add a handler callback
438                            warn!("Received server request (not yet handled): {}", request.method);
439                        }
440                        JsonRpcMessage::Notification { method, params: _ } => {
441                            // Handle notification from server
442                            debug!("Received notification: {}", method);
443                        }
444                    }
445                }
446            }
447        }
448    }
449
450    /// Send a notification (no response expected)
451    async fn notify<P>(&self, method: &str, params: P) -> ClientResult<()>
452    where
453        P: serde::Serialize,
454    {
455        let request = Request {
456            jsonrpc: "2.0".to_string(),
457            method: method.to_string(),
458            params: serde_json::to_value(params)?,
459            id: None, // No ID for notifications
460        };
461
462        self.transport.send(&request).await?;
463        debug!("Sent notification: method={}", method);
464        Ok(())
465    }
466}