1use crate::{Conns, PgMcp};
2use anyhow::Result;
3use rmcp::{
4 Error as McpError, ServerHandler,
5 model::{CallToolResult, Content, ServerCapabilities, ServerInfo},
6 schemars, tool,
7};
8
9#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
10pub struct RegisterRequest {
11 #[schemars(description = "Postgres connection string")]
12 pub conn_str: String,
13}
14
15#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
16pub struct UnregisterRequest {
17 #[schemars(description = "Connection ID to unregister")]
18 pub conn_id: String,
19}
20
21#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
22pub struct QueryRequest {
23 #[schemars(description = "Connection ID")]
24 pub conn_id: String,
25 #[schemars(
26 description = "Single SQL query, could return multiple rows. Caller should properly limit the number of rows returned."
27 )]
28 pub query: String,
29}
30
31#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
32pub struct InsertRequest {
33 #[schemars(description = "Connection ID")]
34 pub conn_id: String,
35 #[schemars(
36 description = "Single SQL insert statement, but multiple rows for the same table are allowed"
37 )]
38 pub query: String,
39}
40
41#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
42pub struct UpdateRequest {
43 #[schemars(description = "Connection ID")]
44 pub conn_id: String,
45 #[schemars(
46 description = "Single SQL update statement, could update multiple rows for the same table based on the WHERE clause"
47 )]
48 pub query: String,
49}
50
51#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
52pub struct DeleteRequest {
53 #[schemars(description = "Connection ID")]
54 pub conn_id: String,
55 #[schemars(
56 description = "Single SQL delete statement, could delete multiple rows for the same table based on the WHERE clause"
57 )]
58 pub query: String,
59}
60
61#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
62pub struct CreateTableRequest {
63 #[schemars(description = "Connection ID")]
64 pub conn_id: String,
65 #[schemars(description = "Single SQL create table statement")]
66 pub query: String,
67}
68
69#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
70pub struct DropTableRequest {
71 #[schemars(description = "Connection ID")]
72 pub conn_id: String,
73 #[schemars(
74 description = "Table name. Format: schema.table. If schema is not provided, it will use the current schema."
75 )]
76 pub table: String,
77}
78
79#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
80pub struct CreateIndexRequest {
81 #[schemars(description = "Connection ID")]
82 pub conn_id: String,
83 #[schemars(description = "SingleSQL create index statement")]
84 pub query: String,
85}
86
87#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
88pub struct DropIndexRequest {
89 #[schemars(description = "Connection ID")]
90 pub conn_id: String,
91 #[schemars(description = "Index name")]
92 pub index: String,
93}
94
95#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
96pub struct DescribeRequest {
97 #[schemars(description = "Connection ID")]
98 pub conn_id: String,
99 #[schemars(description = "Table name")]
100 pub table: String,
101}
102
103#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
104pub struct ListTablesRequest {
105 #[schemars(description = "Connection ID")]
106 pub conn_id: String,
107 #[schemars(description = "Schema name")]
108 pub schema: String,
109}
110
111#[tool(tool_box)]
112impl PgMcp {
113 pub fn new() -> Self {
114 Self {
115 conns: Conns::new(),
116 }
117 }
118
119 #[tool(description = "Register a new Postgres connection")]
120 async fn register(
121 &self,
122 #[tool(aggr)] req: RegisterRequest,
123 ) -> Result<CallToolResult, McpError> {
124 let id = self
125 .conns
126 .register(req.conn_str)
127 .await
128 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
129 Ok(CallToolResult::success(vec![Content::text(id)]))
130 }
131
132 #[tool(description = "Unregister a Postgres connection")]
133 async fn unregister(
134 &self,
135 #[tool(aggr)] req: UnregisterRequest,
136 ) -> Result<CallToolResult, McpError> {
137 self.conns
138 .unregister(req.conn_id)
139 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
140 Ok(CallToolResult::success(vec![Content::text(
141 "success".to_string(),
142 )]))
143 }
144
145 #[tool(description = "Execute a SELECT query")]
146 async fn query(&self, #[tool(aggr)] req: QueryRequest) -> Result<CallToolResult, McpError> {
147 let result = self
148 .conns
149 .query(&req.conn_id, &req.query)
150 .await
151 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
152 Ok(CallToolResult::success(vec![Content::text(result)]))
153 }
154
155 #[tool(description = "Execute an INSERT statement")]
156 async fn insert(&self, #[tool(aggr)] req: InsertRequest) -> Result<CallToolResult, McpError> {
157 let result = self
158 .conns
159 .insert(&req.conn_id, &req.query)
160 .await
161 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
162 Ok(CallToolResult::success(vec![Content::text(result)]))
163 }
164
165 #[tool(description = "Execute an UPDATE statement")]
166 async fn update(&self, #[tool(aggr)] req: UpdateRequest) -> Result<CallToolResult, McpError> {
167 let result = self
168 .conns
169 .update(&req.conn_id, &req.query)
170 .await
171 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
172 Ok(CallToolResult::success(vec![Content::text(result)]))
173 }
174
175 #[tool(description = "Delete a row from a table")]
176 async fn delete(&self, #[tool(aggr)] req: DeleteRequest) -> Result<CallToolResult, McpError> {
177 let result = self
178 .conns
179 .delete(&req.conn_id, &req.query)
180 .await
181 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
182 Ok(CallToolResult::success(vec![Content::text(result)]))
183 }
184
185 #[tool(description = "Create a new table")]
186 async fn create_table(
187 &self,
188 #[tool(aggr)] req: CreateTableRequest,
189 ) -> Result<CallToolResult, McpError> {
190 let result = self
191 .conns
192 .create_table(&req.conn_id, &req.query)
193 .await
194 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
195 Ok(CallToolResult::success(vec![Content::text(result)]))
196 }
197
198 #[tool(description = "Drop a table")]
199 async fn drop_table(
200 &self,
201 #[tool(aggr)] req: DropTableRequest,
202 ) -> Result<CallToolResult, McpError> {
203 let result = self
204 .conns
205 .drop_table(&req.conn_id, &req.table)
206 .await
207 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
208 Ok(CallToolResult::success(vec![Content::text(result)]))
209 }
210
211 #[tool(description = "Create an index")]
212 async fn create_index(
213 &self,
214 #[tool(aggr)] req: CreateIndexRequest,
215 ) -> Result<CallToolResult, McpError> {
216 let result = self
217 .conns
218 .create_index(&req.conn_id, &req.query)
219 .await
220 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
221 Ok(CallToolResult::success(vec![Content::text(result)]))
222 }
223
224 #[tool(description = "Drop an index")]
225 async fn drop_index(
226 &self,
227 #[tool(aggr)] req: DropIndexRequest,
228 ) -> Result<CallToolResult, McpError> {
229 let result = self
230 .conns
231 .drop_index(&req.conn_id, &req.index)
232 .await
233 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
234 Ok(CallToolResult::success(vec![Content::text(result)]))
235 }
236
237 #[tool(description = "Describe a table")]
238 async fn describe(
239 &self,
240 #[tool(aggr)] req: DescribeRequest,
241 ) -> Result<CallToolResult, McpError> {
242 let result = self
243 .conns
244 .describe(&req.conn_id, &req.table)
245 .await
246 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
247 Ok(CallToolResult::success(vec![Content::text(result)]))
248 }
249
250 #[tool(description = "List all tables")]
251 async fn list_tables(
252 &self,
253 #[tool(aggr)] req: ListTablesRequest,
254 ) -> Result<CallToolResult, McpError> {
255 let result = self
256 .conns
257 .list_tables(&req.conn_id, &req.schema)
258 .await
259 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
260 Ok(CallToolResult::success(vec![Content::text(result)]))
261 }
262}
263
264#[tool(tool_box)]
265impl ServerHandler for PgMcp {
266 fn get_info(&self) -> ServerInfo {
267 ServerInfo {
268 instructions: Some(
269 "A Postgres MCP server that allows AI agents to interact with Postgres databases"
270 .into(),
271 ),
272 capabilities: ServerCapabilities::builder().enable_tools().build(),
273 ..Default::default()
274 }
275 }
276}
277
278impl Default for PgMcp {
279 fn default() -> Self {
280 Self::new()
281 }
282}