use database_mcp_config::DatabaseConfig;
use database_mcp_server::AppError;
use database_mcp_sql::identifier::validate_identifier;
use moka::future::Cache;
use sqlx::PgPool;
use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgSslMode};
use tracing::info;
const POOL_CACHE_CAPACITY: u64 = 6;
#[derive(Clone)]
pub struct PostgresAdapter {
pub(crate) config: DatabaseConfig,
default_db: String,
pools: Cache<String, PgPool>,
}
impl std::fmt::Debug for PostgresAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PostgresAdapter")
.field("read_only", &self.config.read_only)
.field("default_db", &self.default_db)
.finish_non_exhaustive()
}
}
impl PostgresAdapter {
pub async fn new(config: &DatabaseConfig) -> Result<Self, AppError> {
let pool = PgPoolOptions::new()
.max_connections(config.max_pool_size)
.connect_with(connect_options(config))
.await
.map_err(|e| AppError::Connection(format!("Failed to connect to PostgreSQL: {e}")))?;
info!(
"PostgreSQL connection pool initialized (max size: {})",
config.max_pool_size
);
let default_db = config
.name
.as_deref()
.filter(|n| !n.is_empty())
.map_or_else(|| config.user.clone(), String::from);
let pools = Cache::builder()
.max_capacity(POOL_CACHE_CAPACITY)
.eviction_listener(|_key, pool: PgPool, _cause| {
tokio::spawn(async move {
pool.close().await;
});
})
.build();
pools.insert(default_db.clone(), pool).await;
Ok(Self {
config: config.clone(),
default_db,
pools,
})
}
pub(crate) fn quote_identifier(name: &str) -> String {
database_mcp_sql::identifier::quote_identifier(name, '"')
}
pub(crate) async fn get_pool(&self, database: Option<&str>) -> Result<PgPool, AppError> {
let db_key = match database {
Some(name) if !name.is_empty() => name,
_ => &self.default_db,
};
if let Some(pool) = self.pools.get(db_key).await {
return Ok(pool);
}
validate_identifier(db_key)?;
let config = self.config.clone();
let db_key_owned = db_key.to_owned();
let pool = self
.pools
.try_get_with(db_key_owned, async {
let mut cfg = config;
cfg.name = Some(db_key.to_owned());
PgPoolOptions::new()
.max_connections(cfg.max_pool_size)
.connect_with(connect_options(&cfg))
.await
.map_err(|e| {
AppError::Connection(format!("Failed to connect to PostgreSQL database '{db_key}': {e}"))
})
})
.await
.map_err(|e| match e.as_ref() {
AppError::Connection(msg) => AppError::Connection(msg.clone()),
other => AppError::Connection(other.to_string()),
})?;
Ok(pool)
}
}
fn connect_options(config: &DatabaseConfig) -> PgConnectOptions {
let mut opts = PgConnectOptions::new_without_pgpass()
.host(&config.host)
.port(config.port)
.username(&config.user);
if let Some(ref password) = config.password {
opts = opts.password(password);
}
if let Some(ref name) = config.name
&& !name.is_empty()
{
opts = opts.database(name);
}
if config.ssl {
opts = if config.ssl_verify_cert {
opts.ssl_mode(PgSslMode::VerifyCa)
} else {
opts.ssl_mode(PgSslMode::Require)
};
if let Some(ref ca) = config.ssl_ca {
opts = opts.ssl_root_cert(ca);
}
if let Some(ref cert) = config.ssl_cert {
opts = opts.ssl_client_cert(cert);
}
if let Some(ref key) = config.ssl_key {
opts = opts.ssl_client_key(key);
}
}
opts
}
#[cfg(test)]
mod tests {
use super::*;
use database_mcp_config::DatabaseBackend;
fn base_config() -> DatabaseConfig {
DatabaseConfig {
backend: DatabaseBackend::Postgres,
host: "pg.example.com".into(),
port: 5433,
user: "pgadmin".into(),
password: Some("pgpass".into()),
name: Some("mydb".into()),
..DatabaseConfig::default()
}
}
#[test]
fn try_from_basic_config() {
let config = base_config();
let opts = connect_options(&config);
assert_eq!(opts.get_host(), "pg.example.com");
assert_eq!(opts.get_port(), 5433);
assert_eq!(opts.get_username(), "pgadmin");
assert_eq!(opts.get_database(), Some("mydb"));
}
#[test]
fn try_from_with_ssl_require() {
let config = DatabaseConfig {
ssl: true,
ssl_verify_cert: false,
..base_config()
};
let opts = connect_options(&config);
assert!(
matches!(opts.get_ssl_mode(), PgSslMode::Require),
"expected Require, got {:?}",
opts.get_ssl_mode()
);
}
#[test]
fn try_from_with_ssl_verify_ca() {
let config = DatabaseConfig {
ssl: true,
ssl_verify_cert: true,
..base_config()
};
let opts = connect_options(&config);
assert!(
matches!(opts.get_ssl_mode(), PgSslMode::VerifyCa),
"expected VerifyCa, got {:?}",
opts.get_ssl_mode()
);
}
#[test]
fn try_from_without_database_name() {
let config = DatabaseConfig {
name: None,
..base_config()
};
let opts = connect_options(&config);
assert_eq!(opts.get_database(), None);
}
#[test]
fn try_from_without_password() {
let config = DatabaseConfig {
password: None,
..base_config()
};
let opts = connect_options(&config);
assert_eq!(opts.get_host(), "pg.example.com");
}
}