1use serde_json::{Value, json};
2use std::sync::Arc;
3use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
4use tokio::net::{TcpListener, TcpStream};
5use tracing::{error, warn};
6
7use crate::actions;
8use crate::config::Config;
9use crate::errors::{MCPError, Result as MCPResult};
10use crate::metrics;
11use crate::pool::ConnectionPool;
12use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
13use once_cell::sync::Lazy;
14
15static TOOLS_LIST_RESPONSE: Lazy<Vec<u8>> = Lazy::new(|| {
19 let tools_json = include_str!("../tools.json");
20 let tools: Vec<Value> = serde_json::from_str(tools_json).expect("Failed to parse tools.json");
21 let resp = json!({ "tools": tools });
22 serde_json::to_vec(&resp).expect("Failed to serialize tools/list response")
23});
24
25const BUFFER_CAPACITY: usize = 4096;
26const NEWLINE: &[u8] = b"\n";
27
28const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
32
33const AUTH_HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
36
37async fn read_line_capped<R>(
45 reader: &mut R,
46 line: &mut String,
47 max_bytes: usize,
48) -> std::io::Result<usize>
49where
50 R: AsyncBufReadExt + Unpin,
51{
52 use std::io::{Error, ErrorKind};
53 line.clear();
54 let mut total: usize = 0;
55 loop {
56 let chunk = reader.fill_buf().await?;
57 if chunk.is_empty() {
58 break;
59 }
60 let (take, done) = match chunk.iter().position(|&b| b == b'\n') {
61 Some(i) => (i + 1, true),
62 None => (chunk.len(), false),
63 };
64 if total + take > max_bytes {
65 reader.consume(take);
66 return Err(Error::new(
67 ErrorKind::InvalidData,
68 "request line exceeds maximum length",
69 ));
70 }
71 let s = std::str::from_utf8(&chunk[..take])
72 .map_err(|_| Error::new(ErrorKind::InvalidData, "request line is not valid UTF-8"))?;
73 line.push_str(s);
74 total += take;
75 reader.consume(take);
76 if done {
77 break;
78 }
79 }
80 if line.is_empty() {
81 return Ok(0);
82 }
83 Ok(line.len())
84}
85
86#[inline]
87#[cold]
88fn parse_error(msg: String) -> JsonRpcResponse {
89 let mcp_error = MCPError::ParseError(msg);
90 JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
91}
92
93#[inline]
94fn parse_request(line: &str) -> Result<JsonRpcRequest, String> {
95 let trimmed = line.trim();
96 if trimmed.is_empty() {
97 return Err("Empty request".to_string());
98 }
99 serde_json::from_str::<JsonRpcRequest>(trimmed).map_err(|e| e.to_string())
100}
101
102pub struct MCPServer {
103 config: Arc<Config>,
104 pool: Arc<ConnectionPool>,
105}
106
107impl MCPServer {
108 pub fn new(config: Config, pool: Arc<ConnectionPool>) -> Self {
109 Self {
110 config: Arc::new(config),
111 pool,
112 }
113 }
114
115 pub async fn run_stdio(&self) -> MCPResult<()> {
117 let stdin = tokio::io::stdin();
118 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
119 let mut stdout = tokio::io::stdout();
120 let mut line = String::with_capacity(512);
121 let mut response_buf = Vec::with_capacity(4096);
125
126 loop {
127 line.clear();
128 match reader.read_line(&mut line).await {
129 Ok(0) => break,
130 Ok(_) => {
131 process_one_line(
132 &line,
133 &self.pool,
134 &self.config,
135 &mut response_buf,
136 &mut stdout,
137 )
138 .await?;
139 }
140 Err(e) => {
141 error!("IO error: {}", e);
142 break;
143 }
144 }
145 }
146 Ok(())
147 }
148
149 pub async fn run(&self) -> MCPResult<()> {
150 let addr = format!("{}:{}", self.config.server.host, self.config.server.port);
151 let listener = TcpListener::bind(&addr).await?;
152
153 tracing::info!("MCP server listening on {}", addr);
154
155 loop {
156 let (socket, peer_addr) = listener.accept().await?;
157
158 if let Err(e) = socket.set_nodelay(true) {
159 warn!("Failed to set TCP_NODELAY: {}", e);
160 }
161
162 let pool = Arc::clone(&self.pool);
163 let config = Arc::clone(&self.config);
164
165 tokio::spawn(async move {
166 if let Err(e) = handle_client(socket, pool, config).await {
167 error!("Client {} error: {}", peer_addr, e);
168 }
169 });
170 }
171 }
172}
173
174#[inline(never)]
175async fn handle_client(
176 socket: TcpStream,
177 pool: Arc<ConnectionPool>,
178 config: Arc<Config>,
179) -> MCPResult<()> {
180 let (reader, mut writer) = socket.into_split();
181 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, reader);
182 let mut line = String::with_capacity(512);
183 let mut response_buf = Vec::with_capacity(4096);
185
186 if let Some(ref token) = config.server.auth_token {
189 let read = tokio::time::timeout(
190 AUTH_HANDSHAKE_TIMEOUT,
191 read_line_capped(&mut reader, &mut line, MAX_REQUEST_BYTES),
192 )
193 .await;
194 match read {
195 Ok(Ok(0)) => return Ok(()),
196 Ok(Ok(_)) => {
197 if !crate::auth::verify_token(token, line.trim()) {
198 warn!("Authentication failed; closing connection");
199 let _ = writer
200 .write_all(
201 b"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32600,\
202 \"message\":\"Unauthorized\"},\"id\":null}\n",
203 )
204 .await;
205 let _ = writer.flush().await;
206 return Ok(());
207 }
208 }
209 Ok(Err(e)) => {
210 error!("IO error during auth: {}", e);
211 return Ok(());
212 }
213 Err(_) => {
214 warn!("Authentication handshake timed out; closing connection");
215 return Ok(());
216 }
217 }
218 }
219
220 loop {
221 match read_line_capped(&mut reader, &mut line, MAX_REQUEST_BYTES).await {
222 Ok(0) => break,
223 Ok(_) => {
224 process_one_line(&line, &pool, &config, &mut response_buf, &mut writer).await?;
225 }
226 Err(e) => {
227 error!("IO error: {}", e);
228 break;
229 }
230 }
231 }
232
233 Ok(())
234}
235
236#[inline]
239async fn process_one_line<W: AsyncWriteExt + Unpin>(
240 line: &str,
241 pool: &Arc<ConnectionPool>,
242 config: &Config,
243 response_buf: &mut Vec<u8>,
244 writer: &mut W,
245) -> MCPResult<()> {
246 metrics::inc_requests();
247
248 let (response, is_notification) = match parse_request(line) {
249 Ok(req) => {
250 if req.method == "tools/list" {
254 if let Some(id) = req.id.as_ref() {
255 response_buf.clear();
256 response_buf.extend_from_slice(b"{\"jsonrpc\":\"2.0\",\"result\":");
257 response_buf.extend_from_slice(&TOOLS_LIST_RESPONSE);
258 response_buf.extend_from_slice(b",\"id\":");
259 serde_json::to_writer(&mut *response_buf, id)?;
260 response_buf.extend_from_slice(b"}");
261 response_buf.extend_from_slice(NEWLINE);
262 writer.write_all(response_buf).await?;
263 writer.flush().await?;
264 maybe_shrink_buf(response_buf);
265 }
266 return Ok(());
268 }
269
270 let is_notif = req.id.is_none();
271 match process_request(&req, pool, config).await {
272 Ok(result) => (JsonRpcResponse::success(req.id, result), is_notif),
273 Err(e) => {
274 metrics::inc_errors();
275 (
276 JsonRpcResponse::error(req.id, e.error_code(), e.to_string()),
277 is_notif,
278 )
279 }
280 }
281 }
282 Err(e) => {
283 metrics::inc_errors();
284 (parse_error(e), false)
285 }
286 };
287
288 if is_notification {
290 return Ok(());
291 }
292
293 response_buf.clear();
294 serde_json::to_writer(&mut *response_buf, &response)?;
295 response_buf.extend_from_slice(NEWLINE);
296
297 writer.write_all(response_buf).await?;
298 writer.flush().await?;
299 maybe_shrink_buf(response_buf);
300 Ok(())
301}
302
303fn maybe_shrink_buf(buf: &mut Vec<u8>) {
307 if buf.capacity() > 65536 {
308 *buf = Vec::with_capacity(4096);
309 }
310}
311
312#[inline]
314pub async fn process_request(
315 req: &JsonRpcRequest,
316 pool: &Arc<ConnectionPool>,
317 config: &Config,
318) -> MCPResult<Value> {
319 match req.method.as_str() {
320 "initialize" => handle_initialize(req),
321 "tools/list" => handle_tools_list(),
322 "tools/call" => handle_tools_call(req, pool, config).await,
323 "ping" => handle_ping(),
324 method if method.starts_with("notifications/") => handle_notification(method),
325 _ => Err(MCPError::MethodNotFound(req.method.clone())),
326 }
327}
328
329#[inline]
331const fn handle_ping() -> MCPResult<Value> {
332 Ok(Value::Null)
333}
334
335#[inline]
337fn handle_notification(method: &str) -> MCPResult<Value> {
338 tracing::trace!("Received notification: {method}");
339 Ok(Value::Null)
340}
341
342pub async fn process_request_http(
344 req: &JsonRpcRequest,
345 pool: &Arc<ConnectionPool>,
346 config: &Config,
347) -> JsonRpcResponse {
348 metrics::inc_requests();
349
350 match process_request(req, pool, config).await {
351 Ok(result) => JsonRpcResponse::success(req.id.clone(), result),
352 Err(e) => {
353 metrics::inc_errors();
354 JsonRpcResponse::error(req.id.clone(), e.error_code(), e.to_string())
355 }
356 }
357}
358
359fn handle_initialize(_req: &JsonRpcRequest) -> MCPResult<Value> {
360 static INIT_RESPONSE: Lazy<Value> = Lazy::new(|| {
362 json!({
363 "protocolVersion": "2024-11-05",
364 "capabilities": {
365 "tools": { "listChanged": false },
366 "resources": { "subscribe": false, "listChanged": false },
367 "prompts": { "listChanged": false }
368 },
369 "serverInfo": {
370 "name": "mcp-postgres",
371 "version": env!("CARGO_PKG_VERSION")
372 }
373 })
374 });
375
376 Ok(INIT_RESPONSE.clone())
377}
378
379#[inline]
380fn handle_tools_list() -> MCPResult<Value> {
381 Ok(serde_json::from_slice(&TOOLS_LIST_RESPONSE)?)
383}
384
385async fn handle_tools_call(
386 req: &JsonRpcRequest,
387 pool: &Arc<ConnectionPool>,
388 config: &Config,
389) -> MCPResult<Value> {
390 let tool_name = req
391 .params
392 .as_ref()
393 .and_then(|p| p.get("name").and_then(|v| v.as_str()))
394 .ok_or_else(|| MCPError::InvalidParams("Missing 'name' parameter".into()))?;
395
396 let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
397
398 if config.server.access_mode == crate::config::AccessMode::Restricted
400 && crate::tools::is_write_tool(tool_name)
401 {
402 return Err(MCPError::InvalidParams(format!(
403 "Operation '{tool_name}' is not allowed in restricted (read-only) mode"
404 )));
405 }
406
407 if tool_name == "import_from_url" && !config.server.allow_url_import {
409 return Err(MCPError::InvalidParams(
410 "'import_from_url' is disabled; start the server with --allow-url-import to enable it"
411 .into(),
412 ));
413 }
414
415 if !crate::tools::tool_exists(tool_name) {
417 return Err(method_not_found(tool_name));
418 }
419
420 let client = pool.acquire().await?;
422
423 let result = match tool_name {
424 "list_tables" => actions::schema::list_tables(&client, &tool_args).await,
426 "describe_table" => actions::schema::describe_table(&client, &tool_args).await,
427 "list_indexes" => actions::schema::list_indexes(&client, &tool_args).await,
428 "list_schemas" => actions::schema::list_schemas(&client, &tool_args).await,
429 "show_constraints" => actions::schema::show_constraints(&client, &tool_args).await,
430 "list_triggers" => actions::schema::list_triggers(&client, &tool_args).await,
431 "create_table" => actions::schema::create_table(&client, &tool_args).await,
432 "drop_table" => actions::schema::drop_table(&client, &tool_args).await,
433 "create_view" => actions::schema::create_view(&client, &tool_args).await,
434 "drop_view" => actions::schema::drop_view(&client, &tool_args).await,
435 "alter_view" => actions::schema::alter_view(&client, &tool_args).await,
436 "create_schema" => actions::schema::create_schema(&client, &tool_args).await,
437 "drop_schema" => actions::schema::drop_schema(&client, &tool_args).await,
438 "create_sequence" => actions::schema::create_sequence(&client, &tool_args).await,
439 "drop_sequence" => actions::schema::drop_sequence(&client, &tool_args).await,
440 "alter_index" => actions::schema::alter_index(&client, &tool_args).await,
441 "list_partitions" => actions::schema::list_partitions(&client, &tool_args).await,
442 "backup_table" => actions::schema::backup_table(&client, &tool_args).await,
443 "create_index" => actions::schema::create_index(&client, &tool_args).await,
444 "drop_index" => actions::schema::drop_index(&client, &tool_args).await,
445 "create_partition" => actions::schema::create_partition(&client, &tool_args).await,
446 "drop_partition" => actions::schema::drop_partition(&client, &tool_args).await,
447 "execute_query" => actions::query::execute_query(&client, &tool_args).await,
449 "execute_insert" => actions::query::execute_insert(&client, &tool_args).await,
450 "execute_update" => actions::query::execute_update(&client, &tool_args).await,
451 "execute_delete" => actions::query::execute_delete(&client, &tool_args).await,
452 "async_execute_insert" => actions::query::async_execute_insert(&client, &tool_args).await,
453 "async_execute_update" => actions::query::async_execute_update(&client, &tool_args).await,
454 "async_execute_delete" => actions::query::async_execute_delete(&client, &tool_args).await,
455 "explain_query" => actions::query::explain_query(&client, &tool_args).await,
456 "async_batch_insert" => actions::batch::async_batch_insert(&client, &tool_args).await,
458 "async_batch_update" => actions::batch::async_batch_update(&client, &tool_args).await,
459 "async_batch_delete" => actions::batch::async_batch_delete(&client, &tool_args).await,
460 "async_batch_insert_copy" => {
461 actions::batch::async_batch_insert_copy(&client, &tool_args).await
462 }
463 "get_table_stats" => actions::monitoring::get_table_stats(&client, &tool_args).await,
465 "get_index_stats" => actions::monitoring::get_index_stats(&client, &tool_args).await,
466 "show_database_size" => actions::monitoring::show_database_size(&client, &tool_args).await,
467 "show_table_size" => actions::monitoring::show_table_size(&client, &tool_args).await,
468 "get_cache_hit_ratio" => {
469 actions::monitoring::get_cache_hit_ratio(&client, &tool_args).await
470 }
471 "list_connections" => actions::connections::list_connections(&client, &tool_args).await,
473 "show_current_user" => actions::connections::show_current_user(&client, &tool_args).await,
474 "show_running_queries" => {
475 actions::connections::show_running_queries(&client, &tool_args).await
476 }
477 "show_connection_summary" => {
478 actions::connections::show_connection_summary(&client, &tool_args).await
479 }
480 "vacuum_analyze" => actions::maintenance::vacuum_analyze(&client, &tool_args).await,
482 "analyze_table" => actions::maintenance::analyze_table(&client, &tool_args).await,
483 "reindex_table" => actions::maintenance::reindex_table(&client, &tool_args).await,
484 "get_pg_stat_statements" => {
485 actions::maintenance::get_pg_stat_statements(&client, &tool_args).await
486 }
487 "reset_statistics" => actions::maintenance::reset_statistics(&client, &tool_args).await,
488 "truncate_table" => actions::maintenance::truncate_table(&client, &tool_args).await,
489 "list_users" => actions::security::list_users(&client, &tool_args).await,
491 "list_user_privileges" => {
492 actions::security::list_user_privileges(&client, &tool_args).await
493 }
494 "list_role_memberships" => {
495 actions::security::list_role_memberships(&client, &tool_args).await
496 }
497 "list_database_privileges" => {
498 actions::security::list_database_privileges(&client, &tool_args).await
499 }
500 "show_session_info" => actions::security::show_session_info(&client, &tool_args).await,
501 "show_all_settings" => actions::config::show_all_settings(&client, &tool_args).await,
503 "get_setting" => actions::config::get_setting(&client, &tool_args).await,
504 "show_memory_settings" => actions::config::show_memory_settings(&client, &tool_args).await,
505 "show_performance_settings" => {
506 actions::config::show_performance_settings(&client, &tool_args).await
507 }
508 "show_log_settings" => actions::config::show_log_settings(&client, &tool_args).await,
509 "show_replication_status" => {
511 actions::replication::show_replication_status(&client, &tool_args).await
512 }
513 "list_replication_slots" => {
514 actions::replication::list_replication_slots(&client, &tool_args).await
515 }
516 "list_standby_servers" => {
517 actions::replication::list_standby_servers(&client, &tool_args).await
518 }
519 "show_wal_info" => actions::replication::show_wal_info(&client, &tool_args).await,
520 "show_base_backup_progress" => {
521 actions::replication::show_base_backup_progress(&client, &tool_args).await
522 }
523 "show_active_transactions" => {
525 actions::transactions::show_active_transactions(&client, &tool_args).await
526 }
527 "show_locks" => actions::transactions::show_locks(&client, &tool_args).await,
528 "show_waiting_locks" => {
529 actions::transactions::show_waiting_locks(&client, &tool_args).await
530 }
531 "show_transaction_isolation" => {
532 actions::transactions::show_transaction_isolation(&client, &tool_args).await
533 }
534 "show_deadlocks" => actions::transactions::show_deadlocks(&client, &tool_args).await,
535 "show_autocommit_status" => {
536 actions::transactions::show_autocommit_status(&client, &tool_args).await
537 }
538 "show_transaction_timeout" => {
539 actions::transactions::show_transaction_timeout(&client, &tool_args).await
540 }
541 "analyze_db_health" => actions::health::analyze_db_health(&client, &tool_args).await,
543 "list_unused_indexes" => actions::health::list_unused_indexes(&client, &tool_args).await,
544 "list_duplicate_indexes" => {
545 actions::health::list_duplicate_indexes(&client, &tool_args).await
546 }
547 "show_vacuum_progress" => actions::health::show_vacuum_progress(&client, &tool_args).await,
548 "get_object_details" => actions::schema::get_object_details(&client, &tool_args).await,
550 "create_user" => actions::user_mgmt::create_user(&client, &tool_args).await,
552 "alter_user" => actions::user_mgmt::alter_user(&client, &tool_args).await,
553 "drop_user" => actions::user_mgmt::drop_user(&client, &tool_args).await,
554 "create_role" => actions::user_mgmt::create_role(&client, &tool_args).await,
555 "alter_role" => actions::user_mgmt::alter_role(&client, &tool_args).await,
556 "drop_role" => actions::user_mgmt::drop_role(&client, &tool_args).await,
557 "grant_privileges" => actions::user_mgmt::grant_privileges(&client, &tool_args).await,
558 "revoke_privileges" => actions::user_mgmt::revoke_privileges(&client, &tool_args).await,
559 "add_column" => actions::schema_alter::add_column(&client, &tool_args).await,
561 "drop_column" => actions::schema_alter::drop_column(&client, &tool_args).await,
562 "rename_column" => actions::schema_alter::rename_column(&client, &tool_args).await,
563 "alter_column_type" => actions::schema_alter::alter_column_type(&client, &tool_args).await,
564 "rename_table" => actions::schema_alter::rename_table(&client, &tool_args).await,
565 "rename_index" => actions::schema_alter::rename_index(&client, &tool_args).await,
566 "rename_schema" => actions::schema_alter::rename_schema(&client, &tool_args).await,
567 "add_foreign_key" => actions::schema_alter::add_foreign_key(&client, &tool_args).await,
568 "drop_foreign_key" => actions::schema_alter::drop_foreign_key(&client, &tool_args).await,
569 "add_unique_constraint" => {
570 actions::schema_alter::add_unique_constraint(&client, &tool_args).await
571 }
572 "drop_constraint" => actions::schema_alter::drop_constraint(&client, &tool_args).await,
573 "cancel_query" => actions::session_mgmt::cancel_query(&client, &tool_args).await,
575 "terminate_connection" => {
576 actions::session_mgmt::terminate_connection(&client, &tool_args).await
577 }
578 "show_blocked_queries" => {
579 actions::session_mgmt::show_blocked_queries(&client, &tool_args).await
580 }
581 "list_extensions" => actions::ext_mgmt::list_extensions(&client, &tool_args).await,
583 "create_extension" => actions::ext_mgmt::create_extension(&client, &tool_args).await,
584 "drop_extension" => actions::ext_mgmt::drop_extension(&client, &tool_args).await,
585 "list_databases" => actions::db_mgmt::list_databases(&client, &tool_args).await,
587 "create_database" => actions::db_mgmt::create_database(&client, &tool_args).await,
588 "vacuum" => actions::maint_ext::vacuum(&client, &tool_args).await,
590 "vacuum_full" => actions::maint_ext::vacuum_full(&client, &tool_args).await,
591 "reindex_database" => actions::maint_ext::reindex_database(&client, &tool_args).await,
592 "generate_create_table_ddl" => {
594 actions::migration_helpers::generate_create_table_ddl(&client, &tool_args).await
595 }
596 "generate_create_index_ddl" => {
597 actions::migration_helpers::generate_create_index_ddl(&client, &tool_args).await
598 }
599 "table_dependencies" => {
600 actions::migration_helpers::table_dependencies(&client, &tool_args).await
601 }
602 "list_vector_columns" => actions::pgvector::list_vector_columns(&client, &tool_args).await,
604 "vector_search" => actions::pgvector::vector_search(&client, &tool_args).await,
605 "create_vector_index" => actions::pgvector::create_vector_index(&client, &tool_args).await,
606 "create_hypertable" => actions::timescaledb::create_hypertable(&client, &tool_args).await,
608 "show_hypertable_details" => {
609 actions::timescaledb::show_hypertable_details(&client, &tool_args).await
610 }
611 "show_chunks" => actions::timescaledb::show_chunks(&client, &tool_args).await,
612 "add_retention_policy" => {
613 actions::timescaledb::add_retention_policy(&client, &tool_args).await
614 }
615 "add_compression_policy" => {
616 actions::timescaledb::add_compression_policy(&client, &tool_args).await
617 }
618 "compress_chunk" => actions::timescaledb::compress_chunk(&client, &tool_args).await,
619 "add_continuous_aggregate" => {
620 actions::timescaledb::add_continuous_aggregate(&client, &tool_args).await
621 }
622 "list_bm25_indexes" => actions::pg_textsearch::list_bm25_indexes(&client, &tool_args).await,
624 "search_bm25" => actions::pg_textsearch::search_bm25(&client, &tool_args).await,
625 "create_bm25_index" => actions::pg_textsearch::create_bm25_index(&client, &tool_args).await,
626 "drop_bm25_index" => actions::pg_textsearch::drop_bm25_index(&client, &tool_args).await,
627 "bm25_force_merge" => actions::pg_textsearch::bm25_force_merge(&client, &tool_args).await,
628 "bm25_index_stats" => actions::pg_textsearch::bm25_index_stats(&client, &tool_args).await,
629 "import_from_url" => actions::data_io::import_from_url(&client, &tool_args).await,
631 "export_csv" => actions::data_io::export_csv(&client, &tool_args).await,
632 "suggest_indexes" => actions::index_advisor::suggest_indexes(&client, &tool_args).await,
634 "find_tables_without_pk" => {
636 actions::schema_health::find_tables_without_pk(&client, &tool_args).await
637 }
638 "find_missing_fk_indexes" => {
639 actions::schema_health::find_missing_fk_indexes(&client, &tool_args).await
640 }
641 "analyze_table_bloat" => {
642 actions::schema_health::analyze_table_bloat(&client, &tool_args).await
643 }
644 "clone_table_schema" => {
645 actions::schema_health::clone_table_schema(&client, &tool_args).await
646 }
647 "security_audit" => actions::security_audit::security_audit(&client, &tool_args).await,
649 "audit_role_usage" => actions::security_audit::audit_role_usage(&client, &tool_args).await,
650 "sample_data" => actions::data_tools::sample_data(&client, &tool_args).await,
652 tool => Err(method_not_found(tool)),
653 };
654
655 if let Err(ref e) = result {
656 error!("Tool '{}' error: {:?}", tool_name, e);
657 }
658 drop(client);
660 result
661}
662
663#[cold]
664fn method_not_found(name: &str) -> MCPError {
665 MCPError::MethodNotFound(name.to_string())
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671
672 #[test]
673 fn test_tools_list_splice_matches_generic() {
674 let id = Value::Number(7.into());
677 let result: Value = serde_json::from_slice(&TOOLS_LIST_RESPONSE).unwrap();
678 let generic =
679 serde_json::to_vec(&JsonRpcResponse::success(Some(id.clone()), result)).unwrap();
680
681 let mut spliced = Vec::new();
682 spliced.extend_from_slice(b"{\"jsonrpc\":\"2.0\",\"result\":");
683 spliced.extend_from_slice(&TOOLS_LIST_RESPONSE);
684 spliced.extend_from_slice(b",\"id\":");
685 serde_json::to_writer(&mut spliced, &id).unwrap();
686 spliced.extend_from_slice(b"}");
687
688 assert_eq!(spliced, generic);
689 }
690
691 #[tokio::test]
692 async fn test_read_line_capped_normal() {
693 let data = b"hello world\nsecond line\n";
694 let mut reader = BufReader::new(&data[..]);
695 let mut line = String::new();
696 let n = read_line_capped(&mut reader, &mut line, 1024)
697 .await
698 .unwrap();
699 assert_eq!(n, "hello world\n".len());
700 assert_eq!(line, "hello world\n");
701 }
702
703 #[tokio::test]
704 async fn test_read_line_capped_eof() {
705 let data = b"";
706 let mut reader = BufReader::new(&data[..]);
707 let mut line = String::new();
708 let n = read_line_capped(&mut reader, &mut line, 1024)
709 .await
710 .unwrap();
711 assert_eq!(n, 0);
712 }
713
714 #[tokio::test]
715 async fn test_read_line_capped_rejects_oversized() {
716 let data = vec![b'a'; 5000];
718 let mut reader = BufReader::new(&data[..]);
719 let mut line = String::new();
720 let err = read_line_capped(&mut reader, &mut line, 1024)
721 .await
722 .unwrap_err();
723 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
724 }
725
726 #[test]
727 fn test_parse_valid_request() {
728 let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
729 let req = parse_request(line).unwrap();
730 assert_eq!(req.method, "initialize");
731 assert_eq!(req.id, Some(Value::Number(1.into())));
732 }
733
734 #[test]
735 fn test_parse_request_with_trailing_newline() {
736 let line = r#"{"jsonrpc":"2.0","method":"tools/list","id":2}"#;
737 let req = parse_request(line).unwrap();
738 assert_eq!(req.method, "tools/list");
739 }
740
741 #[test]
742 fn test_parse_request_with_whitespace() {
743 let line = " {\"jsonrpc\":\"2.0\",\"method\":\"ping\",\"id\":3} ";
744 let req = parse_request(line).unwrap();
745 assert_eq!(req.method, "ping");
746 }
747
748 #[test]
749 fn test_parse_empty_request() {
750 let err = parse_request("").unwrap_err();
751 assert_eq!(err, "Empty request");
752 }
753
754 #[test]
755 fn test_parse_whitespace_only() {
756 let err = parse_request(" \n ").unwrap_err();
757 assert_eq!(err, "Empty request");
758 }
759
760 #[test]
761 fn test_parse_invalid_json() {
762 let err = parse_request("{invalid}").unwrap_err();
763 assert!(
764 !err.is_empty(),
765 "Invalid JSON should produce an error message"
766 );
767 }
768
769 #[test]
770 fn test_parse_missing_method() {
771 let err = parse_request(r#"{"jsonrpc":"2.0","id":1}"#).unwrap_err();
772 assert!(err.contains("method"));
773 }
774
775 #[test]
776 fn test_parse_wrong_version() {
777 let req = parse_request(r#"{"jsonrpc":"1.0","method":"init","id":1}"#).unwrap();
778 assert_eq!(req.jsonrpc, "1.0");
779 }
780
781 #[test]
782 fn test_method_not_found_error() {
783 let err = method_not_found("test_tool");
784 assert_eq!(err.error_code(), -32601);
785 assert!(err.to_string().contains("test_tool"));
786 }
787
788 #[test]
789 fn test_tools_list_static() {
790 let list: Value = serde_json::from_slice(&TOOLS_LIST_RESPONSE).unwrap();
791 let tools = list.get("tools").and_then(|v| v.as_array());
792 assert!(
793 tools.is_some(),
794 "TOOLS_LIST_RESPONSE should contain a tools array"
795 );
796 assert!(!tools.unwrap().is_empty(), "Tools list should not be empty");
797 }
798
799 #[test]
800 fn test_process_request_method_dispatch() {
801 let _req = JsonRpcRequest {
804 jsonrpc: "2.0".to_string(),
805 method: "nonexistent".to_string(),
806 params: None,
807 id: Some(Value::Number(1.into())),
808 };
809 }
812
813 #[test]
814 fn test_handle_initialize_response() {
815 let req = JsonRpcRequest {
816 jsonrpc: "2.0".to_string(),
817 method: "initialize".to_string(),
818 params: None,
819 id: Some(Value::Number(1.into())),
820 };
821 let result = handle_initialize(&req).unwrap();
822 assert_eq!(result["protocolVersion"], "2024-11-05");
823 assert!(result["capabilities"]["tools"]["listChanged"].is_boolean());
824 assert_eq!(result["serverInfo"]["version"], env!("CARGO_PKG_VERSION"));
825 }
826
827 #[test]
831 fn test_no_bare_set_outside_transaction() {
832 let source_files = &[
833 include_str!("../src/actions/query.rs"),
834 include_str!("../src/actions/batch.rs"),
835 ];
836 for (idx, source) in source_files.iter().enumerate() {
837 for (line_no, line) in source.lines().enumerate() {
838 let trimmed = line.trim();
839 if trimmed.starts_with("//")
841 || trimmed.starts_with("/*")
842 || trimmed.starts_with("*")
843 {
844 continue;
845 }
846 if trimmed.contains("UPDATE ") && trimmed.contains("SET ") {
847 continue;
848 }
849 if trimmed.contains("SET LOCAL") {
850 continue;
851 }
852 if trimmed.contains("client.execute(\"SET ") && !trimmed.contains("SET LOCAL") {
854 let names = ["query.rs", "batch.rs"];
855 panic!(
856 "Phase 1.5 violation: bare `SET` (not SET LOCAL) found in {}:{} — \
857 use BEGIN + SET LOCAL + COMMIT pattern to avoid session leakage.\n\
858 Line: {}",
859 names[idx],
860 line_no + 1,
861 trimmed
862 );
863 }
864 }
865 }
866 }
867}