1use arc_swap::ArcSwap;
2use serde::{Deserialize, Serialize};
3use sqlparser::ast::Statement;
4use sqlx::postgres::PgPool;
5use std::collections::HashMap;
6use std::sync::Arc;
7use thiserror::Error;
8
9#[allow(unused)]
10#[derive(Error, Debug)]
11pub enum PgMcpError {
12 #[error("Connection not found for ID: {0}")]
13 ConnectionNotFound(String),
14
15 #[error("SQL validation failed for query '{query}': {kind}")]
16 ValidationFailed {
17 kind: ValidationErrorKind,
18 query: String,
19 details: String,
20 },
21
22 #[error("Database operation '{operation}' failed: {underlying}")]
23 DatabaseError {
24 operation: String,
25 underlying: String,
26 },
27
28 #[error("Serialization failed: {0}")]
29 SerializationError(#[from] serde_json::Error),
30
31 #[error("Database connection failed: {0}")]
32 ConnectionError(String),
33
34 #[error("Internal error: {0}")]
35 InternalError(String),
36}
37
38#[derive(Error, Debug)]
39pub enum ValidationErrorKind {
40 #[error("Invalid statement type, expected {expected}")]
41 InvalidStatementType { expected: String },
42 #[error("Failed to parse SQL")]
43 ParseError,
44}
45
46impl From<sqlx::Error> for PgMcpError {
47 fn from(e: sqlx::Error) -> Self {
48 let msg = e.to_string();
49 if let Some(db_err) = e.as_database_error() {
50 PgMcpError::DatabaseError {
51 operation: "unknown".to_string(),
52 underlying: db_err.to_string(),
53 }
54 } else if msg.contains("error connecting") || msg.contains("timed out") {
55 PgMcpError::ConnectionError(msg)
56 } else {
57 PgMcpError::DatabaseError {
58 operation: "unknown".to_string(),
59 underlying: msg,
60 }
61 }
62 }
63}
64
65#[allow(dead_code)]
66#[derive(Debug, Clone)]
67pub(crate) struct Conn {
68 pub(crate) id: String,
69 pub(crate) conn_str: String,
70 pub(crate) pool: PgPool,
71}
72
73#[derive(Debug, Clone)]
74pub struct Conns {
75 pub(crate) inner: Arc<ArcSwap<HashMap<String, Conn>>>,
76}
77
78#[derive(Debug, Clone)]
79pub struct PgMcp {
80 pub(crate) conns: Conns,
81}
82
83#[derive(Debug, sqlx::FromRow, Serialize, Deserialize)]
84struct JsonRow {
85 ret: sqlx::types::Json<serde_json::Value>,
86}
87
88impl Conns {
89 pub(crate) fn new() -> Self {
90 Self {
91 inner: Arc::new(ArcSwap::new(Arc::new(HashMap::new()))),
92 }
93 }
94
95 pub(crate) async fn register(&self, conn_str: String) -> Result<String, PgMcpError> {
96 let pool = PgPool::connect(&conn_str)
97 .await
98 .map_err(|e| PgMcpError::ConnectionError(e.to_string()))?;
99 let id = uuid::Uuid::new_v4().to_string();
100 let conn = Conn {
101 id: id.clone(),
102 conn_str: conn_str.clone(),
103 pool,
104 };
105
106 let mut conns = self.inner.load().as_ref().clone();
107 conns.insert(id.clone(), conn);
108 self.inner.store(Arc::new(conns));
109
110 Ok(id)
111 }
112
113 pub(crate) fn unregister(&self, id: String) -> Result<(), PgMcpError> {
114 let mut conns = self.inner.load().as_ref().clone();
115 if conns.remove(&id).is_none() {
116 return Err(PgMcpError::ConnectionNotFound(id));
117 }
118 self.inner.store(Arc::new(conns));
119 Ok(())
120 }
121
122 pub(crate) async fn query(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
123 let operation = "query (SELECT)";
124 let conns = self.inner.load();
125 let conn = conns
126 .get(id)
127 .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
128
129 let validated_query =
130 validate_sql(query, |stmt| matches!(stmt, Statement::Query(_)), "SELECT")?;
131
132 let prepared_query = format!(
133 "WITH data AS ({}) SELECT JSON_AGG(data.*) as ret FROM data;",
134 validated_query
135 );
136
137 let ret = sqlx::query_as::<_, JsonRow>(&prepared_query)
138 .fetch_one(&conn.pool)
139 .await
140 .map_err(|e| PgMcpError::DatabaseError {
141 operation: operation.to_string(),
142 underlying: e.to_string(),
143 })?;
144
145 Ok(serde_json::to_string(&ret.ret)?)
146 }
147
148 pub(crate) async fn insert(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
149 let operation = "insert (INSERT)";
150 let conns = self.inner.load();
151 let conn = conns
152 .get(id)
153 .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
154
155 let validated_query = validate_sql(
156 query,
157 |stmt| matches!(stmt, Statement::Insert { .. }),
158 "INSERT",
159 )?;
160
161 let result = sqlx::query(&validated_query)
162 .execute(&conn.pool)
163 .await
164 .map_err(|e| PgMcpError::DatabaseError {
165 operation: operation.to_string(),
166 underlying: e.to_string(),
167 })?;
168
169 Ok(format!(
170 "success, rows_affected: {}",
171 result.rows_affected()
172 ))
173 }
174
175 pub(crate) async fn update(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
176 let operation = "update (UPDATE)";
177 let conns = self.inner.load();
178 let conn = conns
179 .get(id)
180 .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
181
182 let validated_query = validate_sql(
183 query,
184 |stmt| matches!(stmt, Statement::Update { .. }),
185 "UPDATE",
186 )?;
187
188 let result = sqlx::query(&validated_query)
189 .execute(&conn.pool)
190 .await
191 .map_err(|e| PgMcpError::DatabaseError {
192 operation: operation.to_string(),
193 underlying: e.to_string(),
194 })?;
195
196 Ok(format!(
197 "success, rows_affected: {}",
198 result.rows_affected()
199 ))
200 }
201
202 pub(crate) async fn delete(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
203 let operation = "delete (DELETE)";
204 let conns = self.inner.load();
205 let conn = conns
206 .get(id)
207 .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
208
209 let validated_query = validate_sql(
210 query,
211 |stmt| matches!(stmt, Statement::Delete { .. }),
212 "DELETE",
213 )?;
214
215 let result = sqlx::query(&validated_query)
216 .execute(&conn.pool)
217 .await
218 .map_err(|e| PgMcpError::DatabaseError {
219 operation: operation.to_string(),
220 underlying: e.to_string(),
221 })?;
222
223 Ok(format!(
224 "success, rows_affected: {}",
225 result.rows_affected()
226 ))
227 }
228
229 pub(crate) async fn create_table(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
230 let operation = "create_table (CREATE TABLE)";
231 let conns = self.inner.load();
232 let conn = conns
233 .get(id)
234 .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
235
236 let validated_query = validate_sql(
237 query,
238 |stmt| matches!(stmt, Statement::CreateTable { .. }),
239 "CREATE TABLE",
240 )?;
241
242 sqlx::query(&validated_query)
243 .execute(&conn.pool)
244 .await
245 .map_err(|e| PgMcpError::DatabaseError {
246 operation: operation.to_string(),
247 underlying: e.to_string(),
248 })?;
249
250 Ok("success".to_string())
251 }
252
253 pub(crate) async fn drop_table(&self, id: &str, table: &str) -> Result<String, PgMcpError> {
254 let operation = format!("drop_table (DROP TABLE {})", table);
255 let conns = self.inner.load();
256 let conn = conns
257 .get(id)
258 .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
259
260 let query = format!("DROP TABLE {}", table);
261 sqlx::query(&query)
262 .execute(&conn.pool)
263 .await
264 .map_err(|e| PgMcpError::DatabaseError {
265 operation,
266 underlying: e.to_string(),
267 })?;
268
269 Ok("success".to_string())
270 }
271
272 pub(crate) async fn create_index(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
273 let operation = "create_index (CREATE INDEX)";
274 let conns = self.inner.load();
275 let conn = conns
276 .get(id)
277 .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
278
279 let validated_query = validate_sql(
280 query,
281 |stmt| matches!(stmt, Statement::CreateIndex { .. }),
282 "CREATE INDEX",
283 )?;
284
285 sqlx::query(&validated_query)
286 .execute(&conn.pool)
287 .await
288 .map_err(|e| PgMcpError::DatabaseError {
289 operation: operation.to_string(),
290 underlying: e.to_string(),
291 })?;
292
293 Ok("success".to_string())
294 }
295
296 pub(crate) async fn drop_index(&self, id: &str, index: &str) -> Result<String, PgMcpError> {
297 let operation = format!("drop_index (DROP INDEX {})", index);
298 let conns = self.inner.load();
299 let conn = conns
300 .get(id)
301 .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
302
303 let query = format!("DROP INDEX {}", index);
304 sqlx::query(&query)
305 .execute(&conn.pool)
306 .await
307 .map_err(|e| PgMcpError::DatabaseError {
308 operation,
309 underlying: e.to_string(),
310 })?;
311
312 Ok("success".to_string())
313 }
314
315 pub(crate) async fn describe(&self, id: &str, table: &str) -> Result<String, PgMcpError> {
316 let operation = format!("describe (table: {})", table);
317 let conns = self.inner.load();
318 let conn = conns
319 .get(id)
320 .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
321
322 let query = r#"
323 WITH data AS (
324 SELECT column_name, data_type, character_maximum_length, column_default, is_nullable
325 FROM information_schema.columns
326 WHERE table_name = $1
327 ORDER BY ordinal_position)
328 SELECT JSON_AGG(data.*) as ret FROM data"#;
329
330 let ret = sqlx::query_as::<_, JsonRow>(query)
331 .bind(table)
332 .fetch_one(&conn.pool)
333 .await
334 .map_err(|e| PgMcpError::DatabaseError {
335 operation: operation.to_string(),
336 underlying: e.to_string(),
337 })?;
338
339 Ok(serde_json::to_string(&ret.ret)?)
340 }
341
342 pub(crate) async fn list_tables(&self, id: &str, schema: &str) -> Result<String, PgMcpError> {
343 let operation = format!("list_tables (schema: {})", schema);
344 let conns = self.inner.load();
345 let conn = conns
346 .get(id)
347 .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
348
349 let query = r#"
350 WITH data AS (
351 SELECT
352 t.table_name,
353 obj_description(format('%s.%s', t.table_schema, t.table_name)::regclass::oid) as description,
354 pg_stat_get_tuples_inserted(format('%s.%s', t.table_schema, t.table_name)::regclass::oid) as total_rows
355 FROM information_schema.tables t
356 WHERE
357 t.table_schema = $1
358 AND t.table_type = 'BASE TABLE'
359 ORDER BY t.table_name
360 )
361 SELECT JSON_AGG(data.*) as ret FROM data"#;
362 let ret = sqlx::query_as::<_, JsonRow>(query)
363 .bind(schema)
364 .fetch_one(&conn.pool)
365 .await
366 .or_else(|e| {
367 if let sqlx::Error::RowNotFound = e {
368 Ok(JsonRow {
369 ret: sqlx::types::Json(serde_json::json!([])),
370 })
371 } else {
372 Err(PgMcpError::DatabaseError {
373 operation: operation.to_string(),
374 underlying: e.to_string(),
375 })
376 }
377 })?;
378
379 Ok(serde_json::to_string(&ret.ret)?)
380 }
381
382 pub(crate) async fn create_schema(
383 &self,
384 id: &str,
385 schema_name: &str,
386 ) -> Result<String, PgMcpError> {
387 let operation = format!("create_schema (CREATE SCHEMA {})", schema_name);
388 let conns = self.inner.load();
389 let conn = conns
390 .get(id)
391 .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
392
393 let query = format!("CREATE SCHEMA {}", schema_name);
394 sqlx::query(&query)
395 .execute(&conn.pool)
396 .await
397 .map_err(|e| PgMcpError::DatabaseError {
398 operation,
399 underlying: e.to_string(),
400 })?;
401
402 Ok("success".to_string())
403 }
404
405 pub(crate) async fn create_type(&self, id: &str, query: &str) -> Result<String, PgMcpError> {
406 let operation = "create_type (CREATE TYPE)";
407 let conns = self.inner.load();
408 let conn = conns
409 .get(id)
410 .ok_or_else(|| PgMcpError::ConnectionNotFound(id.to_string()))?;
411
412 let validated_query = validate_sql(
413 query,
414 |stmt| matches!(stmt, Statement::CreateType { .. }),
415 "CREATE TYPE",
416 )?;
417
418 sqlx::query(&validated_query)
419 .execute(&conn.pool)
420 .await
421 .map_err(|e| PgMcpError::DatabaseError {
422 operation: operation.to_string(),
423 underlying: e.to_string(),
424 })?;
425
426 Ok("success".to_string())
427 }
428}
429
430impl Default for Conns {
431 fn default() -> Self {
432 Self::new()
433 }
434}
435
436fn validate_sql<F>(
437 query: &str,
438 validator: F,
439 expected_type: &'static str,
440) -> Result<String, PgMcpError>
441where
442 F: Fn(&Statement) -> bool,
443{
444 let dialect = sqlparser::dialect::PostgreSqlDialect {};
445 let statements = sqlparser::parser::Parser::parse_sql(&dialect, query).map_err(|e| {
446 PgMcpError::ValidationFailed {
447 kind: ValidationErrorKind::ParseError,
448 query: query.to_string(),
449 details: e.to_string(),
450 }
451 })?;
452
453 if statements.len() != 1 {
454 return Err(PgMcpError::ValidationFailed {
455 kind: ValidationErrorKind::InvalidStatementType {
456 expected: expected_type.to_string(),
457 },
458 query: query.to_string(),
459 details: format!(
460 "Expected exactly one SQL statement, found {}",
461 statements.len()
462 ),
463 });
464 }
465
466 let stmt = &statements[0];
467 if !validator(stmt) {
468 return Err(PgMcpError::ValidationFailed {
469 kind: ValidationErrorKind::InvalidStatementType {
470 expected: expected_type.to_string(),
471 },
472 query: query.to_string(),
473 details: format!("Statement type validation failed. Received: {:?}", stmt),
474 });
475 }
476
477 Ok(query.to_string())
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use sqlx_db_tester::TestPg;
484
485 const TEST_CONN_STR: &str = "postgres://postgres:postgres@localhost:5432/postgres";
486
487 async fn setup_test_db() -> (TestPg, String) {
488 let tdb = TestPg::new(
489 TEST_CONN_STR.to_string(),
490 std::path::Path::new("./fixtures/migrations"),
491 );
492 let pool = tdb.get_pool().await;
493
494 sqlx::query("SELECT * FROM test_table LIMIT 1")
495 .execute(&pool)
496 .await
497 .unwrap();
498
499 let conn_str = tdb.url();
500
501 (tdb, conn_str)
502 }
503
504 #[tokio::test]
505 async fn register_unregister_should_work() {
506 let (_tdb, conn_str) = setup_test_db().await;
507 let conns = Conns::new();
508
509 let id = conns.register(conn_str.clone()).await.unwrap();
510 assert!(!id.is_empty());
511
512 assert!(conns.unregister(id.clone()).is_ok());
513 assert!(conns.unregister(id).is_err());
514 }
515
516 #[tokio::test]
517 async fn list_tables_describe_should_work() {
518 let (_tdb, conn_str) = setup_test_db().await;
519 let conns = Conns::new();
520 let id = conns.register(conn_str).await.unwrap();
521
522 let tables = conns.list_tables(&id, "public").await.unwrap();
523 assert!(tables.contains("test_table"));
524
525 let description = conns.describe(&id, "test_table").await.unwrap();
526 assert!(description.contains("id"));
527 assert!(description.contains("name"));
528 assert!(description.contains("created_at"));
529 }
530
531 #[tokio::test]
532 async fn create_table_drop_table_should_work() {
533 let (_tdb, conn_str) = setup_test_db().await;
534 let conns = Conns::new();
535 let id = conns.register(conn_str).await.unwrap();
536
537 let create_table = "CREATE TABLE test_table2 (id SERIAL PRIMARY KEY, name TEXT)";
538 assert_eq!(
539 conns.create_table(&id, create_table).await.unwrap(),
540 "success"
541 );
542
543 assert_eq!(
544 conns.drop_table(&id, "test_table2").await.unwrap(),
545 "success"
546 );
547
548 assert!(conns.drop_table(&id, "test_table2").await.is_err());
549 }
550
551 #[tokio::test]
552 async fn query_insert_update_delete_should_work() {
553 let (_tdb, conn_str) = setup_test_db().await;
554 let conns = Conns::new();
555 let id = conns.register(conn_str).await.unwrap();
556
557 let query = "SELECT * FROM test_table ORDER BY id";
558 let result = conns.query(&id, query).await.unwrap();
559 assert!(result.contains("test1"));
560 assert!(result.contains("test2"));
561 assert!(result.contains("test3"));
562
563 let insert = "INSERT INTO test_table (name) VALUES ('test4')";
564 let result = conns.insert(&id, insert).await.unwrap();
565 assert!(result.contains("rows_affected: 1"));
566
567 let update = "UPDATE test_table SET name = 'updated' WHERE name = 'test1'";
568 let result = conns.update(&id, update).await.unwrap();
569 assert!(result.contains("rows_affected: 1"));
570
571 let result = conns
572 .delete(&id, "DELETE FROM test_table WHERE name = 'updated'")
573 .await
574 .unwrap();
575 assert!(result.contains("rows_affected: 1"));
576 }
577
578 #[tokio::test]
579 async fn create_index_drop_index_should_work() {
580 let (_tdb, conn_str) = setup_test_db().await;
581 let conns = Conns::new();
582 let id = conns.register(conn_str).await.unwrap();
583
584 let create_index = "CREATE INDEX idx_test_table_new ON test_table (name, created_at)";
585 assert_eq!(
586 conns.create_index(&id, create_index).await.unwrap(),
587 "success"
588 );
589
590 assert_eq!(
591 conns.drop_index(&id, "idx_test_table_new").await.unwrap(),
592 "success"
593 );
594 }
595
596 #[tokio::test]
597 async fn sql_validation_should_work() {
598 let (_tdb, conn_str) = setup_test_db().await;
599 let conns = Conns::new();
600 let id = conns.register(conn_str).await.unwrap();
601
602 let invalid_query = "INSERT INTO test_table VALUES (1)";
603 assert!(conns.query(&id, invalid_query).await.is_err());
604
605 let invalid_insert = "SELECT * FROM test_table";
606 assert!(conns.insert(&id, invalid_insert).await.is_err());
607
608 let invalid_update = "DELETE FROM test_table";
609 assert!(conns.update(&id, invalid_update).await.is_err());
610
611 let invalid_create = "CREATE INDEX idx_test ON test_table (id)";
612 assert!(conns.create_table(&id, invalid_create).await.is_err());
613
614 let invalid_index = "CREATE TABLE test (id INT)";
615 assert!(conns.create_index(&id, invalid_index).await.is_err());
616 }
617
618 #[tokio::test]
619 async fn create_type_should_work() {
620 let (_tdb, conn_str) = setup_test_db().await;
621 let conns = Conns::new();
622 let id = conns.register(conn_str).await.unwrap();
623
624 let create_type = "CREATE TYPE user_role AS ENUM ('admin', 'user')";
625 assert_eq!(
626 conns.create_type(&id, create_type).await.unwrap(),
627 "success"
628 );
629
630 let invalid_type = "CREATE TABLE test (id INT)";
631 assert!(conns.create_type(&id, invalid_type).await.is_err());
632 }
633
634 #[tokio::test]
635 async fn create_schema_should_work() {
636 let (_tdb, conn_str) = setup_test_db().await;
637 let conns = Conns::new();
638 let id = conns.register(conn_str).await.unwrap();
639
640 let schema_name = "test_schema_unit";
641 assert_eq!(
642 conns.create_schema(&id, schema_name).await.unwrap(),
643 "success"
644 );
645
646 let query = format!(
647 "SELECT schema_name FROM information_schema.schemata WHERE schema_name = '{}'",
648 schema_name
649 );
650 let _result = sqlx::query(&query)
651 .fetch_one(&conns.inner.load().get(&id).unwrap().pool)
652 .await
653 .unwrap();
654
655 let invalid_schema_name = "test;schema";
656 assert!(conns.create_schema(&id, invalid_schema_name).await.is_err());
657 }
658}