Skip to main content

database_mcp/db/
postgres.rs

1//! `PostgreSQL` backend implementation via sqlx.
2//!
3//! Implements [`DatabaseBackend`] for `PostgreSQL` databases. Supports
4//! cross-database operations by maintaining a concurrent cache of connection
5//! pools keyed by database name.
6
7use crate::config::DatabaseConfig;
8use crate::db::backend::DatabaseBackend;
9use crate::db::identifier::validate_identifier;
10use crate::error::AppError;
11use moka::future::Cache;
12use serde_json::{Value, json};
13use sqlx::postgres::{PgConnectOptions, PgPoolOptions, PgRow, PgSslMode};
14use sqlx::{PgPool, Row};
15use sqlx_to_json::RowExt;
16use std::collections::HashMap;
17use tracing::info;
18
19/// Maximum number of database connection pools to cache (including the default).
20const POOL_CACHE_CAPACITY: u64 = 6;
21
22/// Converts [`DatabaseConfig`] into [`PgConnectOptions`].
23///
24/// Uses [`PgConnectOptions::new_without_pgpass`] to avoid unintended
25/// `PG*` environment variable influence, since our config already
26/// resolves values from CLI/env.
27impl From<&DatabaseConfig> for PgConnectOptions {
28    fn from(config: &DatabaseConfig) -> Self {
29        let mut opts = PgConnectOptions::new_without_pgpass()
30            .host(&config.host)
31            .port(config.port)
32            .username(&config.user);
33
34        if let Some(ref password) = config.password {
35            opts = opts.password(password);
36        }
37        if let Some(ref name) = config.name
38            && !name.is_empty()
39        {
40            opts = opts.database(name);
41        }
42
43        if config.ssl {
44            opts = if config.ssl_verify_cert {
45                opts.ssl_mode(PgSslMode::VerifyCa)
46            } else {
47                opts.ssl_mode(PgSslMode::Require)
48            };
49            if let Some(ref ca) = config.ssl_ca {
50                opts = opts.ssl_root_cert(ca);
51            }
52            if let Some(ref cert) = config.ssl_cert {
53                opts = opts.ssl_client_cert(cert);
54            }
55            if let Some(ref key) = config.ssl_key {
56                opts = opts.ssl_client_key(key);
57            }
58        }
59
60        opts
61    }
62}
63
64/// `PostgreSQL` database backend.
65///
66/// All connection pools — including the default — live in a single
67/// concurrent cache keyed by database name. No external mutex required.
68#[derive(Clone)]
69pub struct PostgresBackend {
70    config: DatabaseConfig,
71    default_db: String,
72    pools: Cache<String, PgPool>,
73    pub read_only: bool,
74}
75
76impl std::fmt::Debug for PostgresBackend {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.debug_struct("PostgresBackend")
79            .field("read_only", &self.read_only)
80            .field("default_db", &self.default_db)
81            .finish_non_exhaustive()
82    }
83}
84
85impl PostgresBackend {
86    /// Creates a new `PostgreSQL` backend from configuration.
87    ///
88    /// Stores a clone of the configuration for constructing connection options
89    /// for non-default databases at runtime. The initial pool is placed into
90    /// the shared cache keyed by the configured database name.
91    ///
92    /// # Errors
93    ///
94    /// Returns [`AppError::Connection`] if the connection fails.
95    pub async fn new(config: &DatabaseConfig) -> Result<Self, AppError> {
96        let pool = PgPoolOptions::new()
97            .max_connections(config.max_pool_size)
98            .connect_with(config.into())
99            .await
100            .map_err(|e| AppError::Connection(format!("Failed to connect to PostgreSQL: {e}")))?;
101
102        info!(
103            "PostgreSQL connection pool initialized (max size: {})",
104            config.max_pool_size
105        );
106
107        // PostgreSQL defaults to a database named after the connecting user.
108        let default_db = config
109            .name
110            .as_deref()
111            .filter(|n| !n.is_empty())
112            .map_or_else(|| config.user.clone(), String::from);
113
114        let pools = Cache::builder()
115            .max_capacity(POOL_CACHE_CAPACITY)
116            .eviction_listener(|_key, pool: PgPool, _cause| {
117                tokio::spawn(async move {
118                    pool.close().await;
119                });
120            })
121            .build();
122
123        pools.insert(default_db.clone(), pool).await;
124
125        Ok(Self {
126            config: config.clone(),
127            default_db,
128            pools,
129            read_only: config.read_only,
130        })
131    }
132}
133
134impl PostgresBackend {
135    /// Wraps `name` in double quotes for safe use in `PostgreSQL` SQL statements.
136    ///
137    /// Escapes internal double quotes by doubling them.
138    fn quote_identifier(name: &str) -> String {
139        let escaped = name.replace('"', "\"\"");
140        format!("\"{escaped}\"")
141    }
142
143    /// Returns a connection pool for the requested database.
144    ///
145    /// Resolves `None` or empty names to the default pool. On a cache miss
146    /// a new pool is created and cached. Evicted pools are closed via the
147    /// cache's eviction listener.
148    ///
149    /// # Errors
150    ///
151    /// Returns [`AppError::InvalidIdentifier`] if the database name fails
152    /// validation, or [`AppError::Connection`] if the new pool cannot connect.
153    async fn get_pool(&self, database: Option<&str>) -> Result<PgPool, AppError> {
154        let db_key = match database {
155            Some(name) if !name.is_empty() => name,
156            _ => &self.default_db,
157        };
158
159        if let Some(pool) = self.pools.get(db_key).await {
160            return Ok(pool);
161        }
162
163        // Cache miss — validate then create a new pool.
164        validate_identifier(db_key)?;
165
166        let config = self.config.clone();
167        let db_key_owned = db_key.to_owned();
168
169        let pool = self
170            .pools
171            .try_get_with(db_key_owned, async {
172                let mut cfg = config;
173                cfg.name = Some(db_key.to_owned());
174                PgPoolOptions::new()
175                    .max_connections(cfg.max_pool_size)
176                    .connect_with((&cfg).into())
177                    .await
178                    .map_err(|e| {
179                        AppError::Connection(format!("Failed to connect to PostgreSQL database '{db_key}': {e}"))
180                    })
181            })
182            .await
183            .map_err(|e| match e.as_ref() {
184                AppError::Connection(msg) => AppError::Connection(msg.clone()),
185                other => AppError::Connection(other.to_string()),
186            })?;
187
188        Ok(pool)
189    }
190}
191
192impl DatabaseBackend for PostgresBackend {
193    // `list_databases` uses the default pool intentionally — `pg_database`
194    // is a server-wide catalog that returns all databases regardless of
195    // which database the connection targets.
196    async fn list_databases(&self) -> Result<Vec<String>, AppError> {
197        let pool = self.get_pool(None).await?;
198        let rows: Vec<(String,)> =
199            sqlx::query_as("SELECT datname FROM pg_database WHERE datistemplate = false ORDER BY datname")
200                .fetch_all(&pool)
201                .await
202                .map_err(|e| AppError::Query(e.to_string()))?;
203        Ok(rows.into_iter().map(|r| r.0).collect())
204    }
205
206    async fn list_tables(&self, database: &str) -> Result<Vec<String>, AppError> {
207        let db = if database.is_empty() { None } else { Some(database) };
208        let pool = self.get_pool(db).await?;
209        let rows: Vec<(String,)> =
210            sqlx::query_as("SELECT tablename FROM pg_tables WHERE schemaname = 'public' ORDER BY tablename")
211                .fetch_all(&pool)
212                .await
213                .map_err(|e| AppError::Query(e.to_string()))?;
214        Ok(rows.into_iter().map(|r| r.0).collect())
215    }
216
217    async fn get_table_schema(&self, database: &str, table: &str) -> Result<Value, AppError> {
218        validate_identifier(table)?;
219        let db = if database.is_empty() { None } else { Some(database) };
220        let pool = self.get_pool(db).await?;
221        let rows: Vec<PgRow> = sqlx::query(
222            r"SELECT column_name, data_type, is_nullable, column_default,
223                      character_maximum_length
224               FROM information_schema.columns
225               WHERE table_schema = 'public' AND table_name = $1
226               ORDER BY ordinal_position",
227        )
228        .bind(table)
229        .fetch_all(&pool)
230        .await
231        .map_err(|e| AppError::Query(e.to_string()))?;
232
233        if rows.is_empty() {
234            return Err(AppError::TableNotFound(table.to_string()));
235        }
236
237        let mut schema: HashMap<String, Value> = HashMap::new();
238        for row in &rows {
239            let col_name: String = row.try_get("column_name").unwrap_or_default();
240            let data_type: String = row.try_get("data_type").unwrap_or_default();
241            let nullable: String = row.try_get("is_nullable").unwrap_or_default();
242            let default: Option<String> = row.try_get("column_default").ok();
243            schema.insert(
244                col_name,
245                json!({
246                    "type": data_type,
247                    "nullable": nullable.to_uppercase() == "YES",
248                    "key": Value::Null,
249                    "default": default,
250                    "extra": Value::Null,
251                }),
252            );
253        }
254        Ok(json!(schema))
255    }
256
257    async fn get_table_schema_with_relations(&self, database: &str, table: &str) -> Result<Value, AppError> {
258        let schema = self.get_table_schema(database, table).await?;
259        let mut columns: HashMap<String, Value> = serde_json::from_value(schema).unwrap_or_default();
260
261        // Add null foreign_key to all columns
262        for col in columns.values_mut() {
263            if let Some(obj) = col.as_object_mut() {
264                obj.entry("foreign_key".to_string()).or_insert(Value::Null);
265            }
266        }
267
268        // Get FK info using the same pool as the schema query
269        let db = if database.is_empty() { None } else { Some(database) };
270        let pool = self.get_pool(db).await?;
271        let fk_rows: Vec<PgRow> = sqlx::query(
272            r"SELECT
273                kcu.column_name,
274                tc.constraint_name,
275                ccu.table_name AS referenced_table,
276                ccu.column_name AS referenced_column,
277                rc.update_rule AS on_update,
278                rc.delete_rule AS on_delete
279            FROM information_schema.table_constraints tc
280            JOIN information_schema.key_column_usage kcu
281                ON tc.constraint_name = kcu.constraint_name
282                AND tc.table_schema = kcu.table_schema
283            JOIN information_schema.constraint_column_usage ccu
284                ON ccu.constraint_name = tc.constraint_name
285                AND ccu.table_schema = tc.table_schema
286            JOIN information_schema.referential_constraints rc
287                ON rc.constraint_name = tc.constraint_name
288                AND rc.constraint_schema = tc.table_schema
289            WHERE tc.constraint_type = 'FOREIGN KEY'
290                AND tc.table_name = $1
291                AND tc.table_schema = 'public'",
292        )
293        .bind(table)
294        .fetch_all(&pool)
295        .await
296        .map_err(|e| AppError::Query(e.to_string()))?;
297
298        for fk_row in &fk_rows {
299            let col_name: String = fk_row.try_get("column_name").unwrap_or_default();
300            if let Some(col_info) = columns.get_mut(&col_name)
301                && let Some(obj) = col_info.as_object_mut()
302            {
303                obj.insert(
304                    "foreign_key".to_string(),
305                    json!({
306                        "constraint_name": fk_row.try_get::<String, _>("constraint_name").ok(),
307                        "referenced_table": fk_row.try_get::<String, _>("referenced_table").ok(),
308                        "referenced_column": fk_row.try_get::<String, _>("referenced_column").ok(),
309                        "on_update": fk_row.try_get::<String, _>("on_update").ok(),
310                        "on_delete": fk_row.try_get::<String, _>("on_delete").ok(),
311                    }),
312                );
313            }
314        }
315
316        Ok(json!({
317            "table_name": table,
318            "columns": columns,
319        }))
320    }
321
322    async fn execute_query(&self, sql: &str, database: Option<&str>) -> Result<Value, AppError> {
323        let pool = self.get_pool(database).await?;
324        let rows: Vec<PgRow> = sqlx::query(sql)
325            .fetch_all(&pool)
326            .await
327            .map_err(|e| AppError::Query(e.to_string()))?;
328        Ok(Value::Array(rows.iter().map(RowExt::to_json).collect()))
329    }
330
331    async fn create_database(&self, name: &str) -> Result<Value, AppError> {
332        if self.read_only {
333            return Err(AppError::ReadOnlyViolation);
334        }
335        validate_identifier(name)?;
336
337        let pool = self.get_pool(None).await?;
338
339        // PostgreSQL CREATE DATABASE can't use parameterized queries
340        sqlx::query(&format!("CREATE DATABASE {}", Self::quote_identifier(name)))
341            .execute(&pool)
342            .await
343            .map_err(|e| {
344                let msg = e.to_string();
345                if msg.contains("already exists") {
346                    return AppError::Query(format!("Database '{name}' already exists."));
347                }
348                AppError::Query(msg)
349            })?;
350
351        Ok(json!({
352            "status": "success",
353            "message": format!("Database '{name}' created successfully."),
354            "database_name": name,
355        }))
356    }
357
358    fn dialect(&self) -> Box<dyn sqlparser::dialect::Dialect> {
359        Box::new(sqlparser::dialect::PostgreSqlDialect {})
360    }
361
362    fn read_only(&self) -> bool {
363        self.read_only
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::config::DatabaseBackend;
371
372    fn base_config() -> DatabaseConfig {
373        DatabaseConfig {
374            backend: DatabaseBackend::Postgres,
375            host: "pg.example.com".into(),
376            port: 5433,
377            user: "pgadmin".into(),
378            password: Some("pgpass".into()),
379            name: Some("mydb".into()),
380            ..DatabaseConfig::default()
381        }
382    }
383
384    #[test]
385    fn quote_identifier_wraps_in_double_quotes() {
386        assert_eq!(PostgresBackend::quote_identifier("users"), "\"users\"");
387        assert_eq!(PostgresBackend::quote_identifier("eu-docker"), "\"eu-docker\"");
388    }
389
390    #[test]
391    fn quote_identifier_escapes_double_quotes() {
392        assert_eq!(PostgresBackend::quote_identifier("test\"db"), "\"test\"\"db\"");
393        assert_eq!(PostgresBackend::quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
394    }
395
396    #[test]
397    fn try_from_basic_config() {
398        let config = base_config();
399        let opts = PgConnectOptions::from(&config);
400
401        assert_eq!(opts.get_host(), "pg.example.com");
402        assert_eq!(opts.get_port(), 5433);
403        assert_eq!(opts.get_username(), "pgadmin");
404        assert_eq!(opts.get_database(), Some("mydb"));
405    }
406
407    #[test]
408    fn try_from_with_ssl_require() {
409        let config = DatabaseConfig {
410            ssl: true,
411            ssl_verify_cert: false,
412            ..base_config()
413        };
414        let opts = PgConnectOptions::from(&config);
415
416        assert!(
417            matches!(opts.get_ssl_mode(), PgSslMode::Require),
418            "expected Require, got {:?}",
419            opts.get_ssl_mode()
420        );
421    }
422
423    #[test]
424    fn try_from_with_ssl_verify_ca() {
425        let config = DatabaseConfig {
426            ssl: true,
427            ssl_verify_cert: true,
428            ..base_config()
429        };
430        let opts = PgConnectOptions::from(&config);
431
432        assert!(
433            matches!(opts.get_ssl_mode(), PgSslMode::VerifyCa),
434            "expected VerifyCa, got {:?}",
435            opts.get_ssl_mode()
436        );
437    }
438
439    #[test]
440    fn try_from_without_database_name() {
441        let config = DatabaseConfig {
442            name: None,
443            ..base_config()
444        };
445        let opts = PgConnectOptions::from(&config);
446
447        assert_eq!(opts.get_database(), None);
448    }
449
450    #[test]
451    fn try_from_without_password() {
452        let config = DatabaseConfig {
453            password: None,
454            ..base_config()
455        };
456        let opts = PgConnectOptions::from(&config);
457
458        assert_eq!(opts.get_host(), "pg.example.com");
459    }
460}