1use crate::pg::PgMcpError;
2use crate::{Conns, PgMcp};
3use anyhow::Result;
4use rmcp::{
5 Error as McpError, ServerHandler,
6 model::{CallToolResult, Content, ServerCapabilities, ServerInfo},
7 schemars, tool,
8};
9
10#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
11pub struct RegisterRequest {
12 #[schemars(description = "Postgres connection string")]
13 pub conn_str: String,
14}
15
16#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
17pub struct UnregisterRequest {
18 #[schemars(description = "Connection ID to unregister")]
19 pub conn_id: String,
20}
21
22#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
23pub struct QueryRequest {
24 #[schemars(description = "Connection ID")]
25 pub conn_id: String,
26 #[schemars(
27 description = "Single SQL query, could return multiple rows. Caller should properly limit the number of rows returned."
28 )]
29 pub query: String,
30}
31
32#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
33pub struct InsertRequest {
34 #[schemars(description = "Connection ID")]
35 pub conn_id: String,
36 #[schemars(
37 description = "Single SQL insert statement, but multiple rows for the same table are allowed"
38 )]
39 pub query: String,
40}
41
42#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
43pub struct UpdateRequest {
44 #[schemars(description = "Connection ID")]
45 pub conn_id: String,
46 #[schemars(
47 description = "Single SQL update statement, could update multiple rows for the same table based on the WHERE clause"
48 )]
49 pub query: String,
50}
51
52#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
53pub struct DeleteRequest {
54 #[schemars(description = "Connection ID")]
55 pub conn_id: String,
56 #[schemars(
57 description = "Single SQL delete statement, could delete multiple rows for the same table based on the WHERE clause"
58 )]
59 pub query: String,
60}
61
62#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
63pub struct CreateTableRequest {
64 #[schemars(description = "Connection ID")]
65 pub conn_id: String,
66 #[schemars(description = "Single SQL create table statement")]
67 pub query: String,
68}
69
70#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
71pub struct DropTableRequest {
72 #[schemars(description = "Connection ID")]
73 pub conn_id: String,
74 #[schemars(
75 description = "Table name. Format: schema.table. If schema is not provided, it will use the current schema."
76 )]
77 pub table: String,
78}
79
80#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
81pub struct CreateIndexRequest {
82 #[schemars(description = "Connection ID")]
83 pub conn_id: String,
84 #[schemars(description = "SingleSQL create index statement")]
85 pub query: String,
86}
87
88#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
89pub struct DropIndexRequest {
90 #[schemars(description = "Connection ID")]
91 pub conn_id: String,
92 #[schemars(description = "Index name")]
93 pub index: String,
94}
95
96#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
97pub struct DescribeRequest {
98 #[schemars(description = "Connection ID")]
99 pub conn_id: String,
100 #[schemars(description = "Table name")]
101 pub table: String,
102}
103
104#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
105pub struct ListTablesRequest {
106 #[schemars(description = "Connection ID")]
107 pub conn_id: String,
108 #[schemars(description = "Schema name")]
109 pub schema: String,
110}
111
112#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
113pub struct CreateSchemaRequest {
114 #[schemars(description = "Connection ID")]
115 pub conn_id: String,
116 #[schemars(description = "Schema name")]
117 pub name: String,
118}
119
120#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
121pub struct CreateTypeRequest {
122 #[schemars(description = "Connection ID")]
123 pub conn_id: String,
124 #[schemars(description = "Single SQL create type statement")]
125 pub query: String,
126}
127
128fn map_pg_error(e: PgMcpError) -> McpError {
130 match e {
131 PgMcpError::ConnectionNotFound(id) => McpError::internal_error(
132 format!("Invalid Argument: Connection not found for ID: {}", id),
133 None,
134 ),
135 PgMcpError::ValidationFailed {
136 kind,
137 query,
138 details,
139 } => McpError::internal_error(
140 format!(
141 "Invalid Argument: SQL validation failed for query '{}': {} - {}",
142 query, kind, details
143 ),
144 None,
145 ),
146 PgMcpError::DatabaseError {
147 operation,
148 underlying,
149 } => McpError::internal_error(
150 format!("Database operation '{}' failed: {}", operation, underlying),
151 None,
152 ),
153 PgMcpError::SerializationError(se) => {
154 McpError::internal_error(format!("Result serialization failed: {}", se), None)
155 }
156 PgMcpError::ConnectionError(ce) => {
157 McpError::internal_error(format!("Database connection failed: {}", ce), None)
158 }
159 PgMcpError::InternalError(ie) => {
160 McpError::internal_error(format!("Internal error: {}", ie), None)
161 }
162 }
163}
164
165#[tool(tool_box)]
166impl PgMcp {
167 pub fn new() -> Self {
168 Self {
169 conns: Conns::new(),
170 }
171 }
172
173 #[tool(description = "Register a new Postgres connection")]
174 async fn register(
175 &self,
176 #[tool(aggr)] req: RegisterRequest,
177 ) -> Result<CallToolResult, McpError> {
178 let id = self
179 .conns
180 .register(req.conn_str)
181 .await
182 .map_err(map_pg_error)?;
183 Ok(CallToolResult::success(vec![Content::text(id)]))
184 }
185
186 #[tool(description = "Unregister a Postgres connection")]
187 async fn unregister(
188 &self,
189 #[tool(aggr)] req: UnregisterRequest,
190 ) -> Result<CallToolResult, McpError> {
191 self.conns.unregister(req.conn_id).map_err(map_pg_error)?;
192 Ok(CallToolResult::success(vec![Content::text(
193 "success".to_string(),
194 )]))
195 }
196
197 #[tool(description = "Execute a SELECT query")]
198 async fn query(&self, #[tool(aggr)] req: QueryRequest) -> Result<CallToolResult, McpError> {
199 let result = self
200 .conns
201 .query(&req.conn_id, &req.query)
202 .await
203 .map_err(map_pg_error)?;
204 Ok(CallToolResult::success(vec![Content::text(result)]))
205 }
206
207 #[tool(description = "Execute an INSERT statement")]
208 async fn insert(&self, #[tool(aggr)] req: InsertRequest) -> Result<CallToolResult, McpError> {
209 let result = self
210 .conns
211 .insert(&req.conn_id, &req.query)
212 .await
213 .map_err(map_pg_error)?;
214 Ok(CallToolResult::success(vec![Content::text(result)]))
215 }
216
217 #[tool(description = "Execute an UPDATE statement")]
218 async fn update(&self, #[tool(aggr)] req: UpdateRequest) -> Result<CallToolResult, McpError> {
219 let result = self
220 .conns
221 .update(&req.conn_id, &req.query)
222 .await
223 .map_err(map_pg_error)?;
224 Ok(CallToolResult::success(vec![Content::text(result)]))
225 }
226
227 #[tool(description = "Delete a row from a table")]
228 async fn delete(&self, #[tool(aggr)] req: DeleteRequest) -> Result<CallToolResult, McpError> {
229 let result = self
230 .conns
231 .delete(&req.conn_id, &req.query)
232 .await
233 .map_err(map_pg_error)?;
234 Ok(CallToolResult::success(vec![Content::text(result)]))
235 }
236
237 #[tool(description = "Create a new table")]
238 async fn create_table(
239 &self,
240 #[tool(aggr)] req: CreateTableRequest,
241 ) -> Result<CallToolResult, McpError> {
242 let result = self
243 .conns
244 .create_table(&req.conn_id, &req.query)
245 .await
246 .map_err(map_pg_error)?;
247 Ok(CallToolResult::success(vec![Content::text(result)]))
248 }
249
250 #[tool(description = "Drop a table")]
251 async fn drop_table(
252 &self,
253 #[tool(aggr)] req: DropTableRequest,
254 ) -> Result<CallToolResult, McpError> {
255 let result = self
256 .conns
257 .drop_table(&req.conn_id, &req.table)
258 .await
259 .map_err(map_pg_error)?;
260 Ok(CallToolResult::success(vec![Content::text(result)]))
261 }
262
263 #[tool(description = "Create an index")]
264 async fn create_index(
265 &self,
266 #[tool(aggr)] req: CreateIndexRequest,
267 ) -> Result<CallToolResult, McpError> {
268 let result = self
269 .conns
270 .create_index(&req.conn_id, &req.query)
271 .await
272 .map_err(map_pg_error)?;
273 Ok(CallToolResult::success(vec![Content::text(result)]))
274 }
275
276 #[tool(description = "Drop an index")]
277 async fn drop_index(
278 &self,
279 #[tool(aggr)] req: DropIndexRequest,
280 ) -> Result<CallToolResult, McpError> {
281 let result = self
282 .conns
283 .drop_index(&req.conn_id, &req.index)
284 .await
285 .map_err(map_pg_error)?;
286 Ok(CallToolResult::success(vec![Content::text(result)]))
287 }
288
289 #[tool(description = "Describe a table")]
290 async fn describe(
291 &self,
292 #[tool(aggr)] req: DescribeRequest,
293 ) -> Result<CallToolResult, McpError> {
294 let result = self
295 .conns
296 .describe(&req.conn_id, &req.table)
297 .await
298 .map_err(map_pg_error)?;
299 Ok(CallToolResult::success(vec![Content::text(result)]))
300 }
301
302 #[tool(description = "List tables in a schema")]
303 async fn list_tables(
304 &self,
305 #[tool(aggr)] req: ListTablesRequest,
306 ) -> Result<CallToolResult, McpError> {
307 let result = self
308 .conns
309 .list_tables(&req.conn_id, &req.schema)
310 .await
311 .map_err(map_pg_error)?;
312 Ok(CallToolResult::success(vec![Content::text(result)]))
313 }
314
315 #[tool(description = "Create a new schema")]
316 async fn create_schema(
317 &self,
318 #[tool(aggr)] req: CreateSchemaRequest,
319 ) -> Result<CallToolResult, McpError> {
320 let result = self
321 .conns
322 .create_schema(&req.conn_id, &req.name)
323 .await
324 .map_err(map_pg_error)?;
325 Ok(CallToolResult::success(vec![Content::text(result)]))
326 }
327
328 #[tool(description = "Create a new type")]
329 async fn create_type(
330 &self,
331 #[tool(aggr)] req: CreateTypeRequest,
332 ) -> Result<CallToolResult, McpError> {
333 let result = self
334 .conns
335 .create_type(&req.conn_id, &req.query)
336 .await
337 .map_err(map_pg_error)?;
338 Ok(CallToolResult::success(vec![Content::text(result)]))
339 }
340}
341
342#[tool(tool_box)]
343impl ServerHandler for PgMcp {
344 fn get_info(&self) -> ServerInfo {
345 ServerInfo {
346 instructions: Some(
347 "A Postgres MCP server that allows AI agents to interact with Postgres databases"
348 .into(),
349 ),
350 capabilities: ServerCapabilities::builder().enable_tools().build(),
351 ..Default::default()
352 }
353 }
354}
355
356impl Default for PgMcp {
357 fn default() -> Self {
358 Self::new()
359 }
360}