1use anyhow::Error;
2use arc_swap::ArcSwap;
3use serde::{Deserialize, Serialize};
4use sqlparser::ast::Statement;
5use sqlx::postgres::PgPool;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9#[allow(dead_code)]
10#[derive(Debug, Clone)]
11pub(crate) struct Conn {
12 pub(crate) id: String,
13 pub(crate) conn_str: String,
14 pub(crate) pool: PgPool,
15}
16
17#[derive(Debug, Clone)]
18pub struct Conns {
19 pub(crate) inner: Arc<ArcSwap<HashMap<String, Conn>>>,
20}
21
22#[derive(Debug, Clone)]
23pub struct PgMcp {
24 pub(crate) conns: Conns,
25}
26
27#[derive(Debug, sqlx::FromRow, Serialize, Deserialize)]
28struct JsonRow {
29 ret: sqlx::types::Json<serde_json::Value>,
30}
31
32impl Conns {
33 pub(crate) fn new() -> Self {
34 Self {
35 inner: Arc::new(ArcSwap::new(Arc::new(HashMap::new()))),
36 }
37 }
38
39 pub(crate) async fn register(&self, conn_str: String) -> Result<String, Error> {
40 let pool = PgPool::connect(&conn_str).await?;
41 let id = uuid::Uuid::new_v4().to_string();
42 let conn = Conn {
43 id: id.clone(),
44 conn_str: conn_str.clone(),
45 pool,
46 };
47
48 let mut conns = self.inner.load().as_ref().clone();
49 conns.insert(id.clone(), conn);
50 self.inner.store(Arc::new(conns));
51
52 Ok(id)
53 }
54
55 pub(crate) fn unregister(&self, id: String) -> Result<(), Error> {
56 let mut conns = self.inner.load().as_ref().clone();
57 if conns.remove(&id).is_none() {
58 return Err(anyhow::anyhow!("Connection not found"));
59 }
60 self.inner.store(Arc::new(conns));
61 Ok(())
62 }
63
64 pub(crate) async fn query(&self, id: &str, query: &str) -> Result<String, Error> {
65 let conns = self.inner.load();
66 let conn = conns
67 .get(id)
68 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
69
70 let query = validate_sql(
71 query,
72 |stmt| matches!(stmt, Statement::Query(_)),
73 "Only SELECT queries are allowed",
74 )?;
75
76 let query = format!(
77 "WITH data AS ({}) SELECT JSON_AGG(data.*) as ret FROM data;",
78 query
79 );
80
81 let ret = sqlx::query_as::<_, JsonRow>(&query)
82 .fetch_one(&conn.pool)
83 .await?;
84
85 Ok(serde_json::to_string(&ret.ret)?)
86 }
87
88 pub(crate) async fn insert(&self, id: &str, query: &str) -> Result<String, Error> {
89 let conns = self.inner.load();
90 let conn = conns
91 .get(id)
92 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
93
94 let query = validate_sql(
95 query,
96 |stmt| matches!(stmt, Statement::Insert { .. }),
97 "Only INSERT statements are allowed",
98 )?;
99
100 let result = sqlx::query(&query).execute(&conn.pool).await?;
101
102 Ok(format!(
103 "success, rows_affected: {}",
104 result.rows_affected()
105 ))
106 }
107
108 pub(crate) async fn update(&self, id: &str, query: &str) -> Result<String, Error> {
109 let conns = self.inner.load();
110 let conn = conns
111 .get(id)
112 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
113
114 let query = validate_sql(
115 query,
116 |stmt| matches!(stmt, Statement::Update { .. }),
117 "Only UPDATE statements are allowed",
118 )?;
119
120 let result = sqlx::query(&query).execute(&conn.pool).await?;
121
122 Ok(format!(
123 "success, rows_affected: {}",
124 result.rows_affected()
125 ))
126 }
127
128 pub(crate) async fn delete(&self, id: &str, query: &str) -> Result<String, Error> {
129 let conns = self.inner.load();
130 let conn = conns
131 .get(id)
132 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
133
134 let query = validate_sql(
135 query,
136 |stmt| matches!(stmt, Statement::Delete { .. }),
137 "Only DELETE statements are allowed",
138 )?;
139
140 let result = sqlx::query(&query).execute(&conn.pool).await?;
141
142 Ok(format!(
143 "success, rows_affected: {}",
144 result.rows_affected()
145 ))
146 }
147
148 pub(crate) async fn create_table(&self, id: &str, query: &str) -> Result<String, Error> {
149 let conns = self.inner.load();
150 let conn = conns
151 .get(id)
152 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
153
154 let query = validate_sql(
155 query,
156 |stmt| matches!(stmt, Statement::CreateTable { .. }),
157 "Only CREATE TABLE statements are allowed",
158 )?;
159
160 sqlx::query(&query).execute(&conn.pool).await?;
161
162 Ok("success".to_string())
163 }
164
165 pub(crate) async fn drop_table(&self, id: &str, table: &str) -> Result<String, Error> {
166 let conns = self.inner.load();
167 let conn = conns
168 .get(id)
169 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
170
171 let query = format!("DROP TABLE {}", table);
172 sqlx::query(&query).execute(&conn.pool).await?;
173
174 Ok("success".to_string())
175 }
176
177 pub(crate) async fn create_index(&self, id: &str, query: &str) -> Result<String, Error> {
178 let conns = self.inner.load();
179 let conn = conns
180 .get(id)
181 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
182
183 let query = validate_sql(
184 query,
185 |stmt| matches!(stmt, Statement::CreateIndex { .. }),
186 "Only CREATE INDEX statements are allowed",
187 )?;
188
189 sqlx::query(&query).execute(&conn.pool).await?;
190
191 Ok("success".to_string())
192 }
193
194 pub(crate) async fn drop_index(&self, id: &str, index: &str) -> Result<String, Error> {
195 let conns = self.inner.load();
196 let conn = conns
197 .get(id)
198 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
199
200 let query = format!("DROP INDEX {}", index);
201 sqlx::query(&query).execute(&conn.pool).await?;
202
203 Ok("success".to_string())
204 }
205
206 pub(crate) async fn describe(&self, id: &str, table: &str) -> Result<String, Error> {
207 let conns = self.inner.load();
208 let conn = conns
209 .get(id)
210 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
211
212 let query = r#"
213 WITH data AS (
214 SELECT column_name, data_type, character_maximum_length, column_default, is_nullable
215 FROM information_schema.columns
216 WHERE table_name = $1
217 ORDER BY ordinal_position)
218 SELECT JSON_AGG(data.*) as ret FROM data"#;
219
220 let ret = sqlx::query_as::<_, JsonRow>(query)
221 .bind(table)
222 .fetch_one(&conn.pool)
223 .await?;
224
225 Ok(serde_json::to_string(&ret.ret)?)
226 }
227
228 pub(crate) async fn list_tables(&self, id: &str, schema: &str) -> Result<String, Error> {
229 let conns = self.inner.load();
230 let conn = conns
231 .get(id)
232 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
233
234 let query = r#"
235 WITH data AS (
236 SELECT
237 t.table_name,
238 obj_description(format('%s.%s', t.table_schema, t.table_name)::regclass::oid) as description,
239 pg_stat_get_tuples_inserted(format('%s.%s', t.table_schema, t.table_name)::regclass::oid) as total_rows
240 FROM information_schema.tables t
241 WHERE
242 t.table_schema = $1
243 AND t.table_type = 'BASE TABLE'
244 ORDER BY t.table_name
245 )
246 SELECT JSON_AGG(data.*) as ret FROM data"#;
247 let ret = sqlx::query_as::<_, JsonRow>(query)
248 .bind(schema)
249 .fetch_one(&conn.pool)
250 .await?;
251
252 Ok(serde_json::to_string(&ret.ret)?)
253 }
254
255 pub(crate) async fn create_type(&self, id: &str, query: &str) -> Result<String, Error> {
256 let conns = self.inner.load();
257 let conn = conns
258 .get(id)
259 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
260
261 let query = validate_sql(
262 query,
263 |stmt| matches!(stmt, Statement::CreateType { .. }),
264 "Only CREATE TYPE statements are allowed",
265 )?;
266
267 sqlx::query(&query).execute(&conn.pool).await?;
268
269 Ok("success".to_string())
270 }
271}
272
273impl Default for Conns {
274 fn default() -> Self {
275 Self::new()
276 }
277}
278
279fn validate_sql<F>(query: &str, validator: F, error_msg: &'static str) -> Result<String, Error>
280where
281 F: Fn(&Statement) -> bool,
282{
283 let dialect = sqlparser::dialect::PostgreSqlDialect {};
284 let ast = sqlparser::parser::Parser::parse_sql(&dialect, query)?;
285 if ast.len() != 1 || !validator(&ast[0]) {
286 return Err(anyhow::anyhow!(error_msg));
287 }
288 Ok(ast[0].to_string())
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use sqlx_db_tester::TestPg;
295
296 const TEST_CONN_STR: &str = "postgres://postgres:postgres@localhost:5432/postgres";
297
298 async fn setup_test_db() -> (TestPg, String) {
299 let tdb = TestPg::new(
300 TEST_CONN_STR.to_string(),
301 std::path::Path::new("./fixtures/migrations"),
302 );
303 let pool = tdb.get_pool().await;
304
305 sqlx::query("SELECT * FROM test_table LIMIT 1")
307 .execute(&pool)
308 .await
309 .unwrap();
310
311 let conn_str = tdb.url();
312
313 (tdb, conn_str)
314 }
315
316 #[tokio::test]
317 async fn register_unregister_should_work() {
318 let (_tdb, conn_str) = setup_test_db().await;
319 let conns = Conns::new();
320
321 let id = conns.register(conn_str.clone()).await.unwrap();
323 assert!(!id.is_empty());
324
325 assert!(conns.unregister(id.clone()).is_ok());
327 assert!(conns.unregister(id).is_err());
328 }
329
330 #[tokio::test]
331 async fn list_tables_describe_should_work() {
332 let (_tdb, conn_str) = setup_test_db().await;
333 let conns = Conns::new();
334 let id = conns.register(conn_str).await.unwrap();
335
336 let tables = conns.list_tables(&id, "public").await.unwrap();
338 assert!(tables.contains("test_table"));
339
340 let description = conns.describe(&id, "test_table").await.unwrap();
342 assert!(description.contains("id"));
343 assert!(description.contains("name"));
344 assert!(description.contains("created_at"));
345 }
346
347 #[tokio::test]
348 async fn create_table_drop_table_should_work() {
349 let (_tdb, conn_str) = setup_test_db().await;
350 let conns = Conns::new();
351 let id = conns.register(conn_str).await.unwrap();
352
353 let create_table = "CREATE TABLE test_table2 (id SERIAL PRIMARY KEY, name TEXT)";
355 assert_eq!(
356 conns.create_table(&id, create_table).await.unwrap(),
357 "success"
358 );
359
360 assert_eq!(
362 conns.drop_table(&id, "test_table2").await.unwrap(),
363 "success"
364 );
365
366 assert!(conns.drop_table(&id, "test_table2").await.is_err());
368 }
369
370 #[tokio::test]
371 async fn query_insert_update_delete_should_work() {
372 let (_tdb, conn_str) = setup_test_db().await;
373 let conns = Conns::new();
374 let id = conns.register(conn_str).await.unwrap();
375
376 let query = "SELECT * FROM test_table ORDER BY id";
378 let result = conns.query(&id, query).await.unwrap();
379 assert!(result.contains("test1"));
380 assert!(result.contains("test2"));
381 assert!(result.contains("test3"));
382
383 let insert = "INSERT INTO test_table (name) VALUES ('test4')";
385 let result = conns.insert(&id, insert).await.unwrap();
386 assert!(result.contains("rows_affected: 1"));
387
388 let update = "UPDATE test_table SET name = 'updated' WHERE name = 'test1'";
390 let result = conns.update(&id, update).await.unwrap();
391 assert!(result.contains("rows_affected: 1"));
392
393 let result = conns
395 .delete(&id, "DELETE FROM test_table WHERE name = 'updated'")
396 .await
397 .unwrap();
398 assert!(result.contains("rows_affected: 1"));
399 }
400
401 #[tokio::test]
402 async fn create_index_drop_index_should_work() {
403 let (_tdb, conn_str) = setup_test_db().await;
404 let conns = Conns::new();
405 let id = conns.register(conn_str).await.unwrap();
406
407 let create_index = "CREATE INDEX idx_test_table_new ON test_table (name, created_at)";
409 assert_eq!(
410 conns.create_index(&id, create_index).await.unwrap(),
411 "success"
412 );
413
414 assert_eq!(
416 conns.drop_index(&id, "idx_test_table_new").await.unwrap(),
417 "success"
418 );
419 }
420
421 #[tokio::test]
422 async fn sql_validation_should_work() {
423 let (_tdb, conn_str) = setup_test_db().await;
424 let conns = Conns::new();
425 let id = conns.register(conn_str).await.unwrap();
426
427 let invalid_query = "INSERT INTO test_table VALUES (1)";
429 assert!(conns.query(&id, invalid_query).await.is_err());
430
431 let invalid_insert = "SELECT * FROM test_table";
433 assert!(conns.insert(&id, invalid_insert).await.is_err());
434
435 let invalid_update = "DELETE FROM test_table";
437 assert!(conns.update(&id, invalid_update).await.is_err());
438
439 let invalid_create = "CREATE INDEX idx_test ON test_table (id)";
441 assert!(conns.create_table(&id, invalid_create).await.is_err());
442
443 let invalid_index = "CREATE TABLE test (id INT)";
445 assert!(conns.create_index(&id, invalid_index).await.is_err());
446 }
447
448 #[tokio::test]
449 async fn create_type_should_work() {
450 let (_tdb, conn_str) = setup_test_db().await;
451 let conns = Conns::new();
452 let id = conns.register(conn_str).await.unwrap();
453
454 let create_type = "CREATE TYPE user_role AS ENUM ('admin', 'user')";
456 assert_eq!(
457 conns.create_type(&id, create_type).await.unwrap(),
458 "success"
459 );
460
461 let invalid_type = "CREATE TABLE test (id INT)";
463 assert!(conns.create_type(&id, invalid_type).await.is_err());
464 }
465}