1use crate::{McpError, Result};
4use async_trait::async_trait;
5use serde_json::Value;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::process::{Child, ChildStdin, ChildStdout, Command};
8
9#[async_trait]
11pub trait McpTransport: Send + Sync {
12 async fn send_request(&mut self, request: Value) -> Result<Value>;
14
15 async fn close(&mut self) -> Result<()>;
17}
18
19pub struct StdioTransport {
21 child: Child,
22 stdin: ChildStdin,
23 stdout: BufReader<ChildStdout>,
24 request_id: u64,
25 max_response_size: usize,
27}
28
29impl StdioTransport {
30 const DEFAULT_MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024;
32
33 pub async fn new(command: &str, args: &[&str]) -> Result<Self> {
35 let mut child = Command::new(command)
36 .args(args)
37 .stdin(std::process::Stdio::piped())
38 .stdout(std::process::Stdio::piped())
39 .stderr(std::process::Stdio::null())
40 .spawn()
41 .map_err(|e| McpError::ServerError(format!("Failed to spawn MCP server: {}", e)))?;
42
43 let stdin = child
44 .stdin
45 .take()
46 .ok_or_else(|| McpError::ServerError("Failed to get stdin".to_string()))?;
47
48 let stdout = child
49 .stdout
50 .take()
51 .ok_or_else(|| McpError::ServerError("Failed to get stdout".to_string()))?;
52
53 Ok(Self {
54 child,
55 stdin,
56 stdout: BufReader::new(stdout),
57 request_id: 1,
58 max_response_size: Self::DEFAULT_MAX_RESPONSE_SIZE,
59 })
60 }
61
62 pub fn with_max_response_size(mut self, size: usize) -> Self {
64 self.max_response_size = size;
65 self
66 }
67}
68
69#[async_trait]
70impl McpTransport for StdioTransport {
71 async fn send_request(&mut self, mut request: Value) -> Result<Value> {
72 if let Value::Object(ref mut obj) = request {
74 obj.insert("jsonrpc".to_string(), Value::String("2.0".to_string()));
75 obj.insert("id".to_string(), Value::Number(self.request_id.into()));
76 self.request_id += 1;
77 }
78
79 let request_str = serde_json::to_string(&request)
81 .map_err(|e| McpError::ProtocolError(format!("Failed to serialize request: {}", e)))?;
82
83 self.stdin
84 .write_all(request_str.as_bytes())
85 .await
86 .map_err(|e| McpError::ServerError(format!("Failed to write request: {}", e)))?;
87
88 self.stdin
89 .write_all(b"\n")
90 .await
91 .map_err(|e| McpError::ServerError(format!("Failed to write newline: {}", e)))?;
92
93 self.stdin
94 .flush()
95 .await
96 .map_err(|e| McpError::ServerError(format!("Failed to flush: {}", e)))?;
97
98 let mut response_line = String::new();
100 let bytes_read = self
101 .stdout
102 .read_line(&mut response_line)
103 .await
104 .map_err(|e| McpError::ServerError(format!("Failed to read response: {}", e)))?;
105
106 if bytes_read > self.max_response_size {
107 return Err(McpError::ProtocolError(format!(
108 "Response too large: {} bytes (max: {})",
109 bytes_read, self.max_response_size
110 )));
111 }
112
113 let response: Value = serde_json::from_str(&response_line)
114 .map_err(|e| McpError::ProtocolError(format!("Failed to parse response: {}", e)))?;
115
116 Ok(response)
117 }
118
119 async fn close(&mut self) -> Result<()> {
120 self.child
121 .kill()
122 .await
123 .map_err(|e| McpError::ServerError(format!("Failed to kill child process: {}", e)))
124 }
125}
126
127pub struct HttpTransport {
129 client: reqwest::Client,
130 base_url: String,
131 request_id: u64,
132 max_response_size: usize,
134}
135
136impl HttpTransport {
137 const DEFAULT_MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024;
139
140 pub fn new(base_url: String) -> Self {
142 let client = reqwest::Client::builder()
143 .timeout(std::time::Duration::from_secs(30))
144 .build()
145 .unwrap_or_else(|_| reqwest::Client::new());
146
147 Self {
148 client,
149 base_url,
150 request_id: 1,
151 max_response_size: Self::DEFAULT_MAX_RESPONSE_SIZE,
152 }
153 }
154
155 pub fn with_max_response_size(mut self, size: usize) -> Self {
157 self.max_response_size = size;
158 self
159 }
160
161 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
163 self.client = reqwest::Client::builder()
164 .timeout(timeout)
165 .build()
166 .unwrap_or_else(|_| reqwest::Client::new());
167 self
168 }
169}
170
171#[async_trait]
172impl McpTransport for HttpTransport {
173 async fn send_request(&mut self, mut request: Value) -> Result<Value> {
174 if let Value::Object(ref mut obj) = request {
176 obj.insert("jsonrpc".to_string(), Value::String("2.0".to_string()));
177 obj.insert("id".to_string(), Value::Number(self.request_id.into()));
178 self.request_id += 1;
179 }
180
181 let response = self
182 .client
183 .post(&self.base_url)
184 .json(&request)
185 .send()
186 .await
187 .map_err(|e| McpError::ServerError(format!("HTTP request failed: {}", e)))?;
188
189 let response_json: Value = response
190 .json()
191 .await
192 .map_err(|e| McpError::ProtocolError(format!("Failed to parse response: {}", e)))?;
193
194 Ok(response_json)
195 }
196
197 async fn close(&mut self) -> Result<()> {
198 Ok(())
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_stdio_transport_constants() {
208 assert_eq!(StdioTransport::DEFAULT_MAX_RESPONSE_SIZE, 10 * 1024 * 1024);
209 }
210
211 #[test]
212 fn test_http_transport_constants() {
213 assert_eq!(HttpTransport::DEFAULT_MAX_RESPONSE_SIZE, 10 * 1024 * 1024);
214 }
215
216 #[test]
217 fn test_http_transport_creation() {
218 let transport = HttpTransport::new("http://localhost:3000".to_string());
219 assert_eq!(transport.base_url, "http://localhost:3000");
220 assert_eq!(transport.request_id, 1);
221 assert_eq!(
222 transport.max_response_size,
223 HttpTransport::DEFAULT_MAX_RESPONSE_SIZE
224 );
225 }
226
227 #[test]
228 fn test_http_transport_with_max_response_size() {
229 let transport =
230 HttpTransport::new("http://localhost:3000".to_string()).with_max_response_size(1024);
231 assert_eq!(transport.max_response_size, 1024);
232 }
233
234 #[test]
235 fn test_http_transport_with_timeout() {
236 let transport = HttpTransport::new("http://localhost:3000".to_string())
237 .with_timeout(std::time::Duration::from_secs(5));
238 assert_eq!(transport.base_url, "http://localhost:3000");
240 }
241
242 }