Skip to main content

mcp/
client.rs

1//! HTTP client protocol for talking to MCP-compatible legacy tool servers.
2
3use crate::events::{McpEvent, McpEventHandler};
4use crate::protocol::{ToolError, ToolMetadata, ToolProtocol, ToolResult};
5use async_trait::async_trait;
6use serde_json::Value as JsonValue;
7use std::error::Error;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11/// HTTP client adapter for MCP-compatible tool servers.
12pub struct McpClientProtocol {
13    endpoint: String,
14    client: reqwest::Client,
15    tools_cache: Arc<RwLock<Option<Vec<ToolMetadata>>>>,
16    cache_ttl_secs: u64,
17    last_cache_refresh: Arc<RwLock<Option<std::time::Instant>>>,
18    event_handler: Option<Arc<dyn McpEventHandler>>,
19}
20
21impl McpClientProtocol {
22    /// Create a new client for a remote endpoint.
23    pub fn new(endpoint: String) -> Self {
24        Self {
25            endpoint,
26            client: reqwest::Client::builder()
27                .timeout(std::time::Duration::from_secs(30))
28                .build()
29                .expect("Failed to build HTTP client"),
30            tools_cache: Arc::new(RwLock::new(None)),
31            cache_ttl_secs: 300,
32            last_cache_refresh: Arc::new(RwLock::new(None)),
33            event_handler: None,
34        }
35    }
36
37    /// Override the request timeout.
38    pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
39        self.client = reqwest::Client::builder()
40            .timeout(std::time::Duration::from_secs(timeout_secs))
41            .build()
42            .expect("Failed to build HTTP client");
43        self
44    }
45
46    /// Override the metadata cache TTL.
47    pub fn with_cache_ttl(mut self, ttl_secs: u64) -> Self {
48        self.cache_ttl_secs = ttl_secs;
49        self
50    }
51
52    /// Attach an event handler.
53    pub fn with_event_handler(mut self, handler: Arc<dyn McpEventHandler>) -> Self {
54        self.event_handler = Some(handler);
55        self
56    }
57
58    async fn should_refresh_cache(&self) -> bool {
59        let last_refresh = self.last_cache_refresh.read().await;
60        match *last_refresh {
61            None => true,
62            Some(instant) => instant.elapsed().as_secs() > self.cache_ttl_secs,
63        }
64    }
65
66    async fn refresh_cache(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
67        let response = self
68            .client
69            .post(format!("{}/tools/list", self.endpoint))
70            .json(&serde_json::json!({}))
71            .send()
72            .await?;
73
74        if !response.status().is_success() {
75            return Err(Box::new(ToolError::ProtocolError(format!(
76                "MCP server returned status: {}",
77                response.status()
78            ))));
79        }
80
81        let body: serde_json::Value = response.json().await?;
82        let tools: Vec<ToolMetadata> =
83            if let Some(arr) = body.get("tools").and_then(|v| v.as_array()) {
84                serde_json::from_value(serde_json::Value::Array(arr.clone())).map_err(|e| {
85                    Box::new(ToolError::ProtocolError(format!(
86                        "Failed to deserialize tool list from MCP server: {}",
87                        e
88                    ))) as Box<dyn Error + Send + Sync>
89                })?
90            } else {
91                serde_json::from_value(body).map_err(|e| {
92                    Box::new(ToolError::ProtocolError(format!(
93                        "Failed to deserialize tool list from MCP server: {}",
94                        e
95                    ))) as Box<dyn Error + Send + Sync>
96                })?
97            };
98
99        let tool_count = tools.len();
100        let tool_names: Vec<String> = tools.iter().map(|t| t.name.clone()).collect();
101
102        *self.tools_cache.write().await = Some(tools);
103        *self.last_cache_refresh.write().await = Some(std::time::Instant::now());
104
105        if let Some(ref eh) = self.event_handler {
106            eh.on_mcp_event(&McpEvent::ToolsDiscovered {
107                endpoint: self.endpoint.clone(),
108                tool_count,
109                tool_names,
110            })
111            .await;
112        }
113
114        Ok(())
115    }
116}
117
118#[async_trait]
119impl ToolProtocol for McpClientProtocol {
120    async fn execute(
121        &self,
122        tool_name: &str,
123        parameters: JsonValue,
124    ) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
125        if let Some(ref eh) = self.event_handler {
126            eh.on_mcp_event(&McpEvent::RemoteToolCallStarted {
127                endpoint: self.endpoint.clone(),
128                tool_name: tool_name.to_string(),
129                parameters: parameters.clone(),
130            })
131            .await;
132        }
133
134        let call_start = std::time::Instant::now();
135
136        let response = self
137            .client
138            .post(format!("{}/tools/execute", self.endpoint))
139            .json(&serde_json::json!({
140                "tool": tool_name,
141                "parameters": parameters
142            }))
143            .send()
144            .await;
145
146        match response {
147            Err(e) => {
148                let duration_ms = call_start.elapsed().as_millis() as u64;
149                if let Some(ref eh) = self.event_handler {
150                    eh.on_mcp_event(&McpEvent::ToolError {
151                        source: self.endpoint.clone(),
152                        tool_name: tool_name.to_string(),
153                        error: e.to_string(),
154                        duration_ms,
155                    })
156                    .await;
157                }
158                Err(Box::new(e))
159            }
160            Ok(resp) => {
161                if !resp.status().is_success() {
162                    let duration_ms = call_start.elapsed().as_millis() as u64;
163                    let err_msg = format!("MCP server returned status: {}", resp.status());
164                    if let Some(ref eh) = self.event_handler {
165                        eh.on_mcp_event(&McpEvent::ToolError {
166                            source: self.endpoint.clone(),
167                            tool_name: tool_name.to_string(),
168                            error: err_msg.clone(),
169                            duration_ms,
170                        })
171                        .await;
172                    }
173                    return Err(Box::new(ToolError::ExecutionFailed(err_msg)));
174                }
175
176                let body: serde_json::Value = resp.json().await?;
177                let result: ToolResult = if let Some(r) = body.get("result") {
178                    serde_json::from_value(r.clone()).map_err(|e| {
179                        Box::new(ToolError::ProtocolError(format!(
180                            "Failed to deserialize tool result from MCP server: {}",
181                            e
182                        ))) as Box<dyn Error + Send + Sync>
183                    })?
184                } else {
185                    serde_json::from_value(body).map_err(|e| {
186                        Box::new(ToolError::ProtocolError(format!(
187                            "Failed to deserialize tool result from MCP server: {}",
188                            e
189                        ))) as Box<dyn Error + Send + Sync>
190                    })?
191                };
192
193                let duration_ms = call_start.elapsed().as_millis() as u64;
194                if let Some(ref eh) = self.event_handler {
195                    eh.on_mcp_event(&McpEvent::RemoteToolCallCompleted {
196                        endpoint: self.endpoint.clone(),
197                        tool_name: tool_name.to_string(),
198                        success: result.success,
199                        error: result.error.clone(),
200                        duration_ms,
201                    })
202                    .await;
203                }
204
205                Ok(result)
206            }
207        }
208    }
209
210    async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
211        if self.should_refresh_cache().await {
212            let had_cache = self.last_cache_refresh.read().await.is_some();
213            if had_cache {
214                if let Some(ref eh) = self.event_handler {
215                    eh.on_mcp_event(&McpEvent::CacheExpired {
216                        endpoint: self.endpoint.clone(),
217                    })
218                    .await;
219                }
220            }
221            self.refresh_cache().await?;
222        } else if let Some(ref eh) = self.event_handler {
223            let cache = self.tools_cache.read().await;
224            let count = cache.as_ref().map_or(0, |t| t.len());
225            drop(cache);
226            eh.on_mcp_event(&McpEvent::CacheHit {
227                endpoint: self.endpoint.clone(),
228                tool_count: count,
229            })
230            .await;
231        }
232
233        let cache = self.tools_cache.read().await;
234        cache.as_ref().cloned().ok_or_else(|| {
235            Box::new(ToolError::ProtocolError(
236                "Tools cache not initialized".to_string(),
237            )) as Box<dyn Error + Send + Sync>
238        })
239    }
240
241    async fn get_tool_metadata(
242        &self,
243        tool_name: &str,
244    ) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
245        let tools = self.list_tools().await?;
246        tools
247            .into_iter()
248            .find(|t| t.name == tool_name)
249            .ok_or_else(|| {
250                Box::new(ToolError::NotFound(tool_name.to_string())) as Box<dyn Error + Send + Sync>
251            })
252    }
253
254    fn protocol_name(&self) -> &str {
255        "mcp"
256    }
257
258    async fn initialize(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
259        self.refresh_cache().await?;
260
261        let tool_count = {
262            let cache = self.tools_cache.read().await;
263            cache.as_ref().map_or(0, |t| t.len())
264        };
265
266        if let Some(ref eh) = self.event_handler {
267            eh.on_mcp_event(&McpEvent::ConnectionInitialized {
268                endpoint: self.endpoint.clone(),
269                tool_count,
270            })
271            .await;
272        }
273
274        Ok(())
275    }
276
277    async fn shutdown(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
278        if let Some(ref eh) = self.event_handler {
279            eh.on_mcp_event(&McpEvent::ConnectionClosed {
280                endpoint: self.endpoint.clone(),
281            })
282            .await;
283        }
284        *self.tools_cache.write().await = None;
285        *self.last_cache_refresh.write().await = None;
286        Ok(())
287    }
288}