postgres_mcp/
mcp.rs

1use crate::{Conns, PgMcp};
2use anyhow::Result;
3use rmcp::{
4    Error as McpError, ServerHandler,
5    model::{CallToolResult, Content, ServerCapabilities, ServerInfo},
6    schemars, tool,
7};
8
9#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
10pub struct RegisterRequest {
11    #[schemars(description = "Postgres connection string")]
12    pub conn_str: String,
13}
14
15#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
16pub struct UnregisterRequest {
17    #[schemars(description = "Connection ID to unregister")]
18    pub conn_id: String,
19}
20
21#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
22pub struct QueryRequest {
23    #[schemars(description = "Connection ID")]
24    pub conn_id: String,
25    #[schemars(
26        description = "Single SQL query, could return multiple rows. Caller should properly limit the number of rows returned."
27    )]
28    pub query: String,
29}
30
31#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
32pub struct InsertRequest {
33    #[schemars(description = "Connection ID")]
34    pub conn_id: String,
35    #[schemars(
36        description = "Single SQL insert statement, but multiple rows for the same table are allowed"
37    )]
38    pub query: String,
39}
40
41#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
42pub struct UpdateRequest {
43    #[schemars(description = "Connection ID")]
44    pub conn_id: String,
45    #[schemars(
46        description = "Single SQL update statement, could update multiple rows for the same table based on the WHERE clause"
47    )]
48    pub query: String,
49}
50
51#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
52pub struct DeleteRequest {
53    #[schemars(description = "Connection ID")]
54    pub conn_id: String,
55    #[schemars(
56        description = "Single SQL delete statement, could delete multiple rows for the same table based on the WHERE clause"
57    )]
58    pub query: String,
59}
60
61#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
62pub struct CreateTableRequest {
63    #[schemars(description = "Connection ID")]
64    pub conn_id: String,
65    #[schemars(description = "Single SQL create table statement")]
66    pub query: String,
67}
68
69#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
70pub struct DropTableRequest {
71    #[schemars(description = "Connection ID")]
72    pub conn_id: String,
73    #[schemars(
74        description = "Table name. Format: schema.table. If schema is not provided, it will use the current schema."
75    )]
76    pub table: String,
77}
78
79#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
80pub struct CreateIndexRequest {
81    #[schemars(description = "Connection ID")]
82    pub conn_id: String,
83    #[schemars(description = "SingleSQL create index statement")]
84    pub query: String,
85}
86
87#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
88pub struct DropIndexRequest {
89    #[schemars(description = "Connection ID")]
90    pub conn_id: String,
91    #[schemars(description = "Index name")]
92    pub index: String,
93}
94
95#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
96pub struct DescribeRequest {
97    #[schemars(description = "Connection ID")]
98    pub conn_id: String,
99    #[schemars(description = "Table name")]
100    pub table: String,
101}
102
103#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
104pub struct ListTablesRequest {
105    #[schemars(description = "Connection ID")]
106    pub conn_id: String,
107    #[schemars(description = "Schema name")]
108    pub schema: String,
109}
110
111#[tool(tool_box)]
112impl PgMcp {
113    pub fn new() -> Self {
114        Self {
115            conns: Conns::new(),
116        }
117    }
118
119    #[tool(description = "Register a new Postgres connection")]
120    async fn register(
121        &self,
122        #[tool(aggr)] req: RegisterRequest,
123    ) -> Result<CallToolResult, McpError> {
124        let id = self
125            .conns
126            .register(req.conn_str)
127            .await
128            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
129        Ok(CallToolResult::success(vec![Content::text(id)]))
130    }
131
132    #[tool(description = "Unregister a Postgres connection")]
133    async fn unregister(
134        &self,
135        #[tool(aggr)] req: UnregisterRequest,
136    ) -> Result<CallToolResult, McpError> {
137        self.conns
138            .unregister(req.conn_id)
139            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
140        Ok(CallToolResult::success(vec![Content::text(
141            "success".to_string(),
142        )]))
143    }
144
145    #[tool(description = "Execute a SELECT query")]
146    async fn query(&self, #[tool(aggr)] req: QueryRequest) -> Result<CallToolResult, McpError> {
147        let result = self
148            .conns
149            .query(&req.conn_id, &req.query)
150            .await
151            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
152        Ok(CallToolResult::success(vec![Content::text(result)]))
153    }
154
155    #[tool(description = "Execute an INSERT statement")]
156    async fn insert(&self, #[tool(aggr)] req: InsertRequest) -> Result<CallToolResult, McpError> {
157        let result = self
158            .conns
159            .insert(&req.conn_id, &req.query)
160            .await
161            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
162        Ok(CallToolResult::success(vec![Content::text(result)]))
163    }
164
165    #[tool(description = "Execute an UPDATE statement")]
166    async fn update(&self, #[tool(aggr)] req: UpdateRequest) -> Result<CallToolResult, McpError> {
167        let result = self
168            .conns
169            .update(&req.conn_id, &req.query)
170            .await
171            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
172        Ok(CallToolResult::success(vec![Content::text(result)]))
173    }
174
175    #[tool(description = "Delete a row from a table")]
176    async fn delete(&self, #[tool(aggr)] req: DeleteRequest) -> Result<CallToolResult, McpError> {
177        let result = self
178            .conns
179            .delete(&req.conn_id, &req.query)
180            .await
181            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
182        Ok(CallToolResult::success(vec![Content::text(result)]))
183    }
184
185    #[tool(description = "Create a new table")]
186    async fn create_table(
187        &self,
188        #[tool(aggr)] req: CreateTableRequest,
189    ) -> Result<CallToolResult, McpError> {
190        let result = self
191            .conns
192            .create_table(&req.conn_id, &req.query)
193            .await
194            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
195        Ok(CallToolResult::success(vec![Content::text(result)]))
196    }
197
198    #[tool(description = "Drop a table")]
199    async fn drop_table(
200        &self,
201        #[tool(aggr)] req: DropTableRequest,
202    ) -> Result<CallToolResult, McpError> {
203        let result = self
204            .conns
205            .drop_table(&req.conn_id, &req.table)
206            .await
207            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
208        Ok(CallToolResult::success(vec![Content::text(result)]))
209    }
210
211    #[tool(description = "Create an index")]
212    async fn create_index(
213        &self,
214        #[tool(aggr)] req: CreateIndexRequest,
215    ) -> Result<CallToolResult, McpError> {
216        let result = self
217            .conns
218            .create_index(&req.conn_id, &req.query)
219            .await
220            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
221        Ok(CallToolResult::success(vec![Content::text(result)]))
222    }
223
224    #[tool(description = "Drop an index")]
225    async fn drop_index(
226        &self,
227        #[tool(aggr)] req: DropIndexRequest,
228    ) -> Result<CallToolResult, McpError> {
229        let result = self
230            .conns
231            .drop_index(&req.conn_id, &req.index)
232            .await
233            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
234        Ok(CallToolResult::success(vec![Content::text(result)]))
235    }
236
237    #[tool(description = "Describe a table")]
238    async fn describe(
239        &self,
240        #[tool(aggr)] req: DescribeRequest,
241    ) -> Result<CallToolResult, McpError> {
242        let result = self
243            .conns
244            .describe(&req.conn_id, &req.table)
245            .await
246            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
247        Ok(CallToolResult::success(vec![Content::text(result)]))
248    }
249
250    #[tool(description = "List all tables")]
251    async fn list_tables(
252        &self,
253        #[tool(aggr)] req: ListTablesRequest,
254    ) -> Result<CallToolResult, McpError> {
255        let result = self
256            .conns
257            .list_tables(&req.conn_id, &req.schema)
258            .await
259            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
260        Ok(CallToolResult::success(vec![Content::text(result)]))
261    }
262}
263
264#[tool(tool_box)]
265impl ServerHandler for PgMcp {
266    fn get_info(&self) -> ServerInfo {
267        ServerInfo {
268            instructions: Some(
269                "A Postgres MCP server that allows AI agents to interact with Postgres databases"
270                    .into(),
271            ),
272            capabilities: ServerCapabilities::builder().enable_tools().build(),
273            ..Default::default()
274        }
275    }
276}
277
278impl Default for PgMcp {
279    fn default() -> Self {
280        Self::new()
281    }
282}