postgres_mcp/
mcp.rs

1use crate::pg::PgMcpError;
2use crate::{Conns, PgMcp};
3use anyhow::Result;
4use rmcp::{
5    Error as McpError, ServerHandler,
6    model::{CallToolResult, Content, ServerCapabilities, ServerInfo},
7    schemars, tool,
8};
9
10#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
11pub struct RegisterRequest {
12    #[schemars(description = "Postgres connection string")]
13    pub conn_str: String,
14}
15
16#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
17pub struct UnregisterRequest {
18    #[schemars(description = "Connection ID to unregister")]
19    pub conn_id: String,
20}
21
22#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
23pub struct QueryRequest {
24    #[schemars(description = "Connection ID")]
25    pub conn_id: String,
26    #[schemars(
27        description = "Single SQL query, could return multiple rows. Caller should properly limit the number of rows returned."
28    )]
29    pub query: String,
30}
31
32#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
33pub struct InsertRequest {
34    #[schemars(description = "Connection ID")]
35    pub conn_id: String,
36    #[schemars(
37        description = "Single SQL insert statement, but multiple rows for the same table are allowed"
38    )]
39    pub query: String,
40}
41
42#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
43pub struct UpdateRequest {
44    #[schemars(description = "Connection ID")]
45    pub conn_id: String,
46    #[schemars(
47        description = "Single SQL update statement, could update multiple rows for the same table based on the WHERE clause"
48    )]
49    pub query: String,
50}
51
52#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
53pub struct DeleteRequest {
54    #[schemars(description = "Connection ID")]
55    pub conn_id: String,
56    #[schemars(
57        description = "Single SQL delete statement, could delete multiple rows for the same table based on the WHERE clause"
58    )]
59    pub query: String,
60}
61
62#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
63pub struct CreateTableRequest {
64    #[schemars(description = "Connection ID")]
65    pub conn_id: String,
66    #[schemars(description = "Single SQL create table statement")]
67    pub query: String,
68}
69
70#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
71pub struct DropTableRequest {
72    #[schemars(description = "Connection ID")]
73    pub conn_id: String,
74    #[schemars(
75        description = "Table name. Format: schema.table. If schema is not provided, it will use the current schema."
76    )]
77    pub table: String,
78}
79
80#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
81pub struct CreateIndexRequest {
82    #[schemars(description = "Connection ID")]
83    pub conn_id: String,
84    #[schemars(description = "SingleSQL create index statement")]
85    pub query: String,
86}
87
88#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
89pub struct DropIndexRequest {
90    #[schemars(description = "Connection ID")]
91    pub conn_id: String,
92    #[schemars(description = "Index name")]
93    pub index: String,
94}
95
96#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
97pub struct DescribeRequest {
98    #[schemars(description = "Connection ID")]
99    pub conn_id: String,
100    #[schemars(description = "Table name")]
101    pub table: String,
102}
103
104#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
105pub struct ListTablesRequest {
106    #[schemars(description = "Connection ID")]
107    pub conn_id: String,
108    #[schemars(description = "Schema name")]
109    pub schema: String,
110}
111
112#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
113pub struct CreateSchemaRequest {
114    #[schemars(description = "Connection ID")]
115    pub conn_id: String,
116    #[schemars(description = "Schema name")]
117    pub name: String,
118}
119
120#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
121pub struct CreateTypeRequest {
122    #[schemars(description = "Connection ID")]
123    pub conn_id: String,
124    #[schemars(description = "Single SQL create type statement")]
125    pub query: String,
126}
127
128// Helper function to map PgMcpError to McpError
129fn map_pg_error(e: PgMcpError) -> McpError {
130    match e {
131        PgMcpError::ConnectionNotFound(id) => McpError::internal_error(
132            format!("Invalid Argument: Connection not found for ID: {}", id),
133            None,
134        ),
135        PgMcpError::ValidationFailed {
136            kind,
137            query,
138            details,
139        } => McpError::internal_error(
140            format!(
141                "Invalid Argument: SQL validation failed for query '{}': {} - {}",
142                query, kind, details
143            ),
144            None,
145        ),
146        PgMcpError::DatabaseError {
147            operation,
148            underlying,
149        } => McpError::internal_error(
150            format!("Database operation '{}' failed: {}", operation, underlying),
151            None,
152        ),
153        PgMcpError::SerializationError(se) => {
154            McpError::internal_error(format!("Result serialization failed: {}", se), None)
155        }
156        PgMcpError::ConnectionError(ce) => {
157            McpError::internal_error(format!("Database connection failed: {}", ce), None)
158        }
159        PgMcpError::InternalError(ie) => {
160            McpError::internal_error(format!("Internal error: {}", ie), None)
161        }
162    }
163}
164
165#[tool(tool_box)]
166impl PgMcp {
167    pub fn new() -> Self {
168        Self {
169            conns: Conns::new(),
170        }
171    }
172
173    #[tool(description = "Register a new Postgres connection")]
174    async fn register(
175        &self,
176        #[tool(aggr)] req: RegisterRequest,
177    ) -> Result<CallToolResult, McpError> {
178        let id = self
179            .conns
180            .register(req.conn_str)
181            .await
182            .map_err(map_pg_error)?;
183        Ok(CallToolResult::success(vec![Content::text(id)]))
184    }
185
186    #[tool(description = "Unregister a Postgres connection")]
187    async fn unregister(
188        &self,
189        #[tool(aggr)] req: UnregisterRequest,
190    ) -> Result<CallToolResult, McpError> {
191        self.conns.unregister(req.conn_id).map_err(map_pg_error)?;
192        Ok(CallToolResult::success(vec![Content::text(
193            "success".to_string(),
194        )]))
195    }
196
197    #[tool(description = "Execute a SELECT query")]
198    async fn query(&self, #[tool(aggr)] req: QueryRequest) -> Result<CallToolResult, McpError> {
199        let result = self
200            .conns
201            .query(&req.conn_id, &req.query)
202            .await
203            .map_err(map_pg_error)?;
204        Ok(CallToolResult::success(vec![Content::text(result)]))
205    }
206
207    #[tool(description = "Execute an INSERT statement")]
208    async fn insert(&self, #[tool(aggr)] req: InsertRequest) -> Result<CallToolResult, McpError> {
209        let result = self
210            .conns
211            .insert(&req.conn_id, &req.query)
212            .await
213            .map_err(map_pg_error)?;
214        Ok(CallToolResult::success(vec![Content::text(result)]))
215    }
216
217    #[tool(description = "Execute an UPDATE statement")]
218    async fn update(&self, #[tool(aggr)] req: UpdateRequest) -> Result<CallToolResult, McpError> {
219        let result = self
220            .conns
221            .update(&req.conn_id, &req.query)
222            .await
223            .map_err(map_pg_error)?;
224        Ok(CallToolResult::success(vec![Content::text(result)]))
225    }
226
227    #[tool(description = "Delete a row from a table")]
228    async fn delete(&self, #[tool(aggr)] req: DeleteRequest) -> Result<CallToolResult, McpError> {
229        let result = self
230            .conns
231            .delete(&req.conn_id, &req.query)
232            .await
233            .map_err(map_pg_error)?;
234        Ok(CallToolResult::success(vec![Content::text(result)]))
235    }
236
237    #[tool(description = "Create a new table")]
238    async fn create_table(
239        &self,
240        #[tool(aggr)] req: CreateTableRequest,
241    ) -> Result<CallToolResult, McpError> {
242        let result = self
243            .conns
244            .create_table(&req.conn_id, &req.query)
245            .await
246            .map_err(map_pg_error)?;
247        Ok(CallToolResult::success(vec![Content::text(result)]))
248    }
249
250    #[tool(description = "Drop a table")]
251    async fn drop_table(
252        &self,
253        #[tool(aggr)] req: DropTableRequest,
254    ) -> Result<CallToolResult, McpError> {
255        let result = self
256            .conns
257            .drop_table(&req.conn_id, &req.table)
258            .await
259            .map_err(map_pg_error)?;
260        Ok(CallToolResult::success(vec![Content::text(result)]))
261    }
262
263    #[tool(description = "Create an index")]
264    async fn create_index(
265        &self,
266        #[tool(aggr)] req: CreateIndexRequest,
267    ) -> Result<CallToolResult, McpError> {
268        let result = self
269            .conns
270            .create_index(&req.conn_id, &req.query)
271            .await
272            .map_err(map_pg_error)?;
273        Ok(CallToolResult::success(vec![Content::text(result)]))
274    }
275
276    #[tool(description = "Drop an index")]
277    async fn drop_index(
278        &self,
279        #[tool(aggr)] req: DropIndexRequest,
280    ) -> Result<CallToolResult, McpError> {
281        let result = self
282            .conns
283            .drop_index(&req.conn_id, &req.index)
284            .await
285            .map_err(map_pg_error)?;
286        Ok(CallToolResult::success(vec![Content::text(result)]))
287    }
288
289    #[tool(description = "Describe a table")]
290    async fn describe(
291        &self,
292        #[tool(aggr)] req: DescribeRequest,
293    ) -> Result<CallToolResult, McpError> {
294        let result = self
295            .conns
296            .describe(&req.conn_id, &req.table)
297            .await
298            .map_err(map_pg_error)?;
299        Ok(CallToolResult::success(vec![Content::text(result)]))
300    }
301
302    #[tool(description = "List tables in a schema")]
303    async fn list_tables(
304        &self,
305        #[tool(aggr)] req: ListTablesRequest,
306    ) -> Result<CallToolResult, McpError> {
307        let result = self
308            .conns
309            .list_tables(&req.conn_id, &req.schema)
310            .await
311            .map_err(map_pg_error)?;
312        Ok(CallToolResult::success(vec![Content::text(result)]))
313    }
314
315    #[tool(description = "Create a new schema")]
316    async fn create_schema(
317        &self,
318        #[tool(aggr)] req: CreateSchemaRequest,
319    ) -> Result<CallToolResult, McpError> {
320        let result = self
321            .conns
322            .create_schema(&req.conn_id, &req.name)
323            .await
324            .map_err(map_pg_error)?;
325        Ok(CallToolResult::success(vec![Content::text(result)]))
326    }
327
328    #[tool(description = "Create a new type")]
329    async fn create_type(
330        &self,
331        #[tool(aggr)] req: CreateTypeRequest,
332    ) -> Result<CallToolResult, McpError> {
333        let result = self
334            .conns
335            .create_type(&req.conn_id, &req.query)
336            .await
337            .map_err(map_pg_error)?;
338        Ok(CallToolResult::success(vec![Content::text(result)]))
339    }
340}
341
342#[tool(tool_box)]
343impl ServerHandler for PgMcp {
344    fn get_info(&self) -> ServerInfo {
345        ServerInfo {
346            instructions: Some(
347                "A Postgres MCP server that allows AI agents to interact with Postgres databases"
348                    .into(),
349            ),
350            capabilities: ServerCapabilities::builder().enable_tools().build(),
351            ..Default::default()
352        }
353    }
354}
355
356impl Default for PgMcp {
357    fn default() -> Self {
358        Self::new()
359    }
360}