1use anyhow::Error;
2use arc_swap::ArcSwap;
3use serde::{Deserialize, Serialize};
4use sqlparser::ast::Statement;
5use sqlx::postgres::PgPool;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9#[allow(dead_code)]
10#[derive(Debug, Clone)]
11pub(crate) struct Conn {
12 pub(crate) id: String,
13 pub(crate) conn_str: String,
14 pub(crate) pool: PgPool,
15}
16
17#[derive(Debug, Clone)]
18pub struct Conns {
19 pub(crate) inner: Arc<ArcSwap<HashMap<String, Conn>>>,
20}
21
22#[derive(Debug, Clone)]
23pub struct PgMcp {
24 pub(crate) conns: Conns,
25}
26
27#[derive(Debug, sqlx::FromRow, Serialize, Deserialize)]
28struct JsonRow {
29 ret: sqlx::types::Json<serde_json::Value>,
30}
31
32impl Conns {
33 pub(crate) fn new() -> Self {
34 Self {
35 inner: Arc::new(ArcSwap::new(Arc::new(HashMap::new()))),
36 }
37 }
38
39 pub(crate) async fn register(&self, conn_str: String) -> Result<String, Error> {
40 let pool = PgPool::connect(&conn_str).await?;
41 let id = uuid::Uuid::new_v4().to_string();
42 let conn = Conn {
43 id: id.clone(),
44 conn_str: conn_str.clone(),
45 pool,
46 };
47
48 let mut conns = self.inner.load().as_ref().clone();
49 conns.insert(id.clone(), conn);
50 self.inner.store(Arc::new(conns));
51
52 Ok(id)
53 }
54
55 pub(crate) fn unregister(&self, id: String) -> Result<(), Error> {
56 let mut conns = self.inner.load().as_ref().clone();
57 if conns.remove(&id).is_none() {
58 return Err(anyhow::anyhow!("Connection not found"));
59 }
60 self.inner.store(Arc::new(conns));
61 Ok(())
62 }
63
64 pub(crate) async fn query(&self, id: &str, query: &str) -> Result<String, Error> {
65 let conns = self.inner.load();
66 let conn = conns
67 .get(id)
68 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
69
70 let query = validate_sql(
71 query,
72 |stmt| matches!(stmt, Statement::Query(_)),
73 "Only SELECT queries are allowed",
74 )?;
75
76 let query = format!(
77 "WITH data AS ({}) SELECT JSON_AGG(data.*) as ret FROM data;",
78 query
79 );
80
81 let ret = sqlx::query_as::<_, JsonRow>(&query)
82 .fetch_one(&conn.pool)
83 .await?;
84
85 Ok(serde_json::to_string(&ret.ret)?)
86 }
87
88 pub(crate) async fn insert(&self, id: &str, query: &str) -> Result<String, Error> {
89 let conns = self.inner.load();
90 let conn = conns
91 .get(id)
92 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
93
94 let query = validate_sql(
95 query,
96 |stmt| matches!(stmt, Statement::Insert { .. }),
97 "Only INSERT statements are allowed",
98 )?;
99
100 let result = sqlx::query(&query).execute(&conn.pool).await?;
101
102 Ok(format!(
103 "success, rows_affected: {}",
104 result.rows_affected()
105 ))
106 }
107
108 pub(crate) async fn update(&self, id: &str, query: &str) -> Result<String, Error> {
109 let conns = self.inner.load();
110 let conn = conns
111 .get(id)
112 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
113
114 let query = validate_sql(
115 query,
116 |stmt| matches!(stmt, Statement::Update { .. }),
117 "Only UPDATE statements are allowed",
118 )?;
119
120 let result = sqlx::query(&query).execute(&conn.pool).await?;
121
122 Ok(format!(
123 "success, rows_affected: {}",
124 result.rows_affected()
125 ))
126 }
127
128 pub(crate) async fn delete(&self, id: &str, query: &str) -> Result<String, Error> {
129 let conns = self.inner.load();
130 let conn = conns
131 .get(id)
132 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
133
134 let query = validate_sql(
135 query,
136 |stmt| matches!(stmt, Statement::Delete { .. }),
137 "Only DELETE statements are allowed",
138 )?;
139
140 let result = sqlx::query(&query).execute(&conn.pool).await?;
141
142 Ok(format!(
143 "success, rows_affected: {}",
144 result.rows_affected()
145 ))
146 }
147
148 pub(crate) async fn create_table(&self, id: &str, query: &str) -> Result<String, Error> {
149 let conns = self.inner.load();
150 let conn = conns
151 .get(id)
152 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
153
154 let query = validate_sql(
155 query,
156 |stmt| matches!(stmt, Statement::CreateTable { .. }),
157 "Only CREATE TABLE statements are allowed",
158 )?;
159
160 sqlx::query(&query).execute(&conn.pool).await?;
161
162 Ok("success".to_string())
163 }
164
165 pub(crate) async fn drop_table(&self, id: &str, table: &str) -> Result<String, Error> {
166 let conns = self.inner.load();
167 let conn = conns
168 .get(id)
169 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
170
171 let query = format!("DROP TABLE {}", table);
172 sqlx::query(&query).execute(&conn.pool).await?;
173
174 Ok("success".to_string())
175 }
176
177 pub(crate) async fn create_index(&self, id: &str, query: &str) -> Result<String, Error> {
178 let conns = self.inner.load();
179 let conn = conns
180 .get(id)
181 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
182
183 let query = validate_sql(
184 query,
185 |stmt| matches!(stmt, Statement::CreateIndex { .. }),
186 "Only CREATE INDEX statements are allowed",
187 )?;
188
189 sqlx::query(&query).execute(&conn.pool).await?;
190
191 Ok("success".to_string())
192 }
193
194 pub(crate) async fn drop_index(&self, id: &str, index: &str) -> Result<String, Error> {
195 let conns = self.inner.load();
196 let conn = conns
197 .get(id)
198 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
199
200 let query = format!("DROP INDEX {}", index);
201 sqlx::query(&query).execute(&conn.pool).await?;
202
203 Ok("success".to_string())
204 }
205
206 pub(crate) async fn describe(&self, id: &str, table: &str) -> Result<String, Error> {
207 let conns = self.inner.load();
208 let conn = conns
209 .get(id)
210 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
211
212 let query = r#"
213 WITH data AS (
214 SELECT column_name, data_type, character_maximum_length, column_default, is_nullable
215 FROM information_schema.columns
216 WHERE table_name = $1
217 ORDER BY ordinal_position)
218 SELECT JSON_AGG(data.*) as ret FROM data"#;
219
220 let ret = sqlx::query_as::<_, JsonRow>(query)
221 .bind(table)
222 .fetch_one(&conn.pool)
223 .await?;
224
225 Ok(serde_json::to_string(&ret.ret)?)
226 }
227
228 pub(crate) async fn list_tables(&self, id: &str, schema: &str) -> Result<String, Error> {
229 let conns = self.inner.load();
230 let conn = conns
231 .get(id)
232 .ok_or_else(|| anyhow::anyhow!("Connection not found"))?;
233
234 let query = r#"
235 WITH data AS (
236 SELECT
237 t.table_name,
238 obj_description(format('%s.%s', t.table_schema, t.table_name)::regclass::oid) as description,
239 pg_stat_get_tuples_inserted(format('%s.%s', t.table_schema, t.table_name)::regclass::oid) as total_rows
240 FROM information_schema.tables t
241 WHERE
242 t.table_schema = $1
243 AND t.table_type = 'BASE TABLE'
244 ORDER BY t.table_name
245 )
246 SELECT JSON_AGG(data.*) as ret FROM data"#;
247 let ret = sqlx::query_as::<_, JsonRow>(query)
248 .bind(schema)
249 .fetch_one(&conn.pool)
250 .await?;
251
252 Ok(serde_json::to_string(&ret.ret)?)
253 }
254}
255
256impl Default for Conns {
257 fn default() -> Self {
258 Self::new()
259 }
260}
261
262fn validate_sql<F>(query: &str, validator: F, error_msg: &'static str) -> Result<String, Error>
263where
264 F: Fn(&Statement) -> bool,
265{
266 let dialect = sqlparser::dialect::PostgreSqlDialect {};
267 let ast = sqlparser::parser::Parser::parse_sql(&dialect, query)?;
268 if ast.len() != 1 || !validator(&ast[0]) {
269 return Err(anyhow::anyhow!(error_msg));
270 }
271 Ok(ast[0].to_string())
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use sqlx_db_tester::TestPg;
278
279 const TEST_CONN_STR: &str = "postgres://postgres:postgres@localhost:5432/postgres";
280
281 async fn setup_test_db() -> (TestPg, String) {
282 let tdb = TestPg::new(
283 TEST_CONN_STR.to_string(),
284 std::path::Path::new("./fixtures/migrations"),
285 );
286 let pool = tdb.get_pool().await;
287
288 sqlx::query("SELECT * FROM test_table LIMIT 1")
290 .execute(&pool)
291 .await
292 .unwrap();
293
294 let conn_str = tdb.url();
295
296 (tdb, conn_str)
297 }
298
299 #[tokio::test]
300 async fn register_unregister_should_work() {
301 let (_tdb, conn_str) = setup_test_db().await;
302 let conns = Conns::new();
303
304 let id = conns.register(conn_str.clone()).await.unwrap();
306 assert!(!id.is_empty());
307
308 assert!(conns.unregister(id.clone()).is_ok());
310 assert!(conns.unregister(id).is_err());
311 }
312
313 #[tokio::test]
314 async fn list_tables_describe_should_work() {
315 let (_tdb, conn_str) = setup_test_db().await;
316 let conns = Conns::new();
317 let id = conns.register(conn_str).await.unwrap();
318
319 let tables = conns.list_tables(&id, "public").await.unwrap();
321 assert!(tables.contains("test_table"));
322
323 let description = conns.describe(&id, "test_table").await.unwrap();
325 assert!(description.contains("id"));
326 assert!(description.contains("name"));
327 assert!(description.contains("created_at"));
328 }
329
330 #[tokio::test]
331 async fn create_table_drop_table_should_work() {
332 let (_tdb, conn_str) = setup_test_db().await;
333 let conns = Conns::new();
334 let id = conns.register(conn_str).await.unwrap();
335
336 let create_table = "CREATE TABLE test_table2 (id SERIAL PRIMARY KEY, name TEXT)";
338 assert_eq!(
339 conns.create_table(&id, create_table).await.unwrap(),
340 "success"
341 );
342
343 assert_eq!(
345 conns.drop_table(&id, "test_table2").await.unwrap(),
346 "success"
347 );
348
349 assert!(conns.drop_table(&id, "test_table2").await.is_err());
351 }
352
353 #[tokio::test]
354 async fn query_insert_update_delete_should_work() {
355 let (_tdb, conn_str) = setup_test_db().await;
356 let conns = Conns::new();
357 let id = conns.register(conn_str).await.unwrap();
358
359 let query = "SELECT * FROM test_table ORDER BY id";
361 let result = conns.query(&id, query).await.unwrap();
362 assert!(result.contains("test1"));
363 assert!(result.contains("test2"));
364 assert!(result.contains("test3"));
365
366 let insert = "INSERT INTO test_table (name) VALUES ('test4')";
368 let result = conns.insert(&id, insert).await.unwrap();
369 assert!(result.contains("rows_affected: 1"));
370
371 let update = "UPDATE test_table SET name = 'updated' WHERE name = 'test1'";
373 let result = conns.update(&id, update).await.unwrap();
374 assert!(result.contains("rows_affected: 1"));
375
376 let result = conns
378 .delete(&id, "DELETE FROM test_table WHERE name = 'updated'")
379 .await
380 .unwrap();
381 assert!(result.contains("rows_affected: 1"));
382 }
383
384 #[tokio::test]
385 async fn create_index_drop_index_should_work() {
386 let (_tdb, conn_str) = setup_test_db().await;
387 let conns = Conns::new();
388 let id = conns.register(conn_str).await.unwrap();
389
390 let create_index = "CREATE INDEX idx_test_table_new ON test_table (name, created_at)";
392 assert_eq!(
393 conns.create_index(&id, create_index).await.unwrap(),
394 "success"
395 );
396
397 assert_eq!(
399 conns.drop_index(&id, "idx_test_table_new").await.unwrap(),
400 "success"
401 );
402 }
403
404 #[tokio::test]
405 async fn sql_validation_should_work() {
406 let (_tdb, conn_str) = setup_test_db().await;
407 let conns = Conns::new();
408 let id = conns.register(conn_str).await.unwrap();
409
410 let invalid_query = "INSERT INTO test_table VALUES (1)";
412 assert!(conns.query(&id, invalid_query).await.is_err());
413
414 let invalid_insert = "SELECT * FROM test_table";
416 assert!(conns.insert(&id, invalid_insert).await.is_err());
417
418 let invalid_update = "DELETE FROM test_table";
420 assert!(conns.update(&id, invalid_update).await.is_err());
421
422 let invalid_create = "CREATE INDEX idx_test ON test_table (id)";
424 assert!(conns.create_table(&id, invalid_create).await.is_err());
425
426 let invalid_index = "CREATE TABLE test (id INT)";
428 assert!(conns.create_index(&id, invalid_index).await.is_err());
429 }
430}