1use std::{collections::HashMap, env, sync::Arc};
2
3use crate::{
4 protocol::{Protocol, ProtocolBuilder, RequestOptions},
5 transport::Transport,
6 types::{
7 CallToolRequest, CallToolResponse, ClientCapabilities, Implementation, InitializeRequest,
8 InitializeResponse, ListRequest, ReadResourceRequest, Resource, ResourcesListResponse,
9 ToolsListResponse, LATEST_PROTOCOL_VERSION,
10 },
11};
12
13use anyhow::Result;
14use serde_json::Value;
15use tokio::sync::RwLock;
16use tracing::debug;
17
18#[derive(Clone)]
19pub struct Client<T: Transport> {
20 protocol: Protocol<T>,
21 strict: bool,
22 initialize_res: Arc<RwLock<Option<InitializeResponse>>>,
23 env: Option<HashMap<String, SecureValue>>,
24}
25
26impl<T: Transport> Client<T> {
27 pub fn builder(transport: T) -> ClientBuilder<T> {
28 ClientBuilder::new(transport)
29 }
30
31 pub async fn initialize(&self, client_info: Implementation) -> Result<InitializeResponse> {
32 let request = InitializeRequest {
33 protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
34 capabilities: ClientCapabilities::default(),
35 client_info,
36 };
37 let response = self
38 .request(
39 "initialize",
40 Some(serde_json::to_value(request)?),
41 RequestOptions::default(),
42 )
43 .await?;
44 let response: InitializeResponse = serde_json::from_value(response)
45 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
46
47 if response.protocol_version != LATEST_PROTOCOL_VERSION {
48 return Err(anyhow::anyhow!(
49 "Unsupported protocol version: {}",
50 response.protocol_version
51 ));
52 }
53
54 let mut writer = self.initialize_res.write().await;
56 *writer = Some(response.clone());
57
58 debug!(
59 "Initialized with protocol version: {}",
60 response.protocol_version
61 );
62 self.protocol
63 .notify("notifications/initialized", None)
64 .await?;
65
66 Ok(response)
67 }
68
69 pub async fn assert_initialized(&self) -> Result<(), anyhow::Error> {
70 let reader = self.initialize_res.read().await;
71 match &*reader {
72 Some(_) => Ok(()),
73 None => Err(anyhow::anyhow!("Not initialized")),
74 }
75 }
76
77 pub async fn request(
78 &self,
79 method: &str,
80 params: Option<serde_json::Value>,
81 options: RequestOptions,
82 ) -> Result<serde_json::Value> {
83 let response = self.protocol.request(method, params, options).await?;
84 response
85 .result
86 .ok_or_else(|| anyhow::anyhow!("Request failed: {:?}", response.error))
87 }
88
89 pub async fn list_tools(
90 &self,
91 cursor: Option<String>,
92 request_options: Option<RequestOptions>,
93 ) -> Result<ToolsListResponse> {
94 if self.strict {
95 self.assert_initialized().await?;
96 }
97
98 let list_request = ListRequest { cursor, meta: None };
99
100 let response = self
101 .request(
102 "tools/list",
103 Some(serde_json::to_value(list_request)?),
104 request_options.unwrap_or_else(RequestOptions::default),
105 )
106 .await?;
107
108 Ok(serde_json::from_value(response)
109 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
110 }
111
112 pub async fn call_tool(
113 &self,
114 name: &str,
115 arguements: Option<serde_json::Value>,
116 ) -> Result<CallToolResponse> {
117 if self.strict {
118 self.assert_initialized().await?;
119 }
120
121 let arguments = if let Some(env) = &self.env {
122 arguements
123 .as_ref()
124 .map(|args| apply_secure_replacements(args, env))
125 } else {
126 arguements
127 };
128
129 let arguments =
130 arguments.map(|value| serde_json::from_value(value).unwrap_or_else(|_| HashMap::new()));
131
132 let request = CallToolRequest {
133 name: name.to_string(),
134 arguments,
135 meta: None,
136 };
137
138 let response = self
139 .request(
140 "tools/call",
141 Some(serde_json::to_value(request)?),
142 RequestOptions::default(),
143 )
144 .await?;
145
146 Ok(serde_json::from_value(response)
147 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
148 }
149
150 pub async fn list_resources(
151 &self,
152 cursor: Option<String>,
153 request_options: Option<RequestOptions>,
154 ) -> Result<ResourcesListResponse> {
155 if self.strict {
156 self.assert_initialized().await?;
157 }
158
159 let list_request = ListRequest { cursor, meta: None };
160
161 let response = self
162 .request(
163 "resources/list",
164 Some(serde_json::to_value(list_request)?),
165 request_options.unwrap_or_else(RequestOptions::default),
166 )
167 .await?;
168
169 Ok(serde_json::from_value(response)
170 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
171 }
172
173 pub async fn read_resource(&self, uri: url::Url) -> Result<Resource> {
174 if self.strict {
175 self.assert_initialized().await?;
176 }
177
178 let read_request = ReadResourceRequest { uri };
179
180 let response = self
181 .request(
182 "resources/read",
183 Some(serde_json::to_value(read_request)?),
184 RequestOptions::default(),
185 )
186 .await?;
187
188 Ok(serde_json::from_value(response)
189 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
190 }
191
192 pub async fn subscribe_to_resource(&self, uri: url::Url) -> Result<()> {
193 if self.strict {
194 self.assert_initialized().await?;
195 }
196
197 let subscribe_request = ReadResourceRequest { uri };
198
199 self.request(
200 "resources/subscribe",
201 Some(serde_json::to_value(subscribe_request)?),
202 RequestOptions::default(),
203 )
204 .await?;
205
206 Ok(())
207 }
208
209 pub async fn unsubscribe_to_resource(&self, uri: url::Url) -> Result<()> {
210 if self.strict {
211 self.assert_initialized().await?;
212 }
213
214 let unsubscribe_request = ReadResourceRequest { uri };
215
216 self.request(
217 "resources/unsubscribe",
218 Some(serde_json::to_value(unsubscribe_request)?),
219 RequestOptions::default(),
220 )
221 .await?;
222
223 Ok(())
224 }
225
226 pub async fn start(&self) -> Result<()> {
227 self.protocol.listen().await
228 }
229}
230
231#[derive(Clone)]
232pub enum SecureValue {
233 Static(String),
234 Env(String),
235}
236
237pub struct ClientBuilder<T: Transport> {
238 protocol: ProtocolBuilder<T>,
239 strict: bool,
240 env: Option<HashMap<String, SecureValue>>,
241}
242
243impl<T: Transport> ClientBuilder<T> {
244 pub fn new(transport: T) -> Self {
245 Self {
246 protocol: ProtocolBuilder::new(transport),
247 strict: false,
248 env: None,
249 }
250 }
251
252 pub fn with_secure_value(mut self, key: impl Into<String>, value: SecureValue) -> Self {
253 match &mut self.env {
254 Some(env) => {
255 env.insert(key.into(), value);
256 }
257 None => {
258 let mut new_env = HashMap::new();
259 new_env.insert(key.into(), value);
260 self.env = Some(new_env);
261 }
262 }
263 self
264 }
265
266 pub fn use_strict(mut self) -> Self {
267 self.strict = true;
268 self
269 }
270
271 pub fn with_strict(mut self, strict: bool) -> Self {
272 self.strict = strict;
273 self
274 }
275
276 pub fn build(self) -> Client<T> {
277 Client {
278 protocol: self.protocol.build(),
279 strict: self.strict,
280 env: self.env,
281 initialize_res: Arc::new(RwLock::new(None)),
282 }
283 }
284}
285
286pub fn apply_secure_replacements(
289 value: &Value,
290 secure_values: &HashMap<String, SecureValue>,
291) -> Value {
292 match value {
293 Value::Object(map) => {
294 let mut new_map = serde_json::Map::new();
295 for (k, v) in map.iter() {
296 let new_value = if let Value::String(_) = v {
297 if let Some(secure_val) = secure_values.get(k) {
298 let replacement = match secure_val {
299 SecureValue::Static(val) => val.clone(),
300 SecureValue::Env(env_key) => env::var(env_key)
301 .unwrap_or_else(|_| v.as_str().unwrap().to_string()),
302 };
303 Value::String(replacement)
304 } else {
305 apply_secure_replacements(v, secure_values)
306 }
307 } else {
308 apply_secure_replacements(v, secure_values)
309 };
310 new_map.insert(k.clone(), new_value);
311 }
312 Value::Object(new_map)
313 }
314 Value::Array(arr) => {
315 let new_arr: Vec<Value> = arr
316 .iter()
317 .map(|v| apply_secure_replacements(v, secure_values))
318 .collect();
319 Value::Array(new_arr)
320 }
321 _ => value.clone(),
322 }
323}