1use std::time::Duration;
8
9use database_mcp_config::DatabaseConfig;
10use database_mcp_server::AppError;
11use database_mcp_server::server_info;
12use database_mcp_sql::identifier::validate_identifier;
13use moka::future::Cache;
14use rmcp::RoleServer;
15use rmcp::handler::server::router::tool::ToolRouter;
16use rmcp::handler::server::tool::ToolCallContext;
17use rmcp::model::{CallToolRequestParams, CallToolResult, ListToolsResult, PaginatedRequestParams, ServerInfo, Tool};
18use rmcp::service::RequestContext;
19use rmcp::{ErrorData, ServerHandler};
20use sqlx::PgPool;
21use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgSslMode};
22use tracing::info;
23
24use crate::tools::{
25 CreateDatabaseTool, DropDatabaseTool, DropTableTool, ExplainQueryTool, GetTableSchemaTool, ListDatabasesTool,
26 ListTablesTool, ReadQueryTool, WriteQueryTool,
27};
28
29const POOL_CACHE_CAPACITY: u64 = 6;
31
32const DESCRIPTION: &str = "Database MCP Server for PostgreSQL";
34
35const INSTRUCTIONS: &str = r"## Workflow
37
381. Call `list_databases` to discover available databases.
392. Call `list_tables` with a `database_name` to see its tables.
403. Call `get_table_schema` with `database_name` and `table_name` to inspect columns, types, and foreign keys before writing queries.
414. Use `read_query` for read-only SQL (SELECT, SHOW, EXPLAIN).
425. Use `write_query` for data changes (INSERT, UPDATE, DELETE, CREATE, ALTER, DROP).
436. Use `create_database` to create a new database.
447. Use `drop_database` to drop an existing database.
45
46Tools accept an optional `database_name` parameter to query across databases without reconnecting.
47
48## Constraints
49
50- The `write_query`, `create_database`, and `drop_database` tools are hidden when read-only mode is active.
51- Multi-statement queries are not supported. Send one statement per request.";
52
53#[derive(Clone)]
60pub struct PostgresHandler {
61 pub(crate) config: DatabaseConfig,
62 pub(crate) default_db: String,
63 default_pool: PgPool,
64 pub(crate) pools: Cache<String, PgPool>,
65 tool_router: ToolRouter<Self>,
66}
67
68impl std::fmt::Debug for PostgresHandler {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("PostgresHandler")
71 .field("read_only", &self.config.read_only)
72 .field("default_db", &self.default_db)
73 .finish_non_exhaustive()
74 }
75}
76
77impl PostgresHandler {
78 #[must_use]
84 pub fn new(config: &DatabaseConfig) -> Self {
85 let default_db = config
87 .name
88 .as_deref()
89 .filter(|n| !n.is_empty())
90 .map_or_else(|| config.user.clone(), String::from);
91
92 let default_pool = pool_options(config).connect_lazy_with(connect_options(config));
93
94 info!(
95 "PostgreSQL lazy connection pool created (max size: {})",
96 config.max_pool_size
97 );
98
99 let pools = Cache::builder()
100 .max_capacity(POOL_CACHE_CAPACITY)
101 .eviction_listener(|_key, pool: PgPool, _cause| {
102 tokio::spawn(async move {
103 pool.close().await;
104 });
105 })
106 .build();
107
108 Self {
109 config: config.clone(),
110 default_db,
111 default_pool,
112 pools,
113 tool_router: build_tool_router(config.read_only),
114 }
115 }
116
117 pub(crate) fn quote_identifier(name: &str) -> String {
119 database_mcp_sql::identifier::quote_identifier(name, '"')
120 }
121
122 pub(crate) async fn get_pool(&self, database: Option<&str>) -> Result<PgPool, AppError> {
134 let db_key = match database {
135 Some(name) if !name.is_empty() => name,
136 _ => return Ok(self.default_pool.clone()),
137 };
138
139 if db_key == self.default_db {
141 return Ok(self.default_pool.clone());
142 }
143
144 if let Some(pool) = self.pools.get(db_key).await {
146 return Ok(pool);
147 }
148
149 validate_identifier(db_key)?;
151
152 let config = self.config.clone();
153 let db_key_owned = db_key.to_owned();
154
155 let pool = self
156 .pools
157 .get_with(db_key_owned, async {
158 let mut cfg = config;
159 cfg.name = Some(db_key.to_owned());
160 pool_options(&cfg).connect_lazy_with(connect_options(&cfg))
161 })
162 .await;
163
164 Ok(pool)
165 }
166}
167
168fn pool_options(config: &DatabaseConfig) -> PgPoolOptions {
170 let mut opts = PgPoolOptions::new()
171 .max_connections(config.max_pool_size)
172 .min_connections(DatabaseConfig::DEFAULT_MIN_CONNECTIONS)
173 .idle_timeout(Duration::from_secs(DatabaseConfig::DEFAULT_IDLE_TIMEOUT_SECS))
174 .max_lifetime(Duration::from_secs(DatabaseConfig::DEFAULT_MAX_LIFETIME_SECS));
175
176 if let Some(timeout) = config.connection_timeout {
177 opts = opts.acquire_timeout(Duration::from_secs(timeout));
178 }
179
180 opts
181}
182
183fn connect_options(config: &DatabaseConfig) -> PgConnectOptions {
189 let mut opts = PgConnectOptions::new_without_pgpass()
190 .host(&config.host)
191 .port(config.port)
192 .username(&config.user);
193
194 if let Some(ref password) = config.password {
195 opts = opts.password(password);
196 }
197 if let Some(ref name) = config.name
198 && !name.is_empty()
199 {
200 opts = opts.database(name);
201 }
202
203 if config.ssl {
204 opts = if config.ssl_verify_cert {
205 opts.ssl_mode(PgSslMode::VerifyCa)
206 } else {
207 opts.ssl_mode(PgSslMode::Require)
208 };
209 if let Some(ref ca) = config.ssl_ca {
210 opts = opts.ssl_root_cert(ca);
211 }
212 if let Some(ref cert) = config.ssl_cert {
213 opts = opts.ssl_client_cert(cert);
214 }
215 if let Some(ref key) = config.ssl_key {
216 opts = opts.ssl_client_key(key);
217 }
218 }
219
220 opts
221}
222
223fn build_tool_router(read_only: bool) -> ToolRouter<PostgresHandler> {
225 let mut router = ToolRouter::new()
226 .with_async_tool::<ListDatabasesTool>()
227 .with_async_tool::<ListTablesTool>()
228 .with_async_tool::<GetTableSchemaTool>()
229 .with_async_tool::<ReadQueryTool>()
230 .with_async_tool::<ExplainQueryTool>();
231
232 if !read_only {
233 router = router
234 .with_async_tool::<CreateDatabaseTool>()
235 .with_async_tool::<DropDatabaseTool>()
236 .with_async_tool::<DropTableTool>()
237 .with_async_tool::<WriteQueryTool>();
238 }
239 router
240}
241
242impl ServerHandler for PostgresHandler {
243 fn get_info(&self) -> ServerInfo {
244 let mut info = server_info();
245 info.server_info.description = Some(DESCRIPTION.into());
246 info.instructions = Some(INSTRUCTIONS.into());
247 info
248 }
249
250 async fn call_tool(
251 &self,
252 request: CallToolRequestParams,
253 context: RequestContext<RoleServer>,
254 ) -> Result<CallToolResult, ErrorData> {
255 let tcc = ToolCallContext::new(self, request, context);
256 self.tool_router.call(tcc).await
257 }
258
259 async fn list_tools(
260 &self,
261 _request: Option<PaginatedRequestParams>,
262 _context: RequestContext<RoleServer>,
263 ) -> Result<ListToolsResult, ErrorData> {
264 Ok(ListToolsResult {
265 tools: self.tool_router.list_all(),
266 next_cursor: None,
267 meta: None,
268 })
269 }
270
271 fn get_tool(&self, name: &str) -> Option<Tool> {
272 self.tool_router.get(name).cloned()
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use database_mcp_config::DatabaseBackend;
280
281 fn base_config() -> DatabaseConfig {
282 DatabaseConfig {
283 backend: DatabaseBackend::Postgres,
284 host: "pg.example.com".into(),
285 port: 5433,
286 user: "pgadmin".into(),
287 password: Some("pgpass".into()),
288 name: Some("mydb".into()),
289 ..DatabaseConfig::default()
290 }
291 }
292
293 fn handler(read_only: bool) -> PostgresHandler {
294 PostgresHandler::new(&DatabaseConfig {
295 read_only,
296 ..base_config()
297 })
298 }
299
300 #[test]
301 fn pool_options_applies_defaults() {
302 let config = base_config();
303 let opts = pool_options(&config);
304
305 assert_eq!(opts.get_max_connections(), config.max_pool_size);
306 assert_eq!(opts.get_min_connections(), DatabaseConfig::DEFAULT_MIN_CONNECTIONS);
307 assert_eq!(
308 opts.get_idle_timeout(),
309 Some(Duration::from_secs(DatabaseConfig::DEFAULT_IDLE_TIMEOUT_SECS))
310 );
311 assert_eq!(
312 opts.get_max_lifetime(),
313 Some(Duration::from_secs(DatabaseConfig::DEFAULT_MAX_LIFETIME_SECS))
314 );
315 }
316
317 #[test]
318 fn pool_options_applies_connection_timeout() {
319 let config = DatabaseConfig {
320 connection_timeout: Some(7),
321 ..base_config()
322 };
323 let opts = pool_options(&config);
324
325 assert_eq!(opts.get_acquire_timeout(), Duration::from_secs(7));
326 }
327
328 #[test]
329 fn pool_options_without_connection_timeout_uses_sqlx_default() {
330 let config = base_config();
331 let opts = pool_options(&config);
332
333 assert_eq!(opts.get_acquire_timeout(), Duration::from_secs(30));
334 }
335
336 #[test]
337 fn try_from_basic_config() {
338 let config = base_config();
339 let opts = connect_options(&config);
340
341 assert_eq!(opts.get_host(), "pg.example.com");
342 assert_eq!(opts.get_port(), 5433);
343 assert_eq!(opts.get_username(), "pgadmin");
344 assert_eq!(opts.get_database(), Some("mydb"));
345 }
346
347 #[test]
348 fn try_from_with_ssl_require() {
349 let config = DatabaseConfig {
350 ssl: true,
351 ssl_verify_cert: false,
352 ..base_config()
353 };
354 let opts = connect_options(&config);
355
356 assert!(
357 matches!(opts.get_ssl_mode(), PgSslMode::Require),
358 "expected Require, got {:?}",
359 opts.get_ssl_mode()
360 );
361 }
362
363 #[test]
364 fn try_from_with_ssl_verify_ca() {
365 let config = DatabaseConfig {
366 ssl: true,
367 ssl_verify_cert: true,
368 ..base_config()
369 };
370 let opts = connect_options(&config);
371
372 assert!(
373 matches!(opts.get_ssl_mode(), PgSslMode::VerifyCa),
374 "expected VerifyCa, got {:?}",
375 opts.get_ssl_mode()
376 );
377 }
378
379 #[test]
380 fn try_from_without_database_name() {
381 let config = DatabaseConfig {
382 name: None,
383 ..base_config()
384 };
385 let opts = connect_options(&config);
386
387 assert_eq!(opts.get_database(), None);
388 }
389
390 #[test]
391 fn try_from_without_password() {
392 let config = DatabaseConfig {
393 password: None,
394 ..base_config()
395 };
396 let opts = connect_options(&config);
397
398 assert_eq!(opts.get_host(), "pg.example.com");
399 }
400
401 #[tokio::test]
402 async fn new_creates_lazy_pool() {
403 let config = base_config();
404 let handler = PostgresHandler::new(&config);
405 assert_eq!(handler.default_db, "mydb");
406 assert_eq!(handler.default_pool.size(), 0);
408 }
409
410 #[tokio::test]
411 async fn new_defaults_db_to_username() {
412 let config = DatabaseConfig {
413 name: None,
414 ..base_config()
415 };
416 let handler = PostgresHandler::new(&config);
417 assert_eq!(handler.default_db, "pgadmin");
418 }
419
420 #[tokio::test]
421 async fn router_exposes_all_nine_tools_in_read_write_mode() {
422 let router = handler(false).tool_router;
423 for name in [
424 "list_databases",
425 "list_tables",
426 "get_table_schema",
427 "read_query",
428 "explain_query",
429 "create_database",
430 "drop_database",
431 "drop_table",
432 "write_query",
433 ] {
434 assert!(router.has_route(name), "missing tool: {name}");
435 }
436 }
437
438 #[tokio::test]
439 async fn router_hides_write_tools_in_read_only_mode() {
440 let router = handler(true).tool_router;
441 assert!(router.has_route("list_databases"));
442 assert!(router.has_route("list_tables"));
443 assert!(router.has_route("get_table_schema"));
444 assert!(router.has_route("read_query"));
445 assert!(router.has_route("explain_query"));
446 assert!(!router.has_route("write_query"));
447 assert!(!router.has_route("create_database"));
448 assert!(!router.has_route("drop_database"));
449 assert!(!router.has_route("drop_table"));
450 }
451}