use anyhow::{anyhow, Result};
use async_trait::async_trait;
use reqwest::{header, Client};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::Mutex;
use crate::auth::AuthConfig;
use crate::providers::base::Provider;
use crate::providers::mcp::McpProvider;
use crate::security::{validate_size_limit, validate_url_security};
use crate::tools::Tool;
use crate::transports::{stream::StreamResult, ClientTransport};
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024;
struct McpStdioProcess {
#[allow(dead_code)] child: Child,
stdin: Arc<Mutex<ChildStdin>>,
stdout: Arc<Mutex<BufReader<ChildStdout>>>,
request_id: Arc<Mutex<u64>>,
}
impl McpStdioProcess {
async fn new(
command: &str,
args: &Option<Vec<String>>,
env_vars: &Option<HashMap<String, String>>,
) -> Result<Self> {
crate::security::validate_command(command, &[])?;
if let Some(args_vec) = args {
crate::security::validate_command_args(args_vec)?;
}
let mut cmd = Command::new(command);
if let Some(args_vec) = args {
cmd.args(args_vec);
}
if let Some(env) = env_vars {
for (k, v) in env {
cmd.env(k, v);
}
}
cmd.stdin(std::process::Stdio::piped());
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
let mut child = cmd.spawn()?;
let stdin = child
.stdin
.take()
.ok_or_else(|| anyhow!("Failed to get stdin"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| anyhow!("Failed to get stdout"))?;
let buf_reader = BufReader::with_capacity(65536, stdout);
Ok(Self {
child,
stdin: Arc::new(Mutex::new(stdin)),
stdout: Arc::new(Mutex::new(buf_reader)),
request_id: Arc::new(Mutex::new(1)),
})
}
async fn send_request(&self, method: &str, params: Value) -> Result<Value> {
let mut id_guard = self.request_id.lock().await;
let id = *id_guard;
*id_guard += 1;
drop(id_guard);
let request = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
"id": id,
});
let request_str = serde_json::to_string(&request)?;
let mut stdin = self.stdin.lock().await;
stdin.write_all(request_str.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;
drop(stdin);
let mut stdout = self.stdout.lock().await;
let mut line = String::new();
stdout.read_line(&mut line).await?;
drop(stdout);
if line.is_empty() {
return Err(anyhow!("MCP process closed connection"));
}
validate_size_limit(line.as_bytes(), MAX_RESPONSE_SIZE)?;
let response: Value = serde_json::from_str(&line)?;
if let Some(error) = response.get("error") {
return Err(anyhow!("MCP error: {}", error));
}
response
.get("result")
.cloned()
.ok_or_else(|| anyhow!("No result in MCP response"))
}
}
pub struct McpTransport {
client: Client,
stdio_processes: Arc<Mutex<HashMap<String, Arc<McpStdioProcess>>>>,
}
impl McpTransport {
pub fn new() -> Self {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(120)) .pool_max_idle_per_host(50) .pool_idle_timeout(Some(std::time::Duration::from_secs(90)))
.tcp_keepalive(Some(std::time::Duration::from_secs(30)))
.gzip(true) .http2_adaptive_window(true)
.build()
.expect("Failed to build MCP HTTP client");
Self {
client,
stdio_processes: Arc::new(Mutex::new(HashMap::new())),
}
}
fn apply_auth(
&self,
builder: reqwest::RequestBuilder,
auth: &AuthConfig,
) -> Result<reqwest::RequestBuilder> {
match auth {
AuthConfig::ApiKey(api_key) => {
let location = api_key.location.to_ascii_lowercase();
match location.as_str() {
"header" => Ok(builder.header(&api_key.var_name, &api_key.api_key)),
"query" => {
Ok(builder.query(&[(api_key.var_name.clone(), api_key.api_key.clone())]))
}
"cookie" => {
let cookie_value = format!("{}={}", api_key.var_name, api_key.api_key);
Ok(builder.header(header::COOKIE, cookie_value))
}
other => Err(anyhow!("Unsupported API key location: {}", other)),
}
}
AuthConfig::Basic(basic) => {
Ok(builder.basic_auth(&basic.username, Some(&basic.password)))
}
AuthConfig::OAuth2(_) => Err(anyhow!(
"OAuth2 auth is not yet supported by the MCP transport"
)),
}
}
async fn mcp_http_request(
&self,
prov: &McpProvider,
method: &str,
params: Value,
) -> Result<Value> {
let url = prov
.url
.as_ref()
.ok_or_else(|| anyhow!("No URL provided for HTTP MCP provider"))?;
validate_url_security(url, false)?;
let request = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
"id": 1,
});
let mut req = self.client.post(url).json(&request);
if let Some(headers) = &prov.headers {
for (k, v) in headers {
req = req.header(k, v);
}
}
if let Some(auth) = &prov.base.auth {
req = self.apply_auth(req, auth)?;
}
let response = req.send().await?;
if !response.status().is_success() {
return Err(anyhow!("MCP request failed: {}", response.status()));
}
let body_bytes = response.bytes().await?;
validate_size_limit(&body_bytes, MAX_RESPONSE_SIZE)?;
let result: Value = serde_json::from_slice(&body_bytes)?;
if let Some(error) = result.get("error") {
return Err(anyhow!("MCP error: {}", error));
}
result
.get("result")
.cloned()
.ok_or_else(|| anyhow!("No result in MCP response"))
}
async fn get_or_create_stdio_process(
&self,
prov: &McpProvider,
) -> Result<Arc<McpStdioProcess>> {
let mut processes = self.stdio_processes.lock().await;
if let Some(process) = processes.get(&prov.base.name) {
return Ok(Arc::clone(process));
}
let command = prov
.command
.as_ref()
.ok_or_else(|| anyhow!("No command provided for stdio MCP provider"))?;
let process = Arc::new(McpStdioProcess::new(command, &prov.args, &prov.env_vars).await?);
processes.insert(prov.base.name.clone(), Arc::clone(&process));
Ok(process)
}
async fn mcp_request(&self, prov: &McpProvider, method: &str, params: Value) -> Result<Value> {
if prov.is_http() {
self.mcp_http_request(prov, method, params).await
} else if prov.is_stdio() {
let process = self.get_or_create_stdio_process(prov).await?;
process.send_request(method, params).await
} else {
Err(anyhow!(
"MCP provider must have either 'url' (HTTP) or 'command' (stdio)"
))
}
}
async fn mcp_http_stream(
&self,
prov: &McpProvider,
params: Value,
) -> Result<Box<dyn StreamResult>> {
use eventsource_stream::Eventsource;
use futures::StreamExt;
let url = prov
.url
.as_ref()
.ok_or_else(|| anyhow!("No URL provided for HTTP MCP provider"))?;
validate_url_security(url, false)?;
let request = serde_json::json!({
"jsonrpc": "2.0",
"method": "tools/call",
"params": params,
"id": 1,
});
let mut req = self.client.post(url).json(&request);
if let Some(headers) = &prov.headers {
for (k, v) in headers {
req = req.header(k, v);
}
}
if let Some(auth) = &prov.base.auth {
req = self.apply_auth(req, auth)?;
}
req = req.header("Accept", "text/event-stream");
let response = req.send().await?;
if !response.status().is_success() {
return Err(anyhow!("MCP stream request failed: {}", response.status()));
}
let (tx, rx) = tokio::sync::mpsc::channel(256);
tokio::spawn(async move {
let byte_stream = response.bytes_stream();
let mut event_stream = byte_stream.eventsource();
while let Some(event_result) = event_stream.next().await {
match event_result {
Ok(event) => {
match serde_json::from_str::<Value>(&event.data) {
Ok(value) => {
if tx.send(Ok(value)).await.is_err() {
break; }
}
Err(e) => {
let _ = tx
.send(Err(anyhow!("Failed to parse SSE event: {}", e)))
.await;
break;
}
}
}
Err(e) => {
let _ = tx.send(Err(anyhow!("SSE stream error: {}", e))).await;
break;
}
}
}
});
Ok(crate::transports::stream::boxed_channel_stream(rx, None))
}
async fn mcp_stdio_stream(
&self,
prov: &McpProvider,
params: Value,
) -> Result<Box<dyn StreamResult>> {
let process = self.get_or_create_stdio_process(prov).await?;
let mut id_guard = process.request_id.lock().await;
let id = *id_guard;
*id_guard += 1;
drop(id_guard);
let request = serde_json::json!({
"jsonrpc": "2.0",
"method": "tools/call",
"params": params,
"id": id,
});
let request_str = serde_json::to_string(&request)?;
let mut stdin = process.stdin.lock().await;
stdin.write_all(request_str.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;
drop(stdin);
let (tx, rx) = tokio::sync::mpsc::channel(256);
let stdout = Arc::clone(&process.stdout);
tokio::spawn(async move {
let mut stdout_guard = stdout.lock().await;
loop {
let mut line = String::new();
match stdout_guard.read_line(&mut line).await {
Ok(0) => break, Ok(_) => {
if line.trim().is_empty() {
continue;
}
match serde_json::from_str::<Value>(&line) {
Ok(response) => {
if let Some(error) = response.get("error") {
let _ = tx.send(Err(anyhow!("MCP error: {}", error))).await;
break;
}
if let Some(result) = response.get("result") {
if tx.send(Ok(result.clone())).await.is_err() {
break; }
if response
.get("final")
.and_then(|v| v.as_bool())
.unwrap_or(false)
{
break;
}
}
}
Err(e) => {
let _ = tx
.send(Err(anyhow!("Failed to parse response: {}", e)))
.await;
break;
}
}
}
Err(e) => {
let _ = tx
.send(Err(anyhow!("Failed to read from stdout: {}", e)))
.await;
break;
}
}
}
});
Ok(crate::transports::stream::boxed_channel_stream(rx, None))
}
}
#[async_trait]
impl ClientTransport for McpTransport {
async fn register_tool_provider(&self, _prov: &dyn Provider) -> Result<Vec<Tool>> {
let mcp_prov = _prov
.as_any()
.downcast_ref::<McpProvider>()
.ok_or_else(|| anyhow!("Provider is not an McpProvider"))?;
let params = serde_json::json!({ "cursor": null });
let result = self.mcp_request(mcp_prov, "tools/list", params).await?;
if let Some(tools) = result.get("tools").and_then(|v| v.as_array()) {
let mut parsed = Vec::new();
for tool in tools {
if let Ok(t) = serde_json::from_value::<Tool>(tool.clone()) {
parsed.push(t);
}
}
return Ok(parsed);
}
Ok(vec![])
}
async fn deregister_tool_provider(&self, _prov: &dyn Provider) -> Result<()> {
let mcp_prov = _prov
.as_any()
.downcast_ref::<McpProvider>()
.ok_or_else(|| anyhow!("Provider is not an McpProvider"))?;
if mcp_prov.is_stdio() {
let mut processes = self.stdio_processes.lock().await;
if let Some(process) = processes.remove(&mcp_prov.base.name) {
drop(process);
}
}
Ok(())
}
async fn call_tool(
&self,
tool_name: &str,
args: HashMap<String, Value>,
prov: &dyn Provider,
) -> Result<Value> {
let mcp_prov = prov
.as_any()
.downcast_ref::<McpProvider>()
.ok_or_else(|| anyhow!("Provider is not an McpProvider"))?;
let params = serde_json::json!({
"name": tool_name,
"arguments": args,
});
self.mcp_request(mcp_prov, "tools/call", params).await
}
async fn call_tool_stream(
&self,
tool_name: &str,
args: HashMap<String, Value>,
prov: &dyn Provider,
) -> Result<Box<dyn StreamResult>> {
let mcp_prov = prov
.as_any()
.downcast_ref::<McpProvider>()
.ok_or_else(|| anyhow!("Provider is not an McpProvider"))?;
let params = serde_json::json!({
"name": tool_name,
"arguments": args,
});
if mcp_prov.is_http() {
self.mcp_http_stream(mcp_prov, params).await
} else if mcp_prov.is_stdio() {
self.mcp_stdio_stream(mcp_prov, params).await
} else {
Err(anyhow!(
"MCP provider must have either 'url' (HTTP) or 'command' (stdio)"
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::{ApiKeyAuth, AuthType};
use crate::providers::base::{BaseProvider, ProviderType};
use axum::{extract::Json, http::HeaderValue, routing::post, Router};
use bytes::Bytes;
use serde_json::json;
use std::net::TcpListener;
#[test]
fn apply_auth_adds_expected_headers() {
let transport = McpTransport::new();
let auth = AuthConfig::ApiKey(ApiKeyAuth {
auth_type: AuthType::ApiKey,
api_key: "secret".to_string(),
var_name: "X-MCP".to_string(),
location: "header".to_string(),
});
let request = transport
.apply_auth(reqwest::Client::new().post("http://example.com"), &auth)
.unwrap()
.build()
.unwrap();
assert_eq!(request.headers().get("X-MCP").unwrap(), "secret");
}
#[tokio::test]
async fn mcp_request_requires_transport_configuration() {
let transport = McpTransport::new();
let prov = McpProvider {
base: BaseProvider {
name: "invalid".to_string(),
provider_type: ProviderType::Mcp,
auth: None,
allowed_communication_protocols: None,
},
url: None,
headers: None,
command: None,
args: None,
env_vars: None,
};
let err = transport
.mcp_request(&prov, "ping", Value::Null)
.await
.unwrap_err();
assert!(err
.to_string()
.contains("MCP provider must have either 'url' (HTTP) or 'command' (stdio)"));
}
#[tokio::test]
async fn register_call_and_stream_mcp_http_transport() {
async fn handler(
headers: axum::http::HeaderMap,
Json(payload): Json<Value>,
) -> Json<Value> {
if headers
.get(axum::http::header::ACCEPT)
.and_then(|v| v.to_str().ok())
== Some("text/event-stream")
{
return Json(json!({ "error": "wrong handler" }));
}
let method = payload.get("method").and_then(|v| v.as_str()).unwrap_or("");
match method {
"tools/list" => Json(json!({
"jsonrpc": "2.0",
"result": {
"tools": [{
"name": "echo",
"description": "echo tool",
"inputs": { "type": "object" },
"outputs": { "type": "object" },
"tags": []
}]
},
"id": 1
})),
"tools/call" => {
let params = payload.get("params").cloned().unwrap_or_default();
Json(json!({
"jsonrpc": "2.0",
"result": { "called": params },
"id": 1
}))
}
_ => Json(json!({ "jsonrpc": "2.0", "result": {}, "id": 1 })),
}
}
async fn stream_handler(
headers: axum::http::HeaderMap,
Json(_payload): Json<Value>,
) -> impl axum::response::IntoResponse {
assert_eq!(
headers.get(axum::http::header::ACCEPT),
Some(&HeaderValue::from_static("text/event-stream"))
);
let stream = tokio_stream::iter(vec![
Ok::<_, std::convert::Infallible>(Bytes::from_static(b"data: {\"idx\":1}\n\n")),
Ok(Bytes::from_static(b"data: {\"idx\":2}\n\n")),
]);
(
[(axum::http::header::CONTENT_TYPE, "text/event-stream")],
axum::body::boxed(axum::body::Body::wrap_stream(stream)),
)
}
let app = Router::new()
.route("/", post(handler))
.route("/stream", post(stream_handler));
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::Server::from_tcp(listener)
.unwrap()
.serve(app.into_make_service())
.await
.unwrap();
});
let prov = McpProvider {
base: BaseProvider {
name: "mcp".to_string(),
provider_type: ProviderType::Mcp,
auth: None,
allowed_communication_protocols: None,
},
url: Some(format!("http://{}", addr)),
headers: None,
command: None,
args: None,
env_vars: None,
};
let transport = McpTransport::new();
let tools = transport
.register_tool_provider(&prov)
.await
.expect("register");
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "echo");
let mut args = HashMap::new();
args.insert("msg".into(), Value::String("hi".into()));
let call_value = transport
.call_tool("echo", args.clone(), &prov)
.await
.expect("call");
assert_eq!(
call_value,
json!({ "called": { "name": "echo", "arguments": json!(args) } })
);
let stream_prov = McpProvider {
url: Some(format!("http://{}/stream", addr)),
..prov.clone()
};
let mut stream = transport
.call_tool_stream("echo", args, &stream_prov)
.await
.expect("stream");
assert_eq!(stream.next().await.unwrap().unwrap(), json!({"idx":1}));
assert_eq!(stream.next().await.unwrap().unwrap(), json!({"idx":2}));
stream.close().await.unwrap();
}
}