postgres_mcp/
mcp.rs

1use crate::{Conns, PgMcp};
2use anyhow::Result;
3use rmcp::{
4    Error as McpError, ServerHandler,
5    model::{CallToolResult, Content, ServerCapabilities, ServerInfo},
6    schemars, tool,
7};
8
9#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
10pub struct RegisterRequest {
11    #[schemars(description = "Postgres connection string")]
12    pub conn_str: String,
13}
14
15#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
16pub struct UnregisterRequest {
17    #[schemars(description = "Connection ID to unregister")]
18    pub conn_id: String,
19}
20
21#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
22pub struct QueryRequest {
23    #[schemars(description = "Connection ID")]
24    pub conn_id: String,
25    #[schemars(
26        description = "Single SQL query, could return multiple rows. Caller should properly limit the number of rows returned."
27    )]
28    pub query: String,
29}
30
31#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
32pub struct InsertRequest {
33    #[schemars(description = "Connection ID")]
34    pub conn_id: String,
35    #[schemars(
36        description = "Single SQL insert statement, but multiple rows for the same table are allowed"
37    )]
38    pub query: String,
39}
40
41#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
42pub struct UpdateRequest {
43    #[schemars(description = "Connection ID")]
44    pub conn_id: String,
45    #[schemars(
46        description = "Single SQL update statement, could update multiple rows for the same table based on the WHERE clause"
47    )]
48    pub query: String,
49}
50
51#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
52pub struct DeleteRequest {
53    #[schemars(description = "Connection ID")]
54    pub conn_id: String,
55    #[schemars(
56        description = "Single SQL delete statement, could delete multiple rows for the same table based on the WHERE clause"
57    )]
58    pub query: String,
59}
60
61#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
62pub struct CreateTableRequest {
63    #[schemars(description = "Connection ID")]
64    pub conn_id: String,
65    #[schemars(description = "Single SQL create table statement")]
66    pub query: String,
67}
68
69#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
70pub struct DropTableRequest {
71    #[schemars(description = "Connection ID")]
72    pub conn_id: String,
73    #[schemars(
74        description = "Table name. Format: schema.table. If schema is not provided, it will use the current schema."
75    )]
76    pub table: String,
77}
78
79#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
80pub struct CreateIndexRequest {
81    #[schemars(description = "Connection ID")]
82    pub conn_id: String,
83    #[schemars(description = "SingleSQL create index statement")]
84    pub query: String,
85}
86
87#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
88pub struct DropIndexRequest {
89    #[schemars(description = "Connection ID")]
90    pub conn_id: String,
91    #[schemars(description = "Index name")]
92    pub index: String,
93}
94
95#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
96pub struct DescribeRequest {
97    #[schemars(description = "Connection ID")]
98    pub conn_id: String,
99    #[schemars(description = "Table name")]
100    pub table: String,
101}
102
103#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
104pub struct ListTablesRequest {
105    #[schemars(description = "Connection ID")]
106    pub conn_id: String,
107    #[schemars(description = "Schema name")]
108    pub schema: String,
109}
110
111#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
112pub struct CreateTypeRequest {
113    #[schemars(description = "Connection ID")]
114    pub conn_id: String,
115    #[schemars(description = "Single SQL create type statement")]
116    pub query: String,
117}
118
119#[tool(tool_box)]
120impl PgMcp {
121    pub fn new() -> Self {
122        Self {
123            conns: Conns::new(),
124        }
125    }
126
127    #[tool(description = "Register a new Postgres connection")]
128    async fn register(
129        &self,
130        #[tool(aggr)] req: RegisterRequest,
131    ) -> Result<CallToolResult, McpError> {
132        let id = self
133            .conns
134            .register(req.conn_str)
135            .await
136            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
137        Ok(CallToolResult::success(vec![Content::text(id)]))
138    }
139
140    #[tool(description = "Unregister a Postgres connection")]
141    async fn unregister(
142        &self,
143        #[tool(aggr)] req: UnregisterRequest,
144    ) -> Result<CallToolResult, McpError> {
145        self.conns
146            .unregister(req.conn_id)
147            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
148        Ok(CallToolResult::success(vec![Content::text(
149            "success".to_string(),
150        )]))
151    }
152
153    #[tool(description = "Execute a SELECT query")]
154    async fn query(&self, #[tool(aggr)] req: QueryRequest) -> Result<CallToolResult, McpError> {
155        let result = self
156            .conns
157            .query(&req.conn_id, &req.query)
158            .await
159            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
160        Ok(CallToolResult::success(vec![Content::text(result)]))
161    }
162
163    #[tool(description = "Execute an INSERT statement")]
164    async fn insert(&self, #[tool(aggr)] req: InsertRequest) -> Result<CallToolResult, McpError> {
165        let result = self
166            .conns
167            .insert(&req.conn_id, &req.query)
168            .await
169            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
170        Ok(CallToolResult::success(vec![Content::text(result)]))
171    }
172
173    #[tool(description = "Execute an UPDATE statement")]
174    async fn update(&self, #[tool(aggr)] req: UpdateRequest) -> Result<CallToolResult, McpError> {
175        let result = self
176            .conns
177            .update(&req.conn_id, &req.query)
178            .await
179            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
180        Ok(CallToolResult::success(vec![Content::text(result)]))
181    }
182
183    #[tool(description = "Delete a row from a table")]
184    async fn delete(&self, #[tool(aggr)] req: DeleteRequest) -> Result<CallToolResult, McpError> {
185        let result = self
186            .conns
187            .delete(&req.conn_id, &req.query)
188            .await
189            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
190        Ok(CallToolResult::success(vec![Content::text(result)]))
191    }
192
193    #[tool(description = "Create a new table")]
194    async fn create_table(
195        &self,
196        #[tool(aggr)] req: CreateTableRequest,
197    ) -> Result<CallToolResult, McpError> {
198        let result = self
199            .conns
200            .create_table(&req.conn_id, &req.query)
201            .await
202            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
203        Ok(CallToolResult::success(vec![Content::text(result)]))
204    }
205
206    #[tool(description = "Drop a table")]
207    async fn drop_table(
208        &self,
209        #[tool(aggr)] req: DropTableRequest,
210    ) -> Result<CallToolResult, McpError> {
211        let result = self
212            .conns
213            .drop_table(&req.conn_id, &req.table)
214            .await
215            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
216        Ok(CallToolResult::success(vec![Content::text(result)]))
217    }
218
219    #[tool(description = "Create an index")]
220    async fn create_index(
221        &self,
222        #[tool(aggr)] req: CreateIndexRequest,
223    ) -> Result<CallToolResult, McpError> {
224        let result = self
225            .conns
226            .create_index(&req.conn_id, &req.query)
227            .await
228            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
229        Ok(CallToolResult::success(vec![Content::text(result)]))
230    }
231
232    #[tool(description = "Drop an index")]
233    async fn drop_index(
234        &self,
235        #[tool(aggr)] req: DropIndexRequest,
236    ) -> Result<CallToolResult, McpError> {
237        let result = self
238            .conns
239            .drop_index(&req.conn_id, &req.index)
240            .await
241            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
242        Ok(CallToolResult::success(vec![Content::text(result)]))
243    }
244
245    #[tool(description = "Describe a table")]
246    async fn describe(
247        &self,
248        #[tool(aggr)] req: DescribeRequest,
249    ) -> Result<CallToolResult, McpError> {
250        let result = self
251            .conns
252            .describe(&req.conn_id, &req.table)
253            .await
254            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
255        Ok(CallToolResult::success(vec![Content::text(result)]))
256    }
257
258    #[tool(description = "List all tables")]
259    async fn list_tables(
260        &self,
261        #[tool(aggr)] req: ListTablesRequest,
262    ) -> Result<CallToolResult, McpError> {
263        let result = self
264            .conns
265            .list_tables(&req.conn_id, &req.schema)
266            .await
267            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
268        Ok(CallToolResult::success(vec![Content::text(result)]))
269    }
270
271    #[tool(description = "Create a new type")]
272    async fn create_type(
273        &self,
274        #[tool(aggr)] req: CreateTypeRequest,
275    ) -> Result<CallToolResult, McpError> {
276        let result = self
277            .conns
278            .create_type(&req.conn_id, &req.query)
279            .await
280            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
281        Ok(CallToolResult::success(vec![Content::text(result)]))
282    }
283}
284
285#[tool(tool_box)]
286impl ServerHandler for PgMcp {
287    fn get_info(&self) -> ServerInfo {
288        ServerInfo {
289            instructions: Some(
290                "A Postgres MCP server that allows AI agents to interact with Postgres databases"
291                    .into(),
292            ),
293            capabilities: ServerCapabilities::builder().enable_tools().build(),
294            ..Default::default()
295        }
296    }
297}
298
299impl Default for PgMcp {
300    fn default() -> Self {
301        Self::new()
302    }
303}