1use crate::events::{McpEvent, McpEventHandler};
4use crate::protocol::{ToolError, ToolMetadata, ToolProtocol, ToolResult};
5use async_trait::async_trait;
6use serde_json::Value as JsonValue;
7use std::error::Error;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11pub struct McpClientProtocol {
13 endpoint: String,
14 client: reqwest::Client,
15 tools_cache: Arc<RwLock<Option<Vec<ToolMetadata>>>>,
16 cache_ttl_secs: u64,
17 last_cache_refresh: Arc<RwLock<Option<std::time::Instant>>>,
18 event_handler: Option<Arc<dyn McpEventHandler>>,
19}
20
21impl McpClientProtocol {
22 pub fn new(endpoint: String) -> Self {
24 Self {
25 endpoint,
26 client: reqwest::Client::builder()
27 .timeout(std::time::Duration::from_secs(30))
28 .build()
29 .expect("Failed to build HTTP client"),
30 tools_cache: Arc::new(RwLock::new(None)),
31 cache_ttl_secs: 300,
32 last_cache_refresh: Arc::new(RwLock::new(None)),
33 event_handler: None,
34 }
35 }
36
37 pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
39 self.client = reqwest::Client::builder()
40 .timeout(std::time::Duration::from_secs(timeout_secs))
41 .build()
42 .expect("Failed to build HTTP client");
43 self
44 }
45
46 pub fn with_cache_ttl(mut self, ttl_secs: u64) -> Self {
48 self.cache_ttl_secs = ttl_secs;
49 self
50 }
51
52 pub fn with_event_handler(mut self, handler: Arc<dyn McpEventHandler>) -> Self {
54 self.event_handler = Some(handler);
55 self
56 }
57
58 async fn should_refresh_cache(&self) -> bool {
59 let last_refresh = self.last_cache_refresh.read().await;
60 match *last_refresh {
61 None => true,
62 Some(instant) => instant.elapsed().as_secs() > self.cache_ttl_secs,
63 }
64 }
65
66 async fn refresh_cache(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
67 let response = self
68 .client
69 .post(format!("{}/tools/list", self.endpoint))
70 .json(&serde_json::json!({}))
71 .send()
72 .await?;
73
74 if !response.status().is_success() {
75 return Err(Box::new(ToolError::ProtocolError(format!(
76 "MCP server returned status: {}",
77 response.status()
78 ))));
79 }
80
81 let body: serde_json::Value = response.json().await?;
82 let tools: Vec<ToolMetadata> =
83 if let Some(arr) = body.get("tools").and_then(|v| v.as_array()) {
84 serde_json::from_value(serde_json::Value::Array(arr.clone())).map_err(|e| {
85 Box::new(ToolError::ProtocolError(format!(
86 "Failed to deserialize tool list from MCP server: {}",
87 e
88 ))) as Box<dyn Error + Send + Sync>
89 })?
90 } else {
91 serde_json::from_value(body).map_err(|e| {
92 Box::new(ToolError::ProtocolError(format!(
93 "Failed to deserialize tool list from MCP server: {}",
94 e
95 ))) as Box<dyn Error + Send + Sync>
96 })?
97 };
98
99 let tool_count = tools.len();
100 let tool_names: Vec<String> = tools.iter().map(|t| t.name.clone()).collect();
101
102 *self.tools_cache.write().await = Some(tools);
103 *self.last_cache_refresh.write().await = Some(std::time::Instant::now());
104
105 if let Some(ref eh) = self.event_handler {
106 eh.on_mcp_event(&McpEvent::ToolsDiscovered {
107 endpoint: self.endpoint.clone(),
108 tool_count,
109 tool_names,
110 })
111 .await;
112 }
113
114 Ok(())
115 }
116}
117
118#[async_trait]
119impl ToolProtocol for McpClientProtocol {
120 async fn execute(
121 &self,
122 tool_name: &str,
123 parameters: JsonValue,
124 ) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
125 if let Some(ref eh) = self.event_handler {
126 eh.on_mcp_event(&McpEvent::RemoteToolCallStarted {
127 endpoint: self.endpoint.clone(),
128 tool_name: tool_name.to_string(),
129 parameters: parameters.clone(),
130 })
131 .await;
132 }
133
134 let call_start = std::time::Instant::now();
135
136 let response = self
137 .client
138 .post(format!("{}/tools/execute", self.endpoint))
139 .json(&serde_json::json!({
140 "tool": tool_name,
141 "parameters": parameters
142 }))
143 .send()
144 .await;
145
146 match response {
147 Err(e) => {
148 let duration_ms = call_start.elapsed().as_millis() as u64;
149 if let Some(ref eh) = self.event_handler {
150 eh.on_mcp_event(&McpEvent::ToolError {
151 source: self.endpoint.clone(),
152 tool_name: tool_name.to_string(),
153 error: e.to_string(),
154 duration_ms,
155 })
156 .await;
157 }
158 Err(Box::new(e))
159 }
160 Ok(resp) => {
161 if !resp.status().is_success() {
162 let duration_ms = call_start.elapsed().as_millis() as u64;
163 let err_msg = format!("MCP server returned status: {}", resp.status());
164 if let Some(ref eh) = self.event_handler {
165 eh.on_mcp_event(&McpEvent::ToolError {
166 source: self.endpoint.clone(),
167 tool_name: tool_name.to_string(),
168 error: err_msg.clone(),
169 duration_ms,
170 })
171 .await;
172 }
173 return Err(Box::new(ToolError::ExecutionFailed(err_msg)));
174 }
175
176 let body: serde_json::Value = resp.json().await?;
177 let result: ToolResult = if let Some(r) = body.get("result") {
178 serde_json::from_value(r.clone()).map_err(|e| {
179 Box::new(ToolError::ProtocolError(format!(
180 "Failed to deserialize tool result from MCP server: {}",
181 e
182 ))) as Box<dyn Error + Send + Sync>
183 })?
184 } else {
185 serde_json::from_value(body).map_err(|e| {
186 Box::new(ToolError::ProtocolError(format!(
187 "Failed to deserialize tool result from MCP server: {}",
188 e
189 ))) as Box<dyn Error + Send + Sync>
190 })?
191 };
192
193 let duration_ms = call_start.elapsed().as_millis() as u64;
194 if let Some(ref eh) = self.event_handler {
195 eh.on_mcp_event(&McpEvent::RemoteToolCallCompleted {
196 endpoint: self.endpoint.clone(),
197 tool_name: tool_name.to_string(),
198 success: result.success,
199 error: result.error.clone(),
200 duration_ms,
201 })
202 .await;
203 }
204
205 Ok(result)
206 }
207 }
208 }
209
210 async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
211 if self.should_refresh_cache().await {
212 let had_cache = self.last_cache_refresh.read().await.is_some();
213 if had_cache {
214 if let Some(ref eh) = self.event_handler {
215 eh.on_mcp_event(&McpEvent::CacheExpired {
216 endpoint: self.endpoint.clone(),
217 })
218 .await;
219 }
220 }
221 self.refresh_cache().await?;
222 } else if let Some(ref eh) = self.event_handler {
223 let cache = self.tools_cache.read().await;
224 let count = cache.as_ref().map_or(0, |t| t.len());
225 drop(cache);
226 eh.on_mcp_event(&McpEvent::CacheHit {
227 endpoint: self.endpoint.clone(),
228 tool_count: count,
229 })
230 .await;
231 }
232
233 let cache = self.tools_cache.read().await;
234 cache.as_ref().cloned().ok_or_else(|| {
235 Box::new(ToolError::ProtocolError(
236 "Tools cache not initialized".to_string(),
237 )) as Box<dyn Error + Send + Sync>
238 })
239 }
240
241 async fn get_tool_metadata(
242 &self,
243 tool_name: &str,
244 ) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
245 let tools = self.list_tools().await?;
246 tools
247 .into_iter()
248 .find(|t| t.name == tool_name)
249 .ok_or_else(|| {
250 Box::new(ToolError::NotFound(tool_name.to_string())) as Box<dyn Error + Send + Sync>
251 })
252 }
253
254 fn protocol_name(&self) -> &str {
255 "mcp"
256 }
257
258 async fn initialize(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
259 self.refresh_cache().await?;
260
261 let tool_count = {
262 let cache = self.tools_cache.read().await;
263 cache.as_ref().map_or(0, |t| t.len())
264 };
265
266 if let Some(ref eh) = self.event_handler {
267 eh.on_mcp_event(&McpEvent::ConnectionInitialized {
268 endpoint: self.endpoint.clone(),
269 tool_count,
270 })
271 .await;
272 }
273
274 Ok(())
275 }
276
277 async fn shutdown(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
278 if let Some(ref eh) = self.event_handler {
279 eh.on_mcp_event(&McpEvent::ConnectionClosed {
280 endpoint: self.endpoint.clone(),
281 })
282 .await;
283 }
284 *self.tools_cache.write().await = None;
285 *self.last_cache_refresh.write().await = None;
286 Ok(())
287 }
288}