database_mcp_postgres/
connection.rs1use std::time::Duration;
8
9use database_mcp_config::DatabaseConfig;
10use database_mcp_sql::Connection;
11use database_mcp_sql::SqlError;
12use database_mcp_sql::sanitize::validate_ident;
13use moka::future::Cache;
14use sqlx::postgres::{PgConnectOptions, PgPool, PgSslMode};
15use tracing::info;
16
17pub(crate) const POOL_CACHE_CAPACITY: u64 = 16;
19
20#[derive(Clone)]
22pub(crate) struct PostgresConnection {
23 config: DatabaseConfig,
24 pools: Cache<String, PgPool>,
25}
26
27impl std::fmt::Debug for PostgresConnection {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("PostgresConnection")
30 .field("default_database_name", &self.default_database_name())
31 .finish_non_exhaustive()
32 }
33}
34
35impl PostgresConnection {
36 pub(crate) fn new(config: &DatabaseConfig) -> Self {
41 info!(
42 "PostgreSQL lazy connection pool created (max size: {})",
43 config.max_pool_size
44 );
45
46 let pools = Cache::builder()
47 .max_capacity(POOL_CACHE_CAPACITY)
48 .eviction_listener(|_key, pool: PgPool, _cause| {
49 tokio::spawn(async move {
50 pool.close().await;
51 });
52 })
53 .build();
54
55 Self {
56 config: config.clone(),
57 pools,
58 }
59 }
60
61 pub(crate) fn default_database_name(&self) -> &str {
63 self.config
64 .name
65 .as_deref()
66 .filter(|n| !n.is_empty())
67 .unwrap_or(&self.config.user)
68 }
69
70 pub(crate) async fn invalidate(&self, name: &str) {
74 self.pools.invalidate(name).await;
75 }
76
77 pub(crate) async fn pool(&self, target: Option<&str>) -> Result<PgPool, SqlError> {
86 let database = match target {
87 Some(name) if !name.is_empty() => name,
88 _ => self.default_database_name(),
89 };
90
91 if let Some(pool) = self.pools.get(database).await {
92 return Ok(pool);
93 }
94
95 if database != self.default_database_name() {
96 validate_ident(database)?;
97 }
98
99 let pool = self
100 .pools
101 .get_with(database.to_owned(), async { create_lazy_pool(&self.config, database) })
102 .await;
103
104 Ok(pool)
105 }
106}
107
108impl Connection for PostgresConnection {
109 type DB = sqlx::Postgres;
110
111 async fn pool(&self, target: Option<&str>) -> Result<sqlx::Pool<Self::DB>, SqlError> {
112 self.pool(target).await
113 }
114
115 fn query_timeout(&self) -> Option<u64> {
116 self.config.query_timeout
117 }
118}
119
120fn create_lazy_pool(config: &DatabaseConfig, database: &str) -> PgPool {
126 let mut conn_ops = PgConnectOptions::new_without_pgpass()
127 .host(&config.host)
128 .port(config.port)
129 .username(&config.user);
130
131 if let Some(ref password) = config.password {
132 conn_ops = conn_ops.password(password);
133 }
134 if !database.is_empty() {
135 conn_ops = conn_ops.database(database);
136 }
137
138 if config.ssl {
139 conn_ops = if config.ssl_verify_cert {
140 conn_ops.ssl_mode(PgSslMode::VerifyCa)
141 } else {
142 conn_ops.ssl_mode(PgSslMode::Require)
143 };
144 if let Some(ref ca) = config.ssl_ca {
145 conn_ops = conn_ops.ssl_root_cert(ca);
146 }
147 if let Some(ref cert) = config.ssl_cert {
148 conn_ops = conn_ops.ssl_client_cert(cert);
149 }
150 if let Some(ref key) = config.ssl_key {
151 conn_ops = conn_ops.ssl_client_key(key);
152 }
153 }
154
155 let mut pool_opts = sqlx::pool::PoolOptions::new()
156 .max_connections(config.max_pool_size)
157 .min_connections(DatabaseConfig::DEFAULT_MIN_CONNECTIONS)
158 .idle_timeout(Duration::from_secs(DatabaseConfig::DEFAULT_IDLE_TIMEOUT_SECS))
159 .max_lifetime(Duration::from_secs(DatabaseConfig::DEFAULT_MAX_LIFETIME_SECS));
160
161 if let Some(timeout) = config.connection_timeout {
162 pool_opts = pool_opts.acquire_timeout(Duration::from_secs(timeout));
163 }
164
165 pool_opts.connect_lazy_with(conn_ops)
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use database_mcp_config::DatabaseBackend;
172
173 fn base_config() -> DatabaseConfig {
174 DatabaseConfig {
175 backend: DatabaseBackend::Postgres,
176 host: "pg.example.com".into(),
177 port: 5433,
178 user: "pgadmin".into(),
179 password: Some("pgpass".into()),
180 name: Some("mydb".into()),
181 ..DatabaseConfig::default()
182 }
183 }
184
185 #[tokio::test]
186 async fn create_lazy_pool_returns_idle_pool() {
187 let pool = create_lazy_pool(&base_config(), "mydb");
188 assert_eq!(pool.size(), 0, "pool should be lazy (no connections yet)");
189 }
190
191 #[tokio::test]
192 async fn create_lazy_pool_without_password() {
193 let pool = create_lazy_pool(
194 &DatabaseConfig {
195 password: None,
196 ..base_config()
197 },
198 "mydb",
199 );
200 assert_eq!(pool.size(), 0);
201 }
202
203 #[tokio::test]
204 async fn create_lazy_pool_without_database_name() {
205 let pool = create_lazy_pool(
206 &DatabaseConfig {
207 name: None,
208 ..base_config()
209 },
210 "",
211 );
212 assert_eq!(pool.size(), 0);
213 }
214
215 #[tokio::test]
216 async fn default_database_name_derived_from_config() {
217 let connection = PostgresConnection::new(&base_config());
218 assert_eq!(connection.default_database_name(), "mydb");
219 }
220
221 #[tokio::test]
222 async fn defaults_db_to_username_when_name_missing() {
223 let connection = PostgresConnection::new(&DatabaseConfig {
224 name: None,
225 ..base_config()
226 });
227 assert_eq!(connection.default_database_name(), "pgadmin");
228 }
229
230 #[tokio::test]
231 async fn none_target_returns_default_pool() {
232 let connection = PostgresConnection::new(&base_config());
233 connection.pool(None).await.expect("None target should succeed");
234 }
235
236 #[tokio::test]
237 async fn arbitrary_target_database_is_permitted() {
238 let connection = PostgresConnection::new(&base_config());
239 connection
240 .pool(Some("any_db"))
241 .await
242 .expect("any database should be permitted");
243 }
244
245 #[tokio::test]
246 async fn pool_cache_respects_capacity_const() {
247 let connection = PostgresConnection::new(&base_config());
248
249 for i in 0..=POOL_CACHE_CAPACITY {
252 let name = format!("db_{i}");
253 connection.pool(Some(&name)).await.expect("pool should succeed");
254 }
255 connection.pools.run_pending_tasks().await;
256
257 assert!(
258 connection.pools.entry_count() <= POOL_CACHE_CAPACITY,
259 "cached pools exceeded cap: {} > {POOL_CACHE_CAPACITY}",
260 connection.pools.entry_count()
261 );
262 }
263}