Skip to main content

agent_code_lib/services/mcp/
transport.rs

1//! MCP transport layer.
2//!
3//! Handles the low-level communication with MCP servers over
4//! stdio (subprocess) or SSE (HTTP).
5
6use std::collections::HashMap;
7use std::process::Stdio;
8use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
9use tokio::process::{Child, Command};
10use tokio::sync::Mutex;
11use tracing::{debug, warn};
12
13use super::types::*;
14
15/// A transport connection to an MCP server.
16pub struct McpTransportConnection {
17    inner: TransportInner,
18    next_id: Mutex<u64>,
19}
20
21#[allow(clippy::large_enum_variant)]
22enum TransportInner {
23    Stdio {
24        child: Mutex<Child>,
25        stdin: Mutex<tokio::process::ChildStdin>,
26        stdout: Mutex<BufReader<tokio::process::ChildStdout>>,
27    },
28    Sse {
29        base_url: String,
30        http: reqwest::Client,
31    },
32}
33
34impl McpTransportConnection {
35    /// Connect to an MCP server via stdio subprocess.
36    pub async fn connect_stdio(
37        command: &str,
38        args: &[String],
39        env: &HashMap<String, String>,
40    ) -> Result<Self, String> {
41        let mut cmd = Command::new(command);
42        cmd.args(args)
43            .stdin(Stdio::piped())
44            .stdout(Stdio::piped())
45            .stderr(Stdio::null());
46
47        for (key, value) in env {
48            cmd.env(key, value);
49        }
50
51        let mut child = cmd
52            .spawn()
53            .map_err(|e| format!("Failed to spawn MCP server '{command}': {e}"))?;
54
55        let stdin = child
56            .stdin
57            .take()
58            .ok_or_else(|| "Failed to capture stdin".to_string())?;
59
60        let stdout = child
61            .stdout
62            .take()
63            .ok_or_else(|| "Failed to capture stdout".to_string())?;
64
65        Ok(Self {
66            inner: TransportInner::Stdio {
67                child: Mutex::new(child),
68                stdin: Mutex::new(stdin),
69                stdout: Mutex::new(BufReader::new(stdout)),
70            },
71            next_id: Mutex::new(1),
72        })
73    }
74
75    /// Connect to an MCP server via HTTP/SSE.
76    pub async fn connect_sse(base_url: &str) -> Result<Self, String> {
77        let http = reqwest::Client::builder()
78            .timeout(std::time::Duration::from_secs(60))
79            .build()
80            .map_err(|e| format!("HTTP client error: {e}"))?;
81
82        // Verify the server is reachable.
83        let health_url = format!("{}/health", base_url.trim_end_matches('/'));
84        match http.get(&health_url).send().await {
85            Ok(resp) if resp.status().is_success() => {
86                debug!("MCP SSE server reachable at {base_url}");
87            }
88            Ok(resp) => {
89                debug!(
90                    "MCP SSE server returned {}, proceeding anyway",
91                    resp.status()
92                );
93            }
94            Err(e) => {
95                warn!("MCP SSE server health check failed: {e}, proceeding anyway");
96            }
97        }
98
99        Ok(Self {
100            inner: TransportInner::Sse {
101                base_url: base_url.trim_end_matches('/').to_string(),
102                http,
103            },
104            next_id: Mutex::new(1),
105        })
106    }
107
108    /// Send a JSON-RPC request and wait for the response.
109    pub async fn request(
110        &self,
111        method: &str,
112        params: Option<serde_json::Value>,
113    ) -> Result<serde_json::Value, String> {
114        let id = {
115            let mut next = self.next_id.lock().await;
116            let id = *next;
117            *next += 1;
118            id
119        };
120
121        let request = JsonRpcRequest::new(id, method, params);
122        let request_json = serde_json::to_string(&request)
123            .map_err(|e| format!("Failed to serialize request: {e}"))?;
124
125        debug!("MCP request: {method} (id={id})");
126
127        match &self.inner {
128            TransportInner::Stdio { stdin, stdout, .. } => {
129                // Write the request.
130                {
131                    let mut stdin = stdin.lock().await;
132                    stdin
133                        .write_all(request_json.as_bytes())
134                        .await
135                        .map_err(|e| format!("Failed to write to MCP server: {e}"))?;
136                    stdin
137                        .write_all(b"\n")
138                        .await
139                        .map_err(|e| format!("Failed to write newline: {e}"))?;
140                    stdin
141                        .flush()
142                        .await
143                        .map_err(|e| format!("Failed to flush: {e}"))?;
144                }
145
146                // Read the response.
147                let mut line = String::new();
148                {
149                    let mut stdout = stdout.lock().await;
150                    stdout
151                        .read_line(&mut line)
152                        .await
153                        .map_err(|e| format!("Failed to read from MCP server: {e}"))?;
154                }
155
156                if line.is_empty() {
157                    return Err("MCP server closed connection".to_string());
158                }
159
160                let response: JsonRpcResponse = serde_json::from_str(&line)
161                    .map_err(|e| format!("Invalid JSON-RPC response: {e}"))?;
162
163                if let Some(error) = response.error {
164                    return Err(format!("MCP error ({}): {}", error.code, error.message));
165                }
166
167                response
168                    .result
169                    .ok_or_else(|| "MCP response missing 'result'".to_string())
170            }
171            TransportInner::Sse { base_url, http } => {
172                let url = format!("{base_url}/jsonrpc");
173                let resp = http
174                    .post(&url)
175                    .json(&request)
176                    .send()
177                    .await
178                    .map_err(|e| format!("SSE request failed: {e}"))?;
179
180                if !resp.status().is_success() {
181                    let status = resp.status();
182                    let body = resp.text().await.unwrap_or_default();
183                    return Err(format!("SSE error ({status}): {body}"));
184                }
185
186                let response: JsonRpcResponse = resp
187                    .json()
188                    .await
189                    .map_err(|e| format!("SSE response parse error: {e}"))?;
190
191                if let Some(error) = response.error {
192                    return Err(format!("MCP error ({}): {}", error.code, error.message));
193                }
194
195                response
196                    .result
197                    .ok_or_else(|| "MCP response missing 'result'".to_string())
198            }
199        }
200    }
201
202    /// Send a notification (no response expected).
203    pub async fn notify(
204        &self,
205        method: &str,
206        params: Option<serde_json::Value>,
207    ) -> Result<(), String> {
208        let notification = serde_json::json!({
209            "jsonrpc": "2.0",
210            "method": method,
211            "params": params,
212        });
213
214        let json = serde_json::to_string(&notification)
215            .map_err(|e| format!("Failed to serialize notification: {e}"))?;
216
217        match &self.inner {
218            TransportInner::Stdio { stdin, .. } => {
219                let mut stdin = stdin.lock().await;
220                stdin
221                    .write_all(json.as_bytes())
222                    .await
223                    .map_err(|e| format!("Failed to write notification: {e}"))?;
224                stdin
225                    .write_all(b"\n")
226                    .await
227                    .map_err(|e| format!("Failed to write newline: {e}"))?;
228                stdin
229                    .flush()
230                    .await
231                    .map_err(|e| format!("Flush failed: {e}"))?;
232            }
233            TransportInner::Sse { base_url, http } => {
234                let url = format!("{base_url}/jsonrpc");
235                let _ = http.post(&url).json(&notification).send().await;
236            }
237        }
238
239        Ok(())
240    }
241
242    /// Shut down the transport connection.
243    pub async fn shutdown(&self) {
244        match &self.inner {
245            TransportInner::Stdio { child, .. } => {
246                let mut child = child.lock().await;
247                let _ = child.kill().await;
248            }
249            TransportInner::Sse { .. } => {
250                // HTTP connections are stateless; nothing to shut down.
251            }
252        }
253    }
254}