Skip to main content

database_mcp_postgres/
adapter.rs

1//! `PostgreSQL` adapter definition and connection configuration.
2
3use database_mcp_config::DatabaseConfig;
4use database_mcp_server::AppError;
5use database_mcp_sql::identifier::validate_identifier;
6use moka::future::Cache;
7use sqlx::PgPool;
8use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgSslMode};
9use tracing::info;
10
11/// Maximum number of database connection pools to cache (including the default).
12const POOL_CACHE_CAPACITY: u64 = 6;
13
14/// `PostgreSQL` database adapter.
15///
16/// All connection pools — including the default — live in a single
17/// concurrent cache keyed by database name. No external mutex required.
18#[derive(Clone)]
19pub struct PostgresAdapter {
20    pub(crate) config: DatabaseConfig,
21    default_db: String,
22    pools: Cache<String, PgPool>,
23}
24
25impl std::fmt::Debug for PostgresAdapter {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("PostgresAdapter")
28            .field("read_only", &self.config.read_only)
29            .field("default_db", &self.default_db)
30            .finish_non_exhaustive()
31    }
32}
33
34impl PostgresAdapter {
35    /// Creates a new `PostgreSQL` adapter from configuration.
36    ///
37    /// Stores a clone of the configuration for constructing connection options
38    /// for non-default databases at runtime. The initial pool is placed into
39    /// the shared cache keyed by the configured database name.
40    ///
41    /// # Errors
42    ///
43    /// Returns [`AppError::Connection`] if the connection fails.
44    pub async fn new(config: &DatabaseConfig) -> Result<Self, AppError> {
45        let pool = PgPoolOptions::new()
46            .max_connections(config.max_pool_size)
47            .connect_with(connect_options(config))
48            .await
49            .map_err(|e| AppError::Connection(format!("Failed to connect to PostgreSQL: {e}")))?;
50
51        info!(
52            "PostgreSQL connection pool initialized (max size: {})",
53            config.max_pool_size
54        );
55
56        // PostgreSQL defaults to a database named after the connecting user.
57        let default_db = config
58            .name
59            .as_deref()
60            .filter(|n| !n.is_empty())
61            .map_or_else(|| config.user.clone(), String::from);
62
63        let pools = Cache::builder()
64            .max_capacity(POOL_CACHE_CAPACITY)
65            .eviction_listener(|_key, pool: PgPool, _cause| {
66                tokio::spawn(async move {
67                    pool.close().await;
68                });
69            })
70            .build();
71
72        pools.insert(default_db.clone(), pool).await;
73
74        Ok(Self {
75            config: config.clone(),
76            default_db,
77            pools,
78        })
79    }
80
81    /// Wraps `name` in double quotes for safe use in `PostgreSQL` SQL statements.
82    pub(crate) fn quote_identifier(name: &str) -> String {
83        database_mcp_sql::identifier::quote_identifier(name, '"')
84    }
85
86    /// Returns a connection pool for the requested database.
87    ///
88    /// Resolves `None` or empty names to the default pool. On a cache miss
89    /// a new pool is created and cached. Evicted pools are closed via the
90    /// cache's eviction listener.
91    ///
92    /// # Errors
93    ///
94    /// Returns [`AppError::InvalidIdentifier`] if the database name fails
95    /// validation, or [`AppError::Connection`] if the new pool cannot connect.
96    pub(crate) async fn get_pool(&self, database: Option<&str>) -> Result<PgPool, AppError> {
97        let db_key = match database {
98            Some(name) if !name.is_empty() => name,
99            _ => &self.default_db,
100        };
101
102        if let Some(pool) = self.pools.get(db_key).await {
103            return Ok(pool);
104        }
105
106        // Cache miss — validate then create a new pool.
107        validate_identifier(db_key)?;
108
109        let config = self.config.clone();
110        let db_key_owned = db_key.to_owned();
111
112        let pool = self
113            .pools
114            .try_get_with(db_key_owned, async {
115                let mut cfg = config;
116                cfg.name = Some(db_key.to_owned());
117                PgPoolOptions::new()
118                    .max_connections(cfg.max_pool_size)
119                    .connect_with(connect_options(&cfg))
120                    .await
121                    .map_err(|e| {
122                        AppError::Connection(format!("Failed to connect to PostgreSQL database '{db_key}': {e}"))
123                    })
124            })
125            .await
126            .map_err(|e| match e.as_ref() {
127                AppError::Connection(msg) => AppError::Connection(msg.clone()),
128                other => AppError::Connection(other.to_string()),
129            })?;
130
131        Ok(pool)
132    }
133}
134
135/// Builds [`PgConnectOptions`] from a [`DatabaseConfig`].
136///
137/// Uses [`PgConnectOptions::new_without_pgpass`] to avoid unintended
138/// `PG*` environment variable influence, since our config already
139/// resolves values from CLI/env.
140fn connect_options(config: &DatabaseConfig) -> PgConnectOptions {
141    let mut opts = PgConnectOptions::new_without_pgpass()
142        .host(&config.host)
143        .port(config.port)
144        .username(&config.user);
145
146    if let Some(ref password) = config.password {
147        opts = opts.password(password);
148    }
149    if let Some(ref name) = config.name
150        && !name.is_empty()
151    {
152        opts = opts.database(name);
153    }
154
155    if config.ssl {
156        opts = if config.ssl_verify_cert {
157            opts.ssl_mode(PgSslMode::VerifyCa)
158        } else {
159            opts.ssl_mode(PgSslMode::Require)
160        };
161        if let Some(ref ca) = config.ssl_ca {
162            opts = opts.ssl_root_cert(ca);
163        }
164        if let Some(ref cert) = config.ssl_cert {
165            opts = opts.ssl_client_cert(cert);
166        }
167        if let Some(ref key) = config.ssl_key {
168            opts = opts.ssl_client_key(key);
169        }
170    }
171
172    opts
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use database_mcp_config::DatabaseBackend;
179
180    fn base_config() -> DatabaseConfig {
181        DatabaseConfig {
182            backend: DatabaseBackend::Postgres,
183            host: "pg.example.com".into(),
184            port: 5433,
185            user: "pgadmin".into(),
186            password: Some("pgpass".into()),
187            name: Some("mydb".into()),
188            ..DatabaseConfig::default()
189        }
190    }
191
192    #[test]
193    fn try_from_basic_config() {
194        let config = base_config();
195        let opts = connect_options(&config);
196
197        assert_eq!(opts.get_host(), "pg.example.com");
198        assert_eq!(opts.get_port(), 5433);
199        assert_eq!(opts.get_username(), "pgadmin");
200        assert_eq!(opts.get_database(), Some("mydb"));
201    }
202
203    #[test]
204    fn try_from_with_ssl_require() {
205        let config = DatabaseConfig {
206            ssl: true,
207            ssl_verify_cert: false,
208            ..base_config()
209        };
210        let opts = connect_options(&config);
211
212        assert!(
213            matches!(opts.get_ssl_mode(), PgSslMode::Require),
214            "expected Require, got {:?}",
215            opts.get_ssl_mode()
216        );
217    }
218
219    #[test]
220    fn try_from_with_ssl_verify_ca() {
221        let config = DatabaseConfig {
222            ssl: true,
223            ssl_verify_cert: true,
224            ..base_config()
225        };
226        let opts = connect_options(&config);
227
228        assert!(
229            matches!(opts.get_ssl_mode(), PgSslMode::VerifyCa),
230            "expected VerifyCa, got {:?}",
231            opts.get_ssl_mode()
232        );
233    }
234
235    #[test]
236    fn try_from_without_database_name() {
237        let config = DatabaseConfig {
238            name: None,
239            ..base_config()
240        };
241        let opts = connect_options(&config);
242
243        assert_eq!(opts.get_database(), None);
244    }
245
246    #[test]
247    fn try_from_without_password() {
248        let config = DatabaseConfig {
249            password: None,
250            ..base_config()
251        };
252        let opts = connect_options(&config);
253
254        assert_eq!(opts.get_host(), "pg.example.com");
255    }
256}