use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use super::protocol::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
use crate::common::{AgentError, Result};
pub struct StdioTransport {
child: Child,
stdin: tokio::process::ChildStdin,
reader: BufReader<tokio::process::ChildStdout>,
}
impl StdioTransport {
pub fn new(command: &str, args: &[&str]) -> Result<Self> {
Self::with_env(command, args, &std::collections::HashMap::new())
}
pub fn with_env(
command: &str,
args: &[&str],
env: &std::collections::HashMap<String, String>,
) -> Result<Self> {
let mut cmd = Command::new(command);
cmd.args(args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
if !env.is_empty() {
cmd.envs(env);
}
let mut child = cmd.spawn().map_err(|e| {
AgentError::Transport(format!("Failed to spawn MCP server: {}: {}", command, e))
})?;
let stdin = child
.stdin
.take()
.ok_or_else(|| AgentError::Transport("Failed to open child stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| AgentError::Transport("Failed to open child stdout".to_string()))?;
let reader = BufReader::new(stdout);
if let Some(stderr) = child.stderr.take() {
tokio::spawn(async move {
use tokio::io::AsyncBufReadExt;
let mut lines = BufReader::new(stderr).lines();
while let Ok(Some(line)) = lines.next_line().await {
tracing::debug!(target: "mcp::stderr", "{}", line);
}
});
}
Ok(Self {
child,
stdin,
reader,
})
}
pub fn pid(&self) -> Option<u32> {
self.child.id()
}
pub async fn send(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
let mut payload = serde_json::to_string(request)?;
payload.push('\n');
self.stdin
.write_all(payload.as_bytes())
.await
.map_err(|e| {
AgentError::Transport(format!("Failed to write to MCP server stdin: {}", e))
})?;
self.stdin.flush().await?;
let expected_id = request.id;
let mut line = String::new();
loop {
line.clear();
let bytes_read = self.reader.read_line(&mut line).await.map_err(|e| {
AgentError::Transport(format!("Failed to read from MCP server stdout: {}", e))
})?;
if bytes_read == 0 {
return Err(AgentError::Transport(
"MCP server closed stdout before responding".to_string(),
));
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(trimmed)
&& resp.id == expected_id
{
return Ok(resp);
}
}
}
pub async fn send_notification(&mut self, note: &JsonRpcNotification) -> Result<()> {
let mut payload = serde_json::to_string(note)?;
payload.push('\n');
self.stdin
.write_all(payload.as_bytes())
.await
.map_err(|e| {
AgentError::Transport(format!(
"Failed to write notification to MCP server stdin: {}",
e
))
})?;
self.stdin.flush().await?;
Ok(())
}
pub async fn shutdown(&mut self) -> Result<()> {
let _ = self.child.kill().await;
Ok(())
}
}
pub struct HttpTransport {
base_url: String,
client: reqwest::Client,
custom_headers: reqwest::header::HeaderMap,
session_id: Option<String>,
}
impl HttpTransport {
pub fn new(base_url: &str) -> Self {
Self {
base_url: base_url.trim_end_matches('/').to_string(),
client: reqwest::Client::new(),
custom_headers: reqwest::header::HeaderMap::new(),
session_id: None,
}
}
pub fn with_headers(
base_url: &str,
headers: &std::collections::HashMap<String, String>,
) -> Self {
let mut header_map = reqwest::header::HeaderMap::new();
for (k, v) in headers {
if let (Ok(name), Ok(val)) = (
reqwest::header::HeaderName::from_bytes(k.as_bytes()),
reqwest::header::HeaderValue::from_str(v),
) {
header_map.insert(name, val);
}
}
Self {
base_url: base_url.trim_end_matches('/').to_string(),
client: reqwest::Client::new(),
custom_headers: header_map,
session_id: None,
}
}
pub async fn send(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
use reqwest::header::{ACCEPT, CONTENT_TYPE};
if tracing::enabled!(tracing::Level::TRACE)
&& let Ok(req) = self.build_request(request)
{
tracing::trace!(url = %req.url(), "MCP HTTP send");
}
let mut req_builder = self
.client
.post(&self.base_url)
.headers(self.custom_headers.clone())
.header(CONTENT_TYPE, "application/json")
.header(ACCEPT, "application/json, text/event-stream");
if let Some(ref sid) = self.session_id {
req_builder = req_builder.header("Mcp-Session-Id", sid);
}
let resp = req_builder.json(request).send().await.map_err(|e| {
AgentError::Transport(format!("HTTP request to MCP server failed: {}", e))
})?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(AgentError::Transport(format!(
"MCP HTTP server returned {}: {}",
status, body
)));
}
if let Some(sid) = resp.headers().get("mcp-session-id")
&& let Ok(s) = sid.to_str()
{
self.session_id = Some(s.to_string());
}
let content_type = resp
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_lowercase();
if content_type.contains("text/event-stream") {
self.parse_sse_response(resp, request.id).await
} else {
resp.json::<JsonRpcResponse>().await.map_err(|e| {
AgentError::Transport(format!(
"Failed to parse JSON-RPC response from MCP HTTP server: {}",
e
))
})
}
}
async fn parse_sse_response(
&self,
resp: reqwest::Response,
expected_id: u64,
) -> Result<JsonRpcResponse> {
use futures::StreamExt;
let mut stream = resp.bytes_stream();
let mut buffer = String::new();
let mut data_buf = String::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk
.map_err(|e| AgentError::Transport(format!("SSE stream read error: {}", e)))?;
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(boundary) = buffer.find("\n\n") {
let event_block: String = buffer.drain(..boundary).collect();
buffer.drain(..2);
data_buf.clear();
for line in event_block.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if !data_buf.is_empty() {
data_buf.push('\n');
}
data_buf.push_str(data);
} else if let Some(data) = line.strip_prefix("data:") {
if !data_buf.is_empty() {
data_buf.push('\n');
}
data_buf.push_str(data);
}
}
if data_buf.is_empty() {
continue;
}
if let Ok(rpc_resp) = serde_json::from_str::<JsonRpcResponse>(&data_buf)
&& rpc_resp.id == expected_id
{
return Ok(rpc_resp);
}
}
}
Err(AgentError::Transport(
"SSE stream ended without a matching JSON-RPC response".to_string(),
))
}
pub async fn send_notification(&mut self, note: &JsonRpcNotification) -> Result<()> {
use reqwest::header::{ACCEPT, CONTENT_TYPE};
let mut req_builder = self
.client
.post(&self.base_url)
.headers(self.custom_headers.clone())
.header(CONTENT_TYPE, "application/json")
.header(ACCEPT, "application/json, text/event-stream");
if let Some(ref sid) = self.session_id {
req_builder = req_builder.header("Mcp-Session-Id", sid);
}
let resp = req_builder.json(note).send().await.map_err(|e| {
AgentError::Transport(format!("HTTP notification to MCP server failed: {}", e))
})?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(AgentError::Transport(format!(
"MCP HTTP server returned {} for notification: {}",
status, body
)));
}
Ok(())
}
pub fn build_request(&self, request: &JsonRpcRequest) -> Result<reqwest::Request> {
self.client
.post(&self.base_url)
.json(request)
.build()
.map_err(|e| AgentError::Transport(format!("Failed to build HTTP request: {}", e)))
}
}
pub enum McpTransport {
Stdio(Box<StdioTransport>),
Http(HttpTransport),
}
impl McpTransport {
pub fn pid(&self) -> Option<u32> {
match self {
McpTransport::Stdio(s) => s.pid(),
McpTransport::Http(_) => None,
}
}
pub fn target_url(&self) -> Option<String> {
match self {
McpTransport::Http(t) => {
let probe = JsonRpcRequest::new(0, "ping", None);
t.build_request(&probe).ok().map(|r| r.url().to_string())
}
McpTransport::Stdio(_) => None,
}
}
pub async fn send(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
match self {
McpTransport::Stdio(t) => t.send(request).await,
McpTransport::Http(t) => t.send(request).await,
}
}
pub async fn send_notification(&mut self, note: &JsonRpcNotification) -> Result<()> {
match self {
McpTransport::Stdio(t) => t.send_notification(note).await,
McpTransport::Http(t) => t.send_notification(note).await,
}
}
pub async fn shutdown(&mut self) -> Result<()> {
match self {
McpTransport::Stdio(t) => t.shutdown().await,
McpTransport::Http(_) => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::protocol::JsonRpcRequest;
use serde_json::json;
#[test]
fn test_stdio_transport_new_missing_binary() {
let result = StdioTransport::new("__collet_nonexistent_mcp_server__", &[]);
assert!(result.is_err(), "Expected error when binary is not found");
}
#[test]
fn test_http_transport_build_request() {
let transport = HttpTransport::new("http://localhost:8080/rpc");
let req = JsonRpcRequest::new(1, "tools/list", Some(json!({})));
let http_req = transport.build_request(&req).unwrap();
assert_eq!(http_req.method(), reqwest::Method::POST);
assert_eq!(http_req.url().as_str(), "http://localhost:8080/rpc");
}
#[test]
fn test_http_transport_strips_trailing_slash() {
let transport = HttpTransport::new("http://example.com/mcp/");
assert_eq!(transport.base_url, "http://example.com/mcp");
}
#[test]
fn test_http_transport_session_id_initially_none() {
let transport = HttpTransport::new("http://example.com/mcp");
assert!(transport.session_id.is_none());
}
}