1use std::{collections::HashMap, env, sync::Arc};
2
3use crate::{
4 protocol::RequestOptions,
5 transport::Transport,
6 types::{
7 CallToolRequest, CallToolResponse, ClientCapabilities, Implementation, InitializeRequest,
8 InitializeResponse, ListRequest, ReadResourceRequest, Resource, ResourcesListResponse,
9 ToolsListResponse, LATEST_PROTOCOL_VERSION,
10 },
11};
12
13use anyhow::Result;
14use serde_json::Value;
15use tokio::sync::RwLock;
16use tracing::debug;
17
18#[derive(Clone)]
19pub struct Client<T: Transport> {
20 transport: T,
21 strict: bool,
22 initialize_res: Arc<RwLock<Option<InitializeResponse>>>,
23 env: Option<HashMap<String, SecureValue>>,
24}
25
26impl<T: Transport> Client<T> {
27 pub fn builder(transport: T) -> ClientBuilder<T> {
28 ClientBuilder::new(transport)
29 }
30
31 pub async fn open(&self) -> Result<()> {
32 self.transport.open().await
33 }
34
35 pub async fn initialize(
36 &self,
37 client_info: Implementation,
38 capabilities: ClientCapabilities,
39 ) -> Result<InitializeResponse> {
40 let request = InitializeRequest {
41 protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
42 capabilities,
43 client_info,
44 };
45 let response = self
46 .request(
47 "initialize",
48 Some(serde_json::to_value(request)?),
49 RequestOptions::default(),
50 )
51 .await?;
52 let response: InitializeResponse = serde_json::from_value(response)
53 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
54
55 if response.protocol_version != LATEST_PROTOCOL_VERSION {
56 return Err(anyhow::anyhow!(
57 "Unsupported protocol version: {}",
58 response.protocol_version
59 ));
60 }
61
62 let mut writer = self.initialize_res.write().await;
64 *writer = Some(response.clone());
65
66 debug!(
67 "Initialized with protocol version: {}",
68 response.protocol_version
69 );
70 self.transport
71 .send_notification("notifications/initialized", None)
72 .await?;
73
74 Ok(response)
75 }
76
77 pub async fn assert_initialized(&self) -> Result<(), anyhow::Error> {
78 let reader = self.initialize_res.read().await;
79 match &*reader {
80 Some(_) => Ok(()),
81 None => Err(anyhow::anyhow!("Not initialized")),
82 }
83 }
84
85 pub async fn request(
86 &self,
87 method: &str,
88 params: Option<serde_json::Value>,
89 options: RequestOptions,
90 ) -> Result<serde_json::Value> {
91 let response = self.transport.request(method, params, options).await?;
92 response
93 .result
94 .ok_or_else(|| anyhow::anyhow!("Request failed: {:?}", response.error))
95 }
96
97 pub async fn list_tools(
98 &self,
99 cursor: Option<String>,
100 request_options: Option<RequestOptions>,
101 ) -> Result<ToolsListResponse> {
102 if self.strict {
103 self.assert_initialized().await?;
104 }
105
106 let list_request = ListRequest { cursor, meta: None };
107
108 let response = self
109 .request(
110 "tools/list",
111 Some(serde_json::to_value(list_request)?),
112 request_options.unwrap_or_else(RequestOptions::default),
113 )
114 .await?;
115
116 Ok(serde_json::from_value(response)
117 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
118 }
119
120 pub async fn call_tool(
121 &self,
122 name: &str,
123 arguements: Option<serde_json::Value>,
124 ) -> Result<CallToolResponse> {
125 if self.strict {
126 self.assert_initialized().await?;
127 }
128
129 let arguments = if let Some(env) = &self.env {
130 arguements
131 .as_ref()
132 .map(|args| apply_secure_replacements(args, env))
133 } else {
134 arguements
135 };
136
137 let arguments =
138 arguments.map(|value| serde_json::from_value(value).unwrap_or_else(|_| HashMap::new()));
139
140 let request = CallToolRequest {
141 name: name.to_string(),
142 arguments,
143 meta: None,
144 };
145
146 let response = self
147 .request(
148 "tools/call",
149 Some(serde_json::to_value(request)?),
150 RequestOptions::default(),
151 )
152 .await?;
153
154 Ok(serde_json::from_value(response)
155 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
156 }
157
158 pub async fn list_resources(
159 &self,
160 cursor: Option<String>,
161 request_options: Option<RequestOptions>,
162 ) -> Result<ResourcesListResponse> {
163 if self.strict {
164 self.assert_initialized().await?;
165 }
166
167 let list_request = ListRequest { cursor, meta: None };
168
169 let response = self
170 .request(
171 "resources/list",
172 Some(serde_json::to_value(list_request)?),
173 request_options.unwrap_or_else(RequestOptions::default),
174 )
175 .await?;
176
177 Ok(serde_json::from_value(response)
178 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
179 }
180
181 pub async fn read_resource(&self, uri: url::Url) -> Result<Resource> {
182 if self.strict {
183 self.assert_initialized().await?;
184 }
185
186 let read_request = ReadResourceRequest { uri };
187
188 let response = self
189 .request(
190 "resources/read",
191 Some(serde_json::to_value(read_request)?),
192 RequestOptions::default(),
193 )
194 .await?;
195
196 Ok(serde_json::from_value(response)
197 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
198 }
199
200 pub async fn subscribe_to_resource(&self, uri: url::Url) -> Result<()> {
201 if self.strict {
202 self.assert_initialized().await?;
203 }
204
205 let subscribe_request = ReadResourceRequest { uri };
206
207 self.request(
208 "resources/subscribe",
209 Some(serde_json::to_value(subscribe_request)?),
210 RequestOptions::default(),
211 )
212 .await?;
213
214 Ok(())
215 }
216
217 pub async fn unsubscribe_to_resource(&self, uri: url::Url) -> Result<()> {
218 if self.strict {
219 self.assert_initialized().await?;
220 }
221
222 let unsubscribe_request = ReadResourceRequest { uri };
223
224 self.request(
225 "resources/unsubscribe",
226 Some(serde_json::to_value(unsubscribe_request)?),
227 RequestOptions::default(),
228 )
229 .await?;
230
231 Ok(())
232 }
233}
234
235#[derive(Clone)]
236pub enum SecureValue {
237 Static(String),
238 Env(String),
239}
240
241pub struct ClientBuilder<T: Transport> {
242 transport: T,
243 strict: bool,
244 env: Option<HashMap<String, SecureValue>>,
245}
246
247impl<T: Transport> ClientBuilder<T> {
248 pub fn new(transport: T) -> Self {
249 Self {
250 transport,
251 strict: false,
252 env: None,
253 }
254 }
255
256 pub fn with_secure_value(mut self, key: impl Into<String>, value: SecureValue) -> Self {
257 match &mut self.env {
258 Some(env) => {
259 env.insert(key.into(), value);
260 }
261 None => {
262 let mut new_env = HashMap::new();
263 new_env.insert(key.into(), value);
264 self.env = Some(new_env);
265 }
266 }
267 self
268 }
269
270 pub fn use_strict(mut self) -> Self {
271 self.strict = true;
272 self
273 }
274
275 pub fn with_strict(mut self, strict: bool) -> Self {
276 self.strict = strict;
277 self
278 }
279
280 pub fn build(self) -> Client<T> {
281 Client {
282 transport: self.transport,
283 strict: self.strict,
284 env: self.env,
285 initialize_res: Arc::new(RwLock::new(None)),
286 }
287 }
288}
289
290pub fn apply_secure_replacements(
293 value: &Value,
294 secure_values: &HashMap<String, SecureValue>,
295) -> Value {
296 match value {
297 Value::Object(map) => {
298 let mut new_map = serde_json::Map::new();
299 for (k, v) in map.iter() {
300 let new_value = if let Value::String(_) = v {
301 if let Some(secure_val) = secure_values.get(k) {
302 let replacement = match secure_val {
303 SecureValue::Static(val) => val.clone(),
304 SecureValue::Env(env_key) => env::var(env_key)
305 .unwrap_or_else(|_| v.as_str().unwrap().to_string()),
306 };
307 Value::String(replacement)
308 } else {
309 apply_secure_replacements(v, secure_values)
310 }
311 } else {
312 apply_secure_replacements(v, secure_values)
313 };
314 new_map.insert(k.clone(), new_value);
315 }
316 Value::Object(new_map)
317 }
318 Value::Array(arr) => {
319 let new_arr: Vec<Value> = arr
320 .iter()
321 .map(|v| apply_secure_replacements(v, secure_values))
322 .collect();
323 Value::Array(new_arr)
324 }
325 _ => value.clone(),
326 }
327}