agent_code_lib/services/mcp/
transport.rs1use std::collections::HashMap;
7use std::process::Stdio;
8use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
9use tokio::process::{Child, Command};
10use tokio::sync::Mutex;
11use tracing::{debug, warn};
12
13use super::types::*;
14
15pub struct McpTransportConnection {
17 inner: TransportInner,
18 next_id: Mutex<u64>,
19}
20
21#[allow(clippy::large_enum_variant)]
22enum TransportInner {
23 Stdio {
24 child: Mutex<Child>,
25 stdin: Mutex<tokio::process::ChildStdin>,
26 stdout: Mutex<BufReader<tokio::process::ChildStdout>>,
27 },
28 Sse {
29 base_url: String,
30 http: reqwest::Client,
31 },
32}
33
34impl McpTransportConnection {
35 pub async fn connect_stdio(
37 command: &str,
38 args: &[String],
39 env: &HashMap<String, String>,
40 ) -> Result<Self, String> {
41 let mut cmd = Command::new(command);
42 cmd.args(args)
43 .stdin(Stdio::piped())
44 .stdout(Stdio::piped())
45 .stderr(Stdio::null());
46
47 for (key, value) in env {
48 cmd.env(key, value);
49 }
50
51 let mut child = cmd
52 .spawn()
53 .map_err(|e| format!("Failed to spawn MCP server '{command}': {e}"))?;
54
55 let stdin = child
56 .stdin
57 .take()
58 .ok_or_else(|| "Failed to capture stdin".to_string())?;
59
60 let stdout = child
61 .stdout
62 .take()
63 .ok_or_else(|| "Failed to capture stdout".to_string())?;
64
65 Ok(Self {
66 inner: TransportInner::Stdio {
67 child: Mutex::new(child),
68 stdin: Mutex::new(stdin),
69 stdout: Mutex::new(BufReader::new(stdout)),
70 },
71 next_id: Mutex::new(1),
72 })
73 }
74
75 pub async fn connect_sse(base_url: &str) -> Result<Self, String> {
77 let http = reqwest::Client::builder()
78 .timeout(std::time::Duration::from_secs(60))
79 .build()
80 .map_err(|e| format!("HTTP client error: {e}"))?;
81
82 let health_url = format!("{}/health", base_url.trim_end_matches('/'));
84 match http.get(&health_url).send().await {
85 Ok(resp) if resp.status().is_success() => {
86 debug!("MCP SSE server reachable at {base_url}");
87 }
88 Ok(resp) => {
89 debug!(
90 "MCP SSE server returned {}, proceeding anyway",
91 resp.status()
92 );
93 }
94 Err(e) => {
95 warn!("MCP SSE server health check failed: {e}, proceeding anyway");
96 }
97 }
98
99 Ok(Self {
100 inner: TransportInner::Sse {
101 base_url: base_url.trim_end_matches('/').to_string(),
102 http,
103 },
104 next_id: Mutex::new(1),
105 })
106 }
107
108 pub async fn request(
110 &self,
111 method: &str,
112 params: Option<serde_json::Value>,
113 ) -> Result<serde_json::Value, String> {
114 let id = {
115 let mut next = self.next_id.lock().await;
116 let id = *next;
117 *next += 1;
118 id
119 };
120
121 let request = JsonRpcRequest::new(id, method, params);
122 let request_json = serde_json::to_string(&request)
123 .map_err(|e| format!("Failed to serialize request: {e}"))?;
124
125 debug!("MCP request: {method} (id={id})");
126
127 match &self.inner {
128 TransportInner::Stdio { stdin, stdout, .. } => {
129 {
131 let mut stdin = stdin.lock().await;
132 stdin
133 .write_all(request_json.as_bytes())
134 .await
135 .map_err(|e| format!("Failed to write to MCP server: {e}"))?;
136 stdin
137 .write_all(b"\n")
138 .await
139 .map_err(|e| format!("Failed to write newline: {e}"))?;
140 stdin
141 .flush()
142 .await
143 .map_err(|e| format!("Failed to flush: {e}"))?;
144 }
145
146 let mut line = String::new();
148 {
149 let mut stdout = stdout.lock().await;
150 stdout
151 .read_line(&mut line)
152 .await
153 .map_err(|e| format!("Failed to read from MCP server: {e}"))?;
154 }
155
156 if line.is_empty() {
157 return Err("MCP server closed connection".to_string());
158 }
159
160 let response: JsonRpcResponse = serde_json::from_str(&line)
161 .map_err(|e| format!("Invalid JSON-RPC response: {e}"))?;
162
163 if let Some(error) = response.error {
164 return Err(format!("MCP error ({}): {}", error.code, error.message));
165 }
166
167 response
168 .result
169 .ok_or_else(|| "MCP response missing 'result'".to_string())
170 }
171 TransportInner::Sse { base_url, http } => {
172 let url = format!("{base_url}/jsonrpc");
173 let resp = http
174 .post(&url)
175 .json(&request)
176 .send()
177 .await
178 .map_err(|e| format!("SSE request failed: {e}"))?;
179
180 if !resp.status().is_success() {
181 let status = resp.status();
182 let body = resp.text().await.unwrap_or_default();
183 return Err(format!("SSE error ({status}): {body}"));
184 }
185
186 let response: JsonRpcResponse = resp
187 .json()
188 .await
189 .map_err(|e| format!("SSE response parse error: {e}"))?;
190
191 if let Some(error) = response.error {
192 return Err(format!("MCP error ({}): {}", error.code, error.message));
193 }
194
195 response
196 .result
197 .ok_or_else(|| "MCP response missing 'result'".to_string())
198 }
199 }
200 }
201
202 pub async fn notify(
204 &self,
205 method: &str,
206 params: Option<serde_json::Value>,
207 ) -> Result<(), String> {
208 let notification = serde_json::json!({
209 "jsonrpc": "2.0",
210 "method": method,
211 "params": params,
212 });
213
214 let json = serde_json::to_string(¬ification)
215 .map_err(|e| format!("Failed to serialize notification: {e}"))?;
216
217 match &self.inner {
218 TransportInner::Stdio { stdin, .. } => {
219 let mut stdin = stdin.lock().await;
220 stdin
221 .write_all(json.as_bytes())
222 .await
223 .map_err(|e| format!("Failed to write notification: {e}"))?;
224 stdin
225 .write_all(b"\n")
226 .await
227 .map_err(|e| format!("Failed to write newline: {e}"))?;
228 stdin
229 .flush()
230 .await
231 .map_err(|e| format!("Flush failed: {e}"))?;
232 }
233 TransportInner::Sse { base_url, http } => {
234 let url = format!("{base_url}/jsonrpc");
235 let _ = http.post(&url).json(¬ification).send().await;
236 }
237 }
238
239 Ok(())
240 }
241
242 pub async fn shutdown(&self) {
244 match &self.inner {
245 TransportInner::Stdio { child, .. } => {
246 let mut child = child.lock().await;
247 let _ = child.kill().await;
248 }
249 TransportInner::Sse { .. } => {
250 }
252 }
253 }
254}