1use std::sync::Arc;
8
9use crate::db::backend::{Backend, DatabaseBackend};
10use crate::db::validation::validate_read_only_with_dialect;
11use rmcp::handler::server::common::{FromContextPart, schema_for_empty_input, schema_for_type};
12use rmcp::handler::server::router::tool::{ToolRoute, ToolRouter};
13use rmcp::handler::server::tool::ToolCallContext;
14use rmcp::handler::server::wrapper::Parameters;
15use rmcp::model::{
16 CallToolRequestParams, CallToolResult, Content, ErrorData, Implementation, ListToolsResult, PaginatedRequestParams,
17 ServerCapabilities, ServerInfo, Tool, ToolAnnotations,
18};
19use rmcp::schemars;
20use rmcp::schemars::JsonSchema;
21use rmcp::service::RequestContext;
22use rmcp::{RoleServer, ServerHandler};
23use serde::Deserialize;
24use serde_json::Map as JsonObject;
25
26#[derive(Debug, Deserialize, JsonSchema)]
32pub struct ListTablesRequest {
33 #[schemars(
34 description = "The database name to list tables from. Required. Use list_databases first to see available databases."
35 )]
36 pub database_name: String,
37}
38
39#[derive(Debug, Deserialize, JsonSchema)]
41pub struct GetTableSchemaRequest {
42 #[schemars(
43 description = "The database name containing the table. Required. Use list_databases first to see available databases."
44 )]
45 pub database_name: String,
46 #[schemars(
47 description = "The table name to inspect. Use list_tables first to see available tables in the database."
48 )]
49 pub table_name: String,
50}
51
52#[derive(Debug, Deserialize, JsonSchema)]
54pub struct QueryRequest {
55 #[schemars(description = "The SQL query to execute.")]
56 pub sql_query: String,
57 #[schemars(
58 description = "The database to run the query against. Required. Use list_databases first to see available databases."
59 )]
60 pub database_name: String,
61}
62
63#[derive(Debug, Deserialize, JsonSchema)]
65pub struct CreateDatabaseRequest {
66 #[schemars(
67 description = "Name of the database to create. Must contain only alphanumeric characters and underscores."
68 )]
69 pub database_name: String,
70}
71
72fn schema_for<T: JsonSchema + 'static>() -> Arc<JsonObject<String, serde_json::Value>> {
78 schema_for_type::<Parameters<T>>()
79}
80
81#[must_use]
83fn list_databases_route() -> ToolRoute<Server> {
84 ToolRoute::new_dyn(
85 Tool::new(
86 "list_databases",
87 "List all accessible databases on the connected database server. Call this first to discover available database names.",
88 schema_for_empty_input(),
89 )
90 .with_annotations(
91 ToolAnnotations::new()
92 .read_only(true)
93 .destructive(false)
94 .idempotent(true)
95 .open_world(false),
96 ),
97 |ctx: ToolCallContext<'_, Server>| {
98 let server = ctx.service;
99 Box::pin(async move { server.list_databases().await })
100 },
101 )
102}
103
104#[must_use]
106fn list_tables_route() -> ToolRoute<Server> {
107 ToolRoute::new_dyn(
108 Tool::new(
109 "list_tables",
110 "List all tables in a specific database. Requires database_name from list_databases.",
111 schema_for::<ListTablesRequest>(),
112 )
113 .with_annotations(
114 ToolAnnotations::new()
115 .read_only(true)
116 .destructive(false)
117 .idempotent(true)
118 .open_world(false),
119 ),
120 |mut ctx: ToolCallContext<'_, Server>| {
121 let params = Parameters::<ListTablesRequest>::from_context_part(&mut ctx);
122 let server = ctx.service;
123 Box::pin(async move {
124 let params = params?;
125 server.list_tables(params).await
126 })
127 },
128 )
129}
130
131#[must_use]
133fn get_table_schema_route() -> ToolRoute<Server> {
134 ToolRoute::new_dyn(
135 Tool::new(
136 "get_table_schema",
137 "Get column definitions (type, nullable, key, default) for a table. Requires database_name and table_name.",
138 schema_for::<GetTableSchemaRequest>(),
139 )
140 .with_annotations(
141 ToolAnnotations::new()
142 .read_only(true)
143 .destructive(false)
144 .idempotent(true)
145 .open_world(false),
146 ),
147 |mut ctx: ToolCallContext<'_, Server>| {
148 let params = Parameters::<GetTableSchemaRequest>::from_context_part(&mut ctx);
149 let server = ctx.service;
150 Box::pin(async move {
151 let params = params?;
152 server.get_table_schema(params).await
153 })
154 },
155 )
156}
157
158#[must_use]
160fn get_table_schema_with_relations_route() -> ToolRoute<Server> {
161 ToolRoute::new_dyn(
162 Tool::new(
163 "get_table_schema_with_relations",
164 "Get column definitions plus foreign key relationships for a table. Requires database_name and table_name.",
165 schema_for::<GetTableSchemaRequest>(),
166 )
167 .with_annotations(
168 ToolAnnotations::new()
169 .read_only(true)
170 .destructive(false)
171 .idempotent(true)
172 .open_world(false),
173 ),
174 |mut ctx: ToolCallContext<'_, Server>| {
175 let params = Parameters::<GetTableSchemaRequest>::from_context_part(&mut ctx);
176 let server = ctx.service;
177 Box::pin(async move {
178 let params = params?;
179 server.get_table_schema_with_relations(params).await
180 })
181 },
182 )
183}
184
185#[must_use]
187fn read_query_route() -> ToolRoute<Server> {
188 ToolRoute::new_dyn(
189 Tool::new(
190 "read_query",
191 "Execute a read-only SQL query (SELECT, SHOW, DESCRIBE, USE, EXPLAIN).",
192 schema_for::<QueryRequest>(),
193 )
194 .with_annotations(
195 ToolAnnotations::new()
196 .read_only(true)
197 .destructive(false)
198 .idempotent(true)
199 .open_world(true),
200 ),
201 |mut ctx: ToolCallContext<'_, Server>| {
202 let params = Parameters::<QueryRequest>::from_context_part(&mut ctx);
203 let server = ctx.service;
204 Box::pin(async move {
205 let params = params?;
206 server.read_query(params).await
207 })
208 },
209 )
210}
211
212#[must_use]
214fn write_query_route() -> ToolRoute<Server> {
215 ToolRoute::new_dyn(
216 Tool::new(
217 "write_query",
218 "Execute a write SQL query (INSERT, UPDATE, DELETE, CREATE, ALTER, DROP).",
219 schema_for::<QueryRequest>(),
220 )
221 .with_annotations(
222 ToolAnnotations::new()
223 .read_only(false)
224 .destructive(true)
225 .idempotent(false)
226 .open_world(true),
227 ),
228 |mut ctx: ToolCallContext<'_, Server>| {
229 let params = Parameters::<QueryRequest>::from_context_part(&mut ctx);
230 let server = ctx.service;
231 Box::pin(async move {
232 let params = params?;
233 server.write_query(params).await
234 })
235 },
236 )
237}
238
239#[must_use]
241fn create_database_route() -> ToolRoute<Server> {
242 ToolRoute::new_dyn(
243 Tool::new(
244 "create_database",
245 "Create a new database. Not supported for SQLite.",
246 schema_for::<CreateDatabaseRequest>(),
247 )
248 .with_annotations(
249 ToolAnnotations::new()
250 .read_only(false)
251 .destructive(false)
252 .idempotent(false)
253 .open_world(false),
254 ),
255 |mut ctx: ToolCallContext<'_, Server>| {
256 let params = Parameters::<CreateDatabaseRequest>::from_context_part(&mut ctx);
257 let server = ctx.service;
258 Box::pin(async move {
259 let params = params?;
260 server.create_database(params).await
261 })
262 },
263 )
264}
265
266fn map_error(e: impl std::fmt::Display) -> ErrorData {
271 ErrorData::internal_error(e.to_string(), None)
272}
273
274#[derive(Clone)]
276pub struct Server {
277 pub backend: Backend,
279 tool_router: ToolRouter<Self>,
280}
281
282impl std::fmt::Debug for Server {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 f.debug_struct("Server").finish_non_exhaustive()
285 }
286}
287
288impl Server {
289 #[must_use]
294 pub fn new(backend: Backend) -> Self {
295 let tool_router = Self::build_tool_router(&backend);
296 Self { backend, tool_router }
297 }
298
299 fn build_tool_router(backend: &Backend) -> ToolRouter<Self> {
305 let mut router = ToolRouter::new();
306
307 if !matches!(backend, Backend::Sqlite(_)) {
308 router.add_route(list_databases_route());
309 }
310
311 router.add_route(list_tables_route());
312 router.add_route(get_table_schema_route());
313 router.add_route(get_table_schema_with_relations_route());
314 router.add_route(read_query_route());
315
316 if backend.read_only() {
317 return router;
318 }
319
320 router.add_route(write_query_route());
321
322 if !matches!(backend, Backend::Sqlite(_)) {
323 router.add_route(create_database_route());
324 }
325
326 router
327 }
328}
329
330impl Server {
335 pub async fn list_databases(&self) -> Result<CallToolResult, ErrorData> {
341 let result = self.backend.tool_list_databases().await.map_err(map_error)?;
342 Ok(CallToolResult::success(vec![Content::text(result)]))
343 }
344
345 pub async fn list_tables(&self, req: Parameters<ListTablesRequest>) -> Result<CallToolResult, ErrorData> {
351 let result = self
352 .backend
353 .tool_list_tables(&req.0.database_name)
354 .await
355 .map_err(map_error)?;
356 Ok(CallToolResult::success(vec![Content::text(result)]))
357 }
358
359 pub async fn get_table_schema(&self, req: Parameters<GetTableSchemaRequest>) -> Result<CallToolResult, ErrorData> {
365 let result = self
366 .backend
367 .tool_get_table_schema(&req.0.database_name, &req.0.table_name)
368 .await
369 .map_err(map_error)?;
370 Ok(CallToolResult::success(vec![Content::text(result)]))
371 }
372
373 pub async fn get_table_schema_with_relations(
379 &self,
380 req: Parameters<GetTableSchemaRequest>,
381 ) -> Result<CallToolResult, ErrorData> {
382 let result = self
383 .backend
384 .tool_get_table_schema_with_relations(&req.0.database_name, &req.0.table_name)
385 .await
386 .map_err(map_error)?;
387 Ok(CallToolResult::success(vec![Content::text(result)]))
388 }
389
390 pub async fn read_query(&self, req: Parameters<QueryRequest>) -> Result<CallToolResult, ErrorData> {
400 {
402 let dialect = self.backend.dialect();
403 validate_read_only_with_dialect(&req.0.sql_query, dialect.as_ref()).map_err(map_error)?;
404 }
405
406 let result = self
407 .backend
408 .tool_execute_sql(&req.0.sql_query, &req.0.database_name)
409 .await
410 .map_err(map_error)?;
411 Ok(CallToolResult::success(vec![Content::text(result)]))
412 }
413
414 pub async fn write_query(&self, req: Parameters<QueryRequest>) -> Result<CallToolResult, ErrorData> {
423 let result = self
424 .backend
425 .tool_execute_sql(&req.0.sql_query, &req.0.database_name)
426 .await
427 .map_err(map_error)?;
428 Ok(CallToolResult::success(vec![Content::text(result)]))
429 }
430
431 pub async fn create_database(&self, req: Parameters<CreateDatabaseRequest>) -> Result<CallToolResult, ErrorData> {
437 let result = self
438 .backend
439 .tool_create_database(&req.0.database_name)
440 .await
441 .map_err(map_error)?;
442 Ok(CallToolResult::success(vec![Content::text(result)]))
443 }
444}
445
446impl ServerHandler for Server {
451 fn get_info(&self) -> ServerInfo {
452 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
453 .with_server_info(Implementation::new(env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")))
454 .with_instructions(
455 "Database MCP Server - provides database exploration and query tools for MySQL, MariaDB, PostgreSQL, and SQLite",
456 )
457 }
458
459 async fn list_tools(
460 &self,
461 _request: Option<PaginatedRequestParams>,
462 _context: RequestContext<RoleServer>,
463 ) -> Result<ListToolsResult, ErrorData> {
464 Ok(ListToolsResult {
465 tools: self.tool_router.list_all(),
466 next_cursor: None,
467 meta: None,
468 })
469 }
470
471 async fn call_tool(
472 &self,
473 request: CallToolRequestParams,
474 context: RequestContext<RoleServer>,
475 ) -> Result<CallToolResult, ErrorData> {
476 let tcc = ToolCallContext::new(self, request, context);
477 self.tool_router.call(tcc).await
478 }
479
480 fn get_tool(&self, name: &str) -> Option<Tool> {
481 self.tool_router.get(name).cloned()
482 }
483}
484
485#[cfg(test)]
490mod tests {
491 use super::*;
492 use crate::error::AppError;
493
494 #[test]
495 fn map_error_converts_display_to_error_data() {
496 let err = AppError::ReadOnlyViolation;
497 let mapped = map_error(err);
498 assert!(
499 mapped.message.contains("read-only"),
500 "mapped error should preserve the original message"
501 );
502 }
503
504 #[test]
505 fn map_error_converts_string_to_error_data() {
506 let mapped = map_error("something went wrong");
507 assert_eq!(mapped.message, "something went wrong");
508 }
509
510 #[test]
511 fn get_info_returns_tools_capability_and_server_info() {
512 let info = ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
513 .with_server_info(Implementation::new(env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")));
514 assert!(info.capabilities.tools.is_some(), "tools capability should be enabled");
515 assert_eq!(info.server_info.name, "database-mcp");
516 assert!(!info.server_info.version.is_empty(), "version should not be empty");
517 }
518
519 #[test]
522 fn list_databases_route_has_correct_name_and_empty_schema() {
523 let route = list_databases_route();
524 assert_eq!(route.attr.name.as_ref(), "list_databases");
525 assert!(
526 route
527 .attr
528 .description
529 .as_deref()
530 .is_some_and(|d| d.contains("List all accessible databases")),
531 "description should mention listing databases"
532 );
533 let schema = &route.attr.input_schema;
534 assert_eq!(schema.get("type").and_then(|v| v.as_str()), Some("object"));
535 }
536
537 #[test]
538 fn list_tables_route_has_correct_name_and_schema() {
539 let route = list_tables_route();
540 assert_eq!(route.attr.name.as_ref(), "list_tables");
541 let props = route.attr.input_schema.get("properties").and_then(|v| v.as_object());
542 assert!(
543 props.is_some_and(|p| p.contains_key("database_name")),
544 "schema should have database_name property"
545 );
546 }
547
548 #[test]
549 fn get_table_schema_route_has_correct_name_and_schema() {
550 let route = get_table_schema_route();
551 assert_eq!(route.attr.name.as_ref(), "get_table_schema");
552 let props = route.attr.input_schema.get("properties").and_then(|v| v.as_object());
553 assert!(
554 props.is_some_and(|p| p.contains_key("database_name") && p.contains_key("table_name")),
555 "schema should have database_name and table_name properties"
556 );
557 }
558
559 #[test]
560 fn get_table_schema_with_relations_route_has_correct_name() {
561 let route = get_table_schema_with_relations_route();
562 assert_eq!(route.attr.name.as_ref(), "get_table_schema_with_relations");
563 let props = route.attr.input_schema.get("properties").and_then(|v| v.as_object());
564 assert!(
565 props.is_some_and(|p| p.contains_key("database_name") && p.contains_key("table_name")),
566 "schema should have database_name and table_name properties"
567 );
568 }
569
570 #[test]
571 fn read_query_route_has_correct_name_and_schema() {
572 let route = read_query_route();
573 assert_eq!(route.attr.name.as_ref(), "read_query");
574 assert!(
575 route
576 .attr
577 .description
578 .as_deref()
579 .is_some_and(|d| d.contains("read-only")),
580 "description should mention read-only"
581 );
582 let props = route.attr.input_schema.get("properties").and_then(|v| v.as_object());
583 assert!(
584 props.is_some_and(|p| p.contains_key("sql_query") && p.contains_key("database_name")),
585 "schema should have sql_query and database_name properties"
586 );
587 }
588
589 #[test]
590 fn write_query_route_has_correct_name_and_schema() {
591 let route = write_query_route();
592 assert_eq!(route.attr.name.as_ref(), "write_query");
593 assert!(
594 route.attr.description.as_deref().is_some_and(|d| d.contains("write")),
595 "description should mention write"
596 );
597 let props = route.attr.input_schema.get("properties").and_then(|v| v.as_object());
598 assert!(
599 props.is_some_and(|p| p.contains_key("sql_query") && p.contains_key("database_name")),
600 "schema should have sql_query and database_name properties"
601 );
602 }
603
604 #[test]
605 fn create_database_route_has_correct_name_and_schema() {
606 let route = create_database_route();
607 assert_eq!(route.attr.name.as_ref(), "create_database");
608 assert!(
609 route.attr.description.as_deref().is_some_and(|d| d.contains("SQLite")),
610 "description should mention SQLite not supported"
611 );
612 let props = route.attr.input_schema.get("properties").and_then(|v| v.as_object());
613 assert!(
614 props.is_some_and(|p| p.contains_key("database_name")),
615 "schema should have database_name property"
616 );
617 }
618
619 #[test]
620 fn read_and_write_query_share_same_schema_shape() {
621 let read = read_query_route();
622 let write = write_query_route();
623 let read_props = read.attr.input_schema.get("properties").and_then(|v| v.as_object());
624 let write_props = write.attr.input_schema.get("properties").and_then(|v| v.as_object());
625 assert!(read_props.is_some());
626 assert_eq!(
627 read_props.map(|p| p.keys().collect::<std::collections::BTreeSet<_>>()),
628 write_props.map(|p| p.keys().collect::<std::collections::BTreeSet<_>>()),
629 "read_query and write_query should have the same input schema properties"
630 );
631 }
632
633 use crate::db::sqlite::SqliteBackend;
639
640 fn router_tool_names(backend: &Backend) -> Vec<String> {
641 Server::build_tool_router(backend)
642 .list_all()
643 .into_iter()
644 .map(|t| t.name.to_string())
645 .collect()
646 }
647
648 fn sqlite_backend(read_only: bool) -> Backend {
649 Backend::Sqlite(SqliteBackend::in_memory(read_only))
650 }
651
652 fn annotations(route: &ToolRoute<Server>) -> &ToolAnnotations {
656 route.attr.annotations.as_ref().expect("tool should have annotations")
657 }
658
659 #[test]
660 fn list_databases_annotations_are_read_only_closed_world() {
661 let route = list_databases_route();
662 let ann = annotations(&route);
663 assert_eq!(ann.read_only_hint, Some(true));
664 assert_eq!(ann.destructive_hint, Some(false));
665 assert_eq!(ann.idempotent_hint, Some(true));
666 assert_eq!(ann.open_world_hint, Some(false));
667 }
668
669 #[test]
670 fn list_tables_annotations_are_read_only_closed_world() {
671 let route = list_tables_route();
672 let ann = annotations(&route);
673 assert_eq!(ann.read_only_hint, Some(true));
674 assert_eq!(ann.destructive_hint, Some(false));
675 assert_eq!(ann.idempotent_hint, Some(true));
676 assert_eq!(ann.open_world_hint, Some(false));
677 }
678
679 #[test]
680 fn get_table_schema_annotations_are_read_only_closed_world() {
681 let route = get_table_schema_route();
682 let ann = annotations(&route);
683 assert_eq!(ann.read_only_hint, Some(true));
684 assert_eq!(ann.destructive_hint, Some(false));
685 assert_eq!(ann.idempotent_hint, Some(true));
686 assert_eq!(ann.open_world_hint, Some(false));
687 }
688
689 #[test]
690 fn get_table_schema_with_relations_annotations_are_read_only_closed_world() {
691 let route = get_table_schema_with_relations_route();
692 let ann = annotations(&route);
693 assert_eq!(ann.read_only_hint, Some(true));
694 assert_eq!(ann.destructive_hint, Some(false));
695 assert_eq!(ann.idempotent_hint, Some(true));
696 assert_eq!(ann.open_world_hint, Some(false));
697 }
698
699 #[test]
700 fn read_query_annotations_are_read_only_open_world() {
701 let route = read_query_route();
702 let ann = annotations(&route);
703 assert_eq!(ann.read_only_hint, Some(true));
704 assert_eq!(ann.destructive_hint, Some(false));
705 assert_eq!(ann.idempotent_hint, Some(true));
706 assert_eq!(ann.open_world_hint, Some(true));
707 }
708
709 #[test]
710 fn write_query_annotations_are_destructive_open_world() {
711 let route = write_query_route();
712 let ann = annotations(&route);
713 assert_eq!(ann.read_only_hint, Some(false));
714 assert_eq!(ann.destructive_hint, Some(true));
715 assert_eq!(ann.idempotent_hint, Some(false));
716 assert_eq!(ann.open_world_hint, Some(true));
717 }
718
719 #[test]
720 fn create_database_annotations_are_non_destructive_closed_world() {
721 let route = create_database_route();
722 let ann = annotations(&route);
723 assert_eq!(ann.read_only_hint, Some(false));
724 assert_eq!(ann.destructive_hint, Some(false));
725 assert_eq!(ann.idempotent_hint, Some(false));
726 assert_eq!(ann.open_world_hint, Some(false));
727 }
728
729 #[tokio::test]
730 async fn all_router_tools_have_annotations() {
731 let backend = sqlite_backend(false);
732 let tools = Server::build_tool_router(&backend).list_all();
733 for tool in &tools {
734 assert!(
735 tool.annotations.is_some(),
736 "tool '{}' should have annotations",
737 tool.name
738 );
739 }
740 }
741
742 #[tokio::test]
748 async fn router_sqlite_read_only_returns_4_tools() {
749 let names = router_tool_names(&sqlite_backend(true));
750 assert_eq!(names.len(), 4);
751 assert!(!names.contains(&"list_databases".to_string()));
752 assert!(names.contains(&"list_tables".to_string()));
753 assert!(names.contains(&"get_table_schema".to_string()));
754 assert!(names.contains(&"get_table_schema_with_relations".to_string()));
755 assert!(names.contains(&"read_query".to_string()));
756 }
757
758 #[tokio::test]
759 async fn router_sqlite_read_write_returns_5_tools() {
760 let names = router_tool_names(&sqlite_backend(false));
761 assert_eq!(names.len(), 5);
762 assert!(!names.contains(&"list_databases".to_string()));
763 assert!(names.contains(&"write_query".to_string()));
764 assert!(!names.contains(&"create_database".to_string()));
765 }
766}