mcp-postgres 1.2.2

High-performance MCP server for PostgreSQL with CPU-aware connection pooling and optimized buffers
Documentation
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::str::FromStr;
use std::time::Duration;

#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum AccessMode {
    #[serde(rename = "unrestricted")]
    Unrestricted,
    #[serde(rename = "restricted")]
    Restricted,
}

impl fmt::Display for AccessMode {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            AccessMode::Unrestricted => write!(f, "unrestricted"),
            AccessMode::Restricted => write!(f, "restricted"),
        }
    }
}

impl FromStr for AccessMode {
    type Err = String;
    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
        match s.to_lowercase().as_str() {
            "unrestricted" => Ok(AccessMode::Unrestricted),
            "restricted" => Ok(AccessMode::Restricted),
            _ => Err(format!("Invalid access mode: {s}. Use 'unrestricted' or 'restricted'")),
        }
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
    pub database: DatabaseConfig,
    pub server: ServerConfig,
    pub pool: PoolConfig,
    pub metrics: MetricsConfig,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
    pub url: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
    pub request_timeout: Duration,
    pub access_mode: AccessMode,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PoolConfig {
    pub min_size: u32,
    pub max_size: u32,
    pub queue_timeout: Duration,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsConfig {
    pub enabled: bool,
    pub port: u16,
}

impl Config {
    pub fn from_args(args: &super::Args) -> Result<Self> {
        let database_url = args.database_url.clone()
            .or_else(|| std::env::var("DATABASE_URL").ok())
            .unwrap_or_else(|| "postgres://postgres:postgres@localhost:5432/postgres".to_string());

        let min_size = args.min_connections.unwrap_or(5);
        let max_size = args.max_connections.unwrap_or(20);

        Ok(Config {
            database: DatabaseConfig {
                url: database_url,
            },
            server: ServerConfig {
                host: args.host.clone(),
                port: args.port,
                request_timeout: Duration::from_secs(30),
                access_mode: args.access_mode,
            },
            pool: PoolConfig {
                min_size,
                max_size,
                queue_timeout: Duration::from_secs(10),
            },
            metrics: MetricsConfig {
                enabled: args.enable_metrics,
                port: args.metrics_port,
            },
        })
    }
}

impl Default for Config {
    fn default() -> Self {
        Self {
            database: DatabaseConfig {
                url: "postgres://postgres:postgres@localhost:5432/postgres".to_string(),
            },
            server: ServerConfig {
                host: "127.0.0.1".to_string(),
                port: 3000,
                request_timeout: Duration::from_secs(30),
                access_mode: AccessMode::Unrestricted,
            },
            pool: PoolConfig {
                min_size: 5,
                max_size: 20,
                queue_timeout: Duration::from_secs(10),
            },
            metrics: MetricsConfig {
                enabled: false,
                port: 9090,
            },
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use clap::Parser;

    #[test]
    fn test_config_defaults() {
        let cfg = Config::default();
        assert_eq!(cfg.server.host, "127.0.0.1");
        assert_eq!(cfg.server.port, 3000);
        assert_eq!(cfg.server.request_timeout, Duration::from_secs(30));
    }

    #[test]
    fn test_database_config_defaults() {
        let cfg = Config::default();
        assert_eq!(cfg.database.url, "postgres://postgres:postgres@localhost:5432/postgres");
    }

    #[test]
    fn test_pool_config_defaults() {
        let cfg = Config::default();
        assert_eq!(cfg.pool.min_size, 5);
        assert_eq!(cfg.pool.max_size, 20);
        assert_eq!(cfg.pool.queue_timeout, Duration::from_secs(10));
    }

    #[test]
    fn test_metrics_config_defaults() {
        let cfg = Config::default();
        assert!(!cfg.metrics.enabled);
        assert_eq!(cfg.metrics.port, 9090);
    }

    #[test]
    fn test_config_serde() {
        let cfg = Config::default();
        let json = serde_json::to_string(&cfg).unwrap();
        let deserialized: Config = serde_json::from_str(&json).unwrap();
        assert_eq!(deserialized.server.port, cfg.server.port);
        assert_eq!(deserialized.pool.min_size, cfg.pool.min_size);
        assert_eq!(deserialized.database.url, cfg.database.url);
    }

    #[test]
    fn test_config_from_args_cpu_aware() {
        let num_cpus = num_cpus::get() as u32;

        // Simulating what from_args does with defaults
        let min_size = 1;
        let max_size = num_cpus * 8;

        assert_eq!(min_size, 1);
        assert!(max_size > 0);
        assert_eq!(max_size, num_cpus * 8);
    }

    #[test]
    fn test_pool_config_values() {
        let cfg = Config::default();
        assert!(cfg.pool.min_size > 0);
        assert!(cfg.pool.max_size >= cfg.pool.min_size);
    }

    #[test]
    fn test_server_config_debug() {
        let cfg = Config::default();
        let debug = format!("{:?}", cfg);
        assert!(debug.contains("127.0.0.1"));
        assert!(debug.contains("3000"));
    }
}