rmcp_postgres/
lib.rs

1//! PostgreSQL MCP Server
2//!
3//! A Model Context Protocol (MCP) server for PostgreSQL databases, built with rmcp.
4//! Provides tools for querying, inserting, updating, deleting data, and inspecting schemas.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use rmcp_postgres::PostgresServer;
10//!
11//! #[tokio::main]
12//! async fn main() -> anyhow::Result<()> {
13//!     let server = PostgresServer::new("host=localhost user=postgres dbname=mydb");
14//!     // Use with rmcp ServiceExt trait
15//!     Ok(())
16//! }
17//! ```
18
19use anyhow::Result;
20use rmcp::{
21    handler::server::{router::tool::ToolRouter, ServerHandler, wrapper::Parameters},
22    model::*,
23    ErrorData as McpError,
24};
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use tokio_postgres::{NoTls, Row};
28
29// ============================================================================
30// Parameter Types
31// ============================================================================
32
33#[derive(Debug, Serialize, Deserialize, JsonSchema)]
34pub struct QueryParams {
35    #[schemars(description = "SQL SELECT query to execute")]
36    pub query: String,
37}
38
39#[derive(Debug, Serialize, Deserialize, JsonSchema)]
40pub struct SchemaParams {
41    #[schemars(description = "Optional table name to filter schema")]
42    pub table_name: Option<String>,
43}
44
45#[derive(Debug, Serialize, Deserialize, JsonSchema)]
46pub struct InsertParams {
47    #[schemars(description = "Table name to insert into")]
48    pub table_name: String,
49    #[schemars(description = "Data to insert as JSON object")]
50    pub data: serde_json::Value,
51}
52
53#[derive(Debug, Serialize, Deserialize, JsonSchema)]
54pub struct TableNameParams {
55    #[schemars(description = "Name of the table")]
56    pub table_name: String,
57}
58
59#[derive(Debug, Serialize, Deserialize, JsonSchema)]
60pub struct CountRowsParams {
61    #[schemars(description = "Name of the table to count rows from")]
62    pub table_name: String,
63    #[schemars(description = "Optional WHERE conditions as JSON object")]
64    pub where_conditions: Option<serde_json::Value>,
65}
66
67#[derive(Debug, Serialize, Deserialize, JsonSchema)]
68pub struct ColumnExistsParams {
69    #[schemars(description = "Name of the table")]
70    pub table_name: String,
71    #[schemars(description = "Name of the column to check")]
72    pub column_name: String,
73}
74
75#[derive(Debug, Serialize, Deserialize, JsonSchema)]
76pub struct TableSampleParams {
77    #[schemars(description = "Name of the table to sample")]
78    pub table_name: String,
79    #[schemars(description = "Number of rows to return (default: 10, max: 100)")]
80    pub limit: Option<i32>,
81}
82
83#[derive(Debug, Serialize, Deserialize, JsonSchema)]
84pub struct UpdateDataParams {
85    #[schemars(description = "Name of the table to update")]
86    pub table_name: String,
87    #[schemars(description = "Object with column names as keys and new values")]
88    pub values: serde_json::Value,
89    #[schemars(description = "Object with column names as keys and values to match for WHERE clause")]
90    pub where_conditions: serde_json::Value,
91    #[schemars(description = "Maximum number of rows to update (safety limit, default: 1000)")]
92    pub limit: Option<i32>,
93}
94
95#[derive(Debug, Serialize, Deserialize, JsonSchema)]
96pub struct DeleteDataParams {
97    #[schemars(description = "Name of the table to delete from")]
98    pub table_name: String,
99    #[schemars(description = "Object with column names as keys and values to match for WHERE clause")]
100    pub where_conditions: serde_json::Value,
101    #[schemars(description = "Maximum number of rows to delete (safety limit, default: 1000)")]
102    pub limit: Option<i32>,
103}
104
105#[derive(Debug, Serialize, Deserialize, JsonSchema)]
106pub struct ExecuteRawQueryParams {
107    #[schemars(description = "SQL query to execute (use with caution)")]
108    pub query: String,
109    #[schemars(description = "Optional array of parameters for parameterized queries")]
110    pub params: Option<Vec<serde_json::Value>>,
111}
112
113#[derive(Debug, Serialize, Deserialize, JsonSchema)]
114pub struct RelationshipsParams {
115    #[schemars(description = "Optional table name to filter relationships")]
116    pub table_name: Option<String>,
117}
118
119// ============================================================================
120// PostgreSQL MCP Server
121// ============================================================================
122
123/// PostgreSQL MCP Server
124///
125/// Provides MCP tools for interacting with a PostgreSQL database.
126pub struct PostgresServer {
127    db_config: String,
128    pub tool_router: ToolRouter<Self>,
129}
130
131impl PostgresServer {
132    /// Create a new PostgreSQL MCP server
133    ///
134    /// # Arguments
135    ///
136    /// * `db_config` - PostgreSQL connection string (e.g., "host=localhost user=postgres dbname=mydb")
137    ///
138    /// # Example
139    ///
140    /// ```no_run
141    /// use rmcp_postgres::PostgresServer;
142    ///
143    /// let server = PostgresServer::new("host=localhost user=postgres dbname=mydb");
144    /// ```
145    pub fn new(db_config: impl Into<String>) -> Self {
146        Self {
147            db_config: db_config.into(),
148            tool_router: Self::tool_router(),
149        }
150    }
151
152    async fn get_client(&self) -> Result<tokio_postgres::Client> {
153        let (client, connection) = tokio_postgres::connect(&self.db_config, NoTls).await?;
154
155        tokio::spawn(async move {
156            if let Err(e) = connection.await {
157                eprintln!("Connection error: {}", e);
158            }
159        });
160
161        Ok(client)
162    }
163
164    fn row_to_json(&self, row: &Row) -> serde_json::Value {
165        let mut map = serde_json::Map::new();
166
167        for (idx, column) in row.columns().iter().enumerate() {
168            let value: serde_json::Value = match column.type_().name() {
169                "int4" | "int8" => {
170                    row.try_get::<_, i64>(idx)
171                        .map(|v| serde_json::json!(v))
172                        .unwrap_or(serde_json::Value::Null)
173                }
174                "float4" | "float8" => {
175                    row.try_get::<_, f64>(idx)
176                        .map(|v| serde_json::json!(v))
177                        .unwrap_or(serde_json::Value::Null)
178                }
179                "bool" => {
180                    row.try_get::<_, bool>(idx)
181                        .map(|v| serde_json::json!(v))
182                        .unwrap_or(serde_json::Value::Null)
183                }
184                "text" | "varchar" => {
185                    row.try_get::<_, String>(idx)
186                        .map(|v| serde_json::json!(v))
187                        .unwrap_or(serde_json::Value::Null)
188                }
189                _ => {
190                    row.try_get::<_, String>(idx)
191                        .map(|v| serde_json::json!(v))
192                        .unwrap_or(serde_json::Value::Null)
193                }
194            };
195
196            map.insert(column.name().to_string(), value);
197        }
198
199        serde_json::Value::Object(map)
200    }
201}
202
203// ============================================================================
204// MCP Tools
205// ============================================================================
206
207#[rmcp::tool_router]
208impl PostgresServer {
209    /// Execute a SELECT query on the database
210    #[rmcp::tool(description = "Execute a SELECT query and return results as JSON")]
211    pub async fn query_data(
212        &self,
213        Parameters(params): Parameters<QueryParams>,
214    ) -> Result<CallToolResult, McpError> {
215        let client = self
216            .get_client()
217            .await
218            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
219
220        let rows = client
221            .query(&params.query, &[])
222            .await
223            .map_err(|e| McpError::internal_error(format!("Query failed: {}", e), None))?;
224
225        let json_rows: Vec<serde_json::Value> = rows.iter().map(|row| self.row_to_json(row)).collect();
226
227        let result = serde_json::json!({
228            "rows": json_rows,
229            "row_count": json_rows.len()
230        });
231
232        Ok(CallToolResult::success(vec![Content::text(
233            serde_json::to_string_pretty(&result).unwrap(),
234        )]))
235    }
236
237    /// Get schema information for database tables
238    #[rmcp::tool(description = "Get column information for database tables")]
239    pub async fn get_schema(
240        &self,
241        Parameters(params): Parameters<SchemaParams>,
242    ) -> Result<CallToolResult, McpError> {
243        let client = self
244            .get_client()
245            .await
246            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
247
248        let query = if let Some(table) = params.table_name {
249            format!(
250                "SELECT table_name, column_name, data_type, is_nullable
251                 FROM information_schema.columns
252                 WHERE table_name = '{}'
253                 ORDER BY ordinal_position",
254                table
255            )
256        } else {
257            "SELECT table_name, column_name, data_type, is_nullable
258             FROM information_schema.columns
259             WHERE table_schema = 'public'
260             ORDER BY table_name, ordinal_position"
261                .to_string()
262        };
263
264        let rows = client
265            .query(&query, &[])
266            .await
267            .map_err(|e| McpError::internal_error(format!("Schema query failed: {}", e), None))?;
268
269        let schema: Vec<serde_json::Value> = rows
270            .iter()
271            .map(|row| {
272                serde_json::json!({
273                    "table_name": row.get::<_, String>(0),
274                    "column_name": row.get::<_, String>(1),
275                    "data_type": row.get::<_, String>(2),
276                    "is_nullable": row.get::<_, String>(3),
277                })
278            })
279            .collect();
280
281        Ok(CallToolResult::success(vec![Content::text(
282            serde_json::to_string_pretty(&schema).unwrap(),
283        )]))
284    }
285
286    /// Insert data into a table
287    #[rmcp::tool(description = "Insert a row into a database table")]
288    pub async fn insert_data(
289        &self,
290        Parameters(params): Parameters<InsertParams>,
291    ) -> Result<CallToolResult, McpError> {
292        let client = self
293            .get_client()
294            .await
295            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
296
297        let obj = params
298            .data
299            .as_object()
300            .ok_or_else(|| McpError::invalid_params("Data must be a JSON object", None))?;
301
302        let columns: Vec<String> = obj.keys().cloned().collect();
303        let placeholders: Vec<String> = (1..=columns.len()).map(|i| format!("${}", i)).collect();
304
305        let query = format!(
306            "INSERT INTO {} ({}) VALUES ({})",
307            params.table_name,
308            columns.join(", "),
309            placeholders.join(", ")
310        );
311
312        // For now, convert all values to strings (we can improve this later)
313        let values: Vec<String> = obj
314            .values()
315            .map(|v| v.to_string().trim_matches('"').to_string())
316            .collect();
317
318        let value_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
319            values.iter().map(|v| v as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
320
321        client
322            .execute(&query, &value_refs[..])
323            .await
324            .map_err(|e| McpError::internal_error(format!("Insert failed: {}", e), None))?;
325
326        Ok(CallToolResult::success(vec![Content::text(format!(
327            "Successfully inserted into {}",
328            params.table_name
329        ))]))
330    }
331
332    /// List all tables in the database
333    #[rmcp::tool(description = "List all tables in the database")]
334    pub async fn list_tables(&self) -> Result<CallToolResult, McpError> {
335        let client = self
336            .get_client()
337            .await
338            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
339
340        let rows = client
341            .query(
342                "SELECT tablename FROM pg_tables WHERE schemaname = 'public' ORDER BY tablename",
343                &[],
344            )
345            .await
346            .map_err(|e| McpError::internal_error(format!("Failed to list tables: {}", e), None))?;
347
348        let tables: Vec<String> = rows.iter().map(|row| row.get(0)).collect();
349
350        Ok(CallToolResult::success(vec![Content::text(
351            serde_json::to_string_pretty(&tables).unwrap(),
352        )]))
353    }
354
355    /// Get detailed information about a table
356    #[rmcp::tool(description = "Get detailed information about a table including indexes and constraints")]
357    pub async fn describe_table(
358        &self,
359        Parameters(params): Parameters<TableNameParams>,
360    ) -> Result<CallToolResult, McpError> {
361        let client = self
362            .get_client()
363            .await
364            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
365
366        // Get columns
367        let columns = client
368            .query(
369                "SELECT column_name, data_type, is_nullable, column_default
370                 FROM information_schema.columns
371                 WHERE table_schema = 'public' AND table_name = $1
372                 ORDER BY ordinal_position",
373                &[&params.table_name],
374            )
375            .await
376            .map_err(|e| McpError::internal_error(format!("Failed to get columns: {}", e), None))?;
377
378        let column_info: Vec<serde_json::Value> = columns
379            .iter()
380            .map(|row| {
381                serde_json::json!({
382                    "column_name": row.get::<_, String>(0),
383                    "data_type": row.get::<_, String>(1),
384                    "is_nullable": row.get::<_, String>(2),
385                    "column_default": row.get::<_, Option<String>>(3),
386                })
387            })
388            .collect();
389
390        // Get indexes
391        let indexes = client
392            .query(
393                "SELECT indexname, indexdef
394                 FROM pg_indexes
395                 WHERE schemaname = 'public' AND tablename = $1",
396                &[&params.table_name],
397            )
398            .await
399            .map_err(|e| McpError::internal_error(format!("Failed to get indexes: {}", e), None))?;
400
401        let index_info: Vec<serde_json::Value> = indexes
402            .iter()
403            .map(|row| {
404                serde_json::json!({
405                    "index_name": row.get::<_, String>(0),
406                    "definition": row.get::<_, String>(1),
407                })
408            })
409            .collect();
410
411        Ok(CallToolResult::success(vec![Content::text(
412            serde_json::to_string_pretty(&serde_json::json!({
413                "table_name": params.table_name,
414                "columns": column_info,
415                "indexes": index_info
416            }))
417            .unwrap(),
418        )]))
419    }
420
421    /// Count rows in a table
422    #[rmcp::tool(description = "Count rows in a table with optional WHERE conditions")]
423    pub async fn count_rows(
424        &self,
425        Parameters(params): Parameters<CountRowsParams>,
426    ) -> Result<CallToolResult, McpError> {
427        let client = self
428            .get_client()
429            .await
430            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
431
432        let query = if let Some(where_obj) = params.where_conditions {
433            let conditions: Vec<String> = where_obj
434                .as_object()
435                .ok_or_else(|| McpError::invalid_params("WHERE conditions must be a JSON object", None))?
436                .iter()
437                .map(|(k, v)| format!("{} = '{}'", k, v.as_str().unwrap_or("")))
438                .collect();
439
440            format!("SELECT COUNT(*) FROM {} WHERE {}", params.table_name, conditions.join(" AND "))
441        } else {
442            format!("SELECT COUNT(*) FROM {}", params.table_name)
443        };
444
445        let row = client
446            .query_one(&query, &[])
447            .await
448            .map_err(|e| McpError::internal_error(format!("Count query failed: {}", e), None))?;
449
450        let count: i64 = row.get(0);
451
452        Ok(CallToolResult::success(vec![Content::text(
453            serde_json::to_string_pretty(&serde_json::json!({
454                "table_name": params.table_name,
455                "count": count
456            }))
457            .unwrap(),
458        )]))
459    }
460
461    /// Check if a table exists
462    #[rmcp::tool(description = "Check if a table exists in the database")]
463    pub async fn table_exists(
464        &self,
465        Parameters(params): Parameters<TableNameParams>,
466    ) -> Result<CallToolResult, McpError> {
467        let client = self
468            .get_client()
469            .await
470            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
471
472        let row = client
473            .query_one(
474                "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = $1)",
475                &[&params.table_name],
476            )
477            .await
478            .map_err(|e| McpError::internal_error(format!("Table exists query failed: {}", e), None))?;
479
480        let exists: bool = row.get(0);
481
482        Ok(CallToolResult::success(vec![Content::text(
483            serde_json::to_string_pretty(&serde_json::json!({
484                "table_name": params.table_name,
485                "exists": exists
486            }))
487            .unwrap(),
488        )]))
489    }
490
491    /// Check if a column exists in a table
492    #[rmcp::tool(description = "Check if a column exists in a table")]
493    pub async fn column_exists(
494        &self,
495        Parameters(params): Parameters<ColumnExistsParams>,
496    ) -> Result<CallToolResult, McpError> {
497        let client = self
498            .get_client()
499            .await
500            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
501
502        let row = client
503            .query_one(
504                "SELECT EXISTS (
505                    SELECT 1 FROM information_schema.columns
506                    WHERE table_schema = 'public'
507                      AND table_name = $1
508                      AND column_name = $2
509                )",
510                &[&params.table_name, &params.column_name],
511            )
512            .await
513            .map_err(|e| McpError::internal_error(format!("Column exists query failed: {}", e), None))?;
514
515        let exists: bool = row.get(0);
516
517        Ok(CallToolResult::success(vec![Content::text(
518            serde_json::to_string_pretty(&serde_json::json!({
519                "table_name": params.table_name,
520                "column_name": params.column_name,
521                "exists": exists
522            }))
523            .unwrap(),
524        )]))
525    }
526
527    /// Get a sample of rows from a table
528    #[rmcp::tool(description = "Get a sample of rows from a table")]
529    pub async fn get_table_sample(
530        &self,
531        Parameters(params): Parameters<TableSampleParams>,
532    ) -> Result<CallToolResult, McpError> {
533        let limit = params.limit.unwrap_or(10).min(100);
534
535        let client = self
536            .get_client()
537            .await
538            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
539
540        let query = format!("SELECT * FROM {} LIMIT {}", params.table_name, limit);
541
542        let rows = client
543            .query(&query, &[])
544            .await
545            .map_err(|e| McpError::internal_error(format!("Sample query failed: {}", e), None))?;
546
547        let json_rows: Vec<serde_json::Value> = rows.iter().map(|row| self.row_to_json(row)).collect();
548
549        Ok(CallToolResult::success(vec![Content::text(
550            serde_json::to_string_pretty(&serde_json::json!({
551                "table_name": params.table_name,
552                "rows": json_rows,
553                "count": json_rows.len()
554            }))
555            .unwrap(),
556        )]))
557    }
558
559    /// Update rows in a table
560    #[rmcp::tool(description = "Update rows in a table with specified values and conditions")]
561    pub async fn update_data(
562        &self,
563        Parameters(params): Parameters<UpdateDataParams>,
564    ) -> Result<CallToolResult, McpError> {
565        let client = self
566            .get_client()
567            .await
568            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
569
570        let limit = params.limit.unwrap_or(1000);
571
572        let values_obj = params
573            .values
574            .as_object()
575            .ok_or_else(|| McpError::invalid_params("Values must be a JSON object", None))?;
576        let where_obj = params
577            .where_conditions
578            .as_object()
579            .ok_or_else(|| McpError::invalid_params("WHERE conditions must be a JSON object", None))?;
580
581        let set_clauses: Vec<String> = values_obj
582            .iter()
583            .map(|(k, v)| format!("{} = '{}'", k, v.as_str().unwrap_or("")))
584            .collect();
585
586        let where_clauses: Vec<String> = where_obj
587            .iter()
588            .map(|(k, v)| format!("{} = '{}'", k, v.as_str().unwrap_or("")))
589            .collect();
590
591        let query = format!(
592            "UPDATE {} SET {} WHERE {} LIMIT {}",
593            params.table_name,
594            set_clauses.join(", "),
595            where_clauses.join(" AND "),
596            limit
597        );
598
599        let rows_affected = client
600            .execute(&query, &[])
601            .await
602            .map_err(|e| McpError::internal_error(format!("Update failed: {}", e), None))?;
603
604        Ok(CallToolResult::success(vec![Content::text(
605            serde_json::to_string_pretty(&serde_json::json!({
606                "table_name": params.table_name,
607                "rows_affected": rows_affected
608            }))
609            .unwrap(),
610        )]))
611    }
612
613    /// Delete rows from a table
614    #[rmcp::tool(description = "Delete rows from a table based on specified conditions")]
615    pub async fn delete_data(
616        &self,
617        Parameters(params): Parameters<DeleteDataParams>,
618    ) -> Result<CallToolResult, McpError> {
619        let client = self
620            .get_client()
621            .await
622            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
623
624        let limit = params.limit.unwrap_or(1000);
625
626        let where_obj = params
627            .where_conditions
628            .as_object()
629            .ok_or_else(|| McpError::invalid_params("WHERE conditions must be a JSON object", None))?;
630
631        let where_clauses: Vec<String> = where_obj
632            .iter()
633            .map(|(k, v)| format!("{} = '{}'", k, v.as_str().unwrap_or("")))
634            .collect();
635
636        let query = format!(
637            "DELETE FROM {} WHERE {} LIMIT {}",
638            params.table_name,
639            where_clauses.join(" AND "),
640            limit
641        );
642
643        let rows_affected = client
644            .execute(&query, &[])
645            .await
646            .map_err(|e| McpError::internal_error(format!("Delete failed: {}", e), None))?;
647
648        Ok(CallToolResult::success(vec![Content::text(
649            serde_json::to_string_pretty(&serde_json::json!({
650                "table_name": params.table_name,
651                "rows_affected": rows_affected
652            }))
653            .unwrap(),
654        )]))
655    }
656
657    /// Execute a raw SQL query
658    #[rmcp::tool(description = "Execute any SQL query including INSERT, UPDATE, DELETE (use with caution)")]
659    pub async fn execute_raw_query(
660        &self,
661        Parameters(params): Parameters<ExecuteRawQueryParams>,
662    ) -> Result<CallToolResult, McpError> {
663        let client = self
664            .get_client()
665            .await
666            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
667
668        // For SELECT queries, return results
669        if params.query.trim().to_uppercase().starts_with("SELECT") {
670            let rows = client
671                .query(&params.query, &[])
672                .await
673                .map_err(|e| McpError::internal_error(format!("Query failed: {}", e), None))?;
674
675            let json_rows: Vec<serde_json::Value> = rows.iter().map(|row| self.row_to_json(row)).collect();
676
677            Ok(CallToolResult::success(vec![Content::text(
678                serde_json::to_string_pretty(&serde_json::json!({
679                    "rows": json_rows,
680                    "count": json_rows.len()
681                }))
682                .unwrap(),
683            )]))
684        } else {
685            // For other queries, return rows affected
686            let rows_affected = client
687                .execute(&params.query, &[])
688                .await
689                .map_err(|e| McpError::internal_error(format!("Query execution failed: {}", e), None))?;
690
691            Ok(CallToolResult::success(vec![Content::text(
692                serde_json::to_string_pretty(&serde_json::json!({
693                    "rows_affected": rows_affected
694                }))
695                .unwrap(),
696            )]))
697        }
698    }
699
700    /// Get foreign key relationships for tables
701    #[rmcp::tool(description = "Get foreign key relationships for tables")]
702    pub async fn get_relationships(
703        &self,
704        Parameters(params): Parameters<RelationshipsParams>,
705    ) -> Result<CallToolResult, McpError> {
706        let client = self
707            .get_client()
708            .await
709            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
710
711        let query = if let Some(table) = params.table_name {
712            format!(
713                "SELECT
714                    tc.table_name,
715                    kcu.column_name,
716                    ccu.table_name AS foreign_table_name,
717                    ccu.column_name AS foreign_column_name
718                FROM information_schema.table_constraints AS tc
719                JOIN information_schema.key_column_usage AS kcu
720                  ON tc.constraint_name = kcu.constraint_name
721                  AND tc.table_schema = kcu.table_schema
722                JOIN information_schema.constraint_column_usage AS ccu
723                  ON ccu.constraint_name = tc.constraint_name
724                  AND ccu.table_schema = tc.table_schema
725                WHERE tc.constraint_type = 'FOREIGN KEY'
726                  AND tc.table_schema = 'public'
727                  AND tc.table_name = '{}'",
728                table
729            )
730        } else {
731            "SELECT
732                tc.table_name,
733                kcu.column_name,
734                ccu.table_name AS foreign_table_name,
735                ccu.column_name AS foreign_column_name
736            FROM information_schema.table_constraints AS tc
737            JOIN information_schema.key_column_usage AS kcu
738              ON tc.constraint_name = kcu.constraint_name
739              AND tc.table_schema = kcu.table_schema
740            JOIN information_schema.constraint_column_usage AS ccu
741              ON ccu.constraint_name = tc.constraint_name
742              AND ccu.table_schema = tc.table_schema
743            WHERE tc.constraint_type = 'FOREIGN KEY'
744              AND tc.table_schema = 'public'"
745                .to_string()
746        };
747
748        let rows = client
749            .query(&query, &[])
750            .await
751            .map_err(|e| McpError::internal_error(format!("Relationships query failed: {}", e), None))?;
752
753        let relationships: Vec<serde_json::Value> = rows
754            .iter()
755            .map(|row| {
756                serde_json::json!({
757                    "table_name": row.get::<_, String>(0),
758                    "column_name": row.get::<_, String>(1),
759                    "foreign_table_name": row.get::<_, String>(2),
760                    "foreign_column_name": row.get::<_, String>(3),
761                })
762            })
763            .collect();
764
765        Ok(CallToolResult::success(vec![Content::text(
766            serde_json::to_string_pretty(&relationships).unwrap(),
767        )]))
768    }
769
770    /// Get database connection status
771    #[rmcp::tool(description = "Get database connection status and basic info")]
772    pub async fn get_connection_status(&self) -> Result<CallToolResult, McpError> {
773        let client = self
774            .get_client()
775            .await
776            .map_err(|e| McpError::internal_error(format!("DB connection failed: {}", e), None))?;
777
778        let version_row = client
779            .query_one("SELECT version()", &[])
780            .await
781            .map_err(|e| McpError::internal_error(format!("Version query failed: {}", e), None))?;
782
783        let version: String = version_row.get(0);
784
785        // Parse connection string to get database name
786        let db_name = self
787            .db_config
788            .split_whitespace()
789            .find(|s| s.starts_with("dbname="))
790            .and_then(|s| s.strip_prefix("dbname="))
791            .unwrap_or("unknown");
792
793        let user = self
794            .db_config
795            .split_whitespace()
796            .find(|s| s.starts_with("user="))
797            .and_then(|s| s.strip_prefix("user="))
798            .unwrap_or("unknown");
799
800        let host = self
801            .db_config
802            .split_whitespace()
803            .find(|s| s.starts_with("host="))
804            .and_then(|s| s.strip_prefix("host="))
805            .unwrap_or("localhost");
806
807        Ok(CallToolResult::success(vec![Content::text(
808            serde_json::to_string_pretty(&serde_json::json!({
809                "connected": true,
810                "database": db_name,
811                "user": user,
812                "host": host,
813                "version": version
814            }))
815            .unwrap(),
816        )]))
817    }
818}
819
820// ============================================================================
821// Server Handler Implementation
822// ============================================================================
823
824#[rmcp::tool_handler]
825impl ServerHandler for PostgresServer {
826    fn get_info(&self) -> ServerInfo {
827        InitializeResult {
828            protocol_version: ProtocolVersion::default(),
829            capabilities: ServerCapabilities {
830                tools: Some(ToolsCapability { list_changed: None }),
831                ..Default::default()
832            },
833            server_info: Implementation {
834                name: "rmcp-postgres".to_string(),
835                title: Some("PostgreSQL MCP Server".to_string()),
836                version: env!("CARGO_PKG_VERSION").to_string(),
837                icons: None,
838                website_url: None,
839            },
840            instructions: Some("MCP server for PostgreSQL databases with full CRUD and schema inspection capabilities".to_string()),
841        }
842    }
843}