Skip to main content

database_mcp_postgres/
connection.rs

1//! `PostgreSQL` connection configuration and backend definition.
2
3use database_mcp_backend::error::AppError;
4use database_mcp_backend::identifier::validate_identifier;
5use database_mcp_config::DatabaseConfig;
6use moka::future::Cache;
7use sqlx::PgPool;
8use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgSslMode};
9use tracing::info;
10
11/// Maximum number of database connection pools to cache (including the default).
12const POOL_CACHE_CAPACITY: u64 = 6;
13
14/// `PostgreSQL` database backend.
15///
16/// All connection pools — including the default — live in a single
17/// concurrent cache keyed by database name. No external mutex required.
18#[derive(Clone)]
19pub struct PostgresBackend {
20    config: DatabaseConfig,
21    default_db: String,
22    pools: Cache<String, PgPool>,
23    pub read_only: bool,
24}
25
26impl std::fmt::Debug for PostgresBackend {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("PostgresBackend")
29            .field("read_only", &self.read_only)
30            .field("default_db", &self.default_db)
31            .finish_non_exhaustive()
32    }
33}
34
35impl PostgresBackend {
36    /// Creates a new `PostgreSQL` backend from configuration.
37    ///
38    /// Stores a clone of the configuration for constructing connection options
39    /// for non-default databases at runtime. The initial pool is placed into
40    /// the shared cache keyed by the configured database name.
41    ///
42    /// # Errors
43    ///
44    /// Returns [`AppError::Connection`] if the connection fails.
45    pub async fn new(config: &DatabaseConfig) -> Result<Self, AppError> {
46        let pool = PgPoolOptions::new()
47            .max_connections(config.max_pool_size)
48            .connect_with(connect_options(config))
49            .await
50            .map_err(|e| AppError::Connection(format!("Failed to connect to PostgreSQL: {e}")))?;
51
52        info!(
53            "PostgreSQL connection pool initialized (max size: {})",
54            config.max_pool_size
55        );
56
57        // PostgreSQL defaults to a database named after the connecting user.
58        let default_db = config
59            .name
60            .as_deref()
61            .filter(|n| !n.is_empty())
62            .map_or_else(|| config.user.clone(), String::from);
63
64        let pools = Cache::builder()
65            .max_capacity(POOL_CACHE_CAPACITY)
66            .eviction_listener(|_key, pool: PgPool, _cause| {
67                tokio::spawn(async move {
68                    pool.close().await;
69                });
70            })
71            .build();
72
73        pools.insert(default_db.clone(), pool).await;
74
75        Ok(Self {
76            config: config.clone(),
77            default_db,
78            pools,
79            read_only: config.read_only,
80        })
81    }
82
83    /// Wraps `name` in double quotes for safe use in `PostgreSQL` SQL statements.
84    pub(crate) fn quote_identifier(name: &str) -> String {
85        database_mcp_backend::identifier::quote_identifier(name, '"')
86    }
87
88    /// Returns a connection pool for the requested database.
89    ///
90    /// Resolves `None` or empty names to the default pool. On a cache miss
91    /// a new pool is created and cached. Evicted pools are closed via the
92    /// cache's eviction listener.
93    ///
94    /// # Errors
95    ///
96    /// Returns [`AppError::InvalidIdentifier`] if the database name fails
97    /// validation, or [`AppError::Connection`] if the new pool cannot connect.
98    pub(crate) async fn get_pool(&self, database: Option<&str>) -> Result<PgPool, AppError> {
99        let db_key = match database {
100            Some(name) if !name.is_empty() => name,
101            _ => &self.default_db,
102        };
103
104        if let Some(pool) = self.pools.get(db_key).await {
105            return Ok(pool);
106        }
107
108        // Cache miss — validate then create a new pool.
109        validate_identifier(db_key)?;
110
111        let config = self.config.clone();
112        let db_key_owned = db_key.to_owned();
113
114        let pool = self
115            .pools
116            .try_get_with(db_key_owned, async {
117                let mut cfg = config;
118                cfg.name = Some(db_key.to_owned());
119                PgPoolOptions::new()
120                    .max_connections(cfg.max_pool_size)
121                    .connect_with(connect_options(&cfg))
122                    .await
123                    .map_err(|e| {
124                        AppError::Connection(format!("Failed to connect to PostgreSQL database '{db_key}': {e}"))
125                    })
126            })
127            .await
128            .map_err(|e| match e.as_ref() {
129                AppError::Connection(msg) => AppError::Connection(msg.clone()),
130                other => AppError::Connection(other.to_string()),
131            })?;
132
133        Ok(pool)
134    }
135}
136
137/// Builds [`PgConnectOptions`] from a [`DatabaseConfig`].
138///
139/// Uses [`PgConnectOptions::new_without_pgpass`] to avoid unintended
140/// `PG*` environment variable influence, since our config already
141/// resolves values from CLI/env.
142fn connect_options(config: &DatabaseConfig) -> PgConnectOptions {
143    let mut opts = PgConnectOptions::new_without_pgpass()
144        .host(&config.host)
145        .port(config.port)
146        .username(&config.user);
147
148    if let Some(ref password) = config.password {
149        opts = opts.password(password);
150    }
151    if let Some(ref name) = config.name
152        && !name.is_empty()
153    {
154        opts = opts.database(name);
155    }
156
157    if config.ssl {
158        opts = if config.ssl_verify_cert {
159            opts.ssl_mode(PgSslMode::VerifyCa)
160        } else {
161            opts.ssl_mode(PgSslMode::Require)
162        };
163        if let Some(ref ca) = config.ssl_ca {
164            opts = opts.ssl_root_cert(ca);
165        }
166        if let Some(ref cert) = config.ssl_cert {
167            opts = opts.ssl_client_cert(cert);
168        }
169        if let Some(ref key) = config.ssl_key {
170            opts = opts.ssl_client_key(key);
171        }
172    }
173
174    opts
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use database_mcp_config::DatabaseBackend;
181
182    fn base_config() -> DatabaseConfig {
183        DatabaseConfig {
184            backend: DatabaseBackend::Postgres,
185            host: "pg.example.com".into(),
186            port: 5433,
187            user: "pgadmin".into(),
188            password: Some("pgpass".into()),
189            name: Some("mydb".into()),
190            ..DatabaseConfig::default()
191        }
192    }
193
194    #[test]
195    fn try_from_basic_config() {
196        let config = base_config();
197        let opts = connect_options(&config);
198
199        assert_eq!(opts.get_host(), "pg.example.com");
200        assert_eq!(opts.get_port(), 5433);
201        assert_eq!(opts.get_username(), "pgadmin");
202        assert_eq!(opts.get_database(), Some("mydb"));
203    }
204
205    #[test]
206    fn try_from_with_ssl_require() {
207        let config = DatabaseConfig {
208            ssl: true,
209            ssl_verify_cert: false,
210            ..base_config()
211        };
212        let opts = connect_options(&config);
213
214        assert!(
215            matches!(opts.get_ssl_mode(), PgSslMode::Require),
216            "expected Require, got {:?}",
217            opts.get_ssl_mode()
218        );
219    }
220
221    #[test]
222    fn try_from_with_ssl_verify_ca() {
223        let config = DatabaseConfig {
224            ssl: true,
225            ssl_verify_cert: true,
226            ..base_config()
227        };
228        let opts = connect_options(&config);
229
230        assert!(
231            matches!(opts.get_ssl_mode(), PgSslMode::VerifyCa),
232            "expected VerifyCa, got {:?}",
233            opts.get_ssl_mode()
234        );
235    }
236
237    #[test]
238    fn try_from_without_database_name() {
239        let config = DatabaseConfig {
240            name: None,
241            ..base_config()
242        };
243        let opts = connect_options(&config);
244
245        assert_eq!(opts.get_database(), None);
246    }
247
248    #[test]
249    fn try_from_without_password() {
250        let config = DatabaseConfig {
251            password: None,
252            ..base_config()
253        };
254        let opts = connect_options(&config);
255
256        assert_eq!(opts.get_host(), "pg.example.com");
257    }
258}