mcp_core/
client.rs

1use std::{collections::HashMap, env, sync::Arc};
2
3use crate::{
4    protocol::{Protocol, ProtocolBuilder, RequestOptions},
5    transport::Transport,
6    types::{
7        CallToolRequest, CallToolResponse, ClientCapabilities, Implementation, InitializeRequest,
8        InitializeResponse, ListRequest, ToolsListResponse, LATEST_PROTOCOL_VERSION,
9    },
10};
11
12use anyhow::Result;
13use serde_json::Value;
14use tokio::sync::RwLock;
15use tracing::debug;
16
17#[derive(Clone)]
18pub struct Client<T: Transport> {
19    protocol: Protocol<T>,
20    strict: bool,
21    initialize_res: Arc<RwLock<Option<InitializeResponse>>>,
22    env: Option<HashMap<String, SecureValue>>,
23}
24
25impl<T: Transport> Client<T> {
26    pub fn builder(transport: T) -> ClientBuilder<T> {
27        ClientBuilder::new(transport)
28    }
29
30    pub async fn initialize(&self, client_info: Implementation) -> Result<InitializeResponse> {
31        let request = InitializeRequest {
32            protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
33            capabilities: ClientCapabilities::default(),
34            client_info,
35        };
36        let response = self
37            .request(
38                "initialize",
39                Some(serde_json::to_value(request)?),
40                RequestOptions::default(),
41            )
42            .await?;
43        let response: InitializeResponse = serde_json::from_value(response)
44            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
45
46        if response.protocol_version != LATEST_PROTOCOL_VERSION {
47            return Err(anyhow::anyhow!(
48                "Unsupported protocol version: {}",
49                response.protocol_version
50            ));
51        }
52
53        // Save the response for later use
54        let mut writer = self.initialize_res.write().await;
55        *writer = Some(response.clone());
56
57        debug!(
58            "Initialized with protocol version: {}",
59            response.protocol_version
60        );
61        self.protocol
62            .notify("notifications/initialized", None)
63            .await?;
64
65        Ok(response)
66    }
67
68    pub async fn assert_initialized(&self) -> Result<(), anyhow::Error> {
69        let reader = self.initialize_res.read().await;
70        match &*reader {
71            Some(_) => Ok(()),
72            None => Err(anyhow::anyhow!("Not initialized")),
73        }
74    }
75
76    pub async fn request(
77        &self,
78        method: &str,
79        params: Option<serde_json::Value>,
80        options: RequestOptions,
81    ) -> Result<serde_json::Value> {
82        let response = self.protocol.request(method, params, options).await?;
83        response
84            .result
85            .ok_or_else(|| anyhow::anyhow!("Request failed: {:?}", response.error))
86    }
87
88    pub async fn list_tools(
89        &self,
90        cursor: Option<String>,
91        request_options: Option<RequestOptions>,
92    ) -> Result<ToolsListResponse> {
93        if self.strict {
94            self.assert_initialized().await?;
95        }
96
97        let list_request = ListRequest { cursor, meta: None };
98
99        let response = self
100            .request(
101                "tools/list",
102                Some(serde_json::to_value(list_request)?),
103                request_options.unwrap_or_else(RequestOptions::default),
104            )
105            .await?;
106
107        Ok(serde_json::from_value(response)
108            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
109    }
110
111    pub async fn call_tool(
112        &self,
113        name: &str,
114        arguements: Option<serde_json::Value>,
115    ) -> Result<CallToolResponse> {
116        if self.strict {
117            self.assert_initialized().await?;
118        }
119
120        let arguments = if let Some(env) = &self.env {
121            arguements
122                .as_ref()
123                .map(|args| apply_secure_replacements(args, env))
124        } else {
125            arguements
126        };
127
128        let arguments =
129            arguments.map(|value| serde_json::from_value(value).unwrap_or_else(|_| HashMap::new()));
130
131        let request = CallToolRequest {
132            name: name.to_string(),
133            arguments,
134            meta: None,
135        };
136
137        let response = self
138            .request(
139                "tools/call",
140                Some(serde_json::to_value(request)?),
141                RequestOptions::default(),
142            )
143            .await?;
144
145        Ok(serde_json::from_value(response)
146            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
147    }
148
149    pub(crate) async fn list_resources() {
150        todo!()
151    }
152
153    pub(crate) async fn read_resource() {
154        todo!()
155    }
156
157    pub(crate) async fn subscribe_to_resource() {
158        todo!()
159    }
160
161    pub async fn start(&self) -> Result<()> {
162        self.protocol.listen().await
163    }
164}
165
166#[derive(Clone)]
167pub enum SecureValue {
168    Static(String),
169    Env(String),
170}
171
172pub struct ClientBuilder<T: Transport> {
173    protocol: ProtocolBuilder<T>,
174    strict: bool,
175    env: Option<HashMap<String, SecureValue>>,
176}
177
178impl<T: Transport> ClientBuilder<T> {
179    pub fn new(transport: T) -> Self {
180        Self {
181            protocol: ProtocolBuilder::new(transport),
182            strict: false,
183            env: None,
184        }
185    }
186
187    pub fn with_secure_value(mut self, key: impl Into<String>, value: SecureValue) -> Self {
188        match &mut self.env {
189            Some(env) => {
190                env.insert(key.into(), value);
191            }
192            None => {
193                let mut new_env = HashMap::new();
194                new_env.insert(key.into(), value);
195                self.env = Some(new_env);
196            }
197        }
198        self
199    }
200
201    pub fn use_strict(mut self) -> Self {
202        self.strict = true;
203        self
204    }
205
206    pub fn with_strict(mut self, strict: bool) -> Self {
207        self.strict = strict;
208        self
209    }
210
211    pub fn build(self) -> Client<T> {
212        Client {
213            protocol: self.protocol.build(),
214            strict: self.strict,
215            env: self.env,
216            initialize_res: Arc::new(RwLock::new(None)),
217        }
218    }
219}
220
221/// Recursively walk through the JSON value. If a JSON string exactly matches
222/// one of the keys in the secure values map, replace it with the corresponding secure value.
223pub fn apply_secure_replacements(
224    value: &Value,
225    secure_values: &HashMap<String, SecureValue>,
226) -> Value {
227    match value {
228        Value::Object(map) => {
229            let mut new_map = serde_json::Map::new();
230            for (k, v) in map.iter() {
231                let new_value = if let Value::String(_) = v {
232                    if let Some(secure_val) = secure_values.get(k) {
233                        let replacement = match secure_val {
234                            SecureValue::Static(val) => val.clone(),
235                            SecureValue::Env(env_key) => env::var(env_key)
236                                .unwrap_or_else(|_| v.as_str().unwrap().to_string()),
237                        };
238                        Value::String(replacement)
239                    } else {
240                        apply_secure_replacements(v, secure_values)
241                    }
242                } else {
243                    apply_secure_replacements(v, secure_values)
244                };
245                new_map.insert(k.clone(), new_value);
246            }
247            Value::Object(new_map)
248        }
249        Value::Array(arr) => {
250            let new_arr: Vec<Value> = arr
251                .iter()
252                .map(|v| apply_secure_replacements(v, secure_values))
253                .collect();
254            Value::Array(new_arr)
255        }
256        _ => value.clone(),
257    }
258}