1use crate::error::{ClientError, ClientResult};
6use crate::transport::{ClientTransport, JsonRpcMessage, next_request_id};
7use pulseengine_mcp_protocol::{
8 CallToolRequestParam, CallToolResult, CompleteRequestParam, CompleteResult,
9 GetPromptRequestParam, GetPromptResult, Implementation, InitializeRequestParam,
10 InitializeResult, ListPromptsResult, ListResourceTemplatesResult, ListResourcesResult,
11 ListToolsResult, NumberOrString, PaginatedRequestParam, ReadResourceRequestParam,
12 ReadResourceResult, Request, Response,
13};
14use serde_json::json;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Duration;
18use tokio::sync::{Mutex, oneshot};
19use tracing::{debug, info, warn};
20
21const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
23
24pub struct McpClient<T: ClientTransport> {
29 transport: Arc<T>,
30 pending: Arc<Mutex<HashMap<String, oneshot::Sender<Response>>>>,
32 server_info: Option<InitializeResult>,
34 timeout: Duration,
36 client_info: Implementation,
38}
39
40impl<T: ClientTransport + 'static> McpClient<T> {
41 pub fn new(transport: T) -> Self {
43 Self {
44 transport: Arc::new(transport),
45 pending: Arc::new(Mutex::new(HashMap::new())),
46 server_info: None,
47 timeout: DEFAULT_TIMEOUT,
48 client_info: Implementation::new("pulseengine-mcp-client", env!("CARGO_PKG_VERSION")),
49 }
50 }
51
52 pub fn with_timeout(mut self, timeout: Duration) -> Self {
54 self.timeout = timeout;
55 self
56 }
57
58 pub fn with_client_info(mut self, name: &str, version: &str) -> Self {
60 self.client_info = Implementation::new(name, version);
61 self
62 }
63
64 pub fn server_info(&self) -> Option<&InitializeResult> {
66 self.server_info.as_ref()
67 }
68
69 pub fn is_initialized(&self) -> bool {
71 self.server_info.is_some()
72 }
73
74 pub async fn initialize(
78 &mut self,
79 client_name: &str,
80 client_version: &str,
81 ) -> ClientResult<InitializeResult> {
82 self.client_info = Implementation::new(client_name, client_version);
83
84 let params = InitializeRequestParam {
85 protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(),
86 capabilities: json!({}), client_info: self.client_info.clone(),
88 };
89
90 let result: InitializeResult = self.request("initialize", params).await?;
91
92 info!(
93 "Initialized with server: {} v{}",
94 result.server_info.name, result.server_info.version
95 );
96
97 self.server_info = Some(result.clone());
98
99 self.notify("notifications/initialized", json!({})).await?;
101
102 Ok(result)
103 }
104
105 pub async fn list_tools(&self) -> ClientResult<ListToolsResult> {
111 self.ensure_initialized()?;
112 self.request("tools/list", PaginatedRequestParam { cursor: None })
113 .await
114 }
115
116 pub async fn list_all_tools(&self) -> ClientResult<Vec<pulseengine_mcp_protocol::Tool>> {
118 self.ensure_initialized()?;
119 let mut all_tools = Vec::new();
120 let mut cursor = None;
121
122 loop {
123 let result: ListToolsResult = self
124 .request("tools/list", PaginatedRequestParam { cursor })
125 .await?;
126
127 all_tools.extend(result.tools);
128
129 match result.next_cursor {
130 Some(next) => cursor = Some(next),
131 None => break,
132 }
133 }
134
135 Ok(all_tools)
136 }
137
138 pub async fn call_tool(
140 &self,
141 name: &str,
142 arguments: serde_json::Value,
143 ) -> ClientResult<CallToolResult> {
144 self.ensure_initialized()?;
145 self.request(
146 "tools/call",
147 CallToolRequestParam {
148 name: name.to_string(),
149 arguments: Some(arguments),
150 },
151 )
152 .await
153 }
154
155 pub async fn list_resources(&self) -> ClientResult<ListResourcesResult> {
161 self.ensure_initialized()?;
162 self.request("resources/list", PaginatedRequestParam { cursor: None })
163 .await
164 }
165
166 pub async fn list_all_resources(
168 &self,
169 ) -> ClientResult<Vec<pulseengine_mcp_protocol::Resource>> {
170 self.ensure_initialized()?;
171 let mut all_resources = Vec::new();
172 let mut cursor = None;
173
174 loop {
175 let result: ListResourcesResult = self
176 .request("resources/list", PaginatedRequestParam { cursor })
177 .await?;
178
179 all_resources.extend(result.resources);
180
181 match result.next_cursor {
182 Some(next) => cursor = Some(next),
183 None => break,
184 }
185 }
186
187 Ok(all_resources)
188 }
189
190 pub async fn read_resource(&self, uri: &str) -> ClientResult<ReadResourceResult> {
192 self.ensure_initialized()?;
193 self.request(
194 "resources/read",
195 ReadResourceRequestParam {
196 uri: uri.to_string(),
197 },
198 )
199 .await
200 }
201
202 pub async fn list_resource_templates(&self) -> ClientResult<ListResourceTemplatesResult> {
204 self.ensure_initialized()?;
205 self.request(
206 "resources/templates/list",
207 PaginatedRequestParam { cursor: None },
208 )
209 .await
210 }
211
212 pub async fn list_prompts(&self) -> ClientResult<ListPromptsResult> {
218 self.ensure_initialized()?;
219 self.request("prompts/list", PaginatedRequestParam { cursor: None })
220 .await
221 }
222
223 pub async fn list_all_prompts(&self) -> ClientResult<Vec<pulseengine_mcp_protocol::Prompt>> {
225 self.ensure_initialized()?;
226 let mut all_prompts = Vec::new();
227 let mut cursor = None;
228
229 loop {
230 let result: ListPromptsResult = self
231 .request("prompts/list", PaginatedRequestParam { cursor })
232 .await?;
233
234 all_prompts.extend(result.prompts);
235
236 match result.next_cursor {
237 Some(next) => cursor = Some(next),
238 None => break,
239 }
240 }
241
242 Ok(all_prompts)
243 }
244
245 pub async fn get_prompt(
247 &self,
248 name: &str,
249 arguments: Option<HashMap<String, String>>,
250 ) -> ClientResult<GetPromptResult> {
251 self.ensure_initialized()?;
252 self.request(
253 "prompts/get",
254 GetPromptRequestParam {
255 name: name.to_string(),
256 arguments,
257 },
258 )
259 .await
260 }
261
262 pub async fn complete(&self, params: CompleteRequestParam) -> ClientResult<CompleteResult> {
268 self.ensure_initialized()?;
269 self.request("completion/complete", params).await
270 }
271
272 pub async fn ping(&self) -> ClientResult<()> {
278 self.ensure_initialized()?;
279 let _: serde_json::Value = self.request("ping", json!({})).await?;
280 Ok(())
281 }
282
283 pub async fn close(&self) -> ClientResult<()> {
285 self.transport.close().await
286 }
287
288 pub async fn notify_progress(
294 &self,
295 progress_token: &str,
296 progress: f64,
297 total: Option<f64>,
298 ) -> ClientResult<()> {
299 self.notify(
300 "notifications/progress",
301 json!({
302 "progressToken": progress_token,
303 "progress": progress,
304 "total": total,
305 }),
306 )
307 .await
308 }
309
310 pub async fn notify_cancelled(
312 &self,
313 request_id: &str,
314 reason: Option<&str>,
315 ) -> ClientResult<()> {
316 self.notify(
317 "notifications/cancelled",
318 json!({
319 "requestId": request_id,
320 "reason": reason,
321 }),
322 )
323 .await
324 }
325
326 pub async fn notify_roots_list_changed(&self) -> ClientResult<()> {
328 self.notify("notifications/roots/list_changed", json!({}))
329 .await
330 }
331
332 fn ensure_initialized(&self) -> ClientResult<()> {
338 if self.server_info.is_none() {
339 return Err(ClientError::NotInitialized);
340 }
341 Ok(())
342 }
343
344 async fn request<P, R>(&self, method: &str, params: P) -> ClientResult<R>
346 where
347 P: serde::Serialize,
348 R: serde::de::DeserializeOwned,
349 {
350 let id = next_request_id();
351 let id_str = match &id {
352 NumberOrString::Number(n) => n.to_string(),
353 NumberOrString::String(s) => s.to_string(),
354 };
355
356 let request = Request {
357 jsonrpc: "2.0".to_string(),
358 method: method.to_string(),
359 params: serde_json::to_value(params)?,
360 id: Some(id),
361 };
362
363 let (tx, rx) = oneshot::channel();
365
366 {
368 let mut pending = self.pending.lock().await;
369 pending.insert(id_str.clone(), tx);
370 }
371
372 self.transport.send(&request).await?;
374
375 debug!("Sent request: method={}, id={}", method, id_str);
376
377 let response = tokio::select! {
379 result = self.wait_for_response(rx) => result?,
380 _ = tokio::time::sleep(self.timeout) => {
381 let mut pending = self.pending.lock().await;
383 pending.remove(&id_str);
384 return Err(ClientError::Timeout(self.timeout));
385 }
386 };
387
388 if let Some(error) = response.error {
390 return Err(ClientError::from_protocol_error(error));
391 }
392
393 let result = response
395 .result
396 .ok_or_else(|| ClientError::protocol("Response has no result or error"))?;
397
398 serde_json::from_value(result).map_err(ClientError::from)
399 }
400
401 async fn wait_for_response(
403 &self,
404 mut rx: oneshot::Receiver<Response>,
405 ) -> ClientResult<Response> {
406 loop {
409 tokio::select! {
410 biased;
411
412 result = &mut rx => {
414 return result.map_err(|_| ClientError::ChannelClosed("Response channel closed".into()));
415 }
416 msg = self.transport.recv() => {
418 match msg? {
419 JsonRpcMessage::Response(response) => {
420 let id_str = response.id.as_ref().map(|id| match id {
422 NumberOrString::Number(n) => n.to_string(),
423 NumberOrString::String(s) => s.to_string(),
424 });
425
426 if let Some(id) = id_str {
427 let mut pending = self.pending.lock().await;
428 if let Some(tx) = pending.remove(&id) {
429 let _ = tx.send(response);
430 } else {
431 warn!("Received response for unknown request: {}", id);
432 }
433 }
434 }
435 JsonRpcMessage::Request(request) => {
436 warn!("Received server request (not yet handled): {}", request.method);
439 }
440 JsonRpcMessage::Notification { method, params: _ } => {
441 debug!("Received notification: {}", method);
443 }
444 }
445 }
446 }
447 }
448 }
449
450 async fn notify<P>(&self, method: &str, params: P) -> ClientResult<()>
452 where
453 P: serde::Serialize,
454 {
455 let request = Request {
456 jsonrpc: "2.0".to_string(),
457 method: method.to_string(),
458 params: serde_json::to_value(params)?,
459 id: None, };
461
462 self.transport.send(&request).await?;
463 debug!("Sent notification: method={}", method);
464 Ok(())
465 }
466}