database_mcp_postgres/
adapter.rs1use std::time::Duration;
8
9use database_mcp_config::DatabaseConfig;
10use database_mcp_server::AppError;
11use database_mcp_sql::identifier::validate_identifier;
12use moka::future::Cache;
13use sqlx::PgPool;
14use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgSslMode};
15use tracing::info;
16
17const POOL_CACHE_CAPACITY: u64 = 6;
19
20#[derive(Clone)]
27pub struct PostgresAdapter {
28 pub(crate) config: DatabaseConfig,
29 pub(crate) default_db: String,
30 default_pool: PgPool,
31 pub(crate) pools: Cache<String, PgPool>,
32}
33
34impl std::fmt::Debug for PostgresAdapter {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("PostgresAdapter")
37 .field("read_only", &self.config.read_only)
38 .field("default_db", &self.default_db)
39 .finish_non_exhaustive()
40 }
41}
42
43impl PostgresAdapter {
44 #[must_use]
49 pub fn new(config: &DatabaseConfig) -> Self {
50 let default_db = config
52 .name
53 .as_deref()
54 .filter(|n| !n.is_empty())
55 .map_or_else(|| config.user.clone(), String::from);
56
57 let default_pool = pool_options(config).connect_lazy_with(connect_options(config));
58
59 info!(
60 "PostgreSQL lazy connection pool created (max size: {})",
61 config.max_pool_size
62 );
63
64 let pools = Cache::builder()
65 .max_capacity(POOL_CACHE_CAPACITY)
66 .eviction_listener(|_key, pool: PgPool, _cause| {
67 tokio::spawn(async move {
68 pool.close().await;
69 });
70 })
71 .build();
72
73 Self {
74 config: config.clone(),
75 default_db,
76 default_pool,
77 pools,
78 }
79 }
80
81 pub(crate) fn quote_identifier(name: &str) -> String {
83 database_mcp_sql::identifier::quote_identifier(name, '"')
84 }
85
86 pub(crate) async fn get_pool(&self, database: Option<&str>) -> Result<PgPool, AppError> {
98 let db_key = match database {
99 Some(name) if !name.is_empty() => name,
100 _ => return Ok(self.default_pool.clone()),
101 };
102
103 if db_key == self.default_db {
105 return Ok(self.default_pool.clone());
106 }
107
108 if let Some(pool) = self.pools.get(db_key).await {
110 return Ok(pool);
111 }
112
113 validate_identifier(db_key)?;
115
116 let config = self.config.clone();
117 let db_key_owned = db_key.to_owned();
118
119 let pool = self
120 .pools
121 .get_with(db_key_owned, async {
122 let mut cfg = config;
123 cfg.name = Some(db_key.to_owned());
124 pool_options(&cfg).connect_lazy_with(connect_options(&cfg))
125 })
126 .await;
127
128 Ok(pool)
129 }
130}
131
132fn pool_options(config: &DatabaseConfig) -> PgPoolOptions {
134 let mut opts = PgPoolOptions::new()
135 .max_connections(config.max_pool_size)
136 .min_connections(DatabaseConfig::DEFAULT_MIN_CONNECTIONS)
137 .idle_timeout(Duration::from_secs(DatabaseConfig::DEFAULT_IDLE_TIMEOUT_SECS))
138 .max_lifetime(Duration::from_secs(DatabaseConfig::DEFAULT_MAX_LIFETIME_SECS));
139
140 if let Some(timeout) = config.connection_timeout {
141 opts = opts.acquire_timeout(Duration::from_secs(timeout));
142 }
143
144 opts
145}
146
147fn connect_options(config: &DatabaseConfig) -> PgConnectOptions {
153 let mut opts = PgConnectOptions::new_without_pgpass()
154 .host(&config.host)
155 .port(config.port)
156 .username(&config.user);
157
158 if let Some(ref password) = config.password {
159 opts = opts.password(password);
160 }
161 if let Some(ref name) = config.name
162 && !name.is_empty()
163 {
164 opts = opts.database(name);
165 }
166
167 if config.ssl {
168 opts = if config.ssl_verify_cert {
169 opts.ssl_mode(PgSslMode::VerifyCa)
170 } else {
171 opts.ssl_mode(PgSslMode::Require)
172 };
173 if let Some(ref ca) = config.ssl_ca {
174 opts = opts.ssl_root_cert(ca);
175 }
176 if let Some(ref cert) = config.ssl_cert {
177 opts = opts.ssl_client_cert(cert);
178 }
179 if let Some(ref key) = config.ssl_key {
180 opts = opts.ssl_client_key(key);
181 }
182 }
183
184 opts
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use database_mcp_config::DatabaseBackend;
191
192 fn base_config() -> DatabaseConfig {
193 DatabaseConfig {
194 backend: DatabaseBackend::Postgres,
195 host: "pg.example.com".into(),
196 port: 5433,
197 user: "pgadmin".into(),
198 password: Some("pgpass".into()),
199 name: Some("mydb".into()),
200 ..DatabaseConfig::default()
201 }
202 }
203
204 #[test]
205 fn pool_options_applies_defaults() {
206 let config = base_config();
207 let opts = pool_options(&config);
208
209 assert_eq!(opts.get_max_connections(), config.max_pool_size);
210 assert_eq!(opts.get_min_connections(), DatabaseConfig::DEFAULT_MIN_CONNECTIONS);
211 assert_eq!(
212 opts.get_idle_timeout(),
213 Some(Duration::from_secs(DatabaseConfig::DEFAULT_IDLE_TIMEOUT_SECS))
214 );
215 assert_eq!(
216 opts.get_max_lifetime(),
217 Some(Duration::from_secs(DatabaseConfig::DEFAULT_MAX_LIFETIME_SECS))
218 );
219 }
220
221 #[test]
222 fn pool_options_applies_connection_timeout() {
223 let config = DatabaseConfig {
224 connection_timeout: Some(7),
225 ..base_config()
226 };
227 let opts = pool_options(&config);
228
229 assert_eq!(opts.get_acquire_timeout(), Duration::from_secs(7));
230 }
231
232 #[test]
233 fn pool_options_without_connection_timeout_uses_sqlx_default() {
234 let config = base_config();
235 let opts = pool_options(&config);
236
237 assert_eq!(opts.get_acquire_timeout(), Duration::from_secs(30));
238 }
239
240 #[test]
241 fn try_from_basic_config() {
242 let config = base_config();
243 let opts = connect_options(&config);
244
245 assert_eq!(opts.get_host(), "pg.example.com");
246 assert_eq!(opts.get_port(), 5433);
247 assert_eq!(opts.get_username(), "pgadmin");
248 assert_eq!(opts.get_database(), Some("mydb"));
249 }
250
251 #[test]
252 fn try_from_with_ssl_require() {
253 let config = DatabaseConfig {
254 ssl: true,
255 ssl_verify_cert: false,
256 ..base_config()
257 };
258 let opts = connect_options(&config);
259
260 assert!(
261 matches!(opts.get_ssl_mode(), PgSslMode::Require),
262 "expected Require, got {:?}",
263 opts.get_ssl_mode()
264 );
265 }
266
267 #[test]
268 fn try_from_with_ssl_verify_ca() {
269 let config = DatabaseConfig {
270 ssl: true,
271 ssl_verify_cert: true,
272 ..base_config()
273 };
274 let opts = connect_options(&config);
275
276 assert!(
277 matches!(opts.get_ssl_mode(), PgSslMode::VerifyCa),
278 "expected VerifyCa, got {:?}",
279 opts.get_ssl_mode()
280 );
281 }
282
283 #[test]
284 fn try_from_without_database_name() {
285 let config = DatabaseConfig {
286 name: None,
287 ..base_config()
288 };
289 let opts = connect_options(&config);
290
291 assert_eq!(opts.get_database(), None);
292 }
293
294 #[test]
295 fn try_from_without_password() {
296 let config = DatabaseConfig {
297 password: None,
298 ..base_config()
299 };
300 let opts = connect_options(&config);
301
302 assert_eq!(opts.get_host(), "pg.example.com");
303 }
304
305 #[tokio::test]
306 async fn new_creates_lazy_pool() {
307 let config = base_config();
308 let adapter = PostgresAdapter::new(&config);
309 assert_eq!(adapter.default_db, "mydb");
310 assert_eq!(adapter.default_pool.size(), 0);
312 }
313
314 #[tokio::test]
315 async fn new_defaults_db_to_username() {
316 let config = DatabaseConfig {
317 name: None,
318 ..base_config()
319 };
320 let adapter = PostgresAdapter::new(&config);
321 assert_eq!(adapter.default_db, "pgadmin");
322 }
323}