Skip to main content

database_mcp_postgres/
adapter.rs

1//! `PostgreSQL` adapter definition and connection configuration.
2//!
3//! Creates a lazy default pool via [`PgPoolOptions::connect_lazy_with`].
4//! Non-default database pools are created on demand and cached in a
5//! moka [`Cache`].
6
7use std::time::Duration;
8
9use database_mcp_config::DatabaseConfig;
10use database_mcp_server::AppError;
11use database_mcp_sql::identifier::validate_identifier;
12use moka::future::Cache;
13use sqlx::PgPool;
14use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgSslMode};
15use tracing::info;
16
17/// Maximum number of database connection pools to cache (including the default).
18const POOL_CACHE_CAPACITY: u64 = 6;
19
20/// `PostgreSQL` database adapter.
21///
22/// The default connection pool is created with
23/// [`PgPoolOptions::connect_lazy_with`], which defers all network I/O
24/// until the first query. Non-default database pools are created on
25/// demand via the moka [`Cache`].
26#[derive(Clone)]
27pub struct PostgresAdapter {
28    pub(crate) config: DatabaseConfig,
29    pub(crate) default_db: String,
30    default_pool: PgPool,
31    pub(crate) pools: Cache<String, PgPool>,
32}
33
34impl std::fmt::Debug for PostgresAdapter {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("PostgresAdapter")
37            .field("read_only", &self.config.read_only)
38            .field("default_db", &self.default_db)
39            .finish_non_exhaustive()
40    }
41}
42
43impl PostgresAdapter {
44    /// Creates a new `PostgreSQL` adapter with a lazy connection pool.
45    ///
46    /// Does **not** establish a database connection. The default pool
47    /// connects on-demand when the first query is executed.
48    #[must_use]
49    pub fn new(config: &DatabaseConfig) -> Self {
50        // PostgreSQL defaults to a database named after the connecting user.
51        let default_db = config
52            .name
53            .as_deref()
54            .filter(|n| !n.is_empty())
55            .map_or_else(|| config.user.clone(), String::from);
56
57        let default_pool = pool_options(config).connect_lazy_with(connect_options(config));
58
59        info!(
60            "PostgreSQL lazy connection pool created (max size: {})",
61            config.max_pool_size
62        );
63
64        let pools = Cache::builder()
65            .max_capacity(POOL_CACHE_CAPACITY)
66            .eviction_listener(|_key, pool: PgPool, _cause| {
67                tokio::spawn(async move {
68                    pool.close().await;
69                });
70            })
71            .build();
72
73        Self {
74            config: config.clone(),
75            default_db,
76            default_pool,
77            pools,
78        }
79    }
80
81    /// Wraps `name` in double quotes for safe use in `PostgreSQL` SQL statements.
82    pub(crate) fn quote_identifier(name: &str) -> String {
83        database_mcp_sql::identifier::quote_identifier(name, '"')
84    }
85
86    /// Returns a connection pool for the requested database.
87    ///
88    /// Resolves `None` or empty names to the default lazy pool. On a
89    /// cache miss for a non-default database, a new lazy pool is created
90    /// and cached. Evicted pools are closed via the cache's eviction
91    /// listener.
92    ///
93    /// # Errors
94    ///
95    /// Returns [`AppError::InvalidIdentifier`] if the database name fails
96    /// validation.
97    pub(crate) async fn get_pool(&self, database: Option<&str>) -> Result<PgPool, AppError> {
98        let db_key = match database {
99            Some(name) if !name.is_empty() => name,
100            _ => return Ok(self.default_pool.clone()),
101        };
102
103        // Check if it's the default database by name.
104        if db_key == self.default_db {
105            return Ok(self.default_pool.clone());
106        }
107
108        // Non-default database: check cache first.
109        if let Some(pool) = self.pools.get(db_key).await {
110            return Ok(pool);
111        }
112
113        // Cache miss — validate then create a new lazy pool.
114        validate_identifier(db_key)?;
115
116        let config = self.config.clone();
117        let db_key_owned = db_key.to_owned();
118
119        let pool = self
120            .pools
121            .get_with(db_key_owned, async {
122                let mut cfg = config;
123                cfg.name = Some(db_key.to_owned());
124                pool_options(&cfg).connect_lazy_with(connect_options(&cfg))
125            })
126            .await;
127
128        Ok(pool)
129    }
130}
131
132/// Builds [`PgPoolOptions`] with lifecycle defaults from a [`DatabaseConfig`].
133fn pool_options(config: &DatabaseConfig) -> PgPoolOptions {
134    let mut opts = PgPoolOptions::new()
135        .max_connections(config.max_pool_size)
136        .min_connections(DatabaseConfig::DEFAULT_MIN_CONNECTIONS)
137        .idle_timeout(Duration::from_secs(DatabaseConfig::DEFAULT_IDLE_TIMEOUT_SECS))
138        .max_lifetime(Duration::from_secs(DatabaseConfig::DEFAULT_MAX_LIFETIME_SECS));
139
140    if let Some(timeout) = config.connection_timeout {
141        opts = opts.acquire_timeout(Duration::from_secs(timeout));
142    }
143
144    opts
145}
146
147/// Builds [`PgConnectOptions`] from a [`DatabaseConfig`].
148///
149/// Uses [`PgConnectOptions::new_without_pgpass`] to avoid unintended
150/// `PG*` environment variable influence, since our config already
151/// resolves values from CLI/env.
152fn connect_options(config: &DatabaseConfig) -> PgConnectOptions {
153    let mut opts = PgConnectOptions::new_without_pgpass()
154        .host(&config.host)
155        .port(config.port)
156        .username(&config.user);
157
158    if let Some(ref password) = config.password {
159        opts = opts.password(password);
160    }
161    if let Some(ref name) = config.name
162        && !name.is_empty()
163    {
164        opts = opts.database(name);
165    }
166
167    if config.ssl {
168        opts = if config.ssl_verify_cert {
169            opts.ssl_mode(PgSslMode::VerifyCa)
170        } else {
171            opts.ssl_mode(PgSslMode::Require)
172        };
173        if let Some(ref ca) = config.ssl_ca {
174            opts = opts.ssl_root_cert(ca);
175        }
176        if let Some(ref cert) = config.ssl_cert {
177            opts = opts.ssl_client_cert(cert);
178        }
179        if let Some(ref key) = config.ssl_key {
180            opts = opts.ssl_client_key(key);
181        }
182    }
183
184    opts
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use database_mcp_config::DatabaseBackend;
191
192    fn base_config() -> DatabaseConfig {
193        DatabaseConfig {
194            backend: DatabaseBackend::Postgres,
195            host: "pg.example.com".into(),
196            port: 5433,
197            user: "pgadmin".into(),
198            password: Some("pgpass".into()),
199            name: Some("mydb".into()),
200            ..DatabaseConfig::default()
201        }
202    }
203
204    #[test]
205    fn pool_options_applies_defaults() {
206        let config = base_config();
207        let opts = pool_options(&config);
208
209        assert_eq!(opts.get_max_connections(), config.max_pool_size);
210        assert_eq!(opts.get_min_connections(), DatabaseConfig::DEFAULT_MIN_CONNECTIONS);
211        assert_eq!(
212            opts.get_idle_timeout(),
213            Some(Duration::from_secs(DatabaseConfig::DEFAULT_IDLE_TIMEOUT_SECS))
214        );
215        assert_eq!(
216            opts.get_max_lifetime(),
217            Some(Duration::from_secs(DatabaseConfig::DEFAULT_MAX_LIFETIME_SECS))
218        );
219    }
220
221    #[test]
222    fn pool_options_applies_connection_timeout() {
223        let config = DatabaseConfig {
224            connection_timeout: Some(7),
225            ..base_config()
226        };
227        let opts = pool_options(&config);
228
229        assert_eq!(opts.get_acquire_timeout(), Duration::from_secs(7));
230    }
231
232    #[test]
233    fn pool_options_without_connection_timeout_uses_sqlx_default() {
234        let config = base_config();
235        let opts = pool_options(&config);
236
237        assert_eq!(opts.get_acquire_timeout(), Duration::from_secs(30));
238    }
239
240    #[test]
241    fn try_from_basic_config() {
242        let config = base_config();
243        let opts = connect_options(&config);
244
245        assert_eq!(opts.get_host(), "pg.example.com");
246        assert_eq!(opts.get_port(), 5433);
247        assert_eq!(opts.get_username(), "pgadmin");
248        assert_eq!(opts.get_database(), Some("mydb"));
249    }
250
251    #[test]
252    fn try_from_with_ssl_require() {
253        let config = DatabaseConfig {
254            ssl: true,
255            ssl_verify_cert: false,
256            ..base_config()
257        };
258        let opts = connect_options(&config);
259
260        assert!(
261            matches!(opts.get_ssl_mode(), PgSslMode::Require),
262            "expected Require, got {:?}",
263            opts.get_ssl_mode()
264        );
265    }
266
267    #[test]
268    fn try_from_with_ssl_verify_ca() {
269        let config = DatabaseConfig {
270            ssl: true,
271            ssl_verify_cert: true,
272            ..base_config()
273        };
274        let opts = connect_options(&config);
275
276        assert!(
277            matches!(opts.get_ssl_mode(), PgSslMode::VerifyCa),
278            "expected VerifyCa, got {:?}",
279            opts.get_ssl_mode()
280        );
281    }
282
283    #[test]
284    fn try_from_without_database_name() {
285        let config = DatabaseConfig {
286            name: None,
287            ..base_config()
288        };
289        let opts = connect_options(&config);
290
291        assert_eq!(opts.get_database(), None);
292    }
293
294    #[test]
295    fn try_from_without_password() {
296        let config = DatabaseConfig {
297            password: None,
298            ..base_config()
299        };
300        let opts = connect_options(&config);
301
302        assert_eq!(opts.get_host(), "pg.example.com");
303    }
304
305    #[tokio::test]
306    async fn new_creates_lazy_pool() {
307        let config = base_config();
308        let adapter = PostgresAdapter::new(&config);
309        assert_eq!(adapter.default_db, "mydb");
310        // Pool exists but has no active connections (lazy).
311        assert_eq!(adapter.default_pool.size(), 0);
312    }
313
314    #[tokio::test]
315    async fn new_defaults_db_to_username() {
316        let config = DatabaseConfig {
317            name: None,
318            ..base_config()
319        };
320        let adapter = PostgresAdapter::new(&config);
321        assert_eq!(adapter.default_db, "pgadmin");
322    }
323}