mcp-postgres 1.0.1

High-performance MCP server for PostgreSQL with lock-free connection pool
use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use serde_json::{json, Value};
use tracing::{error, warn};
use std::sync::Arc;

use crate::config::Config;
use crate::errors::{MCPError, Result as MCPResult};
use crate::metrics;
use crate::pool::ConnectionPool;
use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
use crate::actions;
use once_cell::sync::Lazy;

static TOOLS_LIST: Lazy<Value> = Lazy::new(|| {
    let tools_json = include_str!("../tools.json");
    let tools: Vec<Value> = serde_json::from_str(tools_json)
        .expect("Failed to parse tools.json");
    json!({ "tools": tools })
});

const BUFFER_CAPACITY: usize = 16384;
const NEWLINE: &[u8] = b"\n";

#[inline]
#[cold]
fn parse_error(msg: String) -> JsonRpcResponse {
    let mcp_error = MCPError::ParseError(msg);
    JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
}

#[inline]
fn parse_request(line: &str) -> Result<JsonRpcRequest, String> {
    let trimmed = line.trim();
    if trimmed.is_empty() {
        return Err("Empty request".to_string());
    }
    serde_json::from_str::<JsonRpcRequest>(trimmed)
        .map_err(|e| e.to_string())
}

pub struct MCPServer {
    config: Config,
    pool: Arc<ConnectionPool>,
}

impl MCPServer {
    pub fn new(config: Config, pool: Arc<ConnectionPool>) -> Self {
        Self { config, pool }
    }

    /// Run in stdio mode for MCP compatibility (Claude Desktop, etc.)
    pub async fn run_stdio(&self) -> MCPResult<()> {
        let stdin = tokio::io::stdin();
        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
        let mut stdout = tokio::io::stdout();
        let mut line = String::with_capacity(512);
        let mut response_buf = Vec::with_capacity(65536);

        loop {
            line.clear();
            match reader.read_line(&mut line).await {
                Ok(0) => break,
                Ok(_) => {
                    process_one_line(&line, &self.pool, &self.config, &mut response_buf, &mut stdout).await?;
                }
                Err(e) => {
                    error!("IO error: {}", e);
                    break;
                }
            }
        }
        Ok(())
    }

    pub async fn run(&self) -> MCPResult<()> {
        let addr = format!("{}:{}", self.config.server.host, self.config.server.port);
        let listener = TcpListener::bind(&addr).await?;

        tracing::info!("MCP server listening on {}", addr);

        loop {
            let (socket, peer_addr) = listener.accept().await?;

            if let Err(e) = socket.set_nodelay(true) {
                warn!("Failed to set TCP_NODELAY: {}", e);
            }
            // Apply TCP socket options via raw fd (SO_KEEPALIVE, buffer sizes)
            use std::os::unix::io::AsRawFd;
            let raw = socket.as_raw_fd();
            let on: libc::c_int = 1;
            unsafe {
                libc::setsockopt(raw, libc::SOL_SOCKET, libc::SO_KEEPALIVE, &on as *const _ as *const libc::c_void, std::mem::size_of_val(&on) as libc::socklen_t);
                let buf_size: libc::c_int = 4 * 1024 * 1024;
                libc::setsockopt(raw, libc::SOL_SOCKET, libc::SO_RCVBUF, &buf_size as *const _ as *const libc::c_void, std::mem::size_of_val(&buf_size) as libc::socklen_t);
                libc::setsockopt(raw, libc::SOL_SOCKET, libc::SO_SNDBUF, &buf_size as *const _ as *const libc::c_void, std::mem::size_of_val(&buf_size) as libc::socklen_t);
            }

            let pool = Arc::clone(&self.pool);
            let config = self.config.clone();

            tokio::spawn(async move {
                if let Err(e) = handle_client(socket, pool, config).await {
                    error!("Client {} error: {}", peer_addr, e);
                }
            });
        }
    }
}

#[inline(never)]
async fn handle_client(socket: TcpStream, pool: Arc<ConnectionPool>, config: Config) -> MCPResult<()> {
    let (reader, mut writer) = socket.into_split();
    let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, reader);
    let mut line = String::with_capacity(512);
    let mut response_buf = Vec::with_capacity(65536);

    loop {
        line.clear();
        match reader.read_line(&mut line).await {
            Ok(0) => break,
            Ok(_) => {
                process_one_line(&line, &pool, &config, &mut response_buf, &mut writer).await?;
            }
            Err(e) => {
                error!("IO error: {}", e);
                break;
            }
        }
    }

    Ok(())
}

/// Core per-line processing shared by TCP and stdio transports.
#[inline]
async fn process_one_line<W: AsyncWriteExt + Unpin>(
    line: &str,
    pool: &Arc<ConnectionPool>,
    config: &Config,
    response_buf: &mut Vec<u8>,
    writer: &mut W,
) -> MCPResult<()> {
    metrics::inc_requests();

    let response = match parse_request(line) {
        Ok(req) => match process_request(&req, pool, config).await {
            Ok(result) => JsonRpcResponse::success(req.id, result),
            Err(e) => {
                metrics::inc_errors();
                JsonRpcResponse::error(req.id, e.error_code(), e.to_string())
            }
        },
        Err(e) => {
            metrics::inc_errors();
            parse_error(e)
        }
    };

    response_buf.clear();
    serde_json::to_writer(&mut *response_buf, &response)?;
    response_buf.extend_from_slice(NEWLINE);

    writer.write_all(response_buf).await?;
    writer.flush().await?;
    Ok(())
}

#[inline]
async fn process_request(
    req: &JsonRpcRequest,
    pool: &Arc<ConnectionPool>,
    config: &Config,
) -> MCPResult<Value> {
    match req.method.as_str() {
        "initialize" => handle_initialize(req),
        "tools/list" => handle_tools_list(),
        "tools/call" => handle_tools_call(req, pool, config).await,
        _ => Err(MCPError::MethodNotFound(req.method.clone())),
    }
}

#[inline]
fn handle_initialize(_req: &JsonRpcRequest) -> MCPResult<Value> {
    Ok(json!({
        "protocolVersion": "2024-11-05",
        "capabilities": {
            "tools": {
                "listChanged": false
            },
            "resources": {
                "subscribe": false,
                "listChanged": false
            },
            "prompts": {
                "listChanged": false
            }
        },
        "serverInfo": {
            "name": "mcp-postgres",
            "version": env!("CARGO_PKG_VERSION")
        }
    }))
}

#[inline]
fn handle_tools_list() -> MCPResult<Value> {
    Ok((*TOOLS_LIST).clone())
}

async fn handle_tools_call(
    req: &JsonRpcRequest,
    pool: &Arc<ConnectionPool>,
    _config: &Config,
) -> MCPResult<Value> {
    let tool_name = req
        .params
        .as_ref()
        .and_then(|p| p.get("name").and_then(|v| v.as_str()))
        .ok_or_else(|| MCPError::InvalidParams("Missing 'name' parameter".into()))?;

    let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));

    let client = pool.acquire().await?;

    let result = match tool_name {
        // Schema actions
        "list_tables" => actions::schema::list_tables(&client, tool_args.cloned()).await,
        "describe_table" => actions::schema::describe_table(&client, tool_args.cloned()).await,
        "list_indexes" => actions::schema::list_indexes(&client, tool_args.cloned()).await,
        "list_schemas" => actions::schema::list_schemas(&client, tool_args.cloned()).await,
        "show_constraints" => actions::schema::show_constraints(&client, tool_args.cloned()).await,
        // Query actions
        "execute_query" => actions::query::execute_query(&client, tool_args.cloned()).await,
        "execute_insert" => actions::query::execute_insert(&client, tool_args.cloned()).await,
        "execute_update" => actions::query::execute_update(&client, tool_args.cloned()).await,
        "execute_delete" => actions::query::execute_delete(&client, tool_args.cloned()).await,
        "explain_query" => actions::query::explain_query(&client, tool_args.cloned()).await,
        // Batch operations (high-performance for bulk loads)
        "batch_insert" => actions::batch::batch_insert(&client, tool_args.cloned()).await,
        "batch_update" => actions::batch::batch_update(&client, tool_args.cloned()).await,
        "batch_delete" => actions::batch::batch_delete(&client, tool_args.cloned()).await,
        "batch_insert_copy" => actions::batch::batch_insert_copy(&client, tool_args.cloned()).await,
        // Monitoring actions
        "get_table_stats" => actions::monitoring::get_table_stats(&client, tool_args.cloned()).await,
        "get_index_stats" => actions::monitoring::get_index_stats(&client, tool_args.cloned()).await,
        "show_database_size" => actions::monitoring::show_database_size(&client, tool_args.cloned()).await,
        "show_table_size" => actions::monitoring::show_table_size(&client, tool_args.cloned()).await,
        "get_cache_hit_ratio" => actions::monitoring::get_cache_hit_ratio(&client, tool_args.cloned()).await,
        // Connection actions
        "list_connections" => actions::connections::list_connections(&client, tool_args.cloned()).await,
        "kill_connection" => actions::connections::kill_connection(&client, tool_args.cloned()).await,
        "show_current_user" => actions::connections::show_current_user(&client, tool_args.cloned()).await,
        "show_running_queries" => actions::connections::show_running_queries(&client, tool_args.cloned()).await,
        "show_connection_summary" => actions::connections::show_connection_summary(&client, tool_args.cloned()).await,
        // Maintenance actions
        "vacuum_analyze" => actions::maintenance::vacuum_analyze(&client, tool_args.cloned()).await,
        "analyze_table" => actions::maintenance::analyze_table(&client, tool_args.cloned()).await,
        "reindex_table" => actions::maintenance::reindex_table(&client, tool_args.cloned()).await,
        "get_pg_stat_statements" => actions::maintenance::get_pg_stat_statements(&client, tool_args.cloned()).await,
        "reset_statistics" => actions::maintenance::reset_statistics(&client, tool_args.cloned()).await,
        // Security actions
        "list_users" => actions::security::list_users(&client, tool_args.cloned()).await,
        "list_user_privileges" => actions::security::list_user_privileges(&client, tool_args.cloned()).await,
        "list_role_memberships" => actions::security::list_role_memberships(&client, tool_args.cloned()).await,
        "list_database_privileges" => actions::security::list_database_privileges(&client, tool_args.cloned()).await,
        "show_session_info" => actions::security::show_session_info(&client, tool_args.cloned()).await,
        // Config actions
        "show_all_settings" => actions::config::show_all_settings(&client, tool_args.cloned()).await,
        "get_setting" => actions::config::get_setting(&client, tool_args.cloned()).await,
        "show_memory_settings" => actions::config::show_memory_settings(&client, tool_args.cloned()).await,
        "show_performance_settings" => actions::config::show_performance_settings(&client, tool_args.cloned()).await,
        "show_log_settings" => actions::config::show_log_settings(&client, tool_args.cloned()).await,
        // Replication actions
        "show_replication_status" => actions::replication::show_replication_status(&client, tool_args.cloned()).await,
        "list_replication_slots" => actions::replication::list_replication_slots(&client, tool_args.cloned()).await,
        "list_standby_servers" => actions::replication::list_standby_servers(&client, tool_args.cloned()).await,
        "show_wal_info" => actions::replication::show_wal_info(&client, tool_args.cloned()).await,
        "show_base_backup_progress" => actions::replication::show_base_backup_progress(&client, tool_args.cloned()).await,
        // Transaction actions
        "show_active_transactions" => actions::transactions::show_active_transactions(&client, tool_args.cloned()).await,
        "show_locks" => actions::transactions::show_locks(&client, tool_args.cloned()).await,
        "show_waiting_locks" => actions::transactions::show_waiting_locks(&client, tool_args.cloned()).await,
        "begin_transaction" => actions::transactions::begin_transaction(&client, tool_args.cloned()).await,
        "commit_transaction" => actions::transactions::commit_transaction(&client, tool_args.cloned()).await,
        "rollback_transaction" => actions::transactions::rollback_transaction(&client, tool_args.cloned()).await,
        "show_transaction_isolation" => actions::transactions::show_transaction_isolation(&client, tool_args.cloned()).await,
        "show_deadlocks" => actions::transactions::show_deadlocks(&client, tool_args.cloned()).await,
        "show_autocommit_status" => actions::transactions::show_autocommit_status(&client, tool_args.cloned()).await,
        "show_transaction_timeout" => actions::transactions::show_transaction_timeout(&client, tool_args.cloned()).await,
        _ => Err(method_not_found(tool_name)),
    };

    pool.release(client);
    result
}

#[cold]
fn method_not_found(name: &str) -> MCPError {
    MCPError::MethodNotFound(name.to_string())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_valid_request() {
        let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
        let req = parse_request(line).unwrap();
        assert_eq!(req.method, "initialize");
        assert_eq!(req.id, Some(Value::Number(1.into())));
    }

    #[test]
    fn test_parse_request_with_trailing_newline() {
        let line = r#"{"jsonrpc":"2.0","method":"tools/list","id":2}"#;
        let req = parse_request(line).unwrap();
        assert_eq!(req.method, "tools/list");
    }

    #[test]
    fn test_parse_request_with_whitespace() {
        let line = "  {\"jsonrpc\":\"2.0\",\"method\":\"ping\",\"id\":3}  ";
        let req = parse_request(line).unwrap();
        assert_eq!(req.method, "ping");
    }

    #[test]
    fn test_parse_empty_request() {
        let err = parse_request("").unwrap_err();
        assert_eq!(err, "Empty request");
    }

    #[test]
    fn test_parse_whitespace_only() {
        let err = parse_request("   \n  ").unwrap_err();
        assert_eq!(err, "Empty request");
    }

    #[test]
    fn test_parse_invalid_json() {
        let err = parse_request("{invalid}").unwrap_err();
        assert!(!err.is_empty(), "Invalid JSON should produce an error message");
    }

    #[test]
    fn test_parse_missing_method() {
        let err = parse_request(r#"{"jsonrpc":"2.0","id":1}"#).unwrap_err();
        assert!(err.contains("method"));
    }

    #[test]
    fn test_parse_wrong_version() {
        let req = parse_request(r#"{"jsonrpc":"1.0","method":"init","id":1}"#).unwrap();
        assert_eq!(req.jsonrpc, "1.0");
    }

    #[test]
    fn test_method_not_found_error() {
        let err = method_not_found("test_tool");
        assert_eq!(err.error_code(), -32601);
        assert!(err.to_string().contains("test_tool"));
    }

    #[test]
    fn test_tools_list_static() {
        let list = &*TOOLS_LIST;
        let tools = list.get("tools").and_then(|v| v.as_array());
        assert!(tools.is_some(), "TOOLS_LIST should contain a tools array");
        assert!(!tools.unwrap().is_empty(), "Tools list should not be empty");
    }

    #[test]
    fn test_process_request_method_dispatch() {
        // Verify that process_request handles the dispatch correctly
        // by testing the match on method strings — this is a compilation/coverage test
        let _req = JsonRpcRequest {
            jsonrpc: "2.0".to_string(),
            method: "nonexistent".to_string(),
            params: None,
            id: Some(Value::Number(1.into())),
        };
        // We can't run process_request without a pool, but we can verify the fallback path
        // acts as expected through separate unit tests on the dispatch logic
    }

    #[test]
    fn test_handle_initialize_response() {
        let req = JsonRpcRequest {
            jsonrpc: "2.0".to_string(),
            method: "initialize".to_string(),
            params: None,
            id: Some(Value::Number(1.into())),
        };
        let result = handle_initialize(&req).unwrap();
        assert_eq!(result["protocolVersion"], "2024-11-05");
        assert!(result["capabilities"]["tools"]["listChanged"].is_boolean());
        assert_eq!(result["serverInfo"]["version"], env!("CARGO_PKG_VERSION"));
    }
}