use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tracing::{debug, error, trace, warn};
use turbomcp_protocol::jsonrpc::{
JsonRpcError, JsonRpcErrorCode, JsonRpcRequest, JsonRpcResponse, JsonRpcResponsePayload,
ResponseId,
};
use turbomcp_protocol::types::RequestId;
use turbomcp_protocol::types::{CallToolRequest, GetPromptRequest, ReadResourceRequest};
use turbomcp_protocol::{Error as McpError, Result as McpResult};
use crate::config::{BackendConfig, BackendValidationConfig, FrontendType, SsrfProtection};
use crate::error::{ProxyError, ProxyResult};
use crate::proxy::{AtomicMetrics, BackendConnector, BackendTransport, ProxyService};
use ipnetwork::IpNetwork;
pub const MAX_REQUEST_SIZE: usize = 10 * 1024 * 1024;
pub const DEFAULT_TIMEOUT_MS: u64 = 30_000;
pub const MAX_TIMEOUT_MS: u64 = 300_000;
pub const ALLOWED_COMMANDS: &[&str] = &["python", "python3", "node", "deno", "uv", "npx", "bun"];
pub const DEFAULT_BIND_ADDRESS: &str = "127.0.0.1:3000";
#[derive(Debug)]
pub struct RuntimeProxyBuilder {
backend_config: Option<BackendConfig>,
frontend_type: Option<FrontendType>,
bind_address: Option<String>,
request_size_limit: usize,
timeout_ms: u64,
enable_metrics: bool,
validation_config: BackendValidationConfig,
}
impl RuntimeProxyBuilder {
#[must_use]
pub fn new() -> Self {
Self {
backend_config: None,
frontend_type: None,
bind_address: Some(DEFAULT_BIND_ADDRESS.to_string()),
request_size_limit: MAX_REQUEST_SIZE,
timeout_ms: DEFAULT_TIMEOUT_MS,
enable_metrics: true,
validation_config: BackendValidationConfig::default(),
}
}
#[must_use]
pub fn with_stdio_backend(mut self, command: impl Into<String>, args: Vec<String>) -> Self {
self.backend_config = Some(BackendConfig::Stdio {
command: command.into(),
args,
working_dir: None,
});
self
}
#[must_use]
pub fn with_stdio_backend_and_dir(
mut self,
command: impl Into<String>,
args: Vec<String>,
working_dir: impl Into<String>,
) -> Self {
self.backend_config = Some(BackendConfig::Stdio {
command: command.into(),
args,
working_dir: Some(working_dir.into()),
});
self
}
#[must_use]
pub fn with_http_backend(mut self, url: impl Into<String>, auth_token: Option<String>) -> Self {
self.backend_config = Some(BackendConfig::Http {
url: url.into(),
auth_token,
});
self
}
#[must_use]
pub fn with_websocket_backend(mut self, url: impl Into<String>) -> Self {
self.backend_config = Some(BackendConfig::WebSocket { url: url.into() });
self
}
#[must_use]
pub fn with_tcp_backend(mut self, host: impl Into<String>, port: u16) -> Self {
self.backend_config = Some(BackendConfig::Tcp {
host: host.into(),
port,
});
self
}
#[cfg(unix)]
#[must_use]
pub fn with_unix_backend(mut self, path: impl Into<String>) -> Self {
self.backend_config = Some(BackendConfig::Unix { path: path.into() });
self
}
#[must_use]
pub fn with_http_frontend(mut self, bind: impl Into<String>) -> Self {
self.frontend_type = Some(FrontendType::Http);
self.bind_address = Some(bind.into());
self
}
#[must_use]
pub fn with_stdio_frontend(mut self) -> Self {
self.frontend_type = Some(FrontendType::Stdio);
self
}
#[must_use]
pub fn with_websocket_frontend(mut self, bind: impl Into<String>) -> Self {
self.frontend_type = Some(FrontendType::WebSocket);
self.bind_address = Some(bind.into());
self
}
#[must_use]
pub fn with_request_size_limit(mut self, limit: usize) -> Self {
self.request_size_limit = limit;
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> ProxyResult<Self> {
if timeout_ms > MAX_TIMEOUT_MS {
return Err(ProxyError::configuration_with_key(
format!("Timeout {timeout_ms}ms exceeds maximum {MAX_TIMEOUT_MS}ms"),
"timeout_ms",
));
}
self.timeout_ms = timeout_ms;
Ok(self)
}
#[must_use]
pub fn with_metrics(mut self, enable: bool) -> Self {
self.enable_metrics = enable;
self
}
#[must_use]
pub fn with_backend_validation(mut self, config: BackendValidationConfig) -> Self {
self.validation_config = config;
self
}
pub async fn build(self) -> ProxyResult<RuntimeProxy> {
let backend_config = self
.backend_config
.as_ref()
.ok_or_else(|| ProxyError::configuration("Backend configuration is required"))?;
let frontend_type = self
.frontend_type
.ok_or_else(|| ProxyError::configuration("Frontend type is required"))?;
Self::validate_command(backend_config)?;
Self::validate_url(backend_config, &self.validation_config).await?;
Self::validate_working_dir(backend_config)?;
let backend_config = self.backend_config.unwrap();
let transport = match &backend_config {
BackendConfig::Stdio {
command,
args,
working_dir,
} => BackendTransport::Stdio {
command: command.clone(),
args: args.clone(),
working_dir: working_dir.clone(),
},
BackendConfig::Http { url, auth_token } => BackendTransport::Http {
url: url.clone(),
auth_token: auth_token.clone(),
},
BackendConfig::Tcp { host, port } => BackendTransport::Tcp {
host: host.clone(),
port: *port,
},
#[cfg(unix)]
BackendConfig::Unix { path } => BackendTransport::Unix { path: path.clone() },
BackendConfig::WebSocket { url } => BackendTransport::WebSocket { url: url.clone() },
};
let connector_config = crate::proxy::backend::BackendConfig {
transport,
client_name: "turbomcp-proxy".to_string(),
client_version: crate::VERSION.to_string(),
};
let backend = BackendConnector::new(connector_config).await?;
let metrics = if self.enable_metrics {
Some(Arc::new(AtomicMetrics::new()))
} else {
None
};
Ok(RuntimeProxy {
backend,
frontend_type,
bind_address: self.bind_address,
request_size_limit: self.request_size_limit,
timeout_ms: self.timeout_ms,
metrics,
})
}
fn validate_command(config: &BackendConfig) -> ProxyResult<()> {
if let BackendConfig::Stdio { command, .. } = config
&& !ALLOWED_COMMANDS.contains(&command.as_str())
{
return Err(ProxyError::configuration_with_key(
format!("Command '{command}' not in allowlist. Allowed: {ALLOWED_COMMANDS:#?}"),
"command",
));
}
Ok(())
}
async fn validate_url(
config: &BackendConfig,
validation_config: &BackendValidationConfig,
) -> ProxyResult<()> {
let (BackendConfig::Http { url: url_str, .. } | BackendConfig::WebSocket { url: url_str }) =
config
else {
return Ok(()); };
let parsed = url::Url::parse(url_str)
.map_err(|e| ProxyError::configuration_with_key(format!("Invalid URL: {e}"), "url"))?;
if !validation_config
.allowed_schemes
.contains(&parsed.scheme().to_string())
{
return Err(ProxyError::configuration_with_key(
format!(
"Scheme '{}' not allowed. Allowed schemes: {}",
parsed.scheme(),
validation_config.allowed_schemes.join(", ")
),
"url",
));
}
if parsed.scheme() == "http" || parsed.scheme() == "ws" {
let host = parsed.host_str().unwrap_or("");
if !is_localhost(host) {
let secure_scheme = if parsed.scheme() == "http" {
"https"
} else {
"wss"
};
return Err(ProxyError::configuration_with_key(
format!(
"Secure protocol required for non-localhost URLs. Use {} instead of {}",
secure_scheme,
parsed.scheme()
),
"url",
));
}
}
if let Some(host) = parsed.host_str() {
let port = parsed.port_or_known_default().ok_or_else(|| {
ProxyError::configuration_with_key(
format!(
"URL is missing a usable port for scheme '{}'",
parsed.scheme()
),
"url",
)
})?;
Self::validate_host(host, port, validation_config).await?;
}
Ok(())
}
async fn validate_host(
host: &str,
port: u16,
validation_config: &BackendValidationConfig,
) -> ProxyResult<()> {
if validation_config.blocked_hosts.contains(&host.to_string()) {
return Err(ProxyError::configuration_with_key(
format!("Host '{host}' is blocked by custom blocklist"),
"url",
));
}
match &validation_config.ssrf_protection {
SsrfProtection::Disabled => {
warn!("SSRF protection disabled for host: {}", host);
Ok(())
}
SsrfProtection::Strict => Self::validate_host_strict(host, port).await,
SsrfProtection::Balanced {
allowed_private_networks,
} => Self::validate_host_balanced(host, port, allowed_private_networks).await,
}
}
async fn validate_host_strict(host: &str, port: u16) -> ProxyResult<()> {
if Self::is_cloud_metadata_endpoint(host) {
return Err(ProxyError::configuration_with_key(
format!(
"Cloud metadata endpoint blocked: {host}. \
For internal proxies, use SsrfProtection::Balanced with allowed networks."
),
"url",
));
}
let host_without_brackets = host.trim_start_matches('[').trim_end_matches(']');
if let Ok(ip) = host_without_brackets.parse::<Ipv4Addr>() {
if ip.is_loopback() {
return Ok(()); }
if ip.is_private() || ip.is_link_local() {
return Err(ProxyError::configuration_with_key(
format!(
"Private IPv4 address blocked: {ip}. \
For internal proxies, configure:\n \
SsrfProtection::Balanced {{ \
allowed_private_networks: vec![IpNetwork::from_str(\"10.0.0.0/8\")?] }}"
),
"url",
));
}
return Ok(());
}
if let Ok(ip) = host_without_brackets.parse::<Ipv6Addr>() {
if ip.is_loopback() {
return Ok(()); }
let is_private = Self::is_private_ipv6(&ip);
if is_private {
return Err(ProxyError::configuration_with_key(
format!(
"Private IPv6 address blocked: {ip}. \
For internal proxies, configure:\n \
SsrfProtection::Balanced {{ \
allowed_private_networks: vec![IpNetwork::from_str(\"fc00::/7\")?] }}"
),
"url",
));
}
return Ok(());
}
Self::validate_hostname_resolution(host, port, |_host, ip| match ip {
IpAddr::V4(ipv4) => {
if ipv4.is_loopback() {
Ok(())
} else if ipv4.is_private() || ipv4.is_link_local() {
Err(ProxyError::configuration_with_key(
format!("Resolved private IPv4 address blocked: {ipv4}"),
"url",
))
} else {
Ok(())
}
}
IpAddr::V6(ipv6) => {
if ipv6.is_loopback() {
Ok(())
} else if Self::is_private_ipv6(&ipv6) {
Err(ProxyError::configuration_with_key(
format!("Resolved private IPv6 address blocked: {ipv6}"),
"url",
))
} else {
Ok(())
}
}
})
.await
}
async fn validate_host_balanced(
host: &str,
port: u16,
allowed_networks: &[IpNetwork],
) -> ProxyResult<()> {
if Self::is_cloud_metadata_endpoint(host) {
return Err(ProxyError::configuration_with_key(
format!("Cloud metadata endpoint blocked: {host}"),
"url",
));
}
let host_without_brackets = host.trim_start_matches('[').trim_end_matches(']');
let ip = if let Ok(ipv4) = host_without_brackets.parse::<Ipv4Addr>() {
IpAddr::V4(ipv4)
} else if let Ok(ipv6) = host_without_brackets.parse::<Ipv6Addr>() {
IpAddr::V6(ipv6)
} else {
return Self::validate_hostname_resolution(host, port, |_host, ip| {
Self::validate_ip_balanced(ip, allowed_networks)
})
.await;
};
Self::validate_ip_balanced(ip, allowed_networks)
}
fn validate_ip_balanced(ip: IpAddr, allowed_networks: &[IpNetwork]) -> ProxyResult<()> {
match ip {
IpAddr::V4(ipv4) if ipv4.is_loopback() => return Ok(()),
IpAddr::V6(ipv6) if ipv6.is_loopback() => return Ok(()),
_ => {}
}
let is_private = match ip {
IpAddr::V4(ipv4) => ipv4.is_private() || ipv4.is_link_local(),
IpAddr::V6(ipv6) => Self::is_private_ipv6(&ipv6),
};
if is_private && !allowed_networks.iter().any(|net| net.contains(ip)) {
return Err(ProxyError::configuration_with_key(
format!(
"Private IP {ip} not in allowed networks. Allowed networks: {allowed_networks:?}"
),
"url",
));
}
if is_private {
debug!("Private IP {} allowed by configured network", ip);
}
Ok(())
}
async fn validate_hostname_resolution<F>(
host: &str,
port: u16,
mut validate_ip: F,
) -> ProxyResult<()>
where
F: FnMut(&str, IpAddr) -> ProxyResult<()>,
{
let resolved = tokio::net::lookup_host((host, port)).await.map_err(|e| {
ProxyError::configuration_with_key(
format!("Failed to resolve host '{host}': {e}"),
"url",
)
})?;
let mut saw_ip = false;
for addr in resolved {
saw_ip = true;
validate_ip(host, addr.ip())?;
}
if !saw_ip {
return Err(ProxyError::configuration_with_key(
format!("Host '{host}' resolved to no addresses"),
"url",
));
}
Ok(())
}
fn is_cloud_metadata_endpoint(host: &str) -> bool {
if host == "169.254.169.254" {
return true;
}
if host == "168.63.129.16" {
return true;
}
if host == "metadata.google.internal" || host == "metadata" {
return true;
}
false
}
fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
if ip.segments()[0] & 0xfe00 == 0xfc00 {
return true;
}
if ip.segments()[0] & 0xffc0 == 0xfe80 {
return true;
}
false
}
fn validate_working_dir(config: &BackendConfig) -> ProxyResult<()> {
if let BackendConfig::Stdio {
working_dir: Some(wd),
..
} = config
{
let path = PathBuf::from(wd);
if !path.exists() {
return Err(ProxyError::configuration_with_key(
format!("Working directory does not exist: {wd}"),
"working_dir",
));
}
let canonical = path.canonicalize().map_err(|e| {
ProxyError::configuration_with_key(
format!("Failed to canonicalize working directory: {e}"),
"working_dir",
)
})?;
if !canonical.is_dir() {
return Err(ProxyError::configuration_with_key(
format!("Working directory is not a directory: {wd}"),
"working_dir",
));
}
}
Ok(())
}
}
impl Default for RuntimeProxyBuilder {
fn default() -> Self {
Self::new()
}
}
fn is_localhost(host: &str) -> bool {
matches!(host, "localhost" | "127.0.0.1" | "::1" | "[::1]")
}
#[derive(Debug)]
pub struct RuntimeProxy {
backend: BackendConnector,
frontend_type: FrontendType,
bind_address: Option<String>,
request_size_limit: usize,
timeout_ms: u64,
metrics: Option<Arc<AtomicMetrics>>,
}
impl RuntimeProxy {
pub async fn run(&mut self) -> ProxyResult<()> {
match self.frontend_type {
FrontendType::Http => {
let bind = self
.bind_address
.as_ref()
.ok_or_else(|| {
ProxyError::configuration("Bind address required for HTTP frontend")
})?
.clone();
self.run_http(&bind).await
}
FrontendType::Stdio => self.run_stdio().await,
FrontendType::WebSocket => {
let bind = self
.bind_address
.as_ref()
.ok_or_else(|| {
ProxyError::configuration("Bind address required for WebSocket frontend")
})?
.clone();
self.run_websocket(&bind).await
}
}
}
#[must_use]
pub fn backend(&self) -> &BackendConnector {
&self.backend
}
#[must_use]
pub fn metrics(&self) -> Option<crate::proxy::metrics::ProxyMetrics> {
self.metrics.as_ref().map(|m| m.snapshot())
}
async fn run_http(&mut self, bind: &str) -> ProxyResult<()> {
use axum::{Router, http::StatusCode};
use std::time::Duration;
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::timeout::TimeoutLayer;
use turbomcp_transport::axum::AxumMcpExt;
debug!("Starting HTTP frontend on {}", bind);
let spec = self.backend.introspect().await?;
debug!(
"Backend introspection complete: {} tools, {} resources, {} prompts",
spec.tools.len(),
spec.resources.len(),
spec.prompts.len()
);
let service = ProxyService::new(self.backend.clone(), spec);
let app = Router::new()
.turbo_mcp_routes(service)
.layer(RequestBodyLimitLayer::new(self.request_size_limit))
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
Duration::from_millis(self.timeout_ms),
));
let listener = tokio::net::TcpListener::bind(bind).await.map_err(|e| {
ProxyError::backend_connection(format!("Failed to bind to {bind}: {e}"))
})?;
debug!("HTTP frontend listening on {}", bind);
axum::serve(listener, app)
.await
.map_err(|e| ProxyError::backend(format!("Axum serve error: {e}")))?;
Ok(())
}
async fn run_websocket(&mut self, bind: &str) -> ProxyResult<()> {
use axum::{Router, http::StatusCode};
use std::time::Duration;
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::timeout::TimeoutLayer;
use turbomcp_transport::axum::AxumMcpExt;
debug!("Starting WebSocket frontend on {}", bind);
let spec = self.backend.introspect().await?;
debug!(
"Backend introspection complete: {} tools, {} resources, {} prompts",
spec.tools.len(),
spec.resources.len(),
spec.prompts.len()
);
let service = ProxyService::new(self.backend.clone(), spec);
let app = Router::new()
.turbo_mcp_routes(service)
.layer(RequestBodyLimitLayer::new(self.request_size_limit))
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
Duration::from_millis(self.timeout_ms),
));
let listener = tokio::net::TcpListener::bind(bind).await.map_err(|e| {
ProxyError::backend_connection(format!("Failed to bind to {bind}: {e}"))
})?;
debug!("WebSocket frontend listening on {}", bind);
axum::serve(listener, app)
.await
.map_err(|e| ProxyError::backend(format!("Axum serve error: {e}")))?;
Ok(())
}
fn create_size_limit_error(n: usize) -> JsonRpcResponse {
JsonRpcResponse {
jsonrpc: turbomcp_protocol::jsonrpc::JsonRpcVersion,
payload: JsonRpcResponsePayload::Error {
error: JsonRpcError {
code: JsonRpcErrorCode::InvalidRequest.code(),
message: format!("Request too large: {n} bytes"),
data: None,
},
},
id: ResponseId::null(),
}
}
fn create_response(
result: Result<Result<Value, McpError>, tokio::time::error::Elapsed>,
request_id: RequestId,
timeout_ms: u64,
) -> JsonRpcResponse {
match result {
Ok(Ok(value)) => JsonRpcResponse::success(value, request_id),
Ok(Err(mcp_error)) => JsonRpcResponse::error_response(
JsonRpcError {
code: JsonRpcErrorCode::InternalError.code(),
message: mcp_error.to_string(),
data: None,
},
request_id,
),
Err(_) => JsonRpcResponse::error_response(
JsonRpcError {
code: JsonRpcErrorCode::InternalError.code(),
message: format!("Request timeout after {timeout_ms}ms"),
data: None,
},
request_id,
),
}
}
async fn write_response_to_stdout(
stdout: &mut tokio::io::Stdout,
response: &JsonRpcResponse,
) -> Result<(), String> {
let json = serde_json::to_string(response)
.map_err(|e| format!("Failed to serialize response: {e}"))?;
stdout
.write_all(json.as_bytes())
.await
.map_err(|e| format!("Failed to write response: {e}"))?;
stdout
.write_all(b"\n")
.await
.map_err(|e| format!("Failed to write newline: {e}"))?;
stdout
.flush()
.await
.map_err(|e| format!("Failed to flush stdout: {e}"))?;
trace!("STDIO: Sent response: {json}");
Ok(())
}
async fn process_request_line(
&mut self,
line: &str,
stdout: &mut tokio::io::Stdout,
) -> Result<(), String> {
let request: JsonRpcRequest = serde_json::from_str(line)
.map_err(|e| format!("STDIO: Failed to parse JSON-RPC: {e}"))?;
let request_id = request.id.clone();
let timeout = Duration::from_millis(self.timeout_ms);
let result = tokio::time::timeout(timeout, self.route_request(&request)).await;
let response = Self::create_response(result, request_id, self.timeout_ms);
Self::write_response_to_stdout(stdout, &response).await?;
if let Some(ref metrics) = self.metrics {
metrics.inc_requests_forwarded();
}
Ok(())
}
async fn run_stdio(&mut self) -> ProxyResult<()> {
debug!("Starting STDIO frontend");
let stdin = tokio::io::stdin();
let mut stdout = tokio::io::stdout();
let mut reader = BufReader::new(stdin);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => {
debug!("STDIO: EOF reached, shutting down");
break;
}
Ok(n) => {
if n > self.request_size_limit {
error!(
"STDIO: Request size {} exceeds limit {}",
n, self.request_size_limit
);
let error_response = Self::create_size_limit_error(n);
if let Ok(json) = serde_json::to_string(&error_response) {
let _ = stdout.write_all(json.as_bytes()).await;
let _ = stdout.write_all(b"\n").await;
let _ = stdout.flush().await;
}
continue;
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
trace!("STDIO: Received request: {}", trimmed);
match self.process_request_line(trimmed, &mut stdout).await {
Ok(()) => {}
Err(e)
if e.contains("Failed to write") || e.contains("Failed to flush") =>
{
error!("STDIO: {e}");
break;
}
Err(e) => {
error!("{e}");
let error_response = JsonRpcResponse::parse_error(None);
if let Ok(json) = serde_json::to_string(&error_response) {
let _ = stdout.write_all(json.as_bytes()).await;
let _ = stdout.write_all(b"\n").await;
let _ = stdout.flush().await;
}
}
}
}
Err(e) => {
error!("STDIO: Read error: {}", e);
break;
}
}
}
debug!("STDIO frontend shut down");
Ok(())
}
async fn route_request(&mut self, request: &JsonRpcRequest) -> McpResult<Value> {
trace!("Routing request: method={}", request.method);
match request.method.as_str() {
"tools/list" => {
debug!("Forwarding tools/list to backend");
let tools = self
.backend
.list_tools()
.await
.map_err(|e| McpError::internal(e.to_string()))?;
Ok(serde_json::json!({
"tools": tools
}))
}
"tools/call" => {
debug!("Forwarding tools/call to backend");
let params = request.params.as_ref().ok_or_else(|| {
McpError::invalid_params("Missing params for tools/call".to_string())
})?;
let call_request: CallToolRequest = serde_json::from_value(params.clone())
.map_err(|e| McpError::invalid_params(e.to_string()))?;
let result = self
.backend
.call_tool(&call_request.name, call_request.arguments)
.await
.map_err(|e| McpError::internal(e.to_string()))?;
Ok(serde_json::to_value(result).map_err(|e| McpError::internal(e.to_string()))?)
}
"resources/list" => {
debug!("Forwarding resources/list to backend");
let resources = self
.backend
.list_resources()
.await
.map_err(|e| McpError::internal(e.to_string()))?;
Ok(serde_json::json!({
"resources": resources
}))
}
"resources/read" => {
debug!("Forwarding resources/read to backend");
let params = request.params.as_ref().ok_or_else(|| {
McpError::invalid_params("Missing params for resources/read".to_string())
})?;
let read_request: ReadResourceRequest = serde_json::from_value(params.clone())
.map_err(|e| McpError::invalid_params(e.to_string()))?;
let contents = self
.backend
.read_resource(&read_request.uri)
.await
.map_err(|e| McpError::internal(e.to_string()))?;
Ok(serde_json::json!({
"contents": contents
}))
}
"prompts/list" => {
debug!("Forwarding prompts/list to backend");
let prompts = self
.backend
.list_prompts()
.await
.map_err(|e| McpError::internal(e.to_string()))?;
Ok(serde_json::json!({
"prompts": prompts
}))
}
"prompts/get" => {
debug!("Forwarding prompts/get to backend");
let params = request.params.as_ref().ok_or_else(|| {
McpError::invalid_params("Missing params for prompts/get".to_string())
})?;
let get_request: GetPromptRequest = serde_json::from_value(params.clone())
.map_err(|e| McpError::invalid_params(e.to_string()))?;
let result = self
.backend
.get_prompt(&get_request.name, get_request.arguments)
.await
.map_err(|e| McpError::internal(e.to_string()))?;
Ok(serde_json::to_value(result).map_err(|e| McpError::internal(e.to_string()))?)
}
method => {
error!("Unknown method: {}", method);
Err(McpError::internal(format!("Method not found: {method}")))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_creation() {
let builder = RuntimeProxyBuilder::new();
assert_eq!(builder.request_size_limit, MAX_REQUEST_SIZE);
assert_eq!(builder.timeout_ms, DEFAULT_TIMEOUT_MS);
assert!(builder.enable_metrics);
}
#[test]
fn test_builder_with_stdio_backend() {
let builder =
RuntimeProxyBuilder::new().with_stdio_backend("python", vec!["server.py".to_string()]);
assert!(matches!(
builder.backend_config,
Some(BackendConfig::Stdio { .. })
));
}
#[test]
fn test_builder_with_http_backend() {
let builder = RuntimeProxyBuilder::new().with_http_backend("https://api.example.com", None);
assert!(matches!(
builder.backend_config,
Some(BackendConfig::Http { .. })
));
}
#[test]
fn test_builder_with_tcp_backend() {
let builder = RuntimeProxyBuilder::new().with_tcp_backend("localhost", 5000);
assert!(matches!(
builder.backend_config,
Some(BackendConfig::Tcp {
host: _,
port: 5000
})
));
}
#[cfg(unix)]
#[test]
fn test_builder_with_unix_backend() {
let builder = RuntimeProxyBuilder::new().with_unix_backend("/tmp/mcp.sock");
assert!(matches!(
builder.backend_config,
Some(BackendConfig::Unix { path: _ })
));
}
#[test]
fn test_builder_with_frontends() {
let http_builder = RuntimeProxyBuilder::new().with_http_frontend("0.0.0.0:3000");
assert_eq!(http_builder.frontend_type, Some(FrontendType::Http));
let stdio_builder = RuntimeProxyBuilder::new().with_stdio_frontend();
assert_eq!(stdio_builder.frontend_type, Some(FrontendType::Stdio));
}
#[test]
fn test_builder_with_timeout() {
let result = RuntimeProxyBuilder::new().with_timeout(60_000);
assert!(result.is_ok());
assert_eq!(result.unwrap().timeout_ms, 60_000);
}
#[test]
fn test_builder_timeout_exceeds_max() {
let result = RuntimeProxyBuilder::new().with_timeout(MAX_TIMEOUT_MS + 1);
assert!(result.is_err());
match result {
Err(ProxyError::Configuration { key, .. }) => {
assert_eq!(key, Some("timeout_ms".to_string()));
}
_ => panic!("Expected Configuration error"),
}
}
#[test]
fn test_validate_command_allowed() {
let config = BackendConfig::Stdio {
command: "python".to_string(),
args: vec![],
working_dir: None,
};
assert!(RuntimeProxyBuilder::validate_command(&config).is_ok());
}
#[test]
fn test_validate_command_not_allowed() {
let config = BackendConfig::Stdio {
command: "malicious_command".to_string(),
args: vec![],
working_dir: None,
};
let result = RuntimeProxyBuilder::validate_command(&config);
assert!(result.is_err());
match result {
Err(ProxyError::Configuration { message, key }) => {
assert!(message.contains("not in allowlist"));
assert_eq!(key, Some("command".to_string()));
}
_ => panic!("Expected Configuration error"),
}
}
#[tokio::test]
async fn test_validate_url_https_required() {
let config = BackendConfig::Http {
url: "http://api.example.com".to_string(),
auth_token: None,
};
let validation_config = BackendValidationConfig::default();
let result = RuntimeProxyBuilder::validate_url(&config, &validation_config).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_validate_url_localhost_http_allowed() {
let config = BackendConfig::Http {
url: "http://localhost:3000".to_string(),
auth_token: None,
};
let validation_config = BackendValidationConfig::default();
assert!(
RuntimeProxyBuilder::validate_url(&config, &validation_config)
.await
.is_ok()
);
}
#[tokio::test]
async fn test_validate_url_https_allowed() {
let config = BackendConfig::Http {
url: "https://8.8.8.8".to_string(),
auth_token: None,
};
let validation_config = BackendValidationConfig::default();
assert!(
RuntimeProxyBuilder::validate_url(&config, &validation_config)
.await
.is_ok()
);
}
#[tokio::test]
async fn test_validate_host_blocks_metadata() {
let validation_config = BackendValidationConfig::default();
assert!(
RuntimeProxyBuilder::validate_host("169.254.169.254", 443, &validation_config)
.await
.is_err()
);
assert!(
RuntimeProxyBuilder::validate_host("metadata.google.internal", 443, &validation_config)
.await
.is_err()
);
}
#[tokio::test]
async fn test_validate_host_blocks_private_ips() {
let validation_config = BackendValidationConfig::default();
assert!(
RuntimeProxyBuilder::validate_host("192.168.1.1", 443, &validation_config)
.await
.is_err()
);
assert!(
RuntimeProxyBuilder::validate_host("10.0.0.1", 443, &validation_config)
.await
.is_err()
);
assert!(
RuntimeProxyBuilder::validate_host("172.16.0.1", 443, &validation_config)
.await
.is_err()
);
}
#[tokio::test]
async fn test_validate_host_allows_loopback() {
let validation_config = BackendValidationConfig::default();
assert!(
RuntimeProxyBuilder::validate_host("127.0.0.1", 443, &validation_config)
.await
.is_ok()
);
}
#[test]
fn test_is_localhost() {
assert!(is_localhost("localhost"));
assert!(is_localhost("127.0.0.1"));
assert!(is_localhost("::1"));
assert!(is_localhost("[::1]"));
assert!(!is_localhost("example.com"));
assert!(!is_localhost("192.168.1.1"));
}
#[tokio::test]
async fn test_builder_requires_backend() {
let result = RuntimeProxyBuilder::new()
.with_http_frontend("127.0.0.1:3000")
.build()
.await;
assert!(result.is_err());
match result {
Err(ProxyError::Configuration { message, .. }) => {
assert!(message.contains("Backend configuration is required"));
}
_ => panic!("Expected Configuration error"),
}
}
#[tokio::test]
async fn test_builder_requires_frontend() {
let result = RuntimeProxyBuilder::new()
.with_stdio_backend("python", vec!["server.py".to_string()])
.build()
.await;
assert!(result.is_err());
match result {
Err(ProxyError::Configuration { message, .. }) => {
assert!(message.contains("Frontend type is required"));
}
_ => panic!("Expected Configuration error"),
}
}
#[test]
fn test_validate_working_dir_nonexistent() {
let config = BackendConfig::Stdio {
command: "python".to_string(),
args: vec![],
working_dir: Some("/nonexistent/path/that/does/not/exist".to_string()),
};
let result = RuntimeProxyBuilder::validate_working_dir(&config);
assert!(result.is_err());
}
#[test]
fn test_constants() {
assert_eq!(MAX_REQUEST_SIZE, 10 * 1024 * 1024);
assert_eq!(DEFAULT_TIMEOUT_MS, 30_000);
assert_eq!(MAX_TIMEOUT_MS, 300_000);
assert_eq!(DEFAULT_BIND_ADDRESS, "127.0.0.1:3000");
assert!(ALLOWED_COMMANDS.contains(&"python"));
assert!(ALLOWED_COMMANDS.contains(&"node"));
}
}