use clap::{ArgAction, Parser};
use rmcp_proxy::{
run_sse_client, run_sse_server,
sse_client::SseClientConfig,
sse_server::{SseServerSettings, StdioServerParameters},
};
use std::{collections::HashMap, env, error::Error, net::SocketAddr, time::Duration};
use tracing::debug;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
#[derive(Parser)]
#[command(
name = "mcp-proxy",
about = "Start the MCP proxy in one of two possible modes: as an SSE or stdio client.",
long_about = None,
after_help = "Examples:\n \
Connect to a remote SSE server:\n \
mcp-proxy http://localhost:8080/sse\n\n \
Expose a local stdio server as an SSE server:\n \
mcp-proxy your-command --sse-port 8080 -e KEY VALUE -e ANOTHER_KEY ANOTHER_VALUE\n \
mcp-proxy --sse-port 8080 -- your-command --arg1 value1 --arg2 value2\n \
mcp-proxy --sse-port 8080 -- python mcp_server.py\n \
mcp-proxy --sse-port 8080 --sse-host 0.0.0.0 -- npx -y @modelcontextprotocol/server-everything
",
)]
struct Cli {
#[arg(env = "SSE_URL")]
command_or_url: Option<String>,
#[arg(short = 'H', long = "headers", value_names = ["KEY", "VALUE"], number_of_values = 2)]
headers: Vec<String>,
#[arg(last = true, allow_hyphen_values = true)]
args: Vec<String>,
#[arg(short = 'e', long = "env", value_names = ["KEY", "VALUE"], number_of_values = 2)]
env_vars: Vec<String>,
#[arg(long = "pass-environment", action = ArgAction::SetTrue)]
pass_environment: bool,
#[arg(long = "sse-port", default_value = "0")]
sse_port: u16,
#[arg(long = "sse-host", default_value = "127.0.0.1")]
sse_host: String,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")))
.with(tracing_subscriber::fmt::layer())
.init();
let mut cli = Cli::parse();
let command_or_url = match cli.command_or_url {
Some(value) => value,
None => match cli.args.len() {
0 => {
eprintln!("Error: command or URL is required");
std::process::exit(1);
}
_ => cli.args.remove(0),
},
};
if command_or_url.starts_with("http://") || command_or_url.starts_with("https://") {
debug!("Starting SSE client and stdio server");
let mut headers = HashMap::new();
for i in (0..cli.headers.len()).step_by(2) {
if i + 1 < cli.headers.len() {
headers.insert(cli.headers[i].clone(), cli.headers[i + 1].clone());
}
}
let config = SseClientConfig {
url: command_or_url,
headers,
};
run_sse_client(config).await?;
} else {
debug!("Starting stdio client and SSE server");
let mut env_map = HashMap::new();
if cli.pass_environment {
for (key, value) in env::vars() {
env_map.insert(key, value);
}
}
for i in (0..cli.env_vars.len()).step_by(2) {
if i + 1 < cli.env_vars.len() {
env_map.insert(cli.env_vars[i].clone(), cli.env_vars[i + 1].clone());
}
}
let stdio_params = StdioServerParameters {
command: command_or_url,
args: cli.args,
env: env_map,
};
let sse_settings = SseServerSettings {
bind_addr: format!("{}:{}", cli.sse_host, cli.sse_port).parse::<SocketAddr>()?,
keep_alive: Some(Duration::from_secs(15)),
};
run_sse_server(stdio_params, sse_settings).await?;
}
Ok(())
}