mcp_core/
client.rs

1//! # MCP Client
2//!
3//! This module provides the client-side implementation of the Model Context Protocol (MCP).
4//! The client can connect to MCP servers, initialize the connection, and invoke tools
5//! provided by the server.
6//!
7//! The core functionality includes:
8//! - Establishing connections to MCP servers
9//! - Managing the protocol handshake
10//! - Discovering available tools
11//! - Invoking tools with parameters
12//! - Handling server resources
13
14use std::{collections::HashMap, env, sync::Arc};
15
16use crate::{
17    protocol::RequestOptions,
18    transport::Transport,
19    types::{
20        CallToolRequest, CallToolResponse, ClientCapabilities, Implementation, InitializeRequest,
21        InitializeResponse, ListRequest, ProtocolVersion, ReadResourceRequest, Resource,
22        ResourcesListResponse, ToolsListResponse, LATEST_PROTOCOL_VERSION,
23    },
24};
25
26use anyhow::Result;
27use serde_json::Value;
28use tokio::sync::RwLock;
29use tracing::debug;
30
31/// An MCP client for connecting to MCP servers and invoking their tools.
32///
33/// The `Client` provides a high-level API for interacting with MCP servers,
34/// including initialization, tool discovery, and tool invocation.
35#[derive(Clone)]
36pub struct Client<T: Transport> {
37    transport: T,
38    strict: bool,
39    protocol_version: ProtocolVersion,
40    initialize_res: Arc<RwLock<Option<InitializeResponse>>>,
41    env: Option<HashMap<String, SecureValue>>,
42    client_info: Implementation,
43    capabilities: ClientCapabilities,
44}
45
46impl<T: Transport> Client<T> {
47    /// Creates a new client builder.
48    ///
49    /// # Arguments
50    ///
51    /// * `transport` - The transport to use for communication with the server
52    ///
53    /// # Returns
54    ///
55    /// A `ClientBuilder` for configuring and building the client
56    pub fn builder(transport: T) -> ClientBuilder<T> {
57        ClientBuilder::new(transport)
58    }
59
60    /// Sets the protocol version for the client.
61    ///
62    /// # Arguments
63    ///
64    /// * `protocol_version` - The protocol version to use
65    ///
66    /// # Returns
67    ///
68    /// The modified client instance
69    pub fn set_protocol_version(mut self, protocol_version: ProtocolVersion) -> Self {
70        self.protocol_version = protocol_version;
71        self
72    }
73
74    /// Opens the transport connection.
75    ///
76    /// # Returns
77    ///
78    /// A `Result` indicating success or failure
79    pub async fn open(&self) -> Result<()> {
80        self.transport.open().await
81    }
82
83    /// Initializes the connection with the MCP server.
84    ///
85    /// This sends the initialize request to the server, negotiates protocol
86    /// version and capabilities, and establishes the session.
87    ///
88    /// # Returns
89    ///
90    /// A `Result` containing the server's initialization response if successful
91    pub async fn initialize(&self) -> Result<InitializeResponse> {
92        let request = InitializeRequest {
93            protocol_version: self.protocol_version.as_str().to_string(),
94            capabilities: self.capabilities.clone(),
95            client_info: self.client_info.clone(),
96        };
97        let response = self
98            .request(
99                "initialize",
100                Some(serde_json::to_value(request)?),
101                RequestOptions::default(),
102            )
103            .await?;
104        let response: InitializeResponse = serde_json::from_value(response)
105            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
106
107        if response.protocol_version != self.protocol_version.as_str() {
108            return Err(anyhow::anyhow!(
109                "Unsupported protocol version: {}",
110                response.protocol_version
111            ));
112        }
113
114        // Save the response for later use
115        let mut writer = self.initialize_res.write().await;
116        *writer = Some(response.clone());
117
118        debug!(
119            "Initialized with protocol version: {}",
120            response.protocol_version
121        );
122        self.transport
123            .send_notification("notifications/initialized", None)
124            .await?;
125
126        Ok(response)
127    }
128
129    /// Checks if the client has been initialized.
130    ///
131    /// # Returns
132    ///
133    /// A `Result` indicating if the client is initialized
134    pub async fn assert_initialized(&self) -> Result<(), anyhow::Error> {
135        let reader = self.initialize_res.read().await;
136        match &*reader {
137            Some(_) => Ok(()),
138            None => Err(anyhow::anyhow!("Not initialized")),
139        }
140    }
141
142    /// Sends a request to the server.
143    ///
144    /// # Arguments
145    ///
146    /// * `method` - The method name
147    /// * `params` - Optional parameters for the request
148    /// * `options` - Request options (like timeout)
149    ///
150    /// # Returns
151    ///
152    /// A `Result` containing the server's response if successful
153    pub async fn request(
154        &self,
155        method: &str,
156        params: Option<serde_json::Value>,
157        options: RequestOptions,
158    ) -> Result<serde_json::Value> {
159        let response = self.transport.request(method, params, options).await?;
160        response
161            .result
162            .ok_or_else(|| anyhow::anyhow!("Request failed: {:?}", response.error))
163    }
164
165    /// Lists tools available on the server.
166    ///
167    /// # Arguments
168    ///
169    /// * `cursor` - Optional pagination cursor
170    /// * `request_options` - Optional request options
171    ///
172    /// # Returns
173    ///
174    /// A `Result` containing the list of tools if successful
175    pub async fn list_tools(
176        &self,
177        cursor: Option<String>,
178        request_options: Option<RequestOptions>,
179    ) -> Result<ToolsListResponse> {
180        if self.strict {
181            self.assert_initialized().await?;
182        }
183
184        let list_request = ListRequest { cursor, meta: None };
185
186        let response = self
187            .request(
188                "tools/list",
189                Some(serde_json::to_value(list_request)?),
190                request_options.unwrap_or_else(RequestOptions::default),
191            )
192            .await?;
193
194        Ok(serde_json::from_value(response)
195            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
196    }
197
198    /// Calls a tool on the server.
199    ///
200    /// # Arguments
201    ///
202    /// * `name` - The name of the tool to call
203    /// * `arguments` - Optional arguments for the tool
204    ///
205    /// # Returns
206    ///
207    /// A `Result` containing the tool's response if successful
208    pub async fn call_tool(
209        &self,
210        name: &str,
211        arguements: Option<serde_json::Value>,
212    ) -> Result<CallToolResponse> {
213        if self.strict {
214            self.assert_initialized().await?;
215        }
216
217        let arguments = if let Some(env) = &self.env {
218            arguements
219                .as_ref()
220                .map(|args| apply_secure_replacements(args, env))
221        } else {
222            arguements
223        };
224
225        let arguments =
226            arguments.map(|value| serde_json::from_value(value).unwrap_or_else(|_| HashMap::new()));
227
228        let request = CallToolRequest {
229            name: name.to_string(),
230            arguments,
231            meta: None,
232        };
233
234        let response = self
235            .request(
236                "tools/call",
237                Some(serde_json::to_value(request)?),
238                RequestOptions::default(),
239            )
240            .await?;
241
242        Ok(serde_json::from_value(response)
243            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
244    }
245
246    /// Lists resources available on the server.
247    ///
248    /// # Arguments
249    ///
250    /// * `cursor` - Optional pagination cursor
251    /// * `request_options` - Optional request options
252    ///
253    /// # Returns
254    ///
255    /// A `Result` containing the list of resources if successful
256    pub async fn list_resources(
257        &self,
258        cursor: Option<String>,
259        request_options: Option<RequestOptions>,
260    ) -> Result<ResourcesListResponse> {
261        if self.strict {
262            self.assert_initialized().await?;
263        }
264
265        let list_request = ListRequest { cursor, meta: None };
266
267        let response = self
268            .request(
269                "resources/list",
270                Some(serde_json::to_value(list_request)?),
271                request_options.unwrap_or_else(RequestOptions::default),
272            )
273            .await?;
274
275        Ok(serde_json::from_value(response)
276            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
277    }
278
279    /// Reads a resource from the server.
280    ///
281    /// # Arguments
282    ///
283    /// * `uri` - The URI of the resource to read
284    ///
285    /// # Returns
286    ///
287    /// A `Result` containing the resource if successful
288    pub async fn read_resource(&self, uri: url::Url) -> Result<Resource> {
289        if self.strict {
290            self.assert_initialized().await?;
291        }
292
293        let read_request = ReadResourceRequest { uri };
294
295        let response = self
296            .request(
297                "resources/read",
298                Some(serde_json::to_value(read_request)?),
299                RequestOptions::default(),
300            )
301            .await?;
302
303        Ok(serde_json::from_value(response)
304            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
305    }
306
307    pub async fn subscribe_to_resource(&self, uri: url::Url) -> Result<()> {
308        if self.strict {
309            self.assert_initialized().await?;
310        }
311
312        let subscribe_request = ReadResourceRequest { uri };
313
314        self.request(
315            "resources/subscribe",
316            Some(serde_json::to_value(subscribe_request)?),
317            RequestOptions::default(),
318        )
319        .await?;
320
321        Ok(())
322    }
323
324    pub async fn unsubscribe_to_resource(&self, uri: url::Url) -> Result<()> {
325        if self.strict {
326            self.assert_initialized().await?;
327        }
328
329        let unsubscribe_request = ReadResourceRequest { uri };
330
331        self.request(
332            "resources/unsubscribe",
333            Some(serde_json::to_value(unsubscribe_request)?),
334            RequestOptions::default(),
335        )
336        .await?;
337
338        Ok(())
339    }
340}
341
342/// Represents a value that may contain sensitive information.
343///
344/// Secure values can be either static strings or environment variables.
345#[derive(Clone, Debug)]
346pub enum SecureValue {
347    /// A static string value
348    Static(String),
349    /// An environment variable reference
350    Env(String),
351}
352
353/// Builder for creating configured `Client` instances.
354///
355/// The `ClientBuilder` provides a fluent API for configuring and creating
356/// MCP clients with specific settings.
357pub struct ClientBuilder<T: Transport> {
358    transport: T,
359    strict: bool,
360    env: Option<HashMap<String, SecureValue>>,
361    protocol_version: ProtocolVersion,
362    client_info: Implementation,
363    capabilities: ClientCapabilities,
364}
365
366impl<T: Transport> ClientBuilder<T> {
367    /// Creates a new client builder.
368    ///
369    /// # Arguments
370    ///
371    /// * `transport` - The transport to use for communication with the server
372    ///
373    /// # Returns
374    ///
375    /// A new `ClientBuilder` instance
376    pub fn new(transport: T) -> Self {
377        Self {
378            transport,
379            strict: false,
380            env: None,
381            protocol_version: LATEST_PROTOCOL_VERSION,
382            client_info: Implementation {
383                name: env::var("CARGO_PKG_NAME").unwrap_or_else(|_| "mcp-client".to_string()),
384                version: env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "0.1.0".to_string()),
385            },
386            capabilities: ClientCapabilities::default(),
387        }
388    }
389
390    /// Sets the protocol version for the client.
391    ///
392    /// # Arguments
393    ///
394    /// * `protocol_version` - The protocol version to use
395    ///
396    /// # Returns
397    ///
398    /// The modified builder instance
399    pub fn set_protocol_version(mut self, protocol_version: ProtocolVersion) -> Self {
400        self.protocol_version = protocol_version;
401        self
402    }
403
404    /// Sets the client information.
405    ///
406    /// # Arguments
407    ///
408    /// * `name` - The client name
409    /// * `version` - The client version
410    ///
411    /// # Returns
412    ///
413    /// The modified builder instance
414    pub fn set_client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
415        self.client_info = Implementation {
416            name: name.into(),
417            version: version.into(),
418        };
419        self
420    }
421
422    /// Sets the client capabilities.
423    ///
424    /// # Arguments
425    ///
426    /// * `capabilities` - The client capabilities
427    ///
428    /// # Returns
429    ///
430    /// The modified builder instance
431    pub fn set_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
432        self.capabilities = capabilities;
433        self
434    }
435
436    /// Adds a secure value for substitution in tool arguments.
437    ///
438    /// # Arguments
439    ///
440    /// * `key` - The key for the secure value
441    /// * `value` - The secure value
442    ///
443    /// # Returns
444    ///
445    /// The modified builder instance
446    pub fn with_secure_value(mut self, key: impl Into<String>, value: SecureValue) -> Self {
447        if self.env.is_none() {
448            self.env = Some(HashMap::new());
449        }
450
451        if let Some(env) = &mut self.env {
452            env.insert(key.into(), value);
453        }
454
455        self
456    }
457
458    /// Enables strict mode, which requires initialization before operations.
459    ///
460    /// # Returns
461    ///
462    /// The modified builder instance
463    pub fn use_strict(mut self) -> Self {
464        self.strict = true;
465        self
466    }
467
468    /// Sets the strict mode flag.
469    ///
470    /// # Arguments
471    ///
472    /// * `strict` - Whether to enable strict mode
473    ///
474    /// # Returns
475    ///
476    /// The modified builder instance
477    pub fn with_strict(mut self, strict: bool) -> Self {
478        self.strict = strict;
479        self
480    }
481
482    /// Builds the client with the configured settings.
483    ///
484    /// # Returns
485    ///
486    /// A new `Client` instance
487    pub fn build(self) -> Client<T> {
488        Client {
489            transport: self.transport,
490            strict: self.strict,
491            env: self.env,
492            protocol_version: self.protocol_version,
493            initialize_res: Arc::new(RwLock::new(None)),
494            client_info: self.client_info,
495            capabilities: self.capabilities,
496        }
497    }
498}
499
500/// Recursively walk through the JSON value. If a JSON string exactly matches
501/// one of the keys in the secure values map, replace it with the corresponding secure value.
502pub fn apply_secure_replacements(
503    value: &Value,
504    secure_values: &HashMap<String, SecureValue>,
505) -> Value {
506    match value {
507        Value::Object(map) => {
508            let mut new_map = serde_json::Map::new();
509            for (k, v) in map.iter() {
510                let new_value = if let Value::String(_) = v {
511                    if let Some(secure_val) = secure_values.get(k) {
512                        let replacement = match secure_val {
513                            SecureValue::Static(val) => val.clone(),
514                            SecureValue::Env(env_key) => env::var(env_key)
515                                .unwrap_or_else(|_| v.as_str().unwrap().to_string()),
516                        };
517                        Value::String(replacement)
518                    } else {
519                        apply_secure_replacements(v, secure_values)
520                    }
521                } else {
522                    apply_secure_replacements(v, secure_values)
523                };
524                new_map.insert(k.clone(), new_value);
525            }
526            Value::Object(new_map)
527        }
528        Value::Array(arr) => {
529            let new_arr: Vec<Value> = arr
530                .iter()
531                .map(|v| apply_secure_replacements(v, secure_values))
532                .collect();
533            Value::Array(new_arr)
534        }
535        _ => value.clone(),
536    }
537}