Skip to main content

database_mcp_postgres/
handler.rs

1//! `PostgreSQL` handler: connection pool cache, MCP tool router, and `ServerHandler` impl.
2//!
3//! Creates a lazy default pool via [`PgPoolOptions::connect_lazy_with`].
4//! Non-default database pools are created on demand and cached in a
5//! moka [`Cache`].
6
7use std::time::Duration;
8
9use database_mcp_config::DatabaseConfig;
10use database_mcp_server::AppError;
11use database_mcp_server::server_info;
12use database_mcp_sql::identifier::validate_identifier;
13use moka::future::Cache;
14use rmcp::RoleServer;
15use rmcp::handler::server::router::tool::ToolRouter;
16use rmcp::handler::server::tool::ToolCallContext;
17use rmcp::model::{CallToolRequestParams, CallToolResult, ListToolsResult, PaginatedRequestParams, ServerInfo, Tool};
18use rmcp::service::RequestContext;
19use rmcp::{ErrorData, ServerHandler};
20use sqlx::PgPool;
21use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgSslMode};
22use tracing::info;
23
24use crate::tools::{
25    CreateDatabaseTool, DropDatabaseTool, DropTableTool, ExplainQueryTool, GetTableSchemaTool, ListDatabasesTool,
26    ListTablesTool, ReadQueryTool, WriteQueryTool,
27};
28
29/// Maximum number of database connection pools to cache (including the default).
30const POOL_CACHE_CAPACITY: u64 = 6;
31
32/// Backend-specific description for `PostgreSQL`.
33const DESCRIPTION: &str = "Database MCP Server for PostgreSQL";
34
35/// Backend-specific instructions for `PostgreSQL`.
36const INSTRUCTIONS: &str = r"## Workflow
37
381. Call `list_databases` to discover available databases.
392. Call `list_tables` with a `database_name` to see its tables.
403. Call `get_table_schema` with `database_name` and `table_name` to inspect columns, types, and foreign keys before writing queries.
414. Use `read_query` for read-only SQL (SELECT, SHOW, EXPLAIN).
425. Use `write_query` for data changes (INSERT, UPDATE, DELETE, CREATE, ALTER, DROP).
436. Use `create_database` to create a new database.
447. Use `drop_database` to drop an existing database.
45
46Tools accept an optional `database_name` parameter to query across databases without reconnecting.
47
48## Constraints
49
50- The `write_query`, `create_database`, and `drop_database` tools are hidden when read-only mode is active.
51- Multi-statement queries are not supported. Send one statement per request.";
52
53/// `PostgreSQL` database handler.
54///
55/// The default connection pool is created with
56/// [`PgPoolOptions::connect_lazy_with`], which defers all network I/O
57/// until the first query. Non-default database pools are created on
58/// demand via the moka [`Cache`].
59#[derive(Clone)]
60pub struct PostgresHandler {
61    pub(crate) config: DatabaseConfig,
62    pub(crate) default_db: String,
63    default_pool: PgPool,
64    pub(crate) pools: Cache<String, PgPool>,
65    tool_router: ToolRouter<Self>,
66}
67
68impl std::fmt::Debug for PostgresHandler {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        f.debug_struct("PostgresHandler")
71            .field("read_only", &self.config.read_only)
72            .field("default_db", &self.default_db)
73            .finish_non_exhaustive()
74    }
75}
76
77impl PostgresHandler {
78    /// Creates a new `PostgreSQL` handler with a lazy connection pool.
79    ///
80    /// Does **not** establish a database connection. The default pool
81    /// connects on-demand when the first query is executed. The MCP tool
82    /// router is built once here and reused for every request.
83    #[must_use]
84    pub fn new(config: &DatabaseConfig) -> Self {
85        // PostgreSQL defaults to a database named after the connecting user.
86        let default_db = config
87            .name
88            .as_deref()
89            .filter(|n| !n.is_empty())
90            .map_or_else(|| config.user.clone(), String::from);
91
92        let default_pool = pool_options(config).connect_lazy_with(connect_options(config));
93
94        info!(
95            "PostgreSQL lazy connection pool created (max size: {})",
96            config.max_pool_size
97        );
98
99        let pools = Cache::builder()
100            .max_capacity(POOL_CACHE_CAPACITY)
101            .eviction_listener(|_key, pool: PgPool, _cause| {
102                tokio::spawn(async move {
103                    pool.close().await;
104                });
105            })
106            .build();
107
108        Self {
109            config: config.clone(),
110            default_db,
111            default_pool,
112            pools,
113            tool_router: build_tool_router(config.read_only),
114        }
115    }
116
117    /// Wraps `name` in double quotes for safe use in `PostgreSQL` SQL statements.
118    pub(crate) fn quote_identifier(name: &str) -> String {
119        database_mcp_sql::identifier::quote_identifier(name, '"')
120    }
121
122    /// Returns a connection pool for the requested database.
123    ///
124    /// Resolves `None` or empty names to the default lazy pool. On a
125    /// cache miss for a non-default database, a new lazy pool is created
126    /// and cached. Evicted pools are closed via the cache's eviction
127    /// listener.
128    ///
129    /// # Errors
130    ///
131    /// Returns [`AppError::InvalidIdentifier`] if the database name fails
132    /// validation.
133    pub(crate) async fn get_pool(&self, database: Option<&str>) -> Result<PgPool, AppError> {
134        let db_key = match database {
135            Some(name) if !name.is_empty() => name,
136            _ => return Ok(self.default_pool.clone()),
137        };
138
139        // Check if it's the default database by name.
140        if db_key == self.default_db {
141            return Ok(self.default_pool.clone());
142        }
143
144        // Non-default database: check cache first.
145        if let Some(pool) = self.pools.get(db_key).await {
146            return Ok(pool);
147        }
148
149        // Cache miss — validate then create a new lazy pool.
150        validate_identifier(db_key)?;
151
152        let config = self.config.clone();
153        let db_key_owned = db_key.to_owned();
154
155        let pool = self
156            .pools
157            .get_with(db_key_owned, async {
158                let mut cfg = config;
159                cfg.name = Some(db_key.to_owned());
160                pool_options(&cfg).connect_lazy_with(connect_options(&cfg))
161            })
162            .await;
163
164        Ok(pool)
165    }
166}
167
168/// Builds [`PgPoolOptions`] with lifecycle defaults from a [`DatabaseConfig`].
169fn pool_options(config: &DatabaseConfig) -> PgPoolOptions {
170    let mut opts = PgPoolOptions::new()
171        .max_connections(config.max_pool_size)
172        .min_connections(DatabaseConfig::DEFAULT_MIN_CONNECTIONS)
173        .idle_timeout(Duration::from_secs(DatabaseConfig::DEFAULT_IDLE_TIMEOUT_SECS))
174        .max_lifetime(Duration::from_secs(DatabaseConfig::DEFAULT_MAX_LIFETIME_SECS));
175
176    if let Some(timeout) = config.connection_timeout {
177        opts = opts.acquire_timeout(Duration::from_secs(timeout));
178    }
179
180    opts
181}
182
183/// Builds [`PgConnectOptions`] from a [`DatabaseConfig`].
184///
185/// Uses [`PgConnectOptions::new_without_pgpass`] to avoid unintended
186/// `PG*` environment variable influence, since our config already
187/// resolves values from CLI/env.
188fn connect_options(config: &DatabaseConfig) -> PgConnectOptions {
189    let mut opts = PgConnectOptions::new_without_pgpass()
190        .host(&config.host)
191        .port(config.port)
192        .username(&config.user);
193
194    if let Some(ref password) = config.password {
195        opts = opts.password(password);
196    }
197    if let Some(ref name) = config.name
198        && !name.is_empty()
199    {
200        opts = opts.database(name);
201    }
202
203    if config.ssl {
204        opts = if config.ssl_verify_cert {
205            opts.ssl_mode(PgSslMode::VerifyCa)
206        } else {
207            opts.ssl_mode(PgSslMode::Require)
208        };
209        if let Some(ref ca) = config.ssl_ca {
210            opts = opts.ssl_root_cert(ca);
211        }
212        if let Some(ref cert) = config.ssl_cert {
213            opts = opts.ssl_client_cert(cert);
214        }
215        if let Some(ref key) = config.ssl_key {
216            opts = opts.ssl_client_key(key);
217        }
218    }
219
220    opts
221}
222
223/// Builds the tool router, including write tools only when not in read-only mode.
224fn build_tool_router(read_only: bool) -> ToolRouter<PostgresHandler> {
225    let mut router = ToolRouter::new()
226        .with_async_tool::<ListDatabasesTool>()
227        .with_async_tool::<ListTablesTool>()
228        .with_async_tool::<GetTableSchemaTool>()
229        .with_async_tool::<ReadQueryTool>()
230        .with_async_tool::<ExplainQueryTool>();
231
232    if !read_only {
233        router = router
234            .with_async_tool::<CreateDatabaseTool>()
235            .with_async_tool::<DropDatabaseTool>()
236            .with_async_tool::<DropTableTool>()
237            .with_async_tool::<WriteQueryTool>();
238    }
239    router
240}
241
242impl ServerHandler for PostgresHandler {
243    fn get_info(&self) -> ServerInfo {
244        let mut info = server_info();
245        info.server_info.description = Some(DESCRIPTION.into());
246        info.instructions = Some(INSTRUCTIONS.into());
247        info
248    }
249
250    async fn call_tool(
251        &self,
252        request: CallToolRequestParams,
253        context: RequestContext<RoleServer>,
254    ) -> Result<CallToolResult, ErrorData> {
255        let tcc = ToolCallContext::new(self, request, context);
256        self.tool_router.call(tcc).await
257    }
258
259    async fn list_tools(
260        &self,
261        _request: Option<PaginatedRequestParams>,
262        _context: RequestContext<RoleServer>,
263    ) -> Result<ListToolsResult, ErrorData> {
264        Ok(ListToolsResult {
265            tools: self.tool_router.list_all(),
266            next_cursor: None,
267            meta: None,
268        })
269    }
270
271    fn get_tool(&self, name: &str) -> Option<Tool> {
272        self.tool_router.get(name).cloned()
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use database_mcp_config::DatabaseBackend;
280
281    fn base_config() -> DatabaseConfig {
282        DatabaseConfig {
283            backend: DatabaseBackend::Postgres,
284            host: "pg.example.com".into(),
285            port: 5433,
286            user: "pgadmin".into(),
287            password: Some("pgpass".into()),
288            name: Some("mydb".into()),
289            ..DatabaseConfig::default()
290        }
291    }
292
293    fn handler(read_only: bool) -> PostgresHandler {
294        PostgresHandler::new(&DatabaseConfig {
295            read_only,
296            ..base_config()
297        })
298    }
299
300    #[test]
301    fn pool_options_applies_defaults() {
302        let config = base_config();
303        let opts = pool_options(&config);
304
305        assert_eq!(opts.get_max_connections(), config.max_pool_size);
306        assert_eq!(opts.get_min_connections(), DatabaseConfig::DEFAULT_MIN_CONNECTIONS);
307        assert_eq!(
308            opts.get_idle_timeout(),
309            Some(Duration::from_secs(DatabaseConfig::DEFAULT_IDLE_TIMEOUT_SECS))
310        );
311        assert_eq!(
312            opts.get_max_lifetime(),
313            Some(Duration::from_secs(DatabaseConfig::DEFAULT_MAX_LIFETIME_SECS))
314        );
315    }
316
317    #[test]
318    fn pool_options_applies_connection_timeout() {
319        let config = DatabaseConfig {
320            connection_timeout: Some(7),
321            ..base_config()
322        };
323        let opts = pool_options(&config);
324
325        assert_eq!(opts.get_acquire_timeout(), Duration::from_secs(7));
326    }
327
328    #[test]
329    fn pool_options_without_connection_timeout_uses_sqlx_default() {
330        let config = base_config();
331        let opts = pool_options(&config);
332
333        assert_eq!(opts.get_acquire_timeout(), Duration::from_secs(30));
334    }
335
336    #[test]
337    fn try_from_basic_config() {
338        let config = base_config();
339        let opts = connect_options(&config);
340
341        assert_eq!(opts.get_host(), "pg.example.com");
342        assert_eq!(opts.get_port(), 5433);
343        assert_eq!(opts.get_username(), "pgadmin");
344        assert_eq!(opts.get_database(), Some("mydb"));
345    }
346
347    #[test]
348    fn try_from_with_ssl_require() {
349        let config = DatabaseConfig {
350            ssl: true,
351            ssl_verify_cert: false,
352            ..base_config()
353        };
354        let opts = connect_options(&config);
355
356        assert!(
357            matches!(opts.get_ssl_mode(), PgSslMode::Require),
358            "expected Require, got {:?}",
359            opts.get_ssl_mode()
360        );
361    }
362
363    #[test]
364    fn try_from_with_ssl_verify_ca() {
365        let config = DatabaseConfig {
366            ssl: true,
367            ssl_verify_cert: true,
368            ..base_config()
369        };
370        let opts = connect_options(&config);
371
372        assert!(
373            matches!(opts.get_ssl_mode(), PgSslMode::VerifyCa),
374            "expected VerifyCa, got {:?}",
375            opts.get_ssl_mode()
376        );
377    }
378
379    #[test]
380    fn try_from_without_database_name() {
381        let config = DatabaseConfig {
382            name: None,
383            ..base_config()
384        };
385        let opts = connect_options(&config);
386
387        assert_eq!(opts.get_database(), None);
388    }
389
390    #[test]
391    fn try_from_without_password() {
392        let config = DatabaseConfig {
393            password: None,
394            ..base_config()
395        };
396        let opts = connect_options(&config);
397
398        assert_eq!(opts.get_host(), "pg.example.com");
399    }
400
401    #[tokio::test]
402    async fn new_creates_lazy_pool() {
403        let config = base_config();
404        let handler = PostgresHandler::new(&config);
405        assert_eq!(handler.default_db, "mydb");
406        // Pool exists but has no active connections (lazy).
407        assert_eq!(handler.default_pool.size(), 0);
408    }
409
410    #[tokio::test]
411    async fn new_defaults_db_to_username() {
412        let config = DatabaseConfig {
413            name: None,
414            ..base_config()
415        };
416        let handler = PostgresHandler::new(&config);
417        assert_eq!(handler.default_db, "pgadmin");
418    }
419
420    #[tokio::test]
421    async fn router_exposes_all_nine_tools_in_read_write_mode() {
422        let router = handler(false).tool_router;
423        for name in [
424            "list_databases",
425            "list_tables",
426            "get_table_schema",
427            "read_query",
428            "explain_query",
429            "create_database",
430            "drop_database",
431            "drop_table",
432            "write_query",
433        ] {
434            assert!(router.has_route(name), "missing tool: {name}");
435        }
436    }
437
438    #[tokio::test]
439    async fn router_hides_write_tools_in_read_only_mode() {
440        let router = handler(true).tool_router;
441        assert!(router.has_route("list_databases"));
442        assert!(router.has_route("list_tables"));
443        assert!(router.has_route("get_table_schema"));
444        assert!(router.has_route("read_query"));
445        assert!(router.has_route("explain_query"));
446        assert!(!router.has_route("write_query"));
447        assert!(!router.has_route("create_database"));
448        assert!(!router.has_route("drop_database"));
449        assert!(!router.has_route("drop_table"));
450    }
451}