postgres_mcp/
pg.rs

1use anyhow::Error;
2use arc_swap::ArcSwap;
3use serde::{Deserialize, Serialize};
4use sqlparser::ast::Statement;
5use sqlx::postgres::PgPool;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9#[allow(dead_code)]
10#[derive(Debug, Clone)]
11pub(crate) struct Conn {
12    pub(crate) id: String,
13    pub(crate) conn_str: String,
14    pub(crate) pool: PgPool,
15}
16
17#[derive(Debug, Clone)]
18pub struct Conns {
19    pub(crate) inner: Arc<ArcSwap<HashMap<String, Conn>>>,
20}
21
22#[derive(Debug, Clone)]
23pub struct PgMcp {
24    pub(crate) conns: Conns,
25}
26
27#[derive(Debug, sqlx::FromRow, Serialize, Deserialize)]
28struct JsonRow {
29    ret: sqlx::types::Json<serde_json::Value>,
30}
31
32impl Conns {
33    pub(crate) fn new() -> Self {
34        Self {
35            inner: Arc::new(ArcSwap::new(Arc::new(HashMap::new()))),
36        }
37    }
38
39    pub(crate) async fn register(&self, conn_str: String) -> Result<String, Error> {
40        let pool = PgPool::connect(&conn_str).await?;
41        let id = uuid::Uuid::new_v4().to_string();
42        let conn = Conn {
43            id: id.clone(),
44            conn_str: conn_str.clone(),
45            pool,
46        };
47
48        let mut conns = self.inner.load().as_ref().clone();
49        conns.insert(id.clone(), conn);
50        self.inner.store(Arc::new(conns));
51
52        Ok(id)
53    }
54
55    pub(crate) fn unregister(&self, id: String) -> Result<(), Error> {
56        let mut conns = self.inner.load().as_ref().clone();
57        if conns.remove(&id).is_none() {
58            return Err(anyhow::anyhow!("Connection not found"));
59        }
60        self.inner.store(Arc::new(conns));
61        Ok(())
62    }
63
64    pub(crate) async fn query(&self, id: &str, query: &str) -> Result<String, Error> {
65        let conns = self.inner.load();
66        let conn = conns
67            .get(id)
68            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
69
70        let query = validate_sql(
71            query,
72            |stmt| matches!(stmt, Statement::Query(_)),
73            "Only SELECT queries are allowed",
74        )?;
75
76        let query = format!(
77            "WITH data AS ({}) SELECT JSON_AGG(data.*) as ret FROM data;",
78            query
79        );
80
81        let ret = sqlx::query_as::<_, JsonRow>(&query)
82            .fetch_one(&conn.pool)
83            .await?;
84
85        Ok(serde_json::to_string(&ret.ret)?)
86    }
87
88    pub(crate) async fn insert(&self, id: &str, query: &str) -> Result<String, Error> {
89        let conns = self.inner.load();
90        let conn = conns
91            .get(id)
92            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
93
94        let query = validate_sql(
95            query,
96            |stmt| matches!(stmt, Statement::Insert { .. }),
97            "Only INSERT statements are allowed",
98        )?;
99
100        let result = sqlx::query(&query).execute(&conn.pool).await?;
101
102        Ok(format!(
103            "success, rows_affected: {}",
104            result.rows_affected()
105        ))
106    }
107
108    pub(crate) async fn update(&self, id: &str, query: &str) -> Result<String, Error> {
109        let conns = self.inner.load();
110        let conn = conns
111            .get(id)
112            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
113
114        let query = validate_sql(
115            query,
116            |stmt| matches!(stmt, Statement::Update { .. }),
117            "Only UPDATE statements are allowed",
118        )?;
119
120        let result = sqlx::query(&query).execute(&conn.pool).await?;
121
122        Ok(format!(
123            "success, rows_affected: {}",
124            result.rows_affected()
125        ))
126    }
127
128    pub(crate) async fn delete(&self, id: &str, query: &str) -> Result<String, Error> {
129        let conns = self.inner.load();
130        let conn = conns
131            .get(id)
132            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
133
134        let query = validate_sql(
135            query,
136            |stmt| matches!(stmt, Statement::Delete { .. }),
137            "Only DELETE statements are allowed",
138        )?;
139
140        let result = sqlx::query(&query).execute(&conn.pool).await?;
141
142        Ok(format!(
143            "success, rows_affected: {}",
144            result.rows_affected()
145        ))
146    }
147
148    pub(crate) async fn create_table(&self, id: &str, query: &str) -> Result<String, Error> {
149        let conns = self.inner.load();
150        let conn = conns
151            .get(id)
152            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
153
154        let query = validate_sql(
155            query,
156            |stmt| matches!(stmt, Statement::CreateTable { .. }),
157            "Only CREATE TABLE statements are allowed",
158        )?;
159
160        sqlx::query(&query).execute(&conn.pool).await?;
161
162        Ok("success".to_string())
163    }
164
165    pub(crate) async fn drop_table(&self, id: &str, table: &str) -> Result<String, Error> {
166        let conns = self.inner.load();
167        let conn = conns
168            .get(id)
169            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
170
171        let query = format!("DROP TABLE {}", table);
172        sqlx::query(&query).execute(&conn.pool).await?;
173
174        Ok("success".to_string())
175    }
176
177    pub(crate) async fn create_index(&self, id: &str, query: &str) -> Result<String, Error> {
178        let conns = self.inner.load();
179        let conn = conns
180            .get(id)
181            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
182
183        let query = validate_sql(
184            query,
185            |stmt| matches!(stmt, Statement::CreateIndex { .. }),
186            "Only CREATE INDEX statements are allowed",
187        )?;
188
189        sqlx::query(&query).execute(&conn.pool).await?;
190
191        Ok("success".to_string())
192    }
193
194    pub(crate) async fn drop_index(&self, id: &str, index: &str) -> Result<String, Error> {
195        let conns = self.inner.load();
196        let conn = conns
197            .get(id)
198            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
199
200        let query = format!("DROP INDEX {}", index);
201        sqlx::query(&query).execute(&conn.pool).await?;
202
203        Ok("success".to_string())
204    }
205
206    pub(crate) async fn describe(&self, id: &str, table: &str) -> Result<String, Error> {
207        let conns = self.inner.load();
208        let conn = conns
209            .get(id)
210            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
211
212        let query = r#"
213        WITH data AS (
214          SELECT column_name, data_type, character_maximum_length, column_default, is_nullable
215          FROM information_schema.columns
216          WHERE table_name = $1
217          ORDER BY ordinal_position)
218        SELECT JSON_AGG(data.*) as ret FROM data"#;
219
220        let ret = sqlx::query_as::<_, JsonRow>(query)
221            .bind(table)
222            .fetch_one(&conn.pool)
223            .await?;
224
225        Ok(serde_json::to_string(&ret.ret)?)
226    }
227
228    pub(crate) async fn list_tables(&self, id: &str, schema: &str) -> Result<String, Error> {
229        let conns = self.inner.load();
230        let conn = conns
231            .get(id)
232            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
233
234        let query = r#"
235        WITH data AS (
236          SELECT
237                t.table_name,
238                obj_description(format('%s.%s', t.table_schema, t.table_name)::regclass::oid) as description,
239                pg_stat_get_tuples_inserted(format('%s.%s', t.table_schema, t.table_name)::regclass::oid) as total_rows
240            FROM information_schema.tables t
241            WHERE
242                t.table_schema = $1
243                AND t.table_type = 'BASE TABLE'
244            ORDER BY t.table_name
245        )
246        SELECT JSON_AGG(data.*) as ret FROM data"#;
247        let ret = sqlx::query_as::<_, JsonRow>(query)
248            .bind(schema)
249            .fetch_one(&conn.pool)
250            .await?;
251
252        Ok(serde_json::to_string(&ret.ret)?)
253    }
254
255    pub(crate) async fn create_type(&self, id: &str, query: &str) -> Result<String, Error> {
256        let conns = self.inner.load();
257        let conn = conns
258            .get(id)
259            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
260
261        let query = validate_sql(
262            query,
263            |stmt| matches!(stmt, Statement::CreateType { .. }),
264            "Only CREATE TYPE statements are allowed",
265        )?;
266
267        sqlx::query(&query).execute(&conn.pool).await?;
268
269        Ok("success".to_string())
270    }
271}
272
273impl Default for Conns {
274    fn default() -> Self {
275        Self::new()
276    }
277}
278
279fn validate_sql<F>(query: &str, validator: F, error_msg: &'static str) -> Result<String, Error>
280where
281    F: Fn(&Statement) -> bool,
282{
283    let dialect = sqlparser::dialect::PostgreSqlDialect {};
284    let ast = sqlparser::parser::Parser::parse_sql(&dialect, query)?;
285    if ast.len() != 1 || !validator(&ast[0]) {
286        return Err(anyhow::anyhow!(error_msg));
287    }
288    Ok(ast[0].to_string())
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use sqlx_db_tester::TestPg;
295
296    const TEST_CONN_STR: &str = "postgres://postgres:postgres@localhost:5432/postgres";
297
298    async fn setup_test_db() -> (TestPg, String) {
299        let tdb = TestPg::new(
300            TEST_CONN_STR.to_string(),
301            std::path::Path::new("./fixtures/migrations"),
302        );
303        let pool = tdb.get_pool().await;
304
305        // Ensure migrations are applied
306        sqlx::query("SELECT * FROM test_table LIMIT 1")
307            .execute(&pool)
308            .await
309            .unwrap();
310
311        let conn_str = tdb.url();
312
313        (tdb, conn_str)
314    }
315
316    #[tokio::test]
317    async fn register_unregister_should_work() {
318        let (_tdb, conn_str) = setup_test_db().await;
319        let conns = Conns::new();
320
321        // Test register
322        let id = conns.register(conn_str.clone()).await.unwrap();
323        assert!(!id.is_empty());
324
325        // Test unregister
326        assert!(conns.unregister(id.clone()).is_ok());
327        assert!(conns.unregister(id).is_err());
328    }
329
330    #[tokio::test]
331    async fn list_tables_describe_should_work() {
332        let (_tdb, conn_str) = setup_test_db().await;
333        let conns = Conns::new();
334        let id = conns.register(conn_str).await.unwrap();
335
336        // Test list tables
337        let tables = conns.list_tables(&id, "public").await.unwrap();
338        assert!(tables.contains("test_table"));
339
340        // Test describe table
341        let description = conns.describe(&id, "test_table").await.unwrap();
342        assert!(description.contains("id"));
343        assert!(description.contains("name"));
344        assert!(description.contains("created_at"));
345    }
346
347    #[tokio::test]
348    async fn create_table_drop_table_should_work() {
349        let (_tdb, conn_str) = setup_test_db().await;
350        let conns = Conns::new();
351        let id = conns.register(conn_str).await.unwrap();
352
353        // Test create table
354        let create_table = "CREATE TABLE test_table2 (id SERIAL PRIMARY KEY, name TEXT)";
355        assert_eq!(
356            conns.create_table(&id, create_table).await.unwrap(),
357            "success"
358        );
359
360        // Test drop table
361        assert_eq!(
362            conns.drop_table(&id, "test_table2").await.unwrap(),
363            "success"
364        );
365
366        // Test drop table again
367        assert!(conns.drop_table(&id, "test_table2").await.is_err());
368    }
369
370    #[tokio::test]
371    async fn query_insert_update_delete_should_work() {
372        let (_tdb, conn_str) = setup_test_db().await;
373        let conns = Conns::new();
374        let id = conns.register(conn_str).await.unwrap();
375
376        // Test query
377        let query = "SELECT * FROM test_table ORDER BY id";
378        let result = conns.query(&id, query).await.unwrap();
379        assert!(result.contains("test1"));
380        assert!(result.contains("test2"));
381        assert!(result.contains("test3"));
382
383        // Test insert
384        let insert = "INSERT INTO test_table (name) VALUES ('test4')";
385        let result = conns.insert(&id, insert).await.unwrap();
386        assert!(result.contains("rows_affected: 1"));
387
388        // Test update
389        let update = "UPDATE test_table SET name = 'updated' WHERE name = 'test1'";
390        let result = conns.update(&id, update).await.unwrap();
391        assert!(result.contains("rows_affected: 1"));
392
393        // Test delete
394        let result = conns
395            .delete(&id, "DELETE FROM test_table WHERE name = 'updated'")
396            .await
397            .unwrap();
398        assert!(result.contains("rows_affected: 1"));
399    }
400
401    #[tokio::test]
402    async fn create_index_drop_index_should_work() {
403        let (_tdb, conn_str) = setup_test_db().await;
404        let conns = Conns::new();
405        let id = conns.register(conn_str).await.unwrap();
406
407        // Test create index
408        let create_index = "CREATE INDEX idx_test_table_new ON test_table (name, created_at)";
409        assert_eq!(
410            conns.create_index(&id, create_index).await.unwrap(),
411            "success"
412        );
413
414        // Test drop index
415        assert_eq!(
416            conns.drop_index(&id, "idx_test_table_new").await.unwrap(),
417            "success"
418        );
419    }
420
421    #[tokio::test]
422    async fn sql_validation_should_work() {
423        let (_tdb, conn_str) = setup_test_db().await;
424        let conns = Conns::new();
425        let id = conns.register(conn_str).await.unwrap();
426
427        // Test invalid SELECT
428        let invalid_query = "INSERT INTO test_table VALUES (1)";
429        assert!(conns.query(&id, invalid_query).await.is_err());
430
431        // Test invalid INSERT
432        let invalid_insert = "SELECT * FROM test_table";
433        assert!(conns.insert(&id, invalid_insert).await.is_err());
434
435        // Test invalid UPDATE
436        let invalid_update = "DELETE FROM test_table";
437        assert!(conns.update(&id, invalid_update).await.is_err());
438
439        // Test invalid CREATE TABLE
440        let invalid_create = "CREATE INDEX idx_test ON test_table (id)";
441        assert!(conns.create_table(&id, invalid_create).await.is_err());
442
443        // Test invalid CREATE INDEX
444        let invalid_index = "CREATE TABLE test (id INT)";
445        assert!(conns.create_index(&id, invalid_index).await.is_err());
446    }
447
448    #[tokio::test]
449    async fn create_type_should_work() {
450        let (_tdb, conn_str) = setup_test_db().await;
451        let conns = Conns::new();
452        let id = conns.register(conn_str).await.unwrap();
453
454        // Test create type
455        let create_type = "CREATE TYPE user_role AS ENUM ('admin', 'user')";
456        assert_eq!(
457            conns.create_type(&id, create_type).await.unwrap(),
458            "success"
459        );
460
461        // Test invalid type creation
462        let invalid_type = "CREATE TABLE test (id INT)";
463        assert!(conns.create_type(&id, invalid_type).await.is_err());
464    }
465}