mcp_core/
client.rs

1use std::{collections::HashMap, env, sync::Arc};
2
3use crate::{
4    protocol::RequestOptions,
5    transport::Transport,
6    types::{
7        CallToolRequest, CallToolResponse, ClientCapabilities, Implementation, InitializeRequest,
8        InitializeResponse, ListRequest, ReadResourceRequest, Resource, ResourcesListResponse,
9        ToolsListResponse, LATEST_PROTOCOL_VERSION,
10    },
11};
12
13use anyhow::Result;
14use serde_json::Value;
15use tokio::sync::RwLock;
16use tracing::debug;
17
18#[derive(Clone)]
19pub struct Client<T: Transport> {
20    transport: T,
21    strict: bool,
22    initialize_res: Arc<RwLock<Option<InitializeResponse>>>,
23    env: Option<HashMap<String, SecureValue>>,
24}
25
26impl<T: Transport> Client<T> {
27    pub fn builder(transport: T) -> ClientBuilder<T> {
28        ClientBuilder::new(transport)
29    }
30
31    pub async fn open(&self) -> Result<()> {
32        self.transport.open().await
33    }
34
35    pub async fn initialize(
36        &self,
37        client_info: Implementation,
38        capabilities: ClientCapabilities,
39    ) -> Result<InitializeResponse> {
40        let request = InitializeRequest {
41            protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
42            capabilities,
43            client_info,
44        };
45        let response = self
46            .request(
47                "initialize",
48                Some(serde_json::to_value(request)?),
49                RequestOptions::default(),
50            )
51            .await?;
52        let response: InitializeResponse = serde_json::from_value(response)
53            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
54
55        if response.protocol_version != LATEST_PROTOCOL_VERSION {
56            return Err(anyhow::anyhow!(
57                "Unsupported protocol version: {}",
58                response.protocol_version
59            ));
60        }
61
62        // Save the response for later use
63        let mut writer = self.initialize_res.write().await;
64        *writer = Some(response.clone());
65
66        debug!(
67            "Initialized with protocol version: {}",
68            response.protocol_version
69        );
70        self.transport
71            .send_notification("notifications/initialized", None)
72            .await?;
73
74        Ok(response)
75    }
76
77    pub async fn assert_initialized(&self) -> Result<(), anyhow::Error> {
78        let reader = self.initialize_res.read().await;
79        match &*reader {
80            Some(_) => Ok(()),
81            None => Err(anyhow::anyhow!("Not initialized")),
82        }
83    }
84
85    pub async fn request(
86        &self,
87        method: &str,
88        params: Option<serde_json::Value>,
89        options: RequestOptions,
90    ) -> Result<serde_json::Value> {
91        let response = self.transport.request(method, params, options).await?;
92        response
93            .result
94            .ok_or_else(|| anyhow::anyhow!("Request failed: {:?}", response.error))
95    }
96
97    pub async fn list_tools(
98        &self,
99        cursor: Option<String>,
100        request_options: Option<RequestOptions>,
101    ) -> Result<ToolsListResponse> {
102        if self.strict {
103            self.assert_initialized().await?;
104        }
105
106        let list_request = ListRequest { cursor, meta: None };
107
108        let response = self
109            .request(
110                "tools/list",
111                Some(serde_json::to_value(list_request)?),
112                request_options.unwrap_or_else(RequestOptions::default),
113            )
114            .await?;
115
116        Ok(serde_json::from_value(response)
117            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
118    }
119
120    pub async fn call_tool(
121        &self,
122        name: &str,
123        arguements: Option<serde_json::Value>,
124    ) -> Result<CallToolResponse> {
125        if self.strict {
126            self.assert_initialized().await?;
127        }
128
129        let arguments = if let Some(env) = &self.env {
130            arguements
131                .as_ref()
132                .map(|args| apply_secure_replacements(args, env))
133        } else {
134            arguements
135        };
136
137        let arguments =
138            arguments.map(|value| serde_json::from_value(value).unwrap_or_else(|_| HashMap::new()));
139
140        let request = CallToolRequest {
141            name: name.to_string(),
142            arguments,
143            meta: None,
144        };
145
146        let response = self
147            .request(
148                "tools/call",
149                Some(serde_json::to_value(request)?),
150                RequestOptions::default(),
151            )
152            .await?;
153
154        Ok(serde_json::from_value(response)
155            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
156    }
157
158    pub async fn list_resources(
159        &self,
160        cursor: Option<String>,
161        request_options: Option<RequestOptions>,
162    ) -> Result<ResourcesListResponse> {
163        if self.strict {
164            self.assert_initialized().await?;
165        }
166
167        let list_request = ListRequest { cursor, meta: None };
168
169        let response = self
170            .request(
171                "resources/list",
172                Some(serde_json::to_value(list_request)?),
173                request_options.unwrap_or_else(RequestOptions::default),
174            )
175            .await?;
176
177        Ok(serde_json::from_value(response)
178            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
179    }
180
181    pub async fn read_resource(&self, uri: url::Url) -> Result<Resource> {
182        if self.strict {
183            self.assert_initialized().await?;
184        }
185
186        let read_request = ReadResourceRequest { uri };
187
188        let response = self
189            .request(
190                "resources/read",
191                Some(serde_json::to_value(read_request)?),
192                RequestOptions::default(),
193            )
194            .await?;
195
196        Ok(serde_json::from_value(response)
197            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
198    }
199
200    pub async fn subscribe_to_resource(&self, uri: url::Url) -> Result<()> {
201        if self.strict {
202            self.assert_initialized().await?;
203        }
204
205        let subscribe_request = ReadResourceRequest { uri };
206
207        self.request(
208            "resources/subscribe",
209            Some(serde_json::to_value(subscribe_request)?),
210            RequestOptions::default(),
211        )
212        .await?;
213
214        Ok(())
215    }
216
217    pub async fn unsubscribe_to_resource(&self, uri: url::Url) -> Result<()> {
218        if self.strict {
219            self.assert_initialized().await?;
220        }
221
222        let unsubscribe_request = ReadResourceRequest { uri };
223
224        self.request(
225            "resources/unsubscribe",
226            Some(serde_json::to_value(unsubscribe_request)?),
227            RequestOptions::default(),
228        )
229        .await?;
230
231        Ok(())
232    }
233}
234
235#[derive(Clone)]
236pub enum SecureValue {
237    Static(String),
238    Env(String),
239}
240
241pub struct ClientBuilder<T: Transport> {
242    transport: T,
243    strict: bool,
244    env: Option<HashMap<String, SecureValue>>,
245}
246
247impl<T: Transport> ClientBuilder<T> {
248    pub fn new(transport: T) -> Self {
249        Self {
250            transport,
251            strict: false,
252            env: None,
253        }
254    }
255
256    pub fn with_secure_value(mut self, key: impl Into<String>, value: SecureValue) -> Self {
257        match &mut self.env {
258            Some(env) => {
259                env.insert(key.into(), value);
260            }
261            None => {
262                let mut new_env = HashMap::new();
263                new_env.insert(key.into(), value);
264                self.env = Some(new_env);
265            }
266        }
267        self
268    }
269
270    pub fn use_strict(mut self) -> Self {
271        self.strict = true;
272        self
273    }
274
275    pub fn with_strict(mut self, strict: bool) -> Self {
276        self.strict = strict;
277        self
278    }
279
280    pub fn build(self) -> Client<T> {
281        Client {
282            transport: self.transport,
283            strict: self.strict,
284            env: self.env,
285            initialize_res: Arc::new(RwLock::new(None)),
286        }
287    }
288}
289
290/// Recursively walk through the JSON value. If a JSON string exactly matches
291/// one of the keys in the secure values map, replace it with the corresponding secure value.
292pub fn apply_secure_replacements(
293    value: &Value,
294    secure_values: &HashMap<String, SecureValue>,
295) -> Value {
296    match value {
297        Value::Object(map) => {
298            let mut new_map = serde_json::Map::new();
299            for (k, v) in map.iter() {
300                let new_value = if let Value::String(_) = v {
301                    if let Some(secure_val) = secure_values.get(k) {
302                        let replacement = match secure_val {
303                            SecureValue::Static(val) => val.clone(),
304                            SecureValue::Env(env_key) => env::var(env_key)
305                                .unwrap_or_else(|_| v.as_str().unwrap().to_string()),
306                        };
307                        Value::String(replacement)
308                    } else {
309                        apply_secure_replacements(v, secure_values)
310                    }
311                } else {
312                    apply_secure_replacements(v, secure_values)
313                };
314                new_map.insert(k.clone(), new_value);
315            }
316            Value::Object(new_map)
317        }
318        Value::Array(arr) => {
319            let new_arr: Vec<Value> = arr
320                .iter()
321                .map(|v| apply_secure_replacements(v, secure_values))
322                .collect();
323            Value::Array(new_arr)
324        }
325        _ => value.clone(),
326    }
327}