modelcontextprotocol_client/
client.rs

1// mcp-client/src/client.rs
2use anyhow::{anyhow, Result};
3use serde_json::json;
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{mpsc, Mutex, RwLock};
7
8use mcp_protocol::{
9    constants::{error_codes, methods, PROTOCOL_VERSION},
10    messages::{ClientCapabilities, InitializeParams, InitializeResult, JsonRpcMessage},
11    types::{
12        completion::{CompleteRequest, CompleteResponse},
13        sampling::{CreateMessageParams, CreateMessageResult},
14        tool::{ToolCallParams, ToolCallResult, ToolsListResult},
15        ClientInfo,
16    },
17};
18
19use crate::transport::Transport;
20
21/// MCP client state
22#[derive(Debug, Clone, PartialEq)]
23enum ClientState {
24    Created,
25    Initializing,
26    Ready,
27    ShuttingDown,
28}
29
30/// Represents a pending request waiting for a response
31struct PendingRequest {
32    response_tx: mpsc::Sender<Result<JsonRpcMessage>>,
33}
34
35/// MCP client builder
36pub struct ClientBuilder {
37    name: String,
38    version: String,
39    transport: Option<Box<dyn Transport>>,
40    sampling_enabled: bool,
41}
42
43impl ClientBuilder {
44    /// Create a new client builder
45    pub fn new(name: &str, version: &str) -> Self {
46        Self {
47            name: name.to_string(),
48            version: version.to_string(),
49            transport: None,
50            sampling_enabled: false,
51        }
52    }
53
54    /// Enable sampling capability
55    pub fn with_sampling(mut self) -> Self {
56        self.sampling_enabled = true;
57        self
58    }
59
60    /// Set the transport to use
61    pub fn with_transport<T: Transport>(mut self, transport: T) -> Self {
62        self.transport = Some(Box::new(transport));
63        self
64    }
65
66    /// Build the client
67    pub fn build(self) -> Result<Client> {
68        let transport = self
69            .transport
70            .ok_or_else(|| anyhow!("Transport is required"))?;
71
72        // Create capabilities
73        let capabilities = if self.sampling_enabled {
74            let mut caps = ClientCapabilities::default();
75            caps.sampling = Some(HashMap::new());
76            caps
77        } else {
78            ClientCapabilities::default()
79        };
80
81        Ok(Client {
82            name: self.name,
83            version: self.version,
84            transport,
85            sampling_enabled: self.sampling_enabled,
86            capabilities,
87            state: Arc::new(RwLock::new(ClientState::Created)),
88            next_id: Arc::new(Mutex::new(1)),
89            pending_requests: Arc::new(RwLock::new(HashMap::new())),
90            initialized_result: Arc::new(RwLock::new(None)),
91            sampling_callback: Arc::new(RwLock::new(None)),
92        })
93    }
94}
95
96/// Type for sampling callback function
97pub type SamplingCallback =
98    Box<dyn Fn(CreateMessageParams) -> Result<CreateMessageResult> + Send + Sync>;
99
100/// MCP client
101pub struct Client {
102    name: String,
103    version: String,
104    transport: Box<dyn Transport>,
105    sampling_enabled: bool,
106    capabilities: ClientCapabilities,
107    state: Arc<RwLock<ClientState>>,
108    next_id: Arc<Mutex<i64>>,
109    pending_requests: Arc<RwLock<HashMap<String, PendingRequest>>>,
110    initialized_result: Arc<RwLock<Option<InitializeResult>>>,
111    sampling_callback: Arc<RwLock<Option<SamplingCallback>>>,
112}
113
114impl Client {
115    /// Initialize the client
116    pub async fn initialize(&self) -> Result<InitializeResult> {
117        // Check if we're already initialized
118        {
119            let state = self.state.read().await;
120            if *state != ClientState::Created {
121                return Err(anyhow!("Client already initialized"));
122            }
123        }
124
125        // Update state to initializing
126        {
127            let mut state = self.state.write().await;
128            *state = ClientState::Initializing;
129        }
130
131        // Start the transport
132        self.transport.start().await?;
133
134        // Create initialize parameters
135        let params = InitializeParams {
136            protocol_version: PROTOCOL_VERSION.to_string(),
137            capabilities: self.capabilities.clone(),
138            client_info: ClientInfo {
139                name: self.name.clone(),
140                version: self.version.clone(),
141            },
142        };
143
144        // Send initialize request
145        let id = self.next_request_id().await?;
146        let response = self
147            .send_request(methods::INITIALIZE, Some(json!(params)), id.to_string())
148            .await?;
149
150        match response {
151            JsonRpcMessage::Response { result, error, .. } => {
152                if let Some(error) = error {
153                    return Err(anyhow!(
154                        "Initialize error: {} (code: {})",
155                        error.message,
156                        error.code
157                    ));
158                }
159
160                if let Some(result) = result {
161                    let result: InitializeResult = serde_json::from_value(result)?;
162
163                    // Store the result
164                    {
165                        let mut initialized = self.initialized_result.write().await;
166                        *initialized = Some(result.clone());
167                    }
168
169                    // Send initialized notification
170                    self.transport
171                        .send(JsonRpcMessage::notification(methods::INITIALIZED, None))
172                        .await?;
173
174                    // Update state to ready
175                    {
176                        let mut state = self.state.write().await;
177                        *state = ClientState::Ready;
178                    }
179
180                    return Ok(result);
181                }
182
183                Err(anyhow!("Invalid initialize response"))
184            }
185            _ => Err(anyhow!("Invalid response type")),
186        }
187    }
188
189    /// List available tools
190    pub async fn list_tools(&self) -> Result<ToolsListResult> {
191        // Check if we're initialized
192        {
193            let state = self.state.read().await;
194            if *state != ClientState::Ready {
195                return Err(anyhow!("Client not initialized"));
196            }
197        }
198
199        // Send tools/list request
200        let id = self.next_request_id().await?;
201        let response = self
202            .send_request(methods::TOOLS_LIST, None, id.to_string())
203            .await?;
204
205        match response {
206            JsonRpcMessage::Response { result, error, .. } => {
207                if let Some(error) = error {
208                    return Err(anyhow!(
209                        "List tools error: {} (code: {})",
210                        error.message,
211                        error.code
212                    ));
213                }
214
215                if let Some(result) = result {
216                    let result: ToolsListResult = serde_json::from_value(result)?;
217                    return Ok(result);
218                }
219
220                Err(anyhow!("Invalid list tools response"))
221            }
222            _ => Err(anyhow!("Invalid response type")),
223        }
224    }
225
226    /// List available resource templates
227    pub async fn list_resource_templates(
228        &self,
229    ) -> Result<mcp_protocol::types::resource::ResourceTemplatesListResult> {
230        // Check if we're initialized
231        {
232            let state = self.state.read().await;
233            if *state != ClientState::Ready {
234                return Err(anyhow!("Client not initialized"));
235            }
236        }
237
238        // Send resources/templates/list request
239        let id = self.next_request_id().await?;
240        let response = self
241            .send_request(methods::RESOURCES_TEMPLATES_LIST, None, id.to_string())
242            .await?;
243
244        match response {
245            JsonRpcMessage::Response { result, error, .. } => {
246                if let Some(error) = error {
247                    return Err(anyhow!(
248                        "List resource templates error: {} (code: {})",
249                        error.message,
250                        error.code
251                    ));
252                }
253
254                if let Some(result) = result {
255                    let result: mcp_protocol::types::resource::ResourceTemplatesListResult =
256                        serde_json::from_value(result)?;
257                    return Ok(result);
258                }
259
260                Err(anyhow!("Invalid resource templates list response"))
261            }
262            _ => Err(anyhow!("Invalid response type")),
263        }
264    }
265
266    /// Get completion suggestions for a resource or prompt parameter
267    pub async fn complete(&self, request: CompleteRequest) -> Result<CompleteResponse> {
268        // Check if we're initialized
269        {
270            let state = self.state.read().await;
271            if *state != ClientState::Ready {
272                return Err(anyhow!("Client not initialized"));
273            }
274        }
275
276        // Send completion/complete request
277        let id = self.next_request_id().await?;
278        let response = self
279            .send_request("completion/complete", Some(json!(request)), id.to_string())
280            .await?;
281
282        match response {
283            JsonRpcMessage::Response { result, error, .. } => {
284                if let Some(error) = error {
285                    return Err(anyhow!(
286                        "Completion error: {} (code: {})",
287                        error.message,
288                        error.code
289                    ));
290                }
291
292                if let Some(result) = result {
293                    let result: CompleteResponse = serde_json::from_value(result)?;
294                    return Ok(result);
295                }
296
297                Err(anyhow!("Invalid completion response"))
298            }
299            _ => Err(anyhow!("Invalid response type")),
300        }
301    }
302
303    /// Call a tool on the server
304    pub async fn call_tool(
305        &self,
306        name: &str,
307        arguments: &serde_json::Value,
308    ) -> Result<ToolCallResult> {
309        // Check if we're initialized
310        {
311            let state = self.state.read().await;
312            if *state != ClientState::Ready {
313                return Err(anyhow!("Client not initialized"));
314            }
315        }
316
317        // Create tool call parameters
318        let params = ToolCallParams {
319            name: name.to_string(),
320            arguments: arguments.clone(),
321        };
322
323        // Send tools/call request
324        let id = self.next_request_id().await?;
325        let response = self
326            .send_request(methods::TOOLS_CALL, Some(json!(params)), id.to_string())
327            .await?;
328
329        match response {
330            JsonRpcMessage::Response { result, error, .. } => {
331                if let Some(error) = error {
332                    return Err(anyhow!(
333                        "Tool call error: {} (code: {})",
334                        error.message,
335                        error.code
336                    ));
337                }
338
339                if let Some(result) = result {
340                    let result: ToolCallResult = serde_json::from_value(result)?;
341                    return Ok(result);
342                }
343
344                Err(anyhow!("Invalid tool call response"))
345            }
346            _ => Err(anyhow!("Invalid response type")),
347        }
348    }
349
350    /// Shutdown the client
351    pub async fn shutdown(&self) -> Result<()> {
352        // Check if we're initialized
353        {
354            let state = self.state.read().await;
355            if *state != ClientState::Ready {
356                return Err(anyhow!("Client not initialized"));
357            }
358        }
359
360        // Update state to shutting down
361        {
362            let mut state = self.state.write().await;
363            *state = ClientState::ShuttingDown;
364        }
365
366        // Close the transport
367        self.transport.close().await?;
368
369        Ok(())
370    }
371
372    /// Refresh the list of available prompts
373    pub async fn refresh_prompts(&self) -> Result<serde_json::Value> {
374        // Check if we're initialized
375        {
376            let state = self.state.read().await;
377            if *state != ClientState::Ready {
378                return Err(anyhow!("Client not initialized"));
379            }
380        }
381
382        // Send prompts/list request
383        let id = self.next_request_id().await?;
384        let response = self
385            .send_request(methods::PROMPTS_LIST, None, id.to_string())
386            .await?;
387
388        match response {
389            JsonRpcMessage::Response { result, error, .. } => {
390                if let Some(error) = error {
391                    return Err(anyhow!(
392                        "List prompts error: {} (code: {})",
393                        error.message,
394                        error.code
395                    ));
396                }
397
398                if let Some(result) = result {
399                    return Ok(result);
400                }
401
402                Err(anyhow!("Invalid list prompts response"))
403            }
404            _ => Err(anyhow!("Invalid response type")),
405        }
406    }
407
408    /// Get the next request ID
409    pub async fn next_request_id(&self) -> Result<i64> {
410        let mut id = self.next_id.lock().await;
411        let current = *id;
412        *id += 1;
413        Ok(current)
414    }
415
416    /// Send a request and wait for a response
417    pub async fn send_request(
418        &self,
419        method: &str,
420        params: Option<serde_json::Value>,
421        id: String,
422    ) -> Result<JsonRpcMessage> {
423        // Create request
424        let request = JsonRpcMessage::request(id.clone().into(), method, params);
425
426        // Create response channel
427        let (tx, mut rx) = mpsc::channel(1);
428
429        // Register pending request
430        {
431            let mut pending = self.pending_requests.write().await;
432            pending.insert(id.clone(), PendingRequest { response_tx: tx });
433        }
434
435        // Send request
436        self.transport.send(request).await?;
437
438        // Wait for response
439        match rx.recv().await {
440            Some(result) => {
441                // Remove pending request
442                let mut pending = self.pending_requests.write().await;
443                pending.remove(&id);
444
445                result
446            }
447            None => Err(anyhow!("Failed to receive response")),
448        }
449    }
450
451    /// Register a sampling callback
452    pub async fn register_sampling_callback(&self, callback: SamplingCallback) -> Result<()> {
453        if !self.sampling_enabled {
454            return Err(anyhow!("Sampling is not enabled"));
455        }
456
457        let mut sampling_callback = self.sampling_callback.write().await;
458        *sampling_callback = Some(callback);
459
460        Ok(())
461    }
462
463    /// Handle sampling createMessage request
464    async fn handle_sampling_create_message(&self, message: JsonRpcMessage) -> Result<()> {
465        match message {
466            JsonRpcMessage::Request { id, params, .. } => {
467                // Check if sampling is enabled
468                if !self.sampling_enabled {
469                    // Send error response
470                    self.transport
471                        .send(JsonRpcMessage::error(
472                            id,
473                            error_codes::SAMPLING_NOT_ENABLED,
474                            "Sampling is not enabled",
475                            None,
476                        ))
477                        .await?;
478                    return Ok(());
479                }
480
481                // Parse parameters
482                let params: CreateMessageParams = match params {
483                    Some(params) => match serde_json::from_value(params) {
484                        Ok(params) => params,
485                        Err(err) => {
486                            // Send error response
487                            self.transport
488                                .send(JsonRpcMessage::error(
489                                    id,
490                                    error_codes::INVALID_PARAMS,
491                                    &format!("Invalid sampling parameters: {}", err),
492                                    None,
493                                ))
494                                .await?;
495                            return Ok(());
496                        }
497                    },
498                    None => {
499                        // Send error response
500                        self.transport
501                            .send(JsonRpcMessage::error(
502                                id,
503                                error_codes::INVALID_PARAMS,
504                                "Missing sampling parameters",
505                                None,
506                            ))
507                            .await?;
508                        return Ok(());
509                    }
510                };
511
512                // Get the callback
513                let callback_result = {
514                    let callback = self.sampling_callback.read().await;
515                    if callback.is_some() {
516                        Ok(())
517                    } else {
518                        Err(anyhow!("No sampling callback registered"))
519                    }
520                };
521
522                // Check if we have a callback
523                if let Err(_) = callback_result {
524                    // Send error response
525                    self.transport
526                        .send(JsonRpcMessage::error(
527                            id,
528                            error_codes::SAMPLING_NO_CALLBACK,
529                            "No sampling callback registered",
530                            None,
531                        ))
532                        .await?;
533                    return Ok(());
534                }
535
536                // Call the callback
537                // Get a lock on the callback to invoke it
538                let result = {
539                    let callback_guard = self.sampling_callback.read().await;
540                    // We know this is Some because we checked earlier
541                    if let Some(callback) = &*callback_guard {
542                        callback(params.clone())
543                    } else {
544                        // This shouldn't happen, but just in case
545                        Err(anyhow!("No sampling callback registered"))
546                    }
547                };
548
549                match result {
550                    Ok(result) => {
551                        // Send response
552                        self.transport
553                            .send(JsonRpcMessage::response(id, json!(result)))
554                            .await?;
555                    }
556                    Err(err) => {
557                        // Send error response
558                        self.transport
559                            .send(JsonRpcMessage::error(
560                                id,
561                                error_codes::SAMPLING_ERROR,
562                                &format!("Sampling error: {}", err),
563                                None,
564                            ))
565                            .await?;
566                    }
567                }
568
569                Ok(())
570            }
571            _ => Err(anyhow!(
572                "Expected request message for sampling/createMessage"
573            )),
574        }
575    }
576
577    /// Handle a received message
578    pub async fn handle_message(&self, message: JsonRpcMessage) -> Result<()> {
579        match message.clone() {
580            JsonRpcMessage::Response { ref id, .. } => {
581                // Get id as string
582                let id = match id {
583                    serde_json::Value::String(s) => s.clone(),
584                    serde_json::Value::Number(n) => n.to_string(),
585                    _ => return Err(anyhow!("Invalid response ID type")),
586                };
587
588                // Find pending request
589                let pending = {
590                    let pending = self.pending_requests.read().await;
591                    match pending.get(&id) {
592                        Some(req) => req.response_tx.clone(),
593                        None => return Err(anyhow!("No pending request for ID: {}", id)),
594                    }
595                };
596
597                // Send response
598                if let Err(e) = pending.send(Ok(message)).await {
599                    Err(anyhow!("Failed to send response: {}", e))
600                } else {
601                    Ok(())
602                }
603            }
604            JsonRpcMessage::Notification { method, params, .. } => {
605                // Handle notification
606                match method.as_str() {
607                    // Handle prompt list changed notification
608                    methods::PROMPTS_LIST_CHANGED => {
609                        // Emit a debug message about the change
610                        tracing::debug!("Received notification: prompts list changed");
611
612                        // We could trigger a refresh of the prompts list here
613                        // but we'll skip it for now to avoid complexity with clones
614                        Ok(())
615                    }
616                    // Handle resource updated notification
617                    methods::RESOURCES_UPDATED => {
618                        // Extract the resource URI if available
619                        if let Some(params) = params {
620                            if let Some(uri) = params.get("uri").and_then(|u| u.as_str()) {
621                                tracing::debug!(
622                                    "Received notification: resource updated - URI: {}",
623                                    uri
624                                );
625                            }
626                        }
627                        Ok(())
628                    }
629                    // Add other handlers for specific notifications here
630                    _ => {
631                        tracing::debug!("Unhandled notification: {}", method);
632                        Ok(())
633                    }
634                }
635            }
636            JsonRpcMessage::Request { method, .. } => match method.as_str() {
637                methods::SAMPLING_CREATE_MESSAGE => {
638                    self.handle_sampling_create_message(message).await
639                }
640                _ => {
641                    tracing::debug!("Unhandled server request: {}", method);
642                    Ok(())
643                }
644            },
645        }
646    }
647}