mcpproxy 0.2.1

MCP proxy server for remote HTTP endpoints supporting JSON and Server-Sent Events
use anyhow::{Context, Result};
use clap::Parser;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader};
use futures_util::StreamExt;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
    /// Remote MCP server URL to proxy to
    #[arg(short, long)]
    url: String,

    /// Optional AgentDB API key for authentication (starts with 'agentdb_')
    #[arg(long)]
    api_key: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct JsonRpcRequest {
    jsonrpc: String,
    method: String,
    params: Option<Value>,
    id: Option<Value>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct JsonRpcResponse {
    jsonrpc: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    result: Option<Value>,
    #[serde(skip_serializing_if = "Option::is_none")]
    error: Option<JsonRpcError>,
    id: Option<Value>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct JsonRpcError {
    code: i32,
    message: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    data: Option<Value>,
}

struct McpProxy {
    client: Client,
    remote_url: String,
    api_key: Option<String>,
}

impl McpProxy {
    fn new(remote_url: String, api_key: Option<String>) -> Self {
        Self {
            client: Client::new(),
            remote_url,
            api_key,
        }
    }

    async fn forward_request(&self, request: JsonRpcRequest) -> Result<Vec<JsonRpcResponse>> {
        let mut headers = HashMap::new();
        headers.insert("Content-Type".to_string(), "application/json".to_string());
        // Accept both JSON and event-stream responses
        headers.insert("Accept".to_string(), "application/json, text/event-stream".to_string());

        if let Some(api_key) = &self.api_key {
            headers.insert("Authorization".to_string(), format!("Bearer {}", api_key));
        }

        let mut req_builder = self.client.post(&self.remote_url);

        for (key, value) in headers {
            req_builder = req_builder.header(&key, &value);
        }

        let response = req_builder
            .json(&request)
            .send()
            .await
            .context("Failed to send request to remote server")?;

        if !response.status().is_success() {
            let status = response.status();
            let error_text = response.text().await.unwrap_or_default();
            return Ok(vec![JsonRpcResponse {
                jsonrpc: "2.0".to_string(),
                result: None,
                error: Some(JsonRpcError {
                    code: status.as_u16() as i32,
                    message: format!("Remote server error: {}", status),
                    data: Some(json!(error_text)),
                }),
                id: request.id,
            }]);
        }

        // Check if the response is an event stream
        let content_type = response
            .headers()
            .get("content-type")
            .and_then(|v| v.to_str().ok())
            .unwrap_or("");

        if content_type.starts_with("text/event-stream") {
            // Handle Server-Sent Events
            self.handle_event_stream_response(response, request.id).await
        } else {
            // Handle regular JSON response
            let json_response: JsonRpcResponse = response
                .json()
                .await
                .context("Failed to parse response from remote server")?;
            Ok(vec![json_response])
        }
    }

    async fn handle_event_stream_response(
        &self,
        response: reqwest::Response,
        request_id: Option<Value>,
    ) -> Result<Vec<JsonRpcResponse>> {
        let mut responses = Vec::new();
        let mut stream = response.bytes_stream();

        let mut buffer = String::new();

        while let Some(chunk) = stream.next().await {
            let chunk = chunk.context("Failed to read chunk from event stream")?;
            let chunk_str = String::from_utf8_lossy(&chunk);
            buffer.push_str(&chunk_str);

            // Process complete events (separated by double newlines)
            while let Some(event_end) = buffer.find("\n\n") {
                let event_data = buffer[..event_end].to_string();
                buffer = buffer[event_end + 2..].to_string();

                if let Some(json_response) = self.parse_sse_event(&event_data, request_id.clone())? {
                    responses.push(json_response);
                }
            }
        }

        // Process any remaining data in buffer
        if !buffer.trim().is_empty() {
            if let Some(json_response) = self.parse_sse_event(&buffer, request_id.clone())? {
                responses.push(json_response);
            }
        }

        if responses.is_empty() {
            // If no valid responses were parsed, return an error
            responses.push(JsonRpcResponse {
                jsonrpc: "2.0".to_string(),
                result: None,
                error: Some(JsonRpcError {
                    code: -32603,
                    message: "No valid responses in event stream".to_string(),
                    data: None,
                }),
                id: request_id,
            });
        }

        Ok(responses)
    }

    fn parse_sse_event(
        &self,
        event_data: &str,
        request_id: Option<Value>,
    ) -> Result<Option<JsonRpcResponse>> {
        // Parse Server-Sent Event format
        let mut data_lines = Vec::new();

        for line in event_data.lines() {
            let line = line.trim();
            if line.is_empty() || line.starts_with(':') {
                // Skip empty lines and comments
                continue;
            }

            if let Some(data) = line.strip_prefix("data: ") {
                data_lines.push(data);
            }
            // We could also handle other SSE fields like event:, id:, retry: if needed
        }

        if data_lines.is_empty() {
            return Ok(None);
        }

        // Join all data lines
        let json_data = data_lines.join("\n");

        // Try to parse as JSON-RPC response
        match serde_json::from_str::<JsonRpcResponse>(&json_data) {
            Ok(mut response) => {
                // If the response doesn't have an ID, use the request ID
                if response.id.is_none() {
                    response.id = request_id;
                }
                Ok(Some(response))
            }
            Err(_) => {
                // If it's not a valid JSON-RPC response, wrap it as a result
                Ok(Some(JsonRpcResponse {
                    jsonrpc: "2.0".to_string(),
                    result: Some(json!(json_data)),
                    error: None,
                    id: request_id,
                }))
            }
        }
    }
}

#[tokio::main]
async fn main() -> Result<()> {
    let args = Args::parse();

    let proxy = McpProxy::new(args.url, args.api_key);

    let stdin = io::stdin();
    let mut stdout = io::stdout();
    let mut reader = BufReader::new(stdin);
    let mut line = String::new();

    eprintln!("MCP Proxy started, forwarding to: {}", proxy.remote_url);

    loop {
        line.clear();

        match reader.read_line(&mut line).await {
            Ok(0) => {
                // EOF reached
                eprintln!("EOF reached, shutting down");
                break;
            }
            Ok(_) => {
                let trimmed = line.trim();
                if trimmed.is_empty() {
                    continue;
                }

                // Parse the JSON-RPC request
                match serde_json::from_str::<JsonRpcRequest>(trimmed) {
                    Ok(request) => {
                        eprintln!("Forwarding request: {} (id: {:?})", request.method, request.id);

                        // Forward the request to the remote server
                        match proxy.forward_request(request).await {
                            Ok(responses) => {
                                // Send each response back to stdout
                                for response in responses {
                                    let response_json = serde_json::to_string(&response)
                                        .context("Failed to serialize response")?;

                                    stdout.write_all(response_json.as_bytes()).await
                                        .context("Failed to write response to stdout")?;
                                    stdout.write_all(b"\n").await
                                        .context("Failed to write newline to stdout")?;
                                    stdout.flush().await
                                        .context("Failed to flush stdout")?;

                                    eprintln!("Response sent for id: {:?}", response.id);
                                }
                            }
                            Err(e) => {
                                eprintln!("Error forwarding request: {}", e);

                                // Send an error response
                                let error_response = JsonRpcResponse {
                                    jsonrpc: "2.0".to_string(),
                                    result: None,
                                    error: Some(JsonRpcError {
                                        code: -32603,
                                        message: "Internal error".to_string(),
                                        data: Some(json!(e.to_string())),
                                    }),
                                    id: None, // We might not have the original ID if parsing failed
                                };

                                let error_json = serde_json::to_string(&error_response)
                                    .context("Failed to serialize error response")?;

                                stdout.write_all(error_json.as_bytes()).await
                                    .context("Failed to write error response to stdout")?;
                                stdout.write_all(b"\n").await
                                    .context("Failed to write newline to stdout")?;
                                stdout.flush().await
                                    .context("Failed to flush stdout")?;
                            }
                        }
                    }
                    Err(e) => {
                        eprintln!("Failed to parse JSON-RPC request: {}", e);
                        eprintln!("Invalid input: {}", trimmed);

                        // Send a parse error response
                        let parse_error = JsonRpcResponse {
                            jsonrpc: "2.0".to_string(),
                            result: None,
                            error: Some(JsonRpcError {
                                code: -32700,
                                message: "Parse error".to_string(),
                                data: Some(json!(e.to_string())),
                            }),
                            id: None,
                        };

                        let error_json = serde_json::to_string(&parse_error)
                            .context("Failed to serialize parse error response")?;

                        stdout.write_all(error_json.as_bytes()).await
                            .context("Failed to write parse error response to stdout")?;
                        stdout.write_all(b"\n").await
                            .context("Failed to write newline to stdout")?;
                        stdout.flush().await
                            .context("Failed to flush stdout")?;
                    }
                }
            }
            Err(e) => {
                eprintln!("Error reading from stdin: {}", e);
                break;
            }
        }
    }

    Ok(())
}