Skip to main content

koda_core/mcp/
client.rs

1//! Single MCP server client — wraps rmcp connection lifecycle.
2//!
3//! Each `McpClient` owns one connection to one MCP server.
4//! It handles spawning (stdio) or connecting (HTTP), initialization,
5//! tool discovery, and tool invocation.
6
7use std::time::Duration;
8
9use anyhow::{Context, Result};
10use rmcp::ServiceExt;
11use rmcp::model::CallToolRequestParams;
12use rmcp::service::RunningService;
13use rmcp::transport::StreamableHttpClientTransport;
14use rmcp::transport::child_process::TokioChildProcess;
15use serde_json::Value;
16use tokio::process::Command;
17
18use super::config::{McpServerConfig, McpTransport};
19use super::tool_bridge::McpToolAnnotations;
20use crate::providers::ToolDefinition;
21use crate::tools::web_fetch::is_safe_url;
22
23/// Connection status of a single MCP server.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum McpClientStatus {
26    /// Not yet connected.
27    Disconnected,
28    /// Connection in progress.
29    Connecting,
30    /// Connected and ready.
31    Connected,
32    /// Connection failed.
33    Failed,
34}
35
36/// A discovered MCP tool with its Koda-side definition and annotations.
37#[derive(Debug, Clone)]
38pub struct DiscoveredTool {
39    /// Koda tool definition (qualified name, description, schema).
40    pub definition: ToolDefinition,
41    /// MCP annotations for trust classification.
42    pub annotations: McpToolAnnotations,
43    /// Original (unqualified) tool name on the MCP server.
44    pub original_name: String,
45}
46
47/// Client for a single MCP server.
48pub struct McpClient {
49    /// Server name (user-assigned, e.g. "playwright").
50    name: String,
51    /// Server configuration.
52    config: McpServerConfig,
53    /// Running rmcp service (None when disconnected).
54    service: Option<RunningService<rmcp::service::RoleClient, ()>>,
55    /// Discovered tools after connection.
56    tools: Vec<DiscoveredTool>,
57    /// Current connection status.
58    status: McpClientStatus,
59    /// Error message from the last failed connection attempt.
60    last_error: Option<String>,
61    /// Optional human-readable instructions returned by the server during
62    /// `initialize`. Injected into the system prompt so the model picks up
63    /// per-server guidance (#922). `None` when the server provides none.
64    instructions: Option<String>,
65}
66
67impl McpClient {
68    /// Create a new (disconnected) client for the given server.
69    pub fn new(name: String, config: McpServerConfig) -> Self {
70        Self {
71            name,
72            config,
73            service: None,
74            tools: Vec::new(),
75            status: McpClientStatus::Disconnected,
76            last_error: None,
77            instructions: None,
78        }
79    }
80
81    /// Server name.
82    pub fn name(&self) -> &str {
83        &self.name
84    }
85
86    /// Current connection status.
87    pub fn status(&self) -> McpClientStatus {
88        self.status
89    }
90
91    /// Last error message (if status is Failed).
92    pub fn last_error(&self) -> Option<&str> {
93        self.last_error.as_deref()
94    }
95
96    /// Discovered tools (empty until connected).
97    pub fn tools(&self) -> &[DiscoveredTool] {
98        &self.tools
99    }
100
101    /// Server-provided instructions from the `initialize` response, if any.
102    /// Returns `None` if disconnected or the server didn't provide instructions.
103    pub fn instructions(&self) -> Option<&str> {
104        self.instructions.as_deref()
105    }
106
107    /// Connect to the MCP server, initialize, and discover tools.
108    ///
109    /// Dispatches to stdio or HTTP transport based on config.
110    pub async fn connect(&mut self) -> Result<()> {
111        self.status = McpClientStatus::Connecting;
112        self.last_error = None;
113
114        let timeout = Duration::from_secs(self.config.startup_timeout_sec);
115
116        match tokio::time::timeout(timeout, self.connect_inner()).await {
117            Ok(Ok(())) => {
118                let transport_label = match &self.config.transport {
119                    McpTransport::Stdio { .. } => "stdio",
120                    McpTransport::Http { .. } => "http",
121                };
122                self.status = McpClientStatus::Connected;
123                tracing::info!(
124                    server = %self.name,
125                    transport = transport_label,
126                    tools = self.tools.len(),
127                    "MCP server connected"
128                );
129                Ok(())
130            }
131            Ok(Err(e)) => {
132                self.status = McpClientStatus::Failed;
133                self.last_error = Some(e.to_string());
134                tracing::warn!(
135                    server = %self.name,
136                    error = %e,
137                    "MCP server connection failed"
138                );
139                Err(e)
140            }
141            Err(_) => {
142                self.status = McpClientStatus::Failed;
143                let msg = format!(
144                    "MCP server '{}' startup timed out after {}s",
145                    self.name, self.config.startup_timeout_sec
146                );
147                self.last_error = Some(msg.clone());
148                tracing::warn!(server = %self.name, "{msg}");
149                Err(anyhow::anyhow!(msg))
150            }
151        }
152    }
153
154    /// Inner connect logic — dispatches to the right transport.
155    async fn connect_inner(&mut self) -> Result<()> {
156        // Clone transport to avoid borrowing self.config while calling &mut self methods.
157        let transport = self.config.transport.clone();
158        match transport {
159            McpTransport::Stdio {
160                ref command,
161                ref args,
162                ref env,
163                ref cwd,
164            } => self.connect_stdio(command, args, env, cwd.as_deref()).await,
165            McpTransport::Http {
166                ref url,
167                ref bearer_token,
168                ref headers,
169            } => {
170                self.connect_http(url, bearer_token.as_deref(), headers)
171                    .await
172            }
173        }
174    }
175
176    /// Connect via stdio transport (spawn child process).
177    async fn connect_stdio(
178        &mut self,
179        command: &str,
180        args: &[String],
181        env: &std::collections::HashMap<String, String>,
182        cwd: Option<&str>,
183    ) -> Result<()> {
184        let mut cmd = Command::new(command);
185        cmd.args(args);
186        for (key, val) in env {
187            cmd.env(key, val);
188        }
189        if let Some(cwd) = cwd {
190            cmd.current_dir(cwd);
191        }
192
193        let transport =
194            TokioChildProcess::new(cmd).context("failed to spawn MCP server process")?;
195        let service = ().serve(transport).await.context("MCP handshake failed")?;
196        self.finish_handshake(service).await
197    }
198
199    /// Connect via Streamable HTTP transport.
200    async fn connect_http(
201        &mut self,
202        url: &str,
203        bearer_token: Option<&str>,
204        headers: &std::collections::HashMap<String, String>,
205    ) -> Result<()> {
206        use http::{HeaderName, HeaderValue};
207        use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
208
209        // SSRF protection: reject private/internal URLs before opening any connection.
210        if !is_safe_url(url) {
211            anyhow::bail!(
212                "MCP HTTP URL '{url}' is not allowed: private, loopback, or link-local \
213                 addresses are blocked to prevent SSRF attacks"
214            );
215        }
216
217        // Warn if sending a bearer token over plaintext HTTP.
218        if bearer_token.is_some() && url.starts_with("http://") {
219            tracing::warn!(
220                server = %self.name,
221                url = %url,
222                "MCP bearer token is being sent over plaintext HTTP — use HTTPS in production"
223            );
224        }
225
226        let mut config = StreamableHttpClientTransportConfig::with_uri(url);
227
228        // Set bearer token.  rmcp's StreamableHttpClientTransport passes
229        // auth_header to reqwest's `bearer_auth()`, which prepends "Bearer "
230        // automatically — so we store the raw token, not "Bearer {token}".
231        if let Some(token) = bearer_token {
232            config.auth_header = Some(token.to_string());
233        }
234
235        // Set custom headers.
236        if !headers.is_empty() {
237            let mut header_map = std::collections::HashMap::new();
238            for (k, v) in headers {
239                let name = HeaderName::try_from(k.as_str())
240                    .with_context(|| format!("invalid HTTP header name: {k}"))?;
241                let value = HeaderValue::try_from(v.as_str())
242                    .with_context(|| format!("invalid HTTP header value for {k}"))?;
243                header_map.insert(name, value);
244            }
245            config.custom_headers = header_map;
246        }
247
248        // Enable session recovery for remote servers.
249        config.reinit_on_expired_session = true;
250
251        let transport = StreamableHttpClientTransport::from_config(config);
252        let service = ().serve(transport).await.context("MCP HTTP handshake failed")?;
253        self.finish_handshake(service).await
254    }
255
256    /// Common post-handshake bookkeeping shared by stdio + HTTP connect paths.
257    /// Captures the server's `instructions` (#922) and discovers tools.
258    async fn finish_handshake(
259        &mut self,
260        service: RunningService<rmcp::service::RoleClient, ()>,
261    ) -> Result<()> {
262        // Capture server-provided instructions before storing the service so
263        // we don't have to re-borrow it. Filter empties to keep prompt clean.
264        self.instructions = service
265            .peer()
266            .peer_info()
267            .and_then(|info| info.instructions.clone())
268            .filter(|s| !s.trim().is_empty());
269        self.service = Some(service);
270        self.discover_tools().await
271    }
272
273    /// Fetch the tool list from the connected server.
274    async fn discover_tools(&mut self) -> Result<()> {
275        let service = self.service.as_ref().context("not connected")?;
276
277        let result = service
278            .list_tools(Default::default())
279            .await
280            .context("failed to list MCP tools")?;
281
282        self.tools.clear();
283
284        for tool in result.tools {
285            let tool_name: &str = &tool.name;
286
287            // Apply tool filtering.
288            if !self.config.is_tool_allowed(tool_name) {
289                tracing::debug!(
290                    server = %self.name,
291                    tool = %tool_name,
292                    "MCP tool filtered out by config"
293                );
294                continue;
295            }
296
297            let (definition, annotations) =
298                super::tool_bridge::mcp_tool_to_definition(&self.name, &tool);
299
300            self.tools.push(DiscoveredTool {
301                definition,
302                annotations,
303                original_name: tool_name.to_string(),
304            });
305        }
306
307        Ok(())
308    }
309
310    /// Call a tool on this MCP server.
311    ///
312    /// `tool_name` is the original (unqualified) name on the server.
313    /// `arguments` is the JSON arguments value.
314    pub async fn call_tool(
315        &self,
316        tool_name: &str,
317        arguments: Value,
318    ) -> Result<rmcp::model::CallToolResult> {
319        let service = self.service.as_ref().context("MCP server not connected")?;
320
321        let timeout = Duration::from_secs(self.config.tool_timeout_sec);
322
323        let mut params = CallToolRequestParams::new(tool_name.to_string());
324        if let Value::Object(map) = arguments {
325            params.arguments = Some(map);
326        }
327
328        let result = tokio::time::timeout(timeout, service.call_tool(params))
329            .await
330            .map_err(|_| {
331                anyhow::anyhow!(
332                    "MCP tool call '{}' on server '{}' timed out after {}s",
333                    tool_name,
334                    self.name,
335                    self.config.tool_timeout_sec
336                )
337            })?
338            .context("MCP tool call failed")?;
339
340        Ok(result)
341    }
342
343    /// Disconnect from the MCP server.
344    pub async fn disconnect(&mut self) {
345        if let Some(service) = self.service.take() {
346            // RunningService is dropped, which cleans up the child process.
347            drop(service);
348        }
349        self.tools.clear();
350        self.status = McpClientStatus::Disconnected;
351        self.last_error = None;
352        tracing::info!(server = %self.name, "MCP server disconnected");
353    }
354}
355
356impl McpClient {
357    /// Force status to a specific value (test-only).
358    #[cfg(feature = "test-support")]
359    pub fn set_status_for_test(&mut self, status: McpClientStatus) {
360        self.status = status;
361    }
362
363    /// Force server instructions (test-only). Pass `None` to clear.
364    #[cfg(feature = "test-support")]
365    pub fn set_instructions_for_test(&mut self, instructions: Option<String>) {
366        self.instructions = instructions;
367    }
368
369    /// Force last error (test-only).
370    #[cfg(feature = "test-support")]
371    pub fn set_last_error_for_test(&mut self, err: Option<String>) {
372        self.last_error = err;
373    }
374
375    /// Inject discovered tools directly (test-only).
376    ///
377    /// Lets manager-level tests exercise success-path code (annotation
378    /// caching, `all_tool_definitions`, status summaries) without
379    /// spawning a real MCP server subprocess.
380    #[cfg(feature = "test-support")]
381    pub fn set_tools_for_test(&mut self, tools: Vec<DiscoveredTool>) {
382        self.tools = tools;
383    }
384}
385
386impl Drop for McpClient {
387    fn drop(&mut self) {
388        // Ensure the service is dropped (child process cleaned up).
389        if self.service.is_some() {
390            tracing::debug!(server = %self.name, "McpClient dropped while still connected");
391        }
392    }
393}