postgres_mcp/
pg.rs

1use arc_swap::ArcSwap;
2use serde::{Deserialize, Serialize};
3use sqlparser::ast::Statement;
4use sqlx::postgres::PgPool;
5use std::collections::HashMap;
6use std::sync::Arc;
7use thiserror::Error;
8
9#[allow(unused)]
10#[derive(Error, Debug)]
11pub enum PgMcpError {
12    #[error("Connection not found for ID: {0}")]
13    ConnectionNotFound(String),
14
15    #[error("SQL validation failed for query '{query}': {kind}")]
16    ValidationFailed {
17        kind: ValidationErrorKind,
18        query: String,
19        details: String,
20    },
21
22    #[error("Database operation '{operation}' failed: {underlying}")]
23    DatabaseError {
24        operation: String,
25        underlying: String,
26    },
27
28    #[error("Serialization failed: {0}")]
29    SerializationError(#[from] serde_json::Error),
30
31    #[error("Database connection failed: {0}")]
32    ConnectionError(String),
33
34    #[error("Internal error: {0}")]
35    InternalError(String),
36}
37
38#[derive(Error, Debug)]
39pub enum ValidationErrorKind {
40    #[error("Invalid statement type, expected {expected}")]
41    InvalidStatementType { expected: String },
42    #[error("Failed to parse SQL")]
43    ParseError,
44}
45
46impl From<sqlx::Error> for PgMcpError {
47    fn from(e: sqlx::Error) -> Self {
48        let msg = e.to_string();
49        if let Some(db_err) = e.as_database_error() {
50            PgMcpError::DatabaseError {
51                operation: "unknown".to_string(),
52                underlying: db_err.to_string(),
53            }
54        } else if msg.contains("error connecting") || msg.contains("timed out") {
55            PgMcpError::ConnectionError(msg)
56        } else {
57            PgMcpError::DatabaseError {
58                operation: "unknown".to_string(),
59                underlying: msg,
60            }
61        }
62    }
63}
64
65#[allow(dead_code)]
66#[derive(Debug, Clone)]
67pub(crate) struct Conn {
68    pub(crate) id: String,
69    pub(crate) conn_str: String,
70    pub(crate) pool: PgPool,
71}
72
73#[derive(Debug, Clone)]
74pub struct Conns {
75    pub(crate) inner: Arc<ArcSwap<HashMap<String, Conn>>>,
76}
77
78#[derive(Debug, Clone)]
79pub struct PgMcp {
80    pub(crate) conns: Conns,
81}
82
83#[derive(Debug, sqlx::FromRow, Serialize, Deserialize)]
84struct JsonRow {
85    ret: sqlx::types::Json<serde_json::Value>,
86}
87
88impl Conns {
89    pub(crate) fn new() -> Self {
90        Self {
91            inner: Arc::new(ArcSwap::new(Arc::new(HashMap::new()))),
92        }
93    }
94
95    pub(crate) async fn register(&self, conn_str: String) -> Result<String, PgMcpError> {
96        let pool = PgPool::connect(&conn_str)
97            .await
98            .map_err(|e| PgMcpError::ConnectionError(e.to_string()))?;
99        let id = uuid::Uuid::new_v4().to_string();
100        let conn = Conn {
101            id: id.clone(),
102            conn_str: conn_str.clone(),
103            pool,
104        };
105
106        let mut conns = self.inner.load().as_ref().clone();
107        conns.insert(id.clone(), conn);
108        self.inner.store(Arc::new(conns));
109
110        Ok(id)
111    }
112
113    pub(crate) fn unregister(&self, id: String) -> Result<(), PgMcpError> {
114        let mut conns = self.inner.load().as_ref().clone();
115        if conns.remove(&id).is_none() {
116            return Err(PgMcpError::ConnectionNotFound(id));
117        }
118        self.inner.store(Arc::new(conns));
119        Ok(())
120    }
121
122    pub(crate) async fn query(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
123        let operation = "query (SELECT)";
124        let conns = self.inner.load();
125        let conn = conns
126            .get(id)
127            .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
128
129        let validated_query =
130            validate_sql(query, |stmt| matches!(stmt, Statement::Query(_)), "SELECT")?;
131
132        let prepared_query = format!(
133            "WITH data AS ({}) SELECT JSON_AGG(data.*) as ret FROM data;",
134            validated_query
135        );
136
137        let ret = sqlx::query_as::<_, JsonRow>(&prepared_query)
138            .fetch_one(&conn.pool)
139            .await
140            .map_err(|e| PgMcpError::DatabaseError {
141                operation: operation.to_string(),
142                underlying: e.to_string(),
143            })?;
144
145        Ok(serde_json::to_string(&ret.ret)?)
146    }
147
148    pub(crate) async fn insert(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
149        let operation = "insert (INSERT)";
150        let conns = self.inner.load();
151        let conn = conns
152            .get(id)
153            .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
154
155        let validated_query = validate_sql(
156            query,
157            |stmt| matches!(stmt, Statement::Insert { .. }),
158            "INSERT",
159        )?;
160
161        let result = sqlx::query(&validated_query)
162            .execute(&conn.pool)
163            .await
164            .map_err(|e| PgMcpError::DatabaseError {
165                operation: operation.to_string(),
166                underlying: e.to_string(),
167            })?;
168
169        Ok(format!(
170            "success, rows_affected: {}",
171            result.rows_affected()
172        ))
173    }
174
175    pub(crate) async fn update(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
176        let operation = "update (UPDATE)";
177        let conns = self.inner.load();
178        let conn = conns
179            .get(id)
180            .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
181
182        let validated_query = validate_sql(
183            query,
184            |stmt| matches!(stmt, Statement::Update { .. }),
185            "UPDATE",
186        )?;
187
188        let result = sqlx::query(&validated_query)
189            .execute(&conn.pool)
190            .await
191            .map_err(|e| PgMcpError::DatabaseError {
192                operation: operation.to_string(),
193                underlying: e.to_string(),
194            })?;
195
196        Ok(format!(
197            "success, rows_affected: {}",
198            result.rows_affected()
199        ))
200    }
201
202    pub(crate) async fn delete(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
203        let operation = "delete (DELETE)";
204        let conns = self.inner.load();
205        let conn = conns
206            .get(id)
207            .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
208
209        let validated_query = validate_sql(
210            query,
211            |stmt| matches!(stmt, Statement::Delete { .. }),
212            "DELETE",
213        )?;
214
215        let result = sqlx::query(&validated_query)
216            .execute(&conn.pool)
217            .await
218            .map_err(|e| PgMcpError::DatabaseError {
219                operation: operation.to_string(),
220                underlying: e.to_string(),
221            })?;
222
223        Ok(format!(
224            "success, rows_affected: {}",
225            result.rows_affected()
226        ))
227    }
228
229    pub(crate) async fn create_table(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
230        let operation = "create_table (CREATE TABLE)";
231        let conns = self.inner.load();
232        let conn = conns
233            .get(id)
234            .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
235
236        let validated_query = validate_sql(
237            query,
238            |stmt| matches!(stmt, Statement::CreateTable { .. }),
239            "CREATE TABLE",
240        )?;
241
242        sqlx::query(&validated_query)
243            .execute(&conn.pool)
244            .await
245            .map_err(|e| PgMcpError::DatabaseError {
246                operation: operation.to_string(),
247                underlying: e.to_string(),
248            })?;
249
250        Ok("success".to_string())
251    }
252
253    pub(crate) async fn drop_table(&self, id: &str, table: &str) -> Result<String, PgMcpError> {
254        let operation = format!("drop_table (DROP TABLE {})", table);
255        let conns = self.inner.load();
256        let conn = conns
257            .get(id)
258            .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
259
260        let query = format!("DROP TABLE {}", table);
261        sqlx::query(&query)
262            .execute(&conn.pool)
263            .await
264            .map_err(|e| PgMcpError::DatabaseError {
265                operation,
266                underlying: e.to_string(),
267            })?;
268
269        Ok("success".to_string())
270    }
271
272    pub(crate) async fn create_index(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
273        let operation = "create_index (CREATE INDEX)";
274        let conns = self.inner.load();
275        let conn = conns
276            .get(id)
277            .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
278
279        let validated_query = validate_sql(
280            query,
281            |stmt| matches!(stmt, Statement::CreateIndex { .. }),
282            "CREATE INDEX",
283        )?;
284
285        sqlx::query(&validated_query)
286            .execute(&conn.pool)
287            .await
288            .map_err(|e| PgMcpError::DatabaseError {
289                operation: operation.to_string(),
290                underlying: e.to_string(),
291            })?;
292
293        Ok("success".to_string())
294    }
295
296    pub(crate) async fn drop_index(&self, id: &str, index: &str) -> Result<String, PgMcpError> {
297        let operation = format!("drop_index (DROP INDEX {})", index);
298        let conns = self.inner.load();
299        let conn = conns
300            .get(id)
301            .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
302
303        let query = format!("DROP INDEX {}", index);
304        sqlx::query(&query)
305            .execute(&conn.pool)
306            .await
307            .map_err(|e| PgMcpError::DatabaseError {
308                operation,
309                underlying: e.to_string(),
310            })?;
311
312        Ok("success".to_string())
313    }
314
315    pub(crate) async fn describe(&self, id: &str, table: &str) -> Result<String, PgMcpError> {
316        let operation = format!("describe (table: {})", table);
317        let conns = self.inner.load();
318        let conn = conns
319            .get(id)
320            .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
321
322        let query = r#"
323        WITH data AS (
324          SELECT column_name, data_type, character_maximum_length, column_default, is_nullable
325          FROM information_schema.columns
326          WHERE table_name = $1
327          ORDER BY ordinal_position)
328        SELECT JSON_AGG(data.*) as ret FROM data"#;
329
330        let ret = sqlx::query_as::<_, JsonRow>(query)
331            .bind(table)
332            .fetch_one(&conn.pool)
333            .await
334            .map_err(|e| PgMcpError::DatabaseError {
335                operation: operation.to_string(),
336                underlying: e.to_string(),
337            })?;
338
339        Ok(serde_json::to_string(&ret.ret)?)
340    }
341
342    pub(crate) async fn list_tables(&self, id: &str, schema: &str) -> Result<String, PgMcpError> {
343        let operation = format!("list_tables (schema: {})", schema);
344        let conns = self.inner.load();
345        let conn = conns
346            .get(id)
347            .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
348
349        let query = r#"
350        WITH data AS (
351          SELECT
352                t.table_name,
353                obj_description(format('%s.%s', t.table_schema, t.table_name)::regclass::oid) as description,
354                pg_stat_get_tuples_inserted(format('%s.%s', t.table_schema, t.table_name)::regclass::oid) as total_rows
355            FROM information_schema.tables t
356            WHERE
357                t.table_schema = $1
358                AND t.table_type = 'BASE TABLE'
359            ORDER BY t.table_name
360        )
361        SELECT JSON_AGG(data.*) as ret FROM data"#;
362        let ret = sqlx::query_as::<_, JsonRow>(query)
363            .bind(schema)
364            .fetch_one(&conn.pool)
365            .await
366            .or_else(|e| {
367                if let sqlx::Error::RowNotFound = e {
368                    Ok(JsonRow {
369                        ret: sqlx::types::Json(serde_json::json!([])),
370                    })
371                } else {
372                    Err(PgMcpError::DatabaseError {
373                        operation: operation.to_string(),
374                        underlying: e.to_string(),
375                    })
376                }
377            })?;
378
379        Ok(serde_json::to_string(&ret.ret)?)
380    }
381
382    pub(crate) async fn create_schema(
383        &self,
384        id: &str,
385        schema_name: &str,
386    ) -> Result<String, PgMcpError> {
387        let operation = format!("create_schema (CREATE SCHEMA {})", schema_name);
388        let conns = self.inner.load();
389        let conn = conns
390            .get(id)
391            .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
392
393        let query = format!("CREATE SCHEMA {}", schema_name);
394        sqlx::query(&query)
395            .execute(&conn.pool)
396            .await
397            .map_err(|e| PgMcpError::DatabaseError {
398                operation,
399                underlying: e.to_string(),
400            })?;
401
402        Ok("success".to_string())
403    }
404
405    pub(crate) async fn create_type(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
406        let operation = "create_type (CREATE TYPE)";
407        let conns = self.inner.load();
408        let conn = conns
409            .get(id)
410            .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
411
412        let validated_query = validate_sql(
413            query,
414            |stmt| matches!(stmt, Statement::CreateType { .. }),
415            "CREATE TYPE",
416        )?;
417
418        sqlx::query(&validated_query)
419            .execute(&conn.pool)
420            .await
421            .map_err(|e| PgMcpError::DatabaseError {
422                operation: operation.to_string(),
423                underlying: e.to_string(),
424            })?;
425
426        Ok("success".to_string())
427    }
428}
429
430impl Default for Conns {
431    fn default() -> Self {
432        Self::new()
433    }
434}
435
436fn validate_sql<F>(
437    query: &str,
438    validator: F,
439    expected_type: &'static str,
440) -> Result<String, PgMcpError>
441where
442    F: Fn(&Statement) -> bool,
443{
444    let dialect = sqlparser::dialect::PostgreSqlDialect {};
445    let statements = sqlparser::parser::Parser::parse_sql(&dialect, query).map_err(|e| {
446        PgMcpError::ValidationFailed {
447            kind: ValidationErrorKind::ParseError,
448            query: query.to_string(),
449            details: e.to_string(),
450        }
451    })?;
452
453    if statements.len() != 1 {
454        return Err(PgMcpError::ValidationFailed {
455            kind: ValidationErrorKind::InvalidStatementType {
456                expected: expected_type.to_string(),
457            },
458            query: query.to_string(),
459            details: format!(
460                "Expected exactly one SQL statement, found {}",
461                statements.len()
462            ),
463        });
464    }
465
466    let stmt = &statements[0];
467    if !validator(stmt) {
468        return Err(PgMcpError::ValidationFailed {
469            kind: ValidationErrorKind::InvalidStatementType {
470                expected: expected_type.to_string(),
471            },
472            query: query.to_string(),
473            details: format!("Statement type validation failed. Received: {:?}", stmt),
474        });
475    }
476
477    Ok(query.to_string())
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use sqlx_db_tester::TestPg;
484
485    const TEST_CONN_STR: &str = "postgres://postgres:postgres@localhost:5432/postgres";
486
487    async fn setup_test_db() -> (TestPg, String) {
488        let tdb = TestPg::new(
489            TEST_CONN_STR.to_string(),
490            std::path::Path::new("./fixtures/migrations"),
491        );
492        let pool = tdb.get_pool().await;
493
494        sqlx::query("SELECT * FROM test_table LIMIT 1")
495            .execute(&pool)
496            .await
497            .unwrap();
498
499        let conn_str = tdb.url();
500
501        (tdb, conn_str)
502    }
503
504    #[tokio::test]
505    async fn register_unregister_should_work() {
506        let (_tdb, conn_str) = setup_test_db().await;
507        let conns = Conns::new();
508
509        let id = conns.register(conn_str.clone()).await.unwrap();
510        assert!(!id.is_empty());
511
512        assert!(conns.unregister(id.clone()).is_ok());
513        assert!(conns.unregister(id).is_err());
514    }
515
516    #[tokio::test]
517    async fn list_tables_describe_should_work() {
518        let (_tdb, conn_str) = setup_test_db().await;
519        let conns = Conns::new();
520        let id = conns.register(conn_str).await.unwrap();
521
522        let tables = conns.list_tables(&id, "public").await.unwrap();
523        assert!(tables.contains("test_table"));
524
525        let description = conns.describe(&id, "test_table").await.unwrap();
526        assert!(description.contains("id"));
527        assert!(description.contains("name"));
528        assert!(description.contains("created_at"));
529    }
530
531    #[tokio::test]
532    async fn create_table_drop_table_should_work() {
533        let (_tdb, conn_str) = setup_test_db().await;
534        let conns = Conns::new();
535        let id = conns.register(conn_str).await.unwrap();
536
537        let create_table = "CREATE TABLE test_table2 (id SERIAL PRIMARY KEY, name TEXT)";
538        assert_eq!(
539            conns.create_table(&id, create_table).await.unwrap(),
540            "success"
541        );
542
543        assert_eq!(
544            conns.drop_table(&id, "test_table2").await.unwrap(),
545            "success"
546        );
547
548        assert!(conns.drop_table(&id, "test_table2").await.is_err());
549    }
550
551    #[tokio::test]
552    async fn query_insert_update_delete_should_work() {
553        let (_tdb, conn_str) = setup_test_db().await;
554        let conns = Conns::new();
555        let id = conns.register(conn_str).await.unwrap();
556
557        let query = "SELECT * FROM test_table ORDER BY id";
558        let result = conns.query(&id, query).await.unwrap();
559        assert!(result.contains("test1"));
560        assert!(result.contains("test2"));
561        assert!(result.contains("test3"));
562
563        let insert = "INSERT INTO test_table (name) VALUES ('test4')";
564        let result = conns.insert(&id, insert).await.unwrap();
565        assert!(result.contains("rows_affected: 1"));
566
567        let update = "UPDATE test_table SET name = 'updated' WHERE name = 'test1'";
568        let result = conns.update(&id, update).await.unwrap();
569        assert!(result.contains("rows_affected: 1"));
570
571        let result = conns
572            .delete(&id, "DELETE FROM test_table WHERE name = 'updated'")
573            .await
574            .unwrap();
575        assert!(result.contains("rows_affected: 1"));
576    }
577
578    #[tokio::test]
579    async fn create_index_drop_index_should_work() {
580        let (_tdb, conn_str) = setup_test_db().await;
581        let conns = Conns::new();
582        let id = conns.register(conn_str).await.unwrap();
583
584        let create_index = "CREATE INDEX idx_test_table_new ON test_table (name, created_at)";
585        assert_eq!(
586            conns.create_index(&id, create_index).await.unwrap(),
587            "success"
588        );
589
590        assert_eq!(
591            conns.drop_index(&id, "idx_test_table_new").await.unwrap(),
592            "success"
593        );
594    }
595
596    #[tokio::test]
597    async fn sql_validation_should_work() {
598        let (_tdb, conn_str) = setup_test_db().await;
599        let conns = Conns::new();
600        let id = conns.register(conn_str).await.unwrap();
601
602        let invalid_query = "INSERT INTO test_table VALUES (1)";
603        assert!(conns.query(&id, invalid_query).await.is_err());
604
605        let invalid_insert = "SELECT * FROM test_table";
606        assert!(conns.insert(&id, invalid_insert).await.is_err());
607
608        let invalid_update = "DELETE FROM test_table";
609        assert!(conns.update(&id, invalid_update).await.is_err());
610
611        let invalid_create = "CREATE INDEX idx_test ON test_table (id)";
612        assert!(conns.create_table(&id, invalid_create).await.is_err());
613
614        let invalid_index = "CREATE TABLE test (id INT)";
615        assert!(conns.create_index(&id, invalid_index).await.is_err());
616    }
617
618    #[tokio::test]
619    async fn create_type_should_work() {
620        let (_tdb, conn_str) = setup_test_db().await;
621        let conns = Conns::new();
622        let id = conns.register(conn_str).await.unwrap();
623
624        let create_type = "CREATE TYPE user_role AS ENUM ('admin', 'user')";
625        assert_eq!(
626            conns.create_type(&id, create_type).await.unwrap(),
627            "success"
628        );
629
630        let invalid_type = "CREATE TABLE test (id INT)";
631        assert!(conns.create_type(&id, invalid_type).await.is_err());
632    }
633
634    #[tokio::test]
635    async fn create_schema_should_work() {
636        let (_tdb, conn_str) = setup_test_db().await;
637        let conns = Conns::new();
638        let id = conns.register(conn_str).await.unwrap();
639
640        let schema_name = "test_schema_unit";
641        assert_eq!(
642            conns.create_schema(&id, schema_name).await.unwrap(),
643            "success"
644        );
645
646        let query = format!(
647            "SELECT schema_name FROM information_schema.schemata WHERE schema_name = '{}'",
648            schema_name
649        );
650        let _result = sqlx::query(&query)
651            .fetch_one(&conns.inner.load().get(&id).unwrap().pool)
652            .await
653            .unwrap();
654
655        let invalid_schema_name = "test;schema";
656        assert!(conns.create_schema(&id, invalid_schema_name).await.is_err());
657    }
658}