mcp_core/
client.rs

1use std::{collections::HashMap, env, sync::Arc};
2
3use crate::{
4    protocol::{Protocol, ProtocolBuilder, RequestOptions},
5    transport::Transport,
6    types::{
7        CallToolRequest, CallToolResponse, ClientCapabilities, Implementation, InitializeRequest,
8        InitializeResponse, ListRequest, ReadResourceRequest, Resource, ResourcesListResponse,
9        ToolsListResponse, LATEST_PROTOCOL_VERSION,
10    },
11};
12
13use anyhow::Result;
14use serde_json::Value;
15use tokio::sync::RwLock;
16use tracing::debug;
17
18#[derive(Clone)]
19pub struct Client<T: Transport> {
20    protocol: Protocol<T>,
21    strict: bool,
22    initialize_res: Arc<RwLock<Option<InitializeResponse>>>,
23    env: Option<HashMap<String, SecureValue>>,
24}
25
26impl<T: Transport> Client<T> {
27    pub fn builder(transport: T) -> ClientBuilder<T> {
28        ClientBuilder::new(transport)
29    }
30
31    pub async fn initialize(&self, client_info: Implementation) -> Result<InitializeResponse> {
32        let request = InitializeRequest {
33            protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
34            capabilities: ClientCapabilities::default(),
35            client_info,
36        };
37        let response = self
38            .request(
39                "initialize",
40                Some(serde_json::to_value(request)?),
41                RequestOptions::default(),
42            )
43            .await?;
44        let response: InitializeResponse = serde_json::from_value(response)
45            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
46
47        if response.protocol_version != LATEST_PROTOCOL_VERSION {
48            return Err(anyhow::anyhow!(
49                "Unsupported protocol version: {}",
50                response.protocol_version
51            ));
52        }
53
54        // Save the response for later use
55        let mut writer = self.initialize_res.write().await;
56        *writer = Some(response.clone());
57
58        debug!(
59            "Initialized with protocol version: {}",
60            response.protocol_version
61        );
62        self.protocol
63            .notify("notifications/initialized", None)
64            .await?;
65
66        Ok(response)
67    }
68
69    pub async fn assert_initialized(&self) -> Result<(), anyhow::Error> {
70        let reader = self.initialize_res.read().await;
71        match &*reader {
72            Some(_) => Ok(()),
73            None => Err(anyhow::anyhow!("Not initialized")),
74        }
75    }
76
77    pub async fn request(
78        &self,
79        method: &str,
80        params: Option<serde_json::Value>,
81        options: RequestOptions,
82    ) -> Result<serde_json::Value> {
83        let response = self.protocol.request(method, params, options).await?;
84        response
85            .result
86            .ok_or_else(|| anyhow::anyhow!("Request failed: {:?}", response.error))
87    }
88
89    pub async fn list_tools(
90        &self,
91        cursor: Option<String>,
92        request_options: Option<RequestOptions>,
93    ) -> Result<ToolsListResponse> {
94        if self.strict {
95            self.assert_initialized().await?;
96        }
97
98        let list_request = ListRequest { cursor, meta: None };
99
100        let response = self
101            .request(
102                "tools/list",
103                Some(serde_json::to_value(list_request)?),
104                request_options.unwrap_or_else(RequestOptions::default),
105            )
106            .await?;
107
108        Ok(serde_json::from_value(response)
109            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
110    }
111
112    pub async fn call_tool(
113        &self,
114        name: &str,
115        arguements: Option<serde_json::Value>,
116    ) -> Result<CallToolResponse> {
117        if self.strict {
118            self.assert_initialized().await?;
119        }
120
121        let arguments = if let Some(env) = &self.env {
122            arguements
123                .as_ref()
124                .map(|args| apply_secure_replacements(args, env))
125        } else {
126            arguements
127        };
128
129        let arguments =
130            arguments.map(|value| serde_json::from_value(value).unwrap_or_else(|_| HashMap::new()));
131
132        let request = CallToolRequest {
133            name: name.to_string(),
134            arguments,
135            meta: None,
136        };
137
138        let response = self
139            .request(
140                "tools/call",
141                Some(serde_json::to_value(request)?),
142                RequestOptions::default(),
143            )
144            .await?;
145
146        Ok(serde_json::from_value(response)
147            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
148    }
149
150    pub async fn list_resources(
151        &self,
152        cursor: Option<String>,
153        request_options: Option<RequestOptions>,
154    ) -> Result<ResourcesListResponse> {
155        if self.strict {
156            self.assert_initialized().await?;
157        }
158
159        let list_request = ListRequest { cursor, meta: None };
160
161        let response = self
162            .request(
163                "resources/list",
164                Some(serde_json::to_value(list_request)?),
165                request_options.unwrap_or_else(RequestOptions::default),
166            )
167            .await?;
168
169        Ok(serde_json::from_value(response)
170            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
171    }
172
173    pub async fn read_resource(&self, uri: url::Url) -> Result<Resource> {
174        if self.strict {
175            self.assert_initialized().await?;
176        }
177
178        let read_request = ReadResourceRequest { uri };
179
180        let response = self
181            .request(
182                "resources/read",
183                Some(serde_json::to_value(read_request)?),
184                RequestOptions::default(),
185            )
186            .await?;
187
188        Ok(serde_json::from_value(response)
189            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
190    }
191
192    pub async fn subscribe_to_resource(&self, uri: url::Url) -> Result<()> {
193        if self.strict {
194            self.assert_initialized().await?;
195        }
196
197        let subscribe_request = ReadResourceRequest { uri };
198
199        self.request(
200            "resources/subscribe",
201            Some(serde_json::to_value(subscribe_request)?),
202            RequestOptions::default(),
203        )
204        .await?;
205
206        Ok(())
207    }
208
209    pub async fn unsubscribe_to_resource(&self, uri: url::Url) -> Result<()> {
210        if self.strict {
211            self.assert_initialized().await?;
212        }
213
214        let unsubscribe_request = ReadResourceRequest { uri };
215
216        self.request(
217            "resources/unsubscribe",
218            Some(serde_json::to_value(unsubscribe_request)?),
219            RequestOptions::default(),
220        )
221        .await?;
222
223        Ok(())
224    }
225
226    pub async fn start(&self) -> Result<()> {
227        self.protocol.listen().await
228    }
229}
230
231#[derive(Clone)]
232pub enum SecureValue {
233    Static(String),
234    Env(String),
235}
236
237pub struct ClientBuilder<T: Transport> {
238    protocol: ProtocolBuilder<T>,
239    strict: bool,
240    env: Option<HashMap<String, SecureValue>>,
241}
242
243impl<T: Transport> ClientBuilder<T> {
244    pub fn new(transport: T) -> Self {
245        Self {
246            protocol: ProtocolBuilder::new(transport),
247            strict: false,
248            env: None,
249        }
250    }
251
252    pub fn with_secure_value(mut self, key: impl Into<String>, value: SecureValue) -> Self {
253        match &mut self.env {
254            Some(env) => {
255                env.insert(key.into(), value);
256            }
257            None => {
258                let mut new_env = HashMap::new();
259                new_env.insert(key.into(), value);
260                self.env = Some(new_env);
261            }
262        }
263        self
264    }
265
266    pub fn use_strict(mut self) -> Self {
267        self.strict = true;
268        self
269    }
270
271    pub fn with_strict(mut self, strict: bool) -> Self {
272        self.strict = strict;
273        self
274    }
275
276    pub fn build(self) -> Client<T> {
277        Client {
278            protocol: self.protocol.build(),
279            strict: self.strict,
280            env: self.env,
281            initialize_res: Arc::new(RwLock::new(None)),
282        }
283    }
284}
285
286/// Recursively walk through the JSON value. If a JSON string exactly matches
287/// one of the keys in the secure values map, replace it with the corresponding secure value.
288pub fn apply_secure_replacements(
289    value: &Value,
290    secure_values: &HashMap<String, SecureValue>,
291) -> Value {
292    match value {
293        Value::Object(map) => {
294            let mut new_map = serde_json::Map::new();
295            for (k, v) in map.iter() {
296                let new_value = if let Value::String(_) = v {
297                    if let Some(secure_val) = secure_values.get(k) {
298                        let replacement = match secure_val {
299                            SecureValue::Static(val) => val.clone(),
300                            SecureValue::Env(env_key) => env::var(env_key)
301                                .unwrap_or_else(|_| v.as_str().unwrap().to_string()),
302                        };
303                        Value::String(replacement)
304                    } else {
305                        apply_secure_replacements(v, secure_values)
306                    }
307                } else {
308                    apply_secure_replacements(v, secure_values)
309                };
310                new_map.insert(k.clone(), new_value);
311            }
312            Value::Object(new_map)
313        }
314        Value::Array(arr) => {
315            let new_arr: Vec<Value> = arr
316                .iter()
317                .map(|v| apply_secure_replacements(v, secure_values))
318                .collect();
319            Value::Array(new_arr)
320        }
321        _ => value.clone(),
322    }
323}