1#![cfg_attr(not(test), warn(clippy::indexing_slicing, clippy::unwrap_used))]
7
8use std::collections::HashSet;
9use std::io;
10use std::path::PathBuf;
11use std::sync::Arc;
12
13use std::collections::HashMap;
14
15use agent_client_protocol_schema::{Content as AcpContent, McpServer, McpServerStdio};
16use agent_client_protocol_schema::{ToolCallContent, ToolCallUpdateFields};
17use defect_agent::error::BoxError;
18use defect_agent::session::{SessionToolFactory, StaticToolRegistryBuilder, ToolRegistry};
19use defect_agent::tool::{
20 SafetyClass, Tool, ToolCallDescription, ToolContext, ToolEvent, ToolSchema, ToolStream,
21};
22use futures::future::BoxFuture;
23use futures::stream;
24use http::{HeaderName, HeaderValue};
25use rmcp::model::{CallToolRequestParams, RawContent, Tool as McpTool};
26use rmcp::service::{Peer, RoleClient, RunningService};
27use rmcp::transport::child_process::TokioChildProcess;
28use rmcp::transport::{
29 StreamableHttpClientTransport, streamable_http_client::StreamableHttpClientTransportConfig,
30};
31use rmcp::{ClientHandler, ServiceExt};
32
33use crate::streamable_http::HyperStreamableHttpClient;
34
35mod streamable_http;
36use serde_json::Value;
37use thiserror::Error;
38
39#[derive(Debug, Error)]
41#[non_exhaustive]
42pub enum McpAdapterError {
43 #[error("unsupported MCP transport: {0}")]
44 UnsupportedTransport(String),
45
46 #[error("rmcp initialization failed: {0}")]
47 Initialize(#[source] io::Error),
48
49 #[error("rmcp request failed: {0}")]
50 Request(#[source] io::Error),
51}
52
53#[derive(Debug, Default, Clone)]
55pub struct McpToolFactory {
56 default_servers: Vec<McpServer>,
57}
58
59impl McpToolFactory {
60 #[must_use]
61 pub fn new() -> Self {
62 Self::default()
63 }
64
65 #[must_use]
66 pub fn with_default_servers(default_servers: Vec<McpServer>) -> Self {
67 Self { default_servers }
68 }
69}
70
71impl SessionToolFactory for McpToolFactory {
72 fn build_registry(
73 &self,
74 cwd: PathBuf,
75 mcp_servers: Vec<McpServer>,
76 ) -> BoxFuture<'_, Result<Arc<dyn ToolRegistry>, BoxError>> {
77 let mcp_servers = merge_mcp_servers(&self.default_servers, &mcp_servers);
78 Box::pin(async move {
79 let mut builder = StaticToolRegistryBuilder::default();
80 for server in mcp_servers {
81 let tools = load_server_tools(cwd.clone(), server).await?;
82 for tool in tools {
83 builder = builder.insert(tool);
84 }
85 }
86 Ok(Arc::new(builder.build()) as Arc<dyn ToolRegistry>)
87 })
88 }
89}
90
91fn merge_mcp_servers(
92 default_servers: &[McpServer],
93 session_servers: &[McpServer],
94) -> Vec<McpServer> {
95 let session_server_names = session_servers
96 .iter()
97 .map(mcp_server_name)
98 .collect::<HashSet<_>>();
99
100 default_servers
101 .iter()
102 .filter(|server| !session_server_names.contains(mcp_server_name(server)))
103 .cloned()
104 .chain(session_servers.iter().cloned())
105 .collect()
106}
107
108fn mcp_server_name(server: &McpServer) -> &str {
109 match server {
110 McpServer::Stdio(stdio) => &stdio.name,
111 McpServer::Http(http) => &http.name,
112 McpServer::Sse(sse) => &sse.name,
113 other => unreachable!("unsupported MCP transport variant: {other:?}"),
114 }
115}
116
117async fn load_server_tools(
124 cwd: PathBuf,
125 server: McpServer,
126) -> Result<Vec<Arc<dyn Tool>>, BoxError> {
127 match server {
128 McpServer::Stdio(stdio) => load_stdio_server_tools(cwd, stdio).await,
129 McpServer::Http(http) => {
130 load_streamable_http_server_tools(cwd, http.name, http.url, http.headers).await
131 }
132 McpServer::Sse(sse) => {
133 load_streamable_http_server_tools(cwd, sse.name, sse.url, sse.headers).await
134 }
135 other => Err(BoxError::new(McpAdapterError::UnsupportedTransport(
136 format!("{other:?}"),
137 ))),
138 }
139}
140
141async fn load_stdio_server_tools(
148 cwd: PathBuf,
149 server: McpServerStdio,
150) -> Result<Vec<Arc<dyn Tool>>, BoxError> {
151 let server_name = server.name.clone();
152 let mut command = tokio::process::Command::new(&server.command);
153 command.args(&server.args);
154 command.current_dir(cwd);
155 command.stdin(std::process::Stdio::piped());
156 command.stdout(std::process::Stdio::piped());
157 command.stderr(std::process::Stdio::inherit());
158 for env in server.env {
159 command.env(env.name, env.value);
160 }
161
162 let transport = TokioChildProcess::new(command)
163 .map_err(|source| BoxError::new(McpAdapterError::Initialize(source)))?;
164 let client = EmptyClient.serve(transport).await.map_err(service_error)?;
165 let peer = client.peer().clone();
166 let connection = Arc::new(McpConnection::new(peer.clone(), client));
167 let tools = peer.list_all_tools().await.map_err(service_error)?;
168
169 Ok(tools
170 .into_iter()
171 .map(|tool| {
172 Arc::new(McpToolAdapter::new(connection.clone(), &server_name, tool)) as Arc<dyn Tool>
173 })
174 .collect())
175}
176
177async fn load_streamable_http_server_tools(
184 _cwd: PathBuf,
185 server_name: String,
186 url: String,
187 headers: Vec<agent_client_protocol_schema::HttpHeader>,
188) -> Result<Vec<Arc<dyn Tool>>, BoxError> {
189 let http_client =
190 HyperStreamableHttpClient::from_stack_config(&defect_http::HttpStackConfig::default())
191 .map_err(|e| {
192 BoxError::new(McpAdapterError::Initialize(io::Error::other(e.to_string())))
193 })?;
194 let transport = StreamableHttpClientTransport::with_client(
195 http_client,
196 StreamableHttpClientTransportConfig::with_uri(url).custom_headers(http_headers(headers)?),
197 );
198 let client = EmptyClient.serve(transport).await.map_err(service_error)?;
199 let peer = client.peer().clone();
200 let connection = Arc::new(McpConnection::new(peer.clone(), client));
201 let tools = peer.list_all_tools().await.map_err(service_error)?;
202
203 Ok(tools
204 .into_iter()
205 .map(|tool| {
206 Arc::new(McpToolAdapter::new(connection.clone(), &server_name, tool)) as Arc<dyn Tool>
207 })
208 .collect())
209}
210
211#[derive(Clone, Default)]
212struct EmptyClient;
213
214impl ClientHandler for EmptyClient {}
215
216struct McpConnection {
217 peer: Peer<RoleClient>,
218 _client: RunningService<RoleClient, EmptyClient>,
219}
220
221impl McpConnection {
222 fn new(peer: Peer<RoleClient>, client: RunningService<RoleClient, EmptyClient>) -> Self {
223 Self {
224 peer,
225 _client: client,
226 }
227 }
228}
229
230struct McpToolAdapter {
231 connection: Arc<McpConnection>,
232 upstream_name: String,
234 schema: ToolSchema,
235 safety: SafetyClass,
236}
237
238#[must_use]
250pub fn registered_mcp_tool_name(server: &str, upstream_name: &str) -> String {
251 format!("mcp__{server}__{upstream_name}")
252}
253
254impl McpToolAdapter {
255 fn new(connection: Arc<McpConnection>, server: &str, tool: McpTool) -> Self {
259 let input_schema = match serde_json::to_value(tool.input_schema.as_ref()) {
260 Ok(schema) => schema,
261 Err(err) => {
262 tracing::warn!(
263 tool = %tool.name,
264 error = %err,
265 "failed to serialize MCP tool input schema; falling back to empty object"
266 );
267 Value::Object(Default::default())
268 }
269 };
270 let upstream_name = tool.name.to_string();
271 let registered_name = registered_mcp_tool_name(server, &upstream_name);
272 let schema = ToolSchema {
273 name: registered_name,
274 description: tool.description.clone().unwrap_or_default().to_string(),
275 input_schema,
276 };
277 Self {
278 connection,
279 upstream_name,
280 schema,
281 safety: infer_safety(&tool),
282 }
283 }
284}
285
286impl Tool for McpToolAdapter {
287 fn schema(&self) -> &ToolSchema {
288 &self.schema
289 }
290
291 fn safety_hint(&self, _args: &serde_json::Value) -> SafetyClass {
292 self.safety
293 }
294
295 fn describe<'a>(
296 &'a self,
297 _args: &'a serde_json::Value,
298 _ctx: ToolContext<'a>,
299 ) -> BoxFuture<'a, ToolCallDescription> {
300 Box::pin(async move {
301 ToolCallDescription {
302 fields: ToolCallUpdateFields::new().title(self.schema.description.clone()),
303 }
304 })
305 }
306
307 fn execute(&self, args: serde_json::Value, _ctx: ToolContext<'_>) -> ToolStream {
308 let peer = self.connection.peer.clone();
309 let name = self.upstream_name.clone();
310 Box::pin(stream::once(async move {
311 let params = match build_call_params(name, args) {
312 Ok(params) => params,
313 Err(err) => return ToolEvent::Failed(err),
314 };
315
316 match peer.call_tool(params).await {
317 Ok(call) => completed_event(call),
318 Err(err) => ToolEvent::Failed(defect_agent::tool::ToolError::Execution(
319 BoxError::new(io::Error::other(err.to_string())),
320 )),
321 }
322 }))
323 }
324}
325
326fn infer_safety(tool: &McpTool) -> SafetyClass {
327 let Some(annotations) = tool.annotations.as_ref() else {
328 return SafetyClass::Mutating;
329 };
330 if annotations.read_only_hint == Some(true) {
331 return SafetyClass::ReadOnly;
332 }
333 if annotations.destructive_hint == Some(true) {
334 return SafetyClass::Destructive;
335 }
336 SafetyClass::Mutating
337}
338
339fn build_call_params(
340 name: String,
341 args: Value,
342) -> Result<CallToolRequestParams, defect_agent::tool::ToolError> {
343 match args {
344 Value::Object(arguments) => Ok(CallToolRequestParams::new(name).with_arguments(arguments)),
345 Value::Null => Ok(CallToolRequestParams::new(name)),
346 other => Err(defect_agent::tool::ToolError::InvalidArgs(BoxError::new(
347 io::Error::other(format!("expected object args, got {other}")),
348 ))),
349 }
350}
351
352fn completed_event(call: rmcp::model::CallToolResult) -> ToolEvent {
353 let mut content = call
354 .content
355 .iter()
356 .filter_map(content_text)
357 .map(|text| ToolCallContent::Content(AcpContent::new(text)))
358 .collect::<Vec<_>>();
359
360 if content.is_empty()
361 && let Some(structured_content) = call.structured_content.as_ref()
362 {
363 content.push(ToolCallContent::Content(AcpContent::new(
364 structured_content.to_string(),
365 )));
366 }
367
368 let raw_output = serde_json::to_value(&call).ok();
369 ToolEvent::Completed(
370 ToolCallUpdateFields::new()
371 .content((!content.is_empty()).then_some(content))
372 .raw_output(raw_output),
373 )
374}
375
376fn content_text(content: &rmcp::model::Content) -> Option<String> {
377 match &content.raw {
378 RawContent::Text(text) => Some(text.text.clone()),
379 RawContent::Resource(resource) => match &resource.resource {
380 rmcp::model::ResourceContents::TextResourceContents { text, .. } => Some(text.clone()),
381 _ => None,
382 },
383 _ => None,
384 }
385}
386
387fn service_error<E>(err: E) -> BoxError
388where
389 E: std::error::Error,
390{
391 BoxError::new(McpAdapterError::Request(io::Error::other(err.to_string())))
392}
393
394fn http_headers(
395 headers: Vec<agent_client_protocol_schema::HttpHeader>,
396) -> Result<HashMap<HeaderName, HeaderValue>, BoxError> {
397 headers
398 .into_iter()
399 .map(|header| {
400 let name = HeaderName::try_from(header.name.as_str()).map_err(|err| {
401 BoxError::new(McpAdapterError::Initialize(io::Error::new(
402 io::ErrorKind::InvalidInput,
403 format!("invalid MCP HTTP header name '{}': {err}", header.name),
404 )))
405 })?;
406 let value = HeaderValue::from_str(&header.value).map_err(|err| {
407 BoxError::new(McpAdapterError::Initialize(io::Error::new(
408 io::ErrorKind::InvalidInput,
409 format!("invalid MCP HTTP header value for '{}': {err}", header.name),
410 )))
411 })?;
412 Ok((name, value))
413 })
414 .collect()
415}
416
417#[cfg(test)]
418mod tests;