1use std::{collections::HashMap, env, sync::Arc};
15
16use crate::{
17 protocol::RequestOptions,
18 transport::Transport,
19 types::{
20 CallToolRequest, CallToolResponse, ClientCapabilities, Implementation, InitializeRequest,
21 InitializeResponse, ListRequest, ProtocolVersion, ReadResourceRequest, Resource,
22 ResourcesListResponse, ToolsListResponse, LATEST_PROTOCOL_VERSION,
23 },
24};
25
26use anyhow::Result;
27use serde_json::Value;
28use tokio::sync::RwLock;
29use tracing::debug;
30
31#[derive(Clone)]
36pub struct Client<T: Transport> {
37 transport: T,
38 strict: bool,
39 protocol_version: ProtocolVersion,
40 initialize_res: Arc<RwLock<Option<InitializeResponse>>>,
41 env: Option<HashMap<String, SecureValue>>,
42 client_info: Implementation,
43 capabilities: ClientCapabilities,
44}
45
46impl<T: Transport> Client<T> {
47 pub fn builder(transport: T) -> ClientBuilder<T> {
57 ClientBuilder::new(transport)
58 }
59
60 pub fn set_protocol_version(mut self, protocol_version: ProtocolVersion) -> Self {
70 self.protocol_version = protocol_version;
71 self
72 }
73
74 pub async fn open(&self) -> Result<()> {
80 self.transport.open().await
81 }
82
83 pub async fn initialize(&self) -> Result<InitializeResponse> {
92 let request = InitializeRequest {
93 protocol_version: self.protocol_version.as_str().to_string(),
94 capabilities: self.capabilities.clone(),
95 client_info: self.client_info.clone(),
96 };
97 let response = self
98 .request(
99 "initialize",
100 Some(serde_json::to_value(request)?),
101 RequestOptions::default(),
102 )
103 .await?;
104 let response: InitializeResponse = serde_json::from_value(response)
105 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
106
107 if response.protocol_version != self.protocol_version.as_str() {
108 return Err(anyhow::anyhow!(
109 "Unsupported protocol version: {}",
110 response.protocol_version
111 ));
112 }
113
114 let mut writer = self.initialize_res.write().await;
116 *writer = Some(response.clone());
117
118 debug!(
119 "Initialized with protocol version: {}",
120 response.protocol_version
121 );
122 self.transport
123 .send_notification("notifications/initialized", None)
124 .await?;
125
126 Ok(response)
127 }
128
129 pub async fn assert_initialized(&self) -> Result<(), anyhow::Error> {
135 let reader = self.initialize_res.read().await;
136 match &*reader {
137 Some(_) => Ok(()),
138 None => Err(anyhow::anyhow!("Not initialized")),
139 }
140 }
141
142 pub async fn request(
154 &self,
155 method: &str,
156 params: Option<serde_json::Value>,
157 options: RequestOptions,
158 ) -> Result<serde_json::Value> {
159 let response = self.transport.request(method, params, options).await?;
160 response
161 .result
162 .ok_or_else(|| anyhow::anyhow!("Request failed: {:?}", response.error))
163 }
164
165 pub async fn list_tools(
176 &self,
177 cursor: Option<String>,
178 request_options: Option<RequestOptions>,
179 ) -> Result<ToolsListResponse> {
180 if self.strict {
181 self.assert_initialized().await?;
182 }
183
184 let list_request = ListRequest { cursor, meta: None };
185
186 let response = self
187 .request(
188 "tools/list",
189 Some(serde_json::to_value(list_request)?),
190 request_options.unwrap_or_else(RequestOptions::default),
191 )
192 .await?;
193
194 Ok(serde_json::from_value(response)
195 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
196 }
197
198 pub async fn call_tool(
209 &self,
210 name: &str,
211 arguements: Option<serde_json::Value>,
212 ) -> Result<CallToolResponse> {
213 if self.strict {
214 self.assert_initialized().await?;
215 }
216
217 let arguments = if let Some(env) = &self.env {
218 arguements
219 .as_ref()
220 .map(|args| apply_secure_replacements(args, env))
221 } else {
222 arguements
223 };
224
225 let arguments =
226 arguments.map(|value| serde_json::from_value(value).unwrap_or_else(|_| HashMap::new()));
227
228 let request = CallToolRequest {
229 name: name.to_string(),
230 arguments,
231 meta: None,
232 };
233
234 let response = self
235 .request(
236 "tools/call",
237 Some(serde_json::to_value(request)?),
238 RequestOptions::default(),
239 )
240 .await?;
241
242 Ok(serde_json::from_value(response)
243 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
244 }
245
246 pub async fn list_resources(
257 &self,
258 cursor: Option<String>,
259 request_options: Option<RequestOptions>,
260 ) -> Result<ResourcesListResponse> {
261 if self.strict {
262 self.assert_initialized().await?;
263 }
264
265 let list_request = ListRequest { cursor, meta: None };
266
267 let response = self
268 .request(
269 "resources/list",
270 Some(serde_json::to_value(list_request)?),
271 request_options.unwrap_or_else(RequestOptions::default),
272 )
273 .await?;
274
275 Ok(serde_json::from_value(response)
276 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
277 }
278
279 pub async fn read_resource(&self, uri: url::Url) -> Result<Resource> {
289 if self.strict {
290 self.assert_initialized().await?;
291 }
292
293 let read_request = ReadResourceRequest { uri };
294
295 let response = self
296 .request(
297 "resources/read",
298 Some(serde_json::to_value(read_request)?),
299 RequestOptions::default(),
300 )
301 .await?;
302
303 Ok(serde_json::from_value(response)
304 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
305 }
306
307 pub async fn subscribe_to_resource(&self, uri: url::Url) -> Result<()> {
308 if self.strict {
309 self.assert_initialized().await?;
310 }
311
312 let subscribe_request = ReadResourceRequest { uri };
313
314 self.request(
315 "resources/subscribe",
316 Some(serde_json::to_value(subscribe_request)?),
317 RequestOptions::default(),
318 )
319 .await?;
320
321 Ok(())
322 }
323
324 pub async fn unsubscribe_to_resource(&self, uri: url::Url) -> Result<()> {
325 if self.strict {
326 self.assert_initialized().await?;
327 }
328
329 let unsubscribe_request = ReadResourceRequest { uri };
330
331 self.request(
332 "resources/unsubscribe",
333 Some(serde_json::to_value(unsubscribe_request)?),
334 RequestOptions::default(),
335 )
336 .await?;
337
338 Ok(())
339 }
340}
341
342#[derive(Clone, Debug)]
346pub enum SecureValue {
347 Static(String),
349 Env(String),
351}
352
353pub struct ClientBuilder<T: Transport> {
358 transport: T,
359 strict: bool,
360 env: Option<HashMap<String, SecureValue>>,
361 protocol_version: ProtocolVersion,
362 client_info: Implementation,
363 capabilities: ClientCapabilities,
364}
365
366impl<T: Transport> ClientBuilder<T> {
367 pub fn new(transport: T) -> Self {
377 Self {
378 transport,
379 strict: false,
380 env: None,
381 protocol_version: LATEST_PROTOCOL_VERSION,
382 client_info: Implementation {
383 name: env::var("CARGO_PKG_NAME").unwrap_or_else(|_| "mcp-client".to_string()),
384 version: env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "0.1.0".to_string()),
385 },
386 capabilities: ClientCapabilities::default(),
387 }
388 }
389
390 pub fn set_protocol_version(mut self, protocol_version: ProtocolVersion) -> Self {
400 self.protocol_version = protocol_version;
401 self
402 }
403
404 pub fn set_client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
415 self.client_info = Implementation {
416 name: name.into(),
417 version: version.into(),
418 };
419 self
420 }
421
422 pub fn set_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
432 self.capabilities = capabilities;
433 self
434 }
435
436 pub fn with_secure_value(mut self, key: impl Into<String>, value: SecureValue) -> Self {
447 if self.env.is_none() {
448 self.env = Some(HashMap::new());
449 }
450
451 if let Some(env) = &mut self.env {
452 env.insert(key.into(), value);
453 }
454
455 self
456 }
457
458 pub fn use_strict(mut self) -> Self {
464 self.strict = true;
465 self
466 }
467
468 pub fn with_strict(mut self, strict: bool) -> Self {
478 self.strict = strict;
479 self
480 }
481
482 pub fn build(self) -> Client<T> {
488 Client {
489 transport: self.transport,
490 strict: self.strict,
491 env: self.env,
492 protocol_version: self.protocol_version,
493 initialize_res: Arc::new(RwLock::new(None)),
494 client_info: self.client_info,
495 capabilities: self.capabilities,
496 }
497 }
498}
499
500pub fn apply_secure_replacements(
503 value: &Value,
504 secure_values: &HashMap<String, SecureValue>,
505) -> Value {
506 match value {
507 Value::Object(map) => {
508 let mut new_map = serde_json::Map::new();
509 for (k, v) in map.iter() {
510 let new_value = if let Value::String(_) = v {
511 if let Some(secure_val) = secure_values.get(k) {
512 let replacement = match secure_val {
513 SecureValue::Static(val) => val.clone(),
514 SecureValue::Env(env_key) => env::var(env_key)
515 .unwrap_or_else(|_| v.as_str().unwrap().to_string()),
516 };
517 Value::String(replacement)
518 } else {
519 apply_secure_replacements(v, secure_values)
520 }
521 } else {
522 apply_secure_replacements(v, secure_values)
523 };
524 new_map.insert(k.clone(), new_value);
525 }
526 Value::Object(new_map)
527 }
528 Value::Array(arr) => {
529 let new_arr: Vec<Value> = arr
530 .iter()
531 .map(|v| apply_secure_replacements(v, secure_values))
532 .collect();
533 Value::Array(new_arr)
534 }
535 _ => value.clone(),
536 }
537}