1use std::{collections::HashMap, env, sync::Arc};
2
3use crate::{
4 protocol::{Protocol, ProtocolBuilder, RequestOptions},
5 transport::Transport,
6 types::{
7 CallToolRequest, CallToolResponse, ClientCapabilities, Implementation, InitializeRequest,
8 InitializeResponse, ListRequest, ToolsListResponse, LATEST_PROTOCOL_VERSION,
9 },
10};
11
12use anyhow::Result;
13use serde_json::Value;
14use tokio::sync::RwLock;
15use tracing::debug;
16
17#[derive(Clone)]
18pub struct Client<T: Transport> {
19 protocol: Protocol<T>,
20 strict: bool,
21 initialize_res: Arc<RwLock<Option<InitializeResponse>>>,
22 env: Option<HashMap<String, SecureValue>>,
23}
24
25impl<T: Transport> Client<T> {
26 pub fn builder(transport: T) -> ClientBuilder<T> {
27 ClientBuilder::new(transport)
28 }
29
30 pub async fn initialize(&self, client_info: Implementation) -> Result<InitializeResponse> {
31 let request = InitializeRequest {
32 protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
33 capabilities: ClientCapabilities::default(),
34 client_info,
35 };
36 let response = self
37 .request(
38 "initialize",
39 Some(serde_json::to_value(request)?),
40 RequestOptions::default(),
41 )
42 .await?;
43 let response: InitializeResponse = serde_json::from_value(response)
44 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
45
46 if response.protocol_version != LATEST_PROTOCOL_VERSION {
47 return Err(anyhow::anyhow!(
48 "Unsupported protocol version: {}",
49 response.protocol_version
50 ));
51 }
52
53 let mut writer = self.initialize_res.write().await;
55 *writer = Some(response.clone());
56
57 debug!(
58 "Initialized with protocol version: {}",
59 response.protocol_version
60 );
61 self.protocol
62 .notify("notifications/initialized", None)
63 .await?;
64
65 Ok(response)
66 }
67
68 pub async fn assert_initialized(&self) -> Result<(), anyhow::Error> {
69 let reader = self.initialize_res.read().await;
70 match &*reader {
71 Some(_) => Ok(()),
72 None => Err(anyhow::anyhow!("Not initialized")),
73 }
74 }
75
76 pub async fn request(
77 &self,
78 method: &str,
79 params: Option<serde_json::Value>,
80 options: RequestOptions,
81 ) -> Result<serde_json::Value> {
82 let response = self.protocol.request(method, params, options).await?;
83 response
84 .result
85 .ok_or_else(|| anyhow::anyhow!("Request failed: {:?}", response.error))
86 }
87
88 pub async fn list_tools(
89 &self,
90 cursor: Option<String>,
91 request_options: Option<RequestOptions>,
92 ) -> Result<ToolsListResponse> {
93 if self.strict {
94 self.assert_initialized().await?;
95 }
96
97 let list_request = ListRequest { cursor, meta: None };
98
99 let response = self
100 .request(
101 "tools/list",
102 Some(serde_json::to_value(list_request)?),
103 request_options.unwrap_or_else(RequestOptions::default),
104 )
105 .await?;
106
107 Ok(serde_json::from_value(response)
108 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
109 }
110
111 pub async fn call_tool(
112 &self,
113 name: &str,
114 arguements: Option<serde_json::Value>,
115 ) -> Result<CallToolResponse> {
116 if self.strict {
117 self.assert_initialized().await?;
118 }
119
120 let arguments = if let Some(env) = &self.env {
121 arguements
122 .as_ref()
123 .map(|args| apply_secure_replacements(args, env))
124 } else {
125 arguements
126 };
127
128 let arguments =
129 arguments.map(|value| serde_json::from_value(value).unwrap_or_else(|_| HashMap::new()));
130
131 let request = CallToolRequest {
132 name: name.to_string(),
133 arguments,
134 meta: None,
135 };
136
137 let response = self
138 .request(
139 "tools/call",
140 Some(serde_json::to_value(request)?),
141 RequestOptions::default(),
142 )
143 .await?;
144
145 Ok(serde_json::from_value(response)
146 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?)
147 }
148
149 pub(crate) async fn list_resources() {
150 todo!()
151 }
152
153 pub(crate) async fn read_resource() {
154 todo!()
155 }
156
157 pub(crate) async fn subscribe_to_resource() {
158 todo!()
159 }
160
161 pub async fn start(&self) -> Result<()> {
162 self.protocol.listen().await
163 }
164}
165
166#[derive(Clone)]
167pub enum SecureValue {
168 Static(String),
169 Env(String),
170}
171
172pub struct ClientBuilder<T: Transport> {
173 protocol: ProtocolBuilder<T>,
174 strict: bool,
175 env: Option<HashMap<String, SecureValue>>,
176}
177
178impl<T: Transport> ClientBuilder<T> {
179 pub fn new(transport: T) -> Self {
180 Self {
181 protocol: ProtocolBuilder::new(transport),
182 strict: false,
183 env: None,
184 }
185 }
186
187 pub fn with_secure_value(mut self, key: impl Into<String>, value: SecureValue) -> Self {
188 match &mut self.env {
189 Some(env) => {
190 env.insert(key.into(), value);
191 }
192 None => {
193 let mut new_env = HashMap::new();
194 new_env.insert(key.into(), value);
195 self.env = Some(new_env);
196 }
197 }
198 self
199 }
200
201 pub fn use_strict(mut self) -> Self {
202 self.strict = true;
203 self
204 }
205
206 pub fn with_strict(mut self, strict: bool) -> Self {
207 self.strict = strict;
208 self
209 }
210
211 pub fn build(self) -> Client<T> {
212 Client {
213 protocol: self.protocol.build(),
214 strict: self.strict,
215 env: self.env,
216 initialize_res: Arc::new(RwLock::new(None)),
217 }
218 }
219}
220
221pub fn apply_secure_replacements(
224 value: &Value,
225 secure_values: &HashMap<String, SecureValue>,
226) -> Value {
227 match value {
228 Value::Object(map) => {
229 let mut new_map = serde_json::Map::new();
230 for (k, v) in map.iter() {
231 let new_value = if let Value::String(_) = v {
232 if let Some(secure_val) = secure_values.get(k) {
233 let replacement = match secure_val {
234 SecureValue::Static(val) => val.clone(),
235 SecureValue::Env(env_key) => env::var(env_key)
236 .unwrap_or_else(|_| v.as_str().unwrap().to_string()),
237 };
238 Value::String(replacement)
239 } else {
240 apply_secure_replacements(v, secure_values)
241 }
242 } else {
243 apply_secure_replacements(v, secure_values)
244 };
245 new_map.insert(k.clone(), new_value);
246 }
247 Value::Object(new_map)
248 }
249 Value::Array(arr) => {
250 let new_arr: Vec<Value> = arr
251 .iter()
252 .map(|v| apply_secure_replacements(v, secure_values))
253 .collect();
254 Value::Array(new_arr)
255 }
256 _ => value.clone(),
257 }
258}