1use crate::{Conns, PgMcp};
2use anyhow::Result;
3use rmcp::{
4 Error as McpError, ServerHandler,
5 model::{CallToolResult, Content, ServerCapabilities, ServerInfo},
6 schemars, tool,
7};
8
9#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
10pub struct RegisterRequest {
11 #[schemars(description = "Postgres connection string")]
12 pub conn_str: String,
13}
14
15#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
16pub struct UnregisterRequest {
17 #[schemars(description = "Connection ID to unregister")]
18 pub conn_id: String,
19}
20
21#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
22pub struct QueryRequest {
23 #[schemars(description = "Connection ID")]
24 pub conn_id: String,
25 #[schemars(
26 description = "Single SQL query, could return multiple rows. Caller should properly limit the number of rows returned."
27 )]
28 pub query: String,
29}
30
31#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
32pub struct InsertRequest {
33 #[schemars(description = "Connection ID")]
34 pub conn_id: String,
35 #[schemars(
36 description = "Single SQL insert statement, but multiple rows for the same table are allowed"
37 )]
38 pub query: String,
39}
40
41#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
42pub struct UpdateRequest {
43 #[schemars(description = "Connection ID")]
44 pub conn_id: String,
45 #[schemars(
46 description = "Single SQL update statement, could update multiple rows for the same table based on the WHERE clause"
47 )]
48 pub query: String,
49}
50
51#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
52pub struct DeleteRequest {
53 #[schemars(description = "Connection ID")]
54 pub conn_id: String,
55 #[schemars(
56 description = "Single SQL delete statement, could delete multiple rows for the same table based on the WHERE clause"
57 )]
58 pub query: String,
59}
60
61#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
62pub struct CreateTableRequest {
63 #[schemars(description = "Connection ID")]
64 pub conn_id: String,
65 #[schemars(description = "Single SQL create table statement")]
66 pub query: String,
67}
68
69#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
70pub struct DropTableRequest {
71 #[schemars(description = "Connection ID")]
72 pub conn_id: String,
73 #[schemars(
74 description = "Table name. Format: schema.table. If schema is not provided, it will use the current schema."
75 )]
76 pub table: String,
77}
78
79#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
80pub struct CreateIndexRequest {
81 #[schemars(description = "Connection ID")]
82 pub conn_id: String,
83 #[schemars(description = "SingleSQL create index statement")]
84 pub query: String,
85}
86
87#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
88pub struct DropIndexRequest {
89 #[schemars(description = "Connection ID")]
90 pub conn_id: String,
91 #[schemars(description = "Index name")]
92 pub index: String,
93}
94
95#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
96pub struct DescribeRequest {
97 #[schemars(description = "Connection ID")]
98 pub conn_id: String,
99 #[schemars(description = "Table name")]
100 pub table: String,
101}
102
103#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
104pub struct ListTablesRequest {
105 #[schemars(description = "Connection ID")]
106 pub conn_id: String,
107 #[schemars(description = "Schema name")]
108 pub schema: String,
109}
110
111#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
112pub struct CreateTypeRequest {
113 #[schemars(description = "Connection ID")]
114 pub conn_id: String,
115 #[schemars(description = "Single SQL create type statement")]
116 pub query: String,
117}
118
119#[tool(tool_box)]
120impl PgMcp {
121 pub fn new() -> Self {
122 Self {
123 conns: Conns::new(),
124 }
125 }
126
127 #[tool(description = "Register a new Postgres connection")]
128 async fn register(
129 &self,
130 #[tool(aggr)] req: RegisterRequest,
131 ) -> Result<CallToolResult, McpError> {
132 let id = self
133 .conns
134 .register(req.conn_str)
135 .await
136 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
137 Ok(CallToolResult::success(vec![Content::text(id)]))
138 }
139
140 #[tool(description = "Unregister a Postgres connection")]
141 async fn unregister(
142 &self,
143 #[tool(aggr)] req: UnregisterRequest,
144 ) -> Result<CallToolResult, McpError> {
145 self.conns
146 .unregister(req.conn_id)
147 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
148 Ok(CallToolResult::success(vec![Content::text(
149 "success".to_string(),
150 )]))
151 }
152
153 #[tool(description = "Execute a SELECT query")]
154 async fn query(&self, #[tool(aggr)] req: QueryRequest) -> Result<CallToolResult, McpError> {
155 let result = self
156 .conns
157 .query(&req.conn_id, &req.query)
158 .await
159 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
160 Ok(CallToolResult::success(vec![Content::text(result)]))
161 }
162
163 #[tool(description = "Execute an INSERT statement")]
164 async fn insert(&self, #[tool(aggr)] req: InsertRequest) -> Result<CallToolResult, McpError> {
165 let result = self
166 .conns
167 .insert(&req.conn_id, &req.query)
168 .await
169 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
170 Ok(CallToolResult::success(vec![Content::text(result)]))
171 }
172
173 #[tool(description = "Execute an UPDATE statement")]
174 async fn update(&self, #[tool(aggr)] req: UpdateRequest) -> Result<CallToolResult, McpError> {
175 let result = self
176 .conns
177 .update(&req.conn_id, &req.query)
178 .await
179 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
180 Ok(CallToolResult::success(vec![Content::text(result)]))
181 }
182
183 #[tool(description = "Delete a row from a table")]
184 async fn delete(&self, #[tool(aggr)] req: DeleteRequest) -> Result<CallToolResult, McpError> {
185 let result = self
186 .conns
187 .delete(&req.conn_id, &req.query)
188 .await
189 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
190 Ok(CallToolResult::success(vec![Content::text(result)]))
191 }
192
193 #[tool(description = "Create a new table")]
194 async fn create_table(
195 &self,
196 #[tool(aggr)] req: CreateTableRequest,
197 ) -> Result<CallToolResult, McpError> {
198 let result = self
199 .conns
200 .create_table(&req.conn_id, &req.query)
201 .await
202 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
203 Ok(CallToolResult::success(vec![Content::text(result)]))
204 }
205
206 #[tool(description = "Drop a table")]
207 async fn drop_table(
208 &self,
209 #[tool(aggr)] req: DropTableRequest,
210 ) -> Result<CallToolResult, McpError> {
211 let result = self
212 .conns
213 .drop_table(&req.conn_id, &req.table)
214 .await
215 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
216 Ok(CallToolResult::success(vec![Content::text(result)]))
217 }
218
219 #[tool(description = "Create an index")]
220 async fn create_index(
221 &self,
222 #[tool(aggr)] req: CreateIndexRequest,
223 ) -> Result<CallToolResult, McpError> {
224 let result = self
225 .conns
226 .create_index(&req.conn_id, &req.query)
227 .await
228 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
229 Ok(CallToolResult::success(vec![Content::text(result)]))
230 }
231
232 #[tool(description = "Drop an index")]
233 async fn drop_index(
234 &self,
235 #[tool(aggr)] req: DropIndexRequest,
236 ) -> Result<CallToolResult, McpError> {
237 let result = self
238 .conns
239 .drop_index(&req.conn_id, &req.index)
240 .await
241 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
242 Ok(CallToolResult::success(vec![Content::text(result)]))
243 }
244
245 #[tool(description = "Describe a table")]
246 async fn describe(
247 &self,
248 #[tool(aggr)] req: DescribeRequest,
249 ) -> Result<CallToolResult, McpError> {
250 let result = self
251 .conns
252 .describe(&req.conn_id, &req.table)
253 .await
254 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
255 Ok(CallToolResult::success(vec![Content::text(result)]))
256 }
257
258 #[tool(description = "List all tables")]
259 async fn list_tables(
260 &self,
261 #[tool(aggr)] req: ListTablesRequest,
262 ) -> Result<CallToolResult, McpError> {
263 let result = self
264 .conns
265 .list_tables(&req.conn_id, &req.schema)
266 .await
267 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
268 Ok(CallToolResult::success(vec![Content::text(result)]))
269 }
270
271 #[tool(description = "Create a new type")]
272 async fn create_type(
273 &self,
274 #[tool(aggr)] req: CreateTypeRequest,
275 ) -> Result<CallToolResult, McpError> {
276 let result = self
277 .conns
278 .create_type(&req.conn_id, &req.query)
279 .await
280 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
281 Ok(CallToolResult::success(vec![Content::text(result)]))
282 }
283}
284
285#[tool(tool_box)]
286impl ServerHandler for PgMcp {
287 fn get_info(&self) -> ServerInfo {
288 ServerInfo {
289 instructions: Some(
290 "A Postgres MCP server that allows AI agents to interact with Postgres databases"
291 .into(),
292 ),
293 capabilities: ServerCapabilities::builder().enable_tools().build(),
294 ..Default::default()
295 }
296 }
297}
298
299impl Default for PgMcp {
300 fn default() -> Self {
301 Self::new()
302 }
303}