use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use futures_util::StreamExt;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::Mutex;
use tracing::{debug, warn};
pub type ServerRequestHandler = Arc<dyn Fn(&str, &Value) -> Value + Send + Sync>;
pub type NotificationHandler = Arc<dyn Fn(&str, &Value) + Send + Sync>;
#[async_trait]
pub trait McpTransport: Send + Sync {
async fn request(
&self,
body: &Value,
on_server_request: &ServerRequestHandler,
on_notification: &NotificationHandler,
) -> Result<Value>;
async fn shutdown(&self) -> Result<()>;
}
pub struct HttpTransport {
url: String,
api_key: Option<String>,
http: reqwest::Client,
session_id: Mutex<Option<String>>,
}
impl HttpTransport {
pub fn new(url: String, api_key: Option<String>) -> Self {
Self {
url,
api_key,
http: reqwest::Client::new(),
session_id: Mutex::new(None),
}
}
fn build_request(&self, body: &Value, session_id: &Option<String>) -> reqwest::RequestBuilder {
let mut req = self
.http
.post(&self.url)
.header("content-type", "application/json")
.header("accept", "application/json, text/event-stream");
if let Some(key) = &self.api_key {
req = req.header("authorization", format!("Bearer {key}"));
}
if let Some(sid) = session_id {
req = req.header("mcp-session-id", sid.as_str());
}
req.json(body)
}
async fn send_response(&self, response: &Value, session_id: &Option<String>) -> Result<()> {
let mut req = self
.http
.post(&self.url)
.header("content-type", "application/json");
if let Some(sid) = session_id {
req = req.header("mcp-session-id", sid.as_str());
}
req.json(response)
.send()
.await
.context("Failed to send response to MCP server")?;
Ok(())
}
fn parse_sse_data(raw: &str) -> Option<Value> {
let data_line = raw.lines().find(|l| l.starts_with("data:"))?;
let data = data_line.strip_prefix("data:").unwrap_or("").trim();
serde_json::from_str(data).ok()
}
}
#[async_trait]
impl McpTransport for HttpTransport {
async fn request(
&self,
body: &Value,
on_server_request: &ServerRequestHandler,
on_notification: &NotificationHandler,
) -> Result<Value> {
let session_id = self.session_id.lock().await.clone();
let resp = self
.build_request(body, &session_id)
.send()
.await
.context("Failed to send request to MCP server")?;
if let Some(sid) = resp.headers().get("mcp-session-id") {
if let Ok(s) = sid.to_str() {
*self.session_id.lock().await = Some(s.to_string());
}
}
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
if content_type.contains("text/event-stream") {
let current_sid = self.session_id.lock().await.clone();
let req_id = body.get("id").cloned().unwrap_or(Value::Null);
let mut stream = resp.bytes_stream();
let mut buf = String::new();
loop {
match stream.next().await {
None => bail!("SSE stream ended without a final result"),
Some(Err(e)) => bail!("SSE stream error: {e}"),
Some(Ok(chunk)) => {
buf.push_str(&String::from_utf8_lossy(&chunk));
while let Some(pos) = buf.find("\n\n") {
let raw = buf[..pos].to_string();
buf.drain(..pos + 2);
let Some(data) = Self::parse_sse_data(&raw) else {
continue;
};
if data.get("method").is_some()
&& data.get("id").is_some()
&& data.get("result").is_none()
{
let method = data["method"].as_str().unwrap_or("");
let params = data.get("params").cloned().unwrap_or(Value::Null);
let mut response = on_server_request(method, ¶ms);
if let Value::Object(ref mut map) = response {
map.insert("id".to_string(), data["id"].clone());
map.entry("jsonrpc".to_string())
.or_insert_with(|| json!("2.0"));
}
if let Err(e) = self.send_response(&response, ¤t_sid).await {
warn!("Failed to send server-request response: {e}");
}
continue;
}
if data.get("id") == Some(&req_id)
&& (data.get("result").is_some() || data.get("error").is_some())
{
return Ok(data);
}
if let Some(method) = data.get("method").and_then(|m| m.as_str()) {
let params = data.get("params").cloned().unwrap_or(Value::Null);
on_notification(method, ¶ms);
} else {
debug!("SSE notification (unrecognized): {data}");
}
}
}
}
}
} else {
let data: Value = resp.json().await.context("Failed to parse JSON response")?;
Ok(data)
}
}
async fn shutdown(&self) -> Result<()> {
Ok(())
}
}
pub struct StdioTransport {
stdin: Mutex<tokio::process::ChildStdin>,
reader: Mutex<BufReader<tokio::process::ChildStdout>>,
child: Mutex<tokio::process::Child>,
}
impl StdioTransport {
pub async fn new(
command: &str,
args: &[String],
env: &HashMap<String, String>,
) -> Result<Self> {
let mut cmd = tokio::process::Command::new(command);
cmd.args(args)
.envs(env)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::null());
let mut child = cmd
.spawn()
.with_context(|| format!("Failed to spawn MCP server process: {command}"))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| anyhow::anyhow!("Failed to open stdin of child process"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| anyhow::anyhow!("Failed to open stdout of child process"))?;
Ok(Self {
stdin: Mutex::new(stdin),
reader: Mutex::new(BufReader::new(stdout)),
child: Mutex::new(child),
})
}
}
#[async_trait]
impl McpTransport for StdioTransport {
async fn request(
&self,
body: &Value,
on_server_request: &ServerRequestHandler,
on_notification: &NotificationHandler,
) -> Result<Value> {
let req_id = body.get("id").cloned().unwrap_or(Value::Null);
{
let mut stdin = self.stdin.lock().await;
let mut line = serde_json::to_string(body)?;
line.push('\n');
stdin
.write_all(line.as_bytes())
.await
.context("Failed to write to MCP server stdin")?;
stdin.flush().await?;
}
let mut reader = self.reader.lock().await;
let mut line = String::new();
loop {
line.clear();
let n = reader
.read_line(&mut line)
.await
.context("Failed to read from MCP server stdout")?;
if n == 0 {
bail!("MCP server process closed stdout unexpectedly");
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let data: Value = match serde_json::from_str(trimmed) {
Ok(v) => v,
Err(e) => {
debug!("Ignoring non-JSON line from MCP server: {e}");
continue;
}
};
if data.get("method").is_some()
&& data.get("id").is_some()
&& data.get("result").is_none()
{
let method = data["method"].as_str().unwrap_or("");
let params = data.get("params").cloned().unwrap_or(Value::Null);
let mut response = on_server_request(method, ¶ms);
if let Value::Object(ref mut map) = response {
map.insert("id".to_string(), data["id"].clone());
map.entry("jsonrpc".to_string())
.or_insert_with(|| json!("2.0"));
}
let mut stdin = self.stdin.lock().await;
let mut resp_line = serde_json::to_string(&response)?;
resp_line.push('\n');
stdin.write_all(resp_line.as_bytes()).await?;
stdin.flush().await?;
continue;
}
if data.get("id") == Some(&req_id)
&& (data.get("result").is_some() || data.get("error").is_some())
{
return Ok(data);
}
if let Some(method) = data.get("method").and_then(|m| m.as_str()) {
let params = data.get("params").cloned().unwrap_or(Value::Null);
on_notification(method, ¶ms);
} else {
debug!("stdio message (unrecognized): {data}");
}
}
}
async fn shutdown(&self) -> Result<()> {
drop(self.stdin.lock().await);
let mut child = self.child.lock().await;
let _ = child.kill().await;
Ok(())
}
}