Skip to main content

defect_mcp/
lib.rs

1//! MCP client adapter layer.
2//!
3//! Wraps tools exposed by an external MCP server into a per-session tool table for
4//! [`defect_agent`].
5
6#![cfg_attr(not(test), warn(clippy::indexing_slicing, clippy::unwrap_used))]
7
8use std::collections::HashSet;
9use std::io;
10use std::path::PathBuf;
11use std::sync::Arc;
12
13use std::collections::HashMap;
14
15use agent_client_protocol_schema::{Content as AcpContent, McpServer, McpServerStdio};
16use agent_client_protocol_schema::{ToolCallContent, ToolCallUpdateFields};
17use defect_agent::error::BoxError;
18use defect_agent::session::{SessionToolFactory, StaticToolRegistryBuilder, ToolRegistry};
19use defect_agent::tool::{
20    SafetyClass, Tool, ToolCallDescription, ToolContext, ToolEvent, ToolSchema, ToolStream,
21};
22use futures::future::BoxFuture;
23use futures::stream;
24use http::{HeaderName, HeaderValue};
25use rmcp::model::{CallToolRequestParams, RawContent, Tool as McpTool};
26use rmcp::service::{Peer, RoleClient, RunningService};
27use rmcp::transport::child_process::TokioChildProcess;
28use rmcp::transport::{
29    StreamableHttpClientTransport, streamable_http_client::StreamableHttpClientTransportConfig,
30};
31use rmcp::{ClientHandler, ServiceExt};
32
33use crate::streamable_http::HyperStreamableHttpClient;
34
35mod streamable_http;
36use serde_json::Value;
37use thiserror::Error;
38
39/// MCP adapter errors.
40#[derive(Debug, Error)]
41#[non_exhaustive]
42pub enum McpAdapterError {
43    #[error("unsupported MCP transport: {0}")]
44    UnsupportedTransport(String),
45
46    #[error("rmcp initialization failed: {0}")]
47    Initialize(#[source] io::Error),
48
49    #[error("rmcp request failed: {0}")]
50    Request(#[source] io::Error),
51}
52
53/// Minimal MCP factory.
54#[derive(Debug, Default, Clone)]
55pub struct McpToolFactory {
56    default_servers: Vec<McpServer>,
57}
58
59impl McpToolFactory {
60    #[must_use]
61    pub fn new() -> Self {
62        Self::default()
63    }
64
65    #[must_use]
66    pub fn with_default_servers(default_servers: Vec<McpServer>) -> Self {
67        Self { default_servers }
68    }
69}
70
71impl SessionToolFactory for McpToolFactory {
72    fn build_registry(
73        &self,
74        cwd: PathBuf,
75        mcp_servers: Vec<McpServer>,
76    ) -> BoxFuture<'_, Result<Arc<dyn ToolRegistry>, BoxError>> {
77        let mcp_servers = merge_mcp_servers(&self.default_servers, &mcp_servers);
78        Box::pin(async move {
79            let mut builder = StaticToolRegistryBuilder::default();
80            for server in mcp_servers {
81                let tools = load_server_tools(cwd.clone(), server).await?;
82                for tool in tools {
83                    builder = builder.insert(tool);
84                }
85            }
86            Ok(Arc::new(builder.build()) as Arc<dyn ToolRegistry>)
87        })
88    }
89}
90
91fn merge_mcp_servers(
92    default_servers: &[McpServer],
93    session_servers: &[McpServer],
94) -> Vec<McpServer> {
95    let session_server_names = session_servers
96        .iter()
97        .map(mcp_server_name)
98        .collect::<HashSet<_>>();
99
100    default_servers
101        .iter()
102        .filter(|server| !session_server_names.contains(mcp_server_name(server)))
103        .cloned()
104        .chain(session_servers.iter().cloned())
105        .collect()
106}
107
108fn mcp_server_name(server: &McpServer) -> &str {
109    match server {
110        McpServer::Stdio(stdio) => &stdio.name,
111        McpServer::Http(http) => &http.name,
112        McpServer::Sse(sse) => &sse.name,
113        other => unreachable!("unsupported MCP transport variant: {other:?}"),
114    }
115}
116
117/// Load MCP tools according to the transport configuration.
118///
119/// # Errors
120///
121/// Returns an error if the transport is unsupported, connection initialization fails, or
122/// the remote tool list cannot be fetched.
123async fn load_server_tools(
124    cwd: PathBuf,
125    server: McpServer,
126) -> Result<Vec<Arc<dyn Tool>>, BoxError> {
127    match server {
128        McpServer::Stdio(stdio) => load_stdio_server_tools(cwd, stdio).await,
129        McpServer::Http(http) => {
130            load_streamable_http_server_tools(cwd, http.name, http.url, http.headers).await
131        }
132        McpServer::Sse(sse) => {
133            load_streamable_http_server_tools(cwd, sse.name, sse.url, sse.headers).await
134        }
135        other => Err(BoxError::new(McpAdapterError::UnsupportedTransport(
136            format!("{other:?}"),
137        ))),
138    }
139}
140
141/// Spawns a stdio MCP server and wraps its tools as local tools.
142///
143/// # Errors
144///
145/// Returns an error if the child process fails to start, rmcp initialization fails, or
146/// the tool list request fails.
147async fn load_stdio_server_tools(
148    cwd: PathBuf,
149    server: McpServerStdio,
150) -> Result<Vec<Arc<dyn Tool>>, BoxError> {
151    let server_name = server.name.clone();
152    let mut command = tokio::process::Command::new(&server.command);
153    command.args(&server.args);
154    command.current_dir(cwd);
155    command.stdin(std::process::Stdio::piped());
156    command.stdout(std::process::Stdio::piped());
157    command.stderr(std::process::Stdio::inherit());
158    for env in server.env {
159        command.env(env.name, env.value);
160    }
161
162    let transport = TokioChildProcess::new(command)
163        .map_err(|source| BoxError::new(McpAdapterError::Initialize(source)))?;
164    let client = EmptyClient.serve(transport).await.map_err(service_error)?;
165    let peer = client.peer().clone();
166    let connection = Arc::new(McpConnection::new(peer.clone(), client));
167    let tools = peer.list_all_tools().await.map_err(service_error)?;
168
169    Ok(tools
170        .into_iter()
171        .map(|tool| {
172            Arc::new(McpToolAdapter::new(connection.clone(), &server_name, tool)) as Arc<dyn Tool>
173        })
174        .collect())
175}
176
177/// Connects to an HTTP/SSE MCP server and wraps its tools as local tools.
178///
179/// # Errors
180///
181/// Returns an error if headers are invalid, rmcp initialization fails, or the tool list
182/// request fails.
183async fn load_streamable_http_server_tools(
184    _cwd: PathBuf,
185    server_name: String,
186    url: String,
187    headers: Vec<agent_client_protocol_schema::HttpHeader>,
188) -> Result<Vec<Arc<dyn Tool>>, BoxError> {
189    let http_client =
190        HyperStreamableHttpClient::from_stack_config(&defect_http::HttpStackConfig::default())
191            .map_err(|e| {
192                BoxError::new(McpAdapterError::Initialize(io::Error::other(e.to_string())))
193            })?;
194    let transport = StreamableHttpClientTransport::with_client(
195        http_client,
196        StreamableHttpClientTransportConfig::with_uri(url).custom_headers(http_headers(headers)?),
197    );
198    let client = EmptyClient.serve(transport).await.map_err(service_error)?;
199    let peer = client.peer().clone();
200    let connection = Arc::new(McpConnection::new(peer.clone(), client));
201    let tools = peer.list_all_tools().await.map_err(service_error)?;
202
203    Ok(tools
204        .into_iter()
205        .map(|tool| {
206            Arc::new(McpToolAdapter::new(connection.clone(), &server_name, tool)) as Arc<dyn Tool>
207        })
208        .collect())
209}
210
211#[derive(Clone, Default)]
212struct EmptyClient;
213
214impl ClientHandler for EmptyClient {}
215
216struct McpConnection {
217    peer: Peer<RoleClient>,
218    _client: RunningService<RoleClient, EmptyClient>,
219}
220
221impl McpConnection {
222    fn new(peer: Peer<RoleClient>, client: RunningService<RoleClient, EmptyClient>) -> Self {
223        Self {
224            peer,
225            _client: client,
226        }
227    }
228}
229
230struct McpToolAdapter {
231    connection: Arc<McpConnection>,
232    /// The raw tool name sent back to the MCP server when calling `call_tool`.
233    upstream_name: String,
234    schema: ToolSchema,
235    safety: SafetyClass,
236}
237
238/// Concatenates the MCP server name and upstream tool name into the tool name used for
239/// registration in the local `ToolRegistry`.
240///
241/// See capabilities for MCP tool classification. All MCP tools are registered as
242/// `mcp__<server>__<name>` to avoid name collisions with built-in tools. The `__`
243/// separator (matching the Claude Code / ecosystem convention) keeps the qualified name
244/// within the Anthropic/Bedrock tool-name charset `^[a-zA-Z0-9_-]{1,128}$` — a `.`
245/// separator was previously rejected by Bedrock. This is a pure string concatenation;
246/// unit tests are in the `tests` module. Note: this does NOT sanitize `server` /
247/// `upstream_name` themselves — a server or tool name containing characters outside the
248/// charset will still be rejected upstream.
249#[must_use]
250pub fn registered_mcp_tool_name(server: &str, upstream_name: &str) -> String {
251    format!("mcp__{server}__{upstream_name}")
252}
253
254impl McpToolAdapter {
255    /// See [`registered_mcp_tool_name`]: all MCP tools are registered locally as
256    /// `mcp.<server>.<name>`. `upstream_name` remains the original name, used when
257    /// sending back to the MCP server.
258    fn new(connection: Arc<McpConnection>, server: &str, tool: McpTool) -> Self {
259        let input_schema = match serde_json::to_value(tool.input_schema.as_ref()) {
260            Ok(schema) => schema,
261            Err(err) => {
262                tracing::warn!(
263                    tool = %tool.name,
264                    error = %err,
265                    "failed to serialize MCP tool input schema; falling back to empty object"
266                );
267                Value::Object(Default::default())
268            }
269        };
270        let upstream_name = tool.name.to_string();
271        let registered_name = registered_mcp_tool_name(server, &upstream_name);
272        let schema = ToolSchema {
273            name: registered_name,
274            description: tool.description.clone().unwrap_or_default().to_string(),
275            input_schema,
276        };
277        Self {
278            connection,
279            upstream_name,
280            schema,
281            safety: infer_safety(&tool),
282        }
283    }
284}
285
286impl Tool for McpToolAdapter {
287    fn schema(&self) -> &ToolSchema {
288        &self.schema
289    }
290
291    fn safety_hint(&self, _args: &serde_json::Value) -> SafetyClass {
292        self.safety
293    }
294
295    fn describe<'a>(
296        &'a self,
297        _args: &'a serde_json::Value,
298        _ctx: ToolContext<'a>,
299    ) -> BoxFuture<'a, ToolCallDescription> {
300        Box::pin(async move {
301            ToolCallDescription {
302                fields: ToolCallUpdateFields::new().title(self.schema.description.clone()),
303            }
304        })
305    }
306
307    fn execute(&self, args: serde_json::Value, _ctx: ToolContext<'_>) -> ToolStream {
308        let peer = self.connection.peer.clone();
309        let name = self.upstream_name.clone();
310        Box::pin(stream::once(async move {
311            let params = match build_call_params(name, args) {
312                Ok(params) => params,
313                Err(err) => return ToolEvent::Failed(err),
314            };
315
316            match peer.call_tool(params).await {
317                Ok(call) => completed_event(call),
318                Err(err) => ToolEvent::Failed(defect_agent::tool::ToolError::Execution(
319                    BoxError::new(io::Error::other(err.to_string())),
320                )),
321            }
322        }))
323    }
324}
325
326fn infer_safety(tool: &McpTool) -> SafetyClass {
327    let Some(annotations) = tool.annotations.as_ref() else {
328        return SafetyClass::Mutating;
329    };
330    if annotations.read_only_hint == Some(true) {
331        return SafetyClass::ReadOnly;
332    }
333    if annotations.destructive_hint == Some(true) {
334        return SafetyClass::Destructive;
335    }
336    SafetyClass::Mutating
337}
338
339fn build_call_params(
340    name: String,
341    args: Value,
342) -> Result<CallToolRequestParams, defect_agent::tool::ToolError> {
343    match args {
344        Value::Object(arguments) => Ok(CallToolRequestParams::new(name).with_arguments(arguments)),
345        Value::Null => Ok(CallToolRequestParams::new(name)),
346        other => Err(defect_agent::tool::ToolError::InvalidArgs(BoxError::new(
347            io::Error::other(format!("expected object args, got {other}")),
348        ))),
349    }
350}
351
352fn completed_event(call: rmcp::model::CallToolResult) -> ToolEvent {
353    let mut content = call
354        .content
355        .iter()
356        .filter_map(content_text)
357        .map(|text| ToolCallContent::Content(AcpContent::new(text)))
358        .collect::<Vec<_>>();
359
360    if content.is_empty()
361        && let Some(structured_content) = call.structured_content.as_ref()
362    {
363        content.push(ToolCallContent::Content(AcpContent::new(
364            structured_content.to_string(),
365        )));
366    }
367
368    let raw_output = serde_json::to_value(&call).ok();
369    ToolEvent::Completed(
370        ToolCallUpdateFields::new()
371            .content((!content.is_empty()).then_some(content))
372            .raw_output(raw_output),
373    )
374}
375
376fn content_text(content: &rmcp::model::Content) -> Option<String> {
377    match &content.raw {
378        RawContent::Text(text) => Some(text.text.clone()),
379        RawContent::Resource(resource) => match &resource.resource {
380            rmcp::model::ResourceContents::TextResourceContents { text, .. } => Some(text.clone()),
381            _ => None,
382        },
383        _ => None,
384    }
385}
386
387fn service_error<E>(err: E) -> BoxError
388where
389    E: std::error::Error,
390{
391    BoxError::new(McpAdapterError::Request(io::Error::other(err.to_string())))
392}
393
394fn http_headers(
395    headers: Vec<agent_client_protocol_schema::HttpHeader>,
396) -> Result<HashMap<HeaderName, HeaderValue>, BoxError> {
397    headers
398        .into_iter()
399        .map(|header| {
400            let name = HeaderName::try_from(header.name.as_str()).map_err(|err| {
401                BoxError::new(McpAdapterError::Initialize(io::Error::new(
402                    io::ErrorKind::InvalidInput,
403                    format!("invalid MCP HTTP header name '{}': {err}", header.name),
404                )))
405            })?;
406            let value = HeaderValue::from_str(&header.value).map_err(|err| {
407                BoxError::new(McpAdapterError::Initialize(io::Error::new(
408                    io::ErrorKind::InvalidInput,
409                    format!("invalid MCP HTTP header value for '{}': {err}", header.name),
410                )))
411            })?;
412            Ok((name, value))
413        })
414        .collect()
415}
416
417#[cfg(test)]
418mod tests;