Skip to main content

database_mcp_postgres/
connection.rs

1//! `PostgreSQL` connection: pool cache, pool initialization, and [`Connection`] impl.
2//!
3//! Owns a moka cache of lazily-created per-database pools (including the default).
4//! Hides every backend pool concern from [`PostgresHandler`](crate::PostgresHandler),
5//! which composes one [`PostgresConnection`] as a field.
6
7use std::time::Duration;
8
9use database_mcp_config::DatabaseConfig;
10use database_mcp_sql::Connection;
11use database_mcp_sql::SqlError;
12use database_mcp_sql::sanitize::validate_ident;
13use moka::future::Cache;
14use sqlx::postgres::{PgConnectOptions, PgPool, PgSslMode};
15use tracing::info;
16
17/// Maximum number of cached per-database connection pools.
18pub(crate) const POOL_CACHE_CAPACITY: u64 = 16;
19
20/// Owns every `PgPool` the handler uses and the logic that builds them.
21#[derive(Clone)]
22pub(crate) struct PostgresConnection {
23    config: DatabaseConfig,
24    pools: Cache<String, PgPool>,
25}
26
27impl std::fmt::Debug for PostgresConnection {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("PostgresConnection")
30            .field("default_database_name", &self.default_database_name())
31            .finish_non_exhaustive()
32    }
33}
34
35impl PostgresConnection {
36    /// Builds the connection with an empty pool cache.
37    ///
38    /// Does **not** establish a database connection. All pools — including
39    /// the default — are created lazily on first request via [`pool`](Self::pool).
40    pub(crate) fn new(config: &DatabaseConfig) -> Self {
41        info!(
42            "PostgreSQL lazy connection pool created (max size: {})",
43            config.max_pool_size
44        );
45
46        let pools = Cache::builder()
47            .max_capacity(POOL_CACHE_CAPACITY)
48            .eviction_listener(|_key, pool: PgPool, _cause| {
49                tokio::spawn(async move {
50                    pool.close().await;
51                });
52            })
53            .build();
54
55        Self {
56            config: config.clone(),
57            pools,
58        }
59    }
60
61    /// Returns the configured default database name, or the username as fallback.
62    pub(crate) fn default_database_name(&self) -> &str {
63        self.config
64            .name
65            .as_deref()
66            .filter(|n| !n.is_empty())
67            .unwrap_or(&self.config.user)
68    }
69
70    /// Evicts the cached pool for `name`, closing its connections.
71    ///
72    /// Idempotent — does nothing if the pool was not cached.
73    pub(crate) async fn invalidate(&self, name: &str) {
74        self.pools.invalidate(name).await;
75    }
76
77    /// Resolves the cached pool for `target`, creating it lazily on miss.
78    ///
79    /// Kept crate-private so every tool path goes through the unified
80    /// [`Connection`] methods and cannot bypass timeout / error capture.
81    ///
82    /// # Errors
83    ///
84    /// - [`SqlError::InvalidIdentifier`] — `target` failed identifier validation.
85    pub(crate) async fn pool(&self, target: Option<&str>) -> Result<PgPool, SqlError> {
86        let database = match target {
87            Some(name) if !name.is_empty() => name,
88            _ => self.default_database_name(),
89        };
90
91        if let Some(pool) = self.pools.get(database).await {
92            return Ok(pool);
93        }
94
95        if database != self.default_database_name() {
96            validate_ident(database)?;
97        }
98
99        let pool = self
100            .pools
101            .get_with(database.to_owned(), async { create_lazy_pool(&self.config, database) })
102            .await;
103
104        Ok(pool)
105    }
106}
107
108impl Connection for PostgresConnection {
109    type DB = sqlx::Postgres;
110
111    async fn pool(&self, target: Option<&str>) -> Result<sqlx::Pool<Self::DB>, SqlError> {
112        self.pool(target).await
113    }
114
115    fn query_timeout(&self) -> Option<u64> {
116        self.config.query_timeout
117    }
118}
119
120/// Creates a lazy `PostgreSQL` pool for `db_name`.
121///
122/// Uses [`PgConnectOptions::new_without_pgpass`] to avoid unintended
123/// `PG*` environment variable influence, since our config already
124/// resolves values from CLI/env.
125fn create_lazy_pool(config: &DatabaseConfig, database: &str) -> PgPool {
126    let mut conn_ops = PgConnectOptions::new_without_pgpass()
127        .host(&config.host)
128        .port(config.port)
129        .username(&config.user);
130
131    if let Some(ref password) = config.password {
132        conn_ops = conn_ops.password(password);
133    }
134    if !database.is_empty() {
135        conn_ops = conn_ops.database(database);
136    }
137
138    if config.ssl {
139        conn_ops = if config.ssl_verify_cert {
140            conn_ops.ssl_mode(PgSslMode::VerifyCa)
141        } else {
142            conn_ops.ssl_mode(PgSslMode::Require)
143        };
144        if let Some(ref ca) = config.ssl_ca {
145            conn_ops = conn_ops.ssl_root_cert(ca);
146        }
147        if let Some(ref cert) = config.ssl_cert {
148            conn_ops = conn_ops.ssl_client_cert(cert);
149        }
150        if let Some(ref key) = config.ssl_key {
151            conn_ops = conn_ops.ssl_client_key(key);
152        }
153    }
154
155    let mut pool_opts = sqlx::pool::PoolOptions::new()
156        .max_connections(config.max_pool_size)
157        .min_connections(DatabaseConfig::DEFAULT_MIN_CONNECTIONS)
158        .idle_timeout(Duration::from_secs(DatabaseConfig::DEFAULT_IDLE_TIMEOUT_SECS))
159        .max_lifetime(Duration::from_secs(DatabaseConfig::DEFAULT_MAX_LIFETIME_SECS));
160
161    if let Some(timeout) = config.connection_timeout {
162        pool_opts = pool_opts.acquire_timeout(Duration::from_secs(timeout));
163    }
164
165    pool_opts.connect_lazy_with(conn_ops)
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use database_mcp_config::DatabaseBackend;
172
173    fn base_config() -> DatabaseConfig {
174        DatabaseConfig {
175            backend: DatabaseBackend::Postgres,
176            host: "pg.example.com".into(),
177            port: 5433,
178            user: "pgadmin".into(),
179            password: Some("pgpass".into()),
180            name: Some("mydb".into()),
181            ..DatabaseConfig::default()
182        }
183    }
184
185    #[tokio::test]
186    async fn create_lazy_pool_returns_idle_pool() {
187        let pool = create_lazy_pool(&base_config(), "mydb");
188        assert_eq!(pool.size(), 0, "pool should be lazy (no connections yet)");
189    }
190
191    #[tokio::test]
192    async fn create_lazy_pool_without_password() {
193        let pool = create_lazy_pool(
194            &DatabaseConfig {
195                password: None,
196                ..base_config()
197            },
198            "mydb",
199        );
200        assert_eq!(pool.size(), 0);
201    }
202
203    #[tokio::test]
204    async fn create_lazy_pool_without_database_name() {
205        let pool = create_lazy_pool(
206            &DatabaseConfig {
207                name: None,
208                ..base_config()
209            },
210            "",
211        );
212        assert_eq!(pool.size(), 0);
213    }
214
215    #[tokio::test]
216    async fn default_database_name_derived_from_config() {
217        let connection = PostgresConnection::new(&base_config());
218        assert_eq!(connection.default_database_name(), "mydb");
219    }
220
221    #[tokio::test]
222    async fn defaults_db_to_username_when_name_missing() {
223        let connection = PostgresConnection::new(&DatabaseConfig {
224            name: None,
225            ..base_config()
226        });
227        assert_eq!(connection.default_database_name(), "pgadmin");
228    }
229
230    #[tokio::test]
231    async fn none_target_returns_default_pool() {
232        let connection = PostgresConnection::new(&base_config());
233        connection.pool(None).await.expect("None target should succeed");
234    }
235
236    #[tokio::test]
237    async fn arbitrary_target_database_is_permitted() {
238        let connection = PostgresConnection::new(&base_config());
239        connection
240            .pool(Some("any_db"))
241            .await
242            .expect("any database should be permitted");
243    }
244
245    #[tokio::test]
246    async fn pool_cache_respects_capacity_const() {
247        let connection = PostgresConnection::new(&base_config());
248
249        // Insert one more pool than the cap; moka should evict the
250        // oldest so the cached count stays at or below POOL_CACHE_CAPACITY.
251        for i in 0..=POOL_CACHE_CAPACITY {
252            let name = format!("db_{i}");
253            connection.pool(Some(&name)).await.expect("pool should succeed");
254        }
255        connection.pools.run_pending_tasks().await;
256
257        assert!(
258            connection.pools.entry_count() <= POOL_CACHE_CAPACITY,
259            "cached pools exceeded cap: {} > {POOL_CACHE_CAPACITY}",
260            connection.pools.entry_count()
261        );
262    }
263}