postgres_mcp/
pg.rs

1use anyhow::Error;
2use arc_swap::ArcSwap;
3use serde::{Deserialize, Serialize};
4use sqlparser::ast::Statement;
5use sqlx::postgres::PgPool;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9#[allow(dead_code)]
10#[derive(Debug, Clone)]
11pub(crate) struct Conn {
12    pub(crate) id: String,
13    pub(crate) conn_str: String,
14    pub(crate) pool: PgPool,
15}
16
17#[derive(Debug, Clone)]
18pub struct Conns {
19    pub(crate) inner: Arc<ArcSwap<HashMap<String, Conn>>>,
20}
21
22#[derive(Debug, Clone)]
23pub struct PgMcp {
24    pub(crate) conns: Conns,
25}
26
27#[derive(Debug, sqlx::FromRow, Serialize, Deserialize)]
28struct JsonRow {
29    ret: sqlx::types::Json<serde_json::Value>,
30}
31
32impl Conns {
33    pub(crate) fn new() -> Self {
34        Self {
35            inner: Arc::new(ArcSwap::new(Arc::new(HashMap::new()))),
36        }
37    }
38
39    pub(crate) async fn register(&self, conn_str: String) -> Result<String, Error> {
40        let pool = PgPool::connect(&conn_str).await?;
41        let id = uuid::Uuid::new_v4().to_string();
42        let conn = Conn {
43            id: id.clone(),
44            conn_str: conn_str.clone(),
45            pool,
46        };
47
48        let mut conns = self.inner.load().as_ref().clone();
49        conns.insert(id.clone(), conn);
50        self.inner.store(Arc::new(conns));
51
52        Ok(id)
53    }
54
55    pub(crate) fn unregister(&self, id: String) -> Result<(), Error> {
56        let mut conns = self.inner.load().as_ref().clone();
57        if conns.remove(&id).is_none() {
58            return Err(anyhow::anyhow!("Connection not found"));
59        }
60        self.inner.store(Arc::new(conns));
61        Ok(())
62    }
63
64    pub(crate) async fn query(&self, id: &str, query: &str) -> Result<String, Error> {
65        let conns = self.inner.load();
66        let conn = conns
67            .get(id)
68            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
69
70        let query = validate_sql(
71            query,
72            |stmt| matches!(stmt, Statement::Query(_)),
73            "Only SELECT queries are allowed",
74        )?;
75
76        let query = format!(
77            "WITH data AS ({}) SELECT JSON_AGG(data.*) as ret FROM data;",
78            query
79        );
80
81        let ret = sqlx::query_as::<_, JsonRow>(&query)
82            .fetch_one(&conn.pool)
83            .await?;
84
85        Ok(serde_json::to_string(&ret.ret)?)
86    }
87
88    pub(crate) async fn insert(&self, id: &str, query: &str) -> Result<String, Error> {
89        let conns = self.inner.load();
90        let conn = conns
91            .get(id)
92            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
93
94        let query = validate_sql(
95            query,
96            |stmt| matches!(stmt, Statement::Insert { .. }),
97            "Only INSERT statements are allowed",
98        )?;
99
100        let result = sqlx::query(&query).execute(&conn.pool).await?;
101
102        Ok(format!(
103            "success, rows_affected: {}",
104            result.rows_affected()
105        ))
106    }
107
108    pub(crate) async fn update(&self, id: &str, query: &str) -> Result<String, Error> {
109        let conns = self.inner.load();
110        let conn = conns
111            .get(id)
112            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
113
114        let query = validate_sql(
115            query,
116            |stmt| matches!(stmt, Statement::Update { .. }),
117            "Only UPDATE statements are allowed",
118        )?;
119
120        let result = sqlx::query(&query).execute(&conn.pool).await?;
121
122        Ok(format!(
123            "success, rows_affected: {}",
124            result.rows_affected()
125        ))
126    }
127
128    pub(crate) async fn delete(&self, id: &str, query: &str) -> Result<String, Error> {
129        let conns = self.inner.load();
130        let conn = conns
131            .get(id)
132            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
133
134        let query = validate_sql(
135            query,
136            |stmt| matches!(stmt, Statement::Delete { .. }),
137            "Only DELETE statements are allowed",
138        )?;
139
140        let result = sqlx::query(&query).execute(&conn.pool).await?;
141
142        Ok(format!(
143            "success, rows_affected: {}",
144            result.rows_affected()
145        ))
146    }
147
148    pub(crate) async fn create_table(&self, id: &str, query: &str) -> Result<String, Error> {
149        let conns = self.inner.load();
150        let conn = conns
151            .get(id)
152            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
153
154        let query = validate_sql(
155            query,
156            |stmt| matches!(stmt, Statement::CreateTable { .. }),
157            "Only CREATE TABLE statements are allowed",
158        )?;
159
160        sqlx::query(&query).execute(&conn.pool).await?;
161
162        Ok("success".to_string())
163    }
164
165    pub(crate) async fn drop_table(&self, id: &str, table: &str) -> Result<String, Error> {
166        let conns = self.inner.load();
167        let conn = conns
168            .get(id)
169            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
170
171        let query = format!("DROP TABLE {}", table);
172        sqlx::query(&query).execute(&conn.pool).await?;
173
174        Ok("success".to_string())
175    }
176
177    pub(crate) async fn create_index(&self, id: &str, query: &str) -> Result<String, Error> {
178        let conns = self.inner.load();
179        let conn = conns
180            .get(id)
181            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
182
183        let query = validate_sql(
184            query,
185            |stmt| matches!(stmt, Statement::CreateIndex { .. }),
186            "Only CREATE INDEX statements are allowed",
187        )?;
188
189        sqlx::query(&query).execute(&conn.pool).await?;
190
191        Ok("success".to_string())
192    }
193
194    pub(crate) async fn drop_index(&self, id: &str, index: &str) -> Result<String, Error> {
195        let conns = self.inner.load();
196        let conn = conns
197            .get(id)
198            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
199
200        let query = format!("DROP INDEX {}", index);
201        sqlx::query(&query).execute(&conn.pool).await?;
202
203        Ok("success".to_string())
204    }
205
206    pub(crate) async fn describe(&self, id: &str, table: &str) -> Result<String, Error> {
207        let conns = self.inner.load();
208        let conn = conns
209            .get(id)
210            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
211
212        let query = r#"
213        WITH data AS (
214          SELECT column_name, data_type, character_maximum_length, column_default, is_nullable
215          FROM information_schema.columns
216          WHERE table_name = $1
217          ORDER BY ordinal_position)
218        SELECT JSON_AGG(data.*) as ret FROM data"#;
219
220        let ret = sqlx::query_as::<_, JsonRow>(query)
221            .bind(table)
222            .fetch_one(&conn.pool)
223            .await?;
224
225        Ok(serde_json::to_string(&ret.ret)?)
226    }
227
228    pub(crate) async fn list_tables(&self, id: &str, schema: &str) -> Result<String, Error> {
229        let conns = self.inner.load();
230        let conn = conns
231            .get(id)
232            .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
233
234        let query = r#"
235        WITH data AS (
236          SELECT
237                t.table_name,
238                obj_description(format('%s.%s', t.table_schema, t.table_name)::regclass::oid) as description,
239                pg_stat_get_tuples_inserted(format('%s.%s', t.table_schema, t.table_name)::regclass::oid) as total_rows
240            FROM information_schema.tables t
241            WHERE
242                t.table_schema = $1
243                AND t.table_type = 'BASE TABLE'
244            ORDER BY t.table_name
245        )
246        SELECT JSON_AGG(data.*) as ret FROM data"#;
247        let ret = sqlx::query_as::<_, JsonRow>(query)
248            .bind(schema)
249            .fetch_one(&conn.pool)
250            .await?;
251
252        Ok(serde_json::to_string(&ret.ret)?)
253    }
254}
255
256impl Default for Conns {
257    fn default() -> Self {
258        Self::new()
259    }
260}
261
262fn validate_sql<F>(query: &str, validator: F, error_msg: &'static str) -> Result<String, Error>
263where
264    F: Fn(&Statement) -> bool,
265{
266    let dialect = sqlparser::dialect::PostgreSqlDialect {};
267    let ast = sqlparser::parser::Parser::parse_sql(&dialect, query)?;
268    if ast.len() != 1 || !validator(&ast[0]) {
269        return Err(anyhow::anyhow!(error_msg));
270    }
271    Ok(ast[0].to_string())
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use sqlx_db_tester::TestPg;
278
279    const TEST_CONN_STR: &str = "postgres://postgres:postgres@localhost:5432/postgres";
280
281    async fn setup_test_db() -> (TestPg, String) {
282        let tdb = TestPg::new(
283            TEST_CONN_STR.to_string(),
284            std::path::Path::new("./fixtures/migrations"),
285        );
286        let pool = tdb.get_pool().await;
287
288        // Ensure migrations are applied
289        sqlx::query("SELECT * FROM test_table LIMIT 1")
290            .execute(&pool)
291            .await
292            .unwrap();
293
294        let conn_str = tdb.url();
295
296        (tdb, conn_str)
297    }
298
299    #[tokio::test]
300    async fn register_unregister_should_work() {
301        let (_tdb, conn_str) = setup_test_db().await;
302        let conns = Conns::new();
303
304        // Test register
305        let id = conns.register(conn_str.clone()).await.unwrap();
306        assert!(!id.is_empty());
307
308        // Test unregister
309        assert!(conns.unregister(id.clone()).is_ok());
310        assert!(conns.unregister(id).is_err());
311    }
312
313    #[tokio::test]
314    async fn list_tables_describe_should_work() {
315        let (_tdb, conn_str) = setup_test_db().await;
316        let conns = Conns::new();
317        let id = conns.register(conn_str).await.unwrap();
318
319        // Test list tables
320        let tables = conns.list_tables(&id, "public").await.unwrap();
321        assert!(tables.contains("test_table"));
322
323        // Test describe table
324        let description = conns.describe(&id, "test_table").await.unwrap();
325        assert!(description.contains("id"));
326        assert!(description.contains("name"));
327        assert!(description.contains("created_at"));
328    }
329
330    #[tokio::test]
331    async fn create_table_drop_table_should_work() {
332        let (_tdb, conn_str) = setup_test_db().await;
333        let conns = Conns::new();
334        let id = conns.register(conn_str).await.unwrap();
335
336        // Test create table
337        let create_table = "CREATE TABLE test_table2 (id SERIAL PRIMARY KEY, name TEXT)";
338        assert_eq!(
339            conns.create_table(&id, create_table).await.unwrap(),
340            "success"
341        );
342
343        // Test drop table
344        assert_eq!(
345            conns.drop_table(&id, "test_table2").await.unwrap(),
346            "success"
347        );
348
349        // Test drop table again
350        assert!(conns.drop_table(&id, "test_table2").await.is_err());
351    }
352
353    #[tokio::test]
354    async fn query_insert_update_delete_should_work() {
355        let (_tdb, conn_str) = setup_test_db().await;
356        let conns = Conns::new();
357        let id = conns.register(conn_str).await.unwrap();
358
359        // Test query
360        let query = "SELECT * FROM test_table ORDER BY id";
361        let result = conns.query(&id, query).await.unwrap();
362        assert!(result.contains("test1"));
363        assert!(result.contains("test2"));
364        assert!(result.contains("test3"));
365
366        // Test insert
367        let insert = "INSERT INTO test_table (name) VALUES ('test4')";
368        let result = conns.insert(&id, insert).await.unwrap();
369        assert!(result.contains("rows_affected: 1"));
370
371        // Test update
372        let update = "UPDATE test_table SET name = 'updated' WHERE name = 'test1'";
373        let result = conns.update(&id, update).await.unwrap();
374        assert!(result.contains("rows_affected: 1"));
375
376        // Test delete
377        let result = conns
378            .delete(&id, "DELETE FROM test_table WHERE name = 'updated'")
379            .await
380            .unwrap();
381        assert!(result.contains("rows_affected: 1"));
382    }
383
384    #[tokio::test]
385    async fn create_index_drop_index_should_work() {
386        let (_tdb, conn_str) = setup_test_db().await;
387        let conns = Conns::new();
388        let id = conns.register(conn_str).await.unwrap();
389
390        // Test create index
391        let create_index = "CREATE INDEX idx_test_table_new ON test_table (name, created_at)";
392        assert_eq!(
393            conns.create_index(&id, create_index).await.unwrap(),
394            "success"
395        );
396
397        // Test drop index
398        assert_eq!(
399            conns.drop_index(&id, "idx_test_table_new").await.unwrap(),
400            "success"
401        );
402    }
403
404    #[tokio::test]
405    async fn sql_validation_should_work() {
406        let (_tdb, conn_str) = setup_test_db().await;
407        let conns = Conns::new();
408        let id = conns.register(conn_str).await.unwrap();
409
410        // Test invalid SELECT
411        let invalid_query = "INSERT INTO test_table VALUES (1)";
412        assert!(conns.query(&id, invalid_query).await.is_err());
413
414        // Test invalid INSERT
415        let invalid_insert = "SELECT * FROM test_table";
416        assert!(conns.insert(&id, invalid_insert).await.is_err());
417
418        // Test invalid UPDATE
419        let invalid_update = "DELETE FROM test_table";
420        assert!(conns.update(&id, invalid_update).await.is_err());
421
422        // Test invalid CREATE TABLE
423        let invalid_create = "CREATE INDEX idx_test ON test_table (id)";
424        assert!(conns.create_table(&id, invalid_create).await.is_err());
425
426        // Test invalid CREATE INDEX
427        let invalid_index = "CREATE TABLE test (id INT)";
428        assert!(conns.create_index(&id, invalid_index).await.is_err());
429    }
430}