use anyhow::{Context, Result};
use clap::Parser;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader};
use futures_util::StreamExt;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long)]
url: String,
#[arg(long)]
api_key: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct JsonRpcRequest {
jsonrpc: String,
method: String,
params: Option<Value>,
id: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct JsonRpcResponse {
jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<JsonRpcError>,
id: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct JsonRpcError {
code: i32,
message: String,
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<Value>,
}
struct McpProxy {
client: Client,
remote_url: String,
api_key: Option<String>,
}
impl McpProxy {
fn new(remote_url: String, api_key: Option<String>) -> Self {
Self {
client: Client::new(),
remote_url,
api_key,
}
}
async fn forward_request(&self, request: JsonRpcRequest) -> Result<Vec<JsonRpcResponse>> {
let mut headers = HashMap::new();
headers.insert("Content-Type".to_string(), "application/json".to_string());
headers.insert("Accept".to_string(), "application/json, text/event-stream".to_string());
if let Some(api_key) = &self.api_key {
headers.insert("Authorization".to_string(), format!("Bearer {}", api_key));
}
let mut req_builder = self.client.post(&self.remote_url);
for (key, value) in headers {
req_builder = req_builder.header(&key, &value);
}
let response = req_builder
.json(&request)
.send()
.await
.context("Failed to send request to remote server")?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Ok(vec![JsonRpcResponse {
jsonrpc: "2.0".to_string(),
result: None,
error: Some(JsonRpcError {
code: status.as_u16() as i32,
message: format!("Remote server error: {}", status),
data: Some(json!(error_text)),
}),
id: request.id,
}]);
}
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if content_type.starts_with("text/event-stream") {
self.handle_event_stream_response(response, request.id).await
} else {
let json_response: JsonRpcResponse = response
.json()
.await
.context("Failed to parse response from remote server")?;
Ok(vec![json_response])
}
}
async fn handle_event_stream_response(
&self,
response: reqwest::Response,
request_id: Option<Value>,
) -> Result<Vec<JsonRpcResponse>> {
let mut responses = Vec::new();
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("Failed to read chunk from event stream")?;
let chunk_str = String::from_utf8_lossy(&chunk);
buffer.push_str(&chunk_str);
while let Some(event_end) = buffer.find("\n\n") {
let event_data = buffer[..event_end].to_string();
buffer = buffer[event_end + 2..].to_string();
if let Some(json_response) = self.parse_sse_event(&event_data, request_id.clone())? {
responses.push(json_response);
}
}
}
if !buffer.trim().is_empty() {
if let Some(json_response) = self.parse_sse_event(&buffer, request_id.clone())? {
responses.push(json_response);
}
}
if responses.is_empty() {
responses.push(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
result: None,
error: Some(JsonRpcError {
code: -32603,
message: "No valid responses in event stream".to_string(),
data: None,
}),
id: request_id,
});
}
Ok(responses)
}
fn parse_sse_event(
&self,
event_data: &str,
request_id: Option<Value>,
) -> Result<Option<JsonRpcResponse>> {
let mut data_lines = Vec::new();
for line in event_data.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with(':') {
continue;
}
if let Some(data) = line.strip_prefix("data: ") {
data_lines.push(data);
}
}
if data_lines.is_empty() {
return Ok(None);
}
let json_data = data_lines.join("\n");
match serde_json::from_str::<JsonRpcResponse>(&json_data) {
Ok(mut response) => {
if response.id.is_none() {
response.id = request_id;
}
Ok(Some(response))
}
Err(_) => {
Ok(Some(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
result: Some(json!(json_data)),
error: None,
id: request_id,
}))
}
}
}
}
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse();
let proxy = McpProxy::new(args.url, args.api_key);
let stdin = io::stdin();
let mut stdout = io::stdout();
let mut reader = BufReader::new(stdin);
let mut line = String::new();
eprintln!("MCP Proxy started, forwarding to: {}", proxy.remote_url);
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => {
eprintln!("EOF reached, shutting down");
break;
}
Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
match serde_json::from_str::<JsonRpcRequest>(trimmed) {
Ok(request) => {
eprintln!("Forwarding request: {} (id: {:?})", request.method, request.id);
match proxy.forward_request(request).await {
Ok(responses) => {
for response in responses {
let response_json = serde_json::to_string(&response)
.context("Failed to serialize response")?;
stdout.write_all(response_json.as_bytes()).await
.context("Failed to write response to stdout")?;
stdout.write_all(b"\n").await
.context("Failed to write newline to stdout")?;
stdout.flush().await
.context("Failed to flush stdout")?;
eprintln!("Response sent for id: {:?}", response.id);
}
}
Err(e) => {
eprintln!("Error forwarding request: {}", e);
let error_response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
result: None,
error: Some(JsonRpcError {
code: -32603,
message: "Internal error".to_string(),
data: Some(json!(e.to_string())),
}),
id: None, };
let error_json = serde_json::to_string(&error_response)
.context("Failed to serialize error response")?;
stdout.write_all(error_json.as_bytes()).await
.context("Failed to write error response to stdout")?;
stdout.write_all(b"\n").await
.context("Failed to write newline to stdout")?;
stdout.flush().await
.context("Failed to flush stdout")?;
}
}
}
Err(e) => {
eprintln!("Failed to parse JSON-RPC request: {}", e);
eprintln!("Invalid input: {}", trimmed);
let parse_error = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
result: None,
error: Some(JsonRpcError {
code: -32700,
message: "Parse error".to_string(),
data: Some(json!(e.to_string())),
}),
id: None,
};
let error_json = serde_json::to_string(&parse_error)
.context("Failed to serialize parse error response")?;
stdout.write_all(error_json.as_bytes()).await
.context("Failed to write parse error response to stdout")?;
stdout.write_all(b"\n").await
.context("Failed to write newline to stdout")?;
stdout.flush().await
.context("Failed to flush stdout")?;
}
}
}
Err(e) => {
eprintln!("Error reading from stdin: {}", e);
break;
}
}
}
Ok(())
}