1use tokio::net::{TcpListener, TcpStream};
2use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
3use serde_json::{json, Value};
4use tracing::{error, warn};
5use std::sync::Arc;
6
7use crate::config::Config;
8use crate::errors::{MCPError, Result as MCPResult};
9use crate::metrics;
10use crate::pool::ConnectionPool;
11use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
12use crate::actions;
13use once_cell::sync::Lazy;
14
15static TOOLS_LIST: Lazy<Value> = Lazy::new(|| {
16 let tools_json = include_str!("../tools.json");
17 let tools: Vec<Value> = serde_json::from_str(tools_json)
18 .expect("Failed to parse tools.json");
19 json!({ "tools": tools })
20});
21
22const BUFFER_CAPACITY: usize = 4096;
23const NEWLINE: &[u8] = b"\n";
24
25#[inline]
26#[cold]
27fn parse_error(msg: String) -> JsonRpcResponse {
28 let mcp_error = MCPError::ParseError(msg);
29 JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
30}
31
32#[inline]
33fn parse_request(line: &str) -> Result<JsonRpcRequest, String> {
34 let trimmed = line.trim();
35 if trimmed.is_empty() {
36 return Err("Empty request".to_string());
37 }
38 serde_json::from_str::<JsonRpcRequest>(trimmed)
39 .map_err(|e| e.to_string())
40}
41
42pub struct MCPServer {
43 config: Config,
44 pool: Arc<ConnectionPool>,
45}
46
47impl MCPServer {
48 pub fn new(config: Config, pool: Arc<ConnectionPool>) -> Self {
49 Self { config, pool }
50 }
51
52 pub async fn run_stdio(&self) -> MCPResult<()> {
54 let stdin = tokio::io::stdin();
55 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
56 let mut stdout = tokio::io::stdout();
57 let mut line = String::with_capacity(512);
58 let mut response_buf = Vec::with_capacity(65536);
59
60 loop {
61 line.clear();
62 match reader.read_line(&mut line).await {
63 Ok(0) => break,
64 Ok(_) => {
65 process_one_line(&line, &self.pool, &self.config, &mut response_buf, &mut stdout).await?;
66 }
67 Err(e) => {
68 error!("IO error: {}", e);
69 break;
70 }
71 }
72 }
73 Ok(())
74 }
75
76 pub async fn run(&self) -> MCPResult<()> {
77 let addr = format!("{}:{}", self.config.server.host, self.config.server.port);
78 let listener = TcpListener::bind(&addr).await?;
79
80 tracing::info!("MCP server listening on {}", addr);
81
82 loop {
83 let (socket, peer_addr) = listener.accept().await?;
84
85 if let Err(e) = socket.set_nodelay(true) {
86 warn!("Failed to set TCP_NODELAY: {}", e);
87 }
88
89 let pool = Arc::clone(&self.pool);
90 let config = self.config.clone();
91
92 tokio::spawn(async move {
93 if let Err(e) = handle_client(socket, pool, config).await {
94 error!("Client {} error: {}", peer_addr, e);
95 }
96 });
97 }
98 }
99}
100
101#[inline(never)]
102async fn handle_client(socket: TcpStream, pool: Arc<ConnectionPool>, config: Config) -> MCPResult<()> {
103 let (reader, mut writer) = socket.into_split();
104 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, reader);
105 let mut line = String::with_capacity(512);
106 let mut response_buf = Vec::with_capacity(65536);
107
108 loop {
109 line.clear();
110 match reader.read_line(&mut line).await {
111 Ok(0) => break,
112 Ok(_) => {
113 process_one_line(&line, &pool, &config, &mut response_buf, &mut writer).await?;
114 }
115 Err(e) => {
116 error!("IO error: {}", e);
117 break;
118 }
119 }
120 }
121
122 Ok(())
123}
124
125#[inline]
127async fn process_one_line<W: AsyncWriteExt + Unpin>(
128 line: &str,
129 pool: &Arc<ConnectionPool>,
130 config: &Config,
131 response_buf: &mut Vec<u8>,
132 writer: &mut W,
133) -> MCPResult<()> {
134 metrics::inc_requests();
135
136 let response = match parse_request(line) {
137 Ok(req) => match process_request(&req, pool, config).await {
138 Ok(result) => JsonRpcResponse::success(req.id, result),
139 Err(e) => {
140 metrics::inc_errors();
141 JsonRpcResponse::error(req.id, e.error_code(), e.to_string())
142 }
143 },
144 Err(e) => {
145 metrics::inc_errors();
146 parse_error(e)
147 }
148 };
149
150 response_buf.clear();
151 serde_json::to_writer(&mut *response_buf, &response)?;
152 response_buf.extend_from_slice(NEWLINE);
153
154 writer.write_all(response_buf).await?;
155 writer.flush().await?;
156 Ok(())
157}
158
159#[inline]
161pub async fn process_request(
162 req: &JsonRpcRequest,
163 pool: &Arc<ConnectionPool>,
164 config: &Config,
165) -> MCPResult<Value> {
166 match req.method.as_str() {
167 "initialize" => handle_initialize(req),
168 "tools/list" => handle_tools_list(),
169 "tools/call" => handle_tools_call(req, pool, config).await,
170 _ => Err(MCPError::MethodNotFound(req.method.clone())),
171 }
172}
173
174pub async fn process_request_http(
176 req: &JsonRpcRequest,
177 pool: &Arc<ConnectionPool>,
178 config: &Config,
179) -> JsonRpcResponse {
180 metrics::inc_requests();
181
182 let response = match process_request(req, pool, config).await {
183 Ok(result) => JsonRpcResponse::success(req.id.clone(), result),
184 Err(e) => {
185 metrics::inc_errors();
186 JsonRpcResponse::error(req.id.clone(), e.error_code(), e.to_string())
187 }
188 };
189
190 response
191}
192
193#[inline]
194fn handle_initialize(_req: &JsonRpcRequest) -> MCPResult<Value> {
195 Ok(json!({
196 "protocolVersion": "2024-11-05",
197 "capabilities": {
198 "tools": {
199 "listChanged": false
200 },
201 "resources": {
202 "subscribe": false,
203 "listChanged": false
204 },
205 "prompts": {
206 "listChanged": false
207 }
208 },
209 "serverInfo": {
210 "name": "mcp-postgres",
211 "version": env!("CARGO_PKG_VERSION")
212 }
213 }))
214}
215
216#[inline]
217fn handle_tools_list() -> MCPResult<Value> {
218 Ok((*TOOLS_LIST).clone())
219}
220
221async fn handle_tools_call(
222 req: &JsonRpcRequest,
223 pool: &Arc<ConnectionPool>,
224 config: &Config,
225) -> MCPResult<Value> {
226 let tool_name = req
227 .params
228 .as_ref()
229 .and_then(|p| p.get("name").and_then(|v| v.as_str()))
230 .ok_or_else(|| MCPError::InvalidParams("Missing 'name' parameter".into()))?;
231
232 let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
233
234 let write_tools: &[&str] = &[
236 "execute_insert", "execute_update", "execute_delete",
237 "async_execute_insert", "async_execute_update", "async_execute_delete",
238 "async_batch_insert", "async_batch_update", "async_batch_delete", "async_batch_insert_copy",
239 "create_table", "drop_table", "create_view", "drop_view", "alter_view", "create_schema", "drop_schema", "create_sequence", "drop_sequence", "alter_index", "backup_table", "create_index", "drop_index", "create_partition", "drop_partition",
240 "vacuum_analyze", "analyze_table", "reindex_table",
241 "reset_statistics", "truncate_table",
242 ];
243
244 if config.server.access_mode == crate::config::AccessMode::Restricted
245 && write_tools.contains(&tool_name)
246 {
247 return Err(MCPError::InvalidParams(format!(
248 "Operation '{tool_name}' is not allowed in restricted (read-only) mode"
249 )));
250 }
251
252 let no_db_tools: &[&str] = &["list_tables", "list_schemas", "show_constraints"];
254 if !no_db_tools.contains(&tool_name) {
255 let tool_exists = matches!(tool_name,
257 "describe_table" | "list_triggers" | "list_indexes" | "execute_query" | "execute_insert"
258 | "execute_update" | "execute_delete" | "explain_query"
259 | "async_execute_insert" | "async_execute_update" | "async_execute_delete"
260 | "async_batch_insert" | "async_batch_update" | "async_batch_delete" | "async_batch_insert_copy"
261 | "create_table" | "drop_table" | "create_view" | "drop_view" | "alter_view" | "create_schema" | "drop_schema" | "create_sequence" | "drop_sequence" | "alter_index" | "create_index" | "list_partitions" | "drop_index" | "create_partition" | "drop_partition"
262 | "get_table_stats" | "get_index_stats" | "show_database_size"
263 | "show_table_size" | "get_cache_hit_ratio"
264 | "list_connections" | "show_current_user"
265 | "show_running_queries" | "show_connection_summary"
266 | "vacuum_analyze" | "analyze_table" | "reindex_table"
267 | "get_pg_stat_statements" | "reset_statistics" | "truncate_table"
268 | "list_users" | "list_user_privileges" | "list_role_memberships"
269 | "list_database_privileges" | "show_session_info"
270 | "show_all_settings" | "get_setting" | "show_memory_settings"
271 | "show_performance_settings" | "show_log_settings"
272 | "show_replication_status" | "list_replication_slots"
273 | "list_standby_servers" | "show_wal_info" | "show_base_backup_progress"
274 | "show_active_transactions" | "show_locks" | "show_waiting_locks"
275 | "show_transaction_isolation" | "show_deadlocks"
276 | "show_autocommit_status" | "show_transaction_timeout"
277 | "analyze_db_health" | "list_unused_indexes" | "list_duplicate_indexes"
278 | "show_vacuum_progress" | "get_object_details"
279 );
280 if !tool_exists {
281 return Err(method_not_found(tool_name));
282 }
283 }
284
285 let client = pool.acquire().await?;
287
288 let result = match tool_name {
289 "list_tables" => actions::schema::list_tables(&client, &tool_args).await,
291 "describe_table" => actions::schema::describe_table(&client, &tool_args).await,
292 "list_indexes" => actions::schema::list_indexes(&client, &tool_args).await,
293 "list_schemas" => actions::schema::list_schemas(&client, &tool_args).await,
294 "show_constraints" => actions::schema::show_constraints(&client, &tool_args).await,
295 "list_triggers" => actions::schema::list_triggers(&client, &tool_args).await,
296 "create_table" => actions::schema::create_table(&client, &tool_args).await,
297 "drop_table" => actions::schema::drop_table(&client, &tool_args).await,
298 "create_view" => actions::schema::create_view(&client, &tool_args).await,
299 "drop_view" => actions::schema::drop_view(&client, &tool_args).await,
300 "alter_view" => actions::schema::alter_view(&client, &tool_args).await,
301 "create_schema" => actions::schema::create_schema(&client, &tool_args).await,
302 "drop_schema" => actions::schema::drop_schema(&client, &tool_args).await,
303 "create_sequence" => actions::schema::create_sequence(&client, &tool_args).await,
304 "drop_sequence" => actions::schema::drop_sequence(&client, &tool_args).await,
305 "alter_index" => actions::schema::alter_index(&client, &tool_args).await,
306 "list_partitions" => actions::schema::list_partitions(&client, &tool_args).await,
307 "backup_table" => actions::schema::backup_table(&client, &tool_args).await,
308 "create_index" => actions::schema::create_index(&client, &tool_args).await,
309 "drop_index" => actions::schema::drop_index(&client, &tool_args).await,
310 "create_partition" => actions::schema::create_partition(&client, &tool_args).await,
311 "drop_partition" => actions::schema::drop_partition(&client, &tool_args).await,
312 "execute_query" => actions::query::execute_query(&client, &tool_args).await,
314 "execute_insert" => actions::query::execute_insert(&client, &tool_args).await,
315 "execute_update" => actions::query::execute_update(&client, &tool_args).await,
316 "execute_delete" => actions::query::execute_delete(&client, &tool_args).await,
317 "async_execute_insert" => actions::query::async_execute_insert(&client, &tool_args).await,
318 "async_execute_update" => actions::query::async_execute_update(&client, &tool_args).await,
319 "async_execute_delete" => actions::query::async_execute_delete(&client, &tool_args).await,
320 "explain_query" => actions::query::explain_query(&client, &tool_args).await,
321 "async_batch_insert" => actions::batch::async_batch_insert(&client, &tool_args).await,
323 "async_batch_update" => actions::batch::async_batch_update(&client, &tool_args).await,
324 "async_batch_delete" => actions::batch::async_batch_delete(&client, &tool_args).await,
325 "async_batch_insert_copy" => actions::batch::async_batch_insert_copy(&client, &tool_args).await,
326 "get_table_stats" => actions::monitoring::get_table_stats(&client, &tool_args).await,
328 "get_index_stats" => actions::monitoring::get_index_stats(&client, &tool_args).await,
329 "show_database_size" => actions::monitoring::show_database_size(&client, &tool_args).await,
330 "show_table_size" => actions::monitoring::show_table_size(&client, &tool_args).await,
331 "get_cache_hit_ratio" => actions::monitoring::get_cache_hit_ratio(&client, &tool_args).await,
332 "list_connections" => actions::connections::list_connections(&client, &tool_args).await,
334 "show_current_user" => actions::connections::show_current_user(&client, &tool_args).await,
335 "show_running_queries" => actions::connections::show_running_queries(&client, &tool_args).await,
336 "show_connection_summary" => actions::connections::show_connection_summary(&client, &tool_args).await,
337 "vacuum_analyze" => actions::maintenance::vacuum_analyze(&client, &tool_args).await,
339 "analyze_table" => actions::maintenance::analyze_table(&client, &tool_args).await,
340 "reindex_table" => actions::maintenance::reindex_table(&client, &tool_args).await,
341 "get_pg_stat_statements" => actions::maintenance::get_pg_stat_statements(&client, &tool_args).await,
342 "reset_statistics" => actions::maintenance::reset_statistics(&client, &tool_args).await,
343 "truncate_table" => actions::maintenance::truncate_table(&client, &tool_args).await,
344 "list_users" => actions::security::list_users(&client, &tool_args).await,
346 "list_user_privileges" => actions::security::list_user_privileges(&client, &tool_args).await,
347 "list_role_memberships" => actions::security::list_role_memberships(&client, &tool_args).await,
348 "list_database_privileges" => actions::security::list_database_privileges(&client, &tool_args).await,
349 "show_session_info" => actions::security::show_session_info(&client, &tool_args).await,
350 "show_all_settings" => actions::config::show_all_settings(&client, &tool_args).await,
352 "get_setting" => actions::config::get_setting(&client, &tool_args).await,
353 "show_memory_settings" => actions::config::show_memory_settings(&client, &tool_args).await,
354 "show_performance_settings" => actions::config::show_performance_settings(&client, &tool_args).await,
355 "show_log_settings" => actions::config::show_log_settings(&client, &tool_args).await,
356 "show_replication_status" => actions::replication::show_replication_status(&client, &tool_args).await,
358 "list_replication_slots" => actions::replication::list_replication_slots(&client, &tool_args).await,
359 "list_standby_servers" => actions::replication::list_standby_servers(&client, &tool_args).await,
360 "show_wal_info" => actions::replication::show_wal_info(&client, &tool_args).await,
361 "show_base_backup_progress" => actions::replication::show_base_backup_progress(&client, &tool_args).await,
362 "show_active_transactions" => actions::transactions::show_active_transactions(&client, &tool_args).await,
364 "show_locks" => actions::transactions::show_locks(&client, &tool_args).await,
365 "show_waiting_locks" => actions::transactions::show_waiting_locks(&client, &tool_args).await,
366 "show_transaction_isolation" => actions::transactions::show_transaction_isolation(&client, &tool_args).await,
367 "show_deadlocks" => actions::transactions::show_deadlocks(&client, &tool_args).await,
368 "show_autocommit_status" => actions::transactions::show_autocommit_status(&client, &tool_args).await,
369 "show_transaction_timeout" => actions::transactions::show_transaction_timeout(&client, &tool_args).await,
370 "analyze_db_health" => actions::health::analyze_db_health(&client, &tool_args).await,
372 "list_unused_indexes" => actions::health::list_unused_indexes(&client, &tool_args).await,
373 "list_duplicate_indexes" => actions::health::list_duplicate_indexes(&client, &tool_args).await,
374 "show_vacuum_progress" => actions::health::show_vacuum_progress(&client, &tool_args).await,
375 "get_object_details" => actions::schema::get_object_details(&client, &tool_args).await,
377 tool => Err(method_not_found(tool)),
378 };
379
380 pool.release(client);
381 result
382}
383
384#[cold]
385fn method_not_found(name: &str) -> MCPError {
386 MCPError::MethodNotFound(name.to_string())
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_parse_valid_request() {
395 let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
396 let req = parse_request(line).unwrap();
397 assert_eq!(req.method, "initialize");
398 assert_eq!(req.id, Some(Value::Number(1.into())));
399 }
400
401 #[test]
402 fn test_parse_request_with_trailing_newline() {
403 let line = r#"{"jsonrpc":"2.0","method":"tools/list","id":2}"#;
404 let req = parse_request(line).unwrap();
405 assert_eq!(req.method, "tools/list");
406 }
407
408 #[test]
409 fn test_parse_request_with_whitespace() {
410 let line = " {\"jsonrpc\":\"2.0\",\"method\":\"ping\",\"id\":3} ";
411 let req = parse_request(line).unwrap();
412 assert_eq!(req.method, "ping");
413 }
414
415 #[test]
416 fn test_parse_empty_request() {
417 let err = parse_request("").unwrap_err();
418 assert_eq!(err, "Empty request");
419 }
420
421 #[test]
422 fn test_parse_whitespace_only() {
423 let err = parse_request(" \n ").unwrap_err();
424 assert_eq!(err, "Empty request");
425 }
426
427 #[test]
428 fn test_parse_invalid_json() {
429 let err = parse_request("{invalid}").unwrap_err();
430 assert!(!err.is_empty(), "Invalid JSON should produce an error message");
431 }
432
433 #[test]
434 fn test_parse_missing_method() {
435 let err = parse_request(r#"{"jsonrpc":"2.0","id":1}"#).unwrap_err();
436 assert!(err.contains("method"));
437 }
438
439 #[test]
440 fn test_parse_wrong_version() {
441 let req = parse_request(r#"{"jsonrpc":"1.0","method":"init","id":1}"#).unwrap();
442 assert_eq!(req.jsonrpc, "1.0");
443 }
444
445 #[test]
446 fn test_method_not_found_error() {
447 let err = method_not_found("test_tool");
448 assert_eq!(err.error_code(), -32601);
449 assert!(err.to_string().contains("test_tool"));
450 }
451
452 #[test]
453 fn test_tools_list_static() {
454 let list = &*TOOLS_LIST;
455 let tools = list.get("tools").and_then(|v| v.as_array());
456 assert!(tools.is_some(), "TOOLS_LIST should contain a tools array");
457 assert!(!tools.unwrap().is_empty(), "Tools list should not be empty");
458 }
459
460 #[test]
461 fn test_process_request_method_dispatch() {
462 let _req = JsonRpcRequest {
465 jsonrpc: "2.0".to_string(),
466 method: "nonexistent".to_string(),
467 params: None,
468 id: Some(Value::Number(1.into())),
469 };
470 }
473
474 #[test]
475 fn test_handle_initialize_response() {
476 let req = JsonRpcRequest {
477 jsonrpc: "2.0".to_string(),
478 method: "initialize".to_string(),
479 params: None,
480 id: Some(Value::Number(1.into())),
481 };
482 let result = handle_initialize(&req).unwrap();
483 assert_eq!(result["protocolVersion"], "2024-11-05");
484 assert!(result["capabilities"]["tools"]["listChanged"].is_boolean());
485 assert_eq!(result["serverInfo"]["version"], env!("CARGO_PKG_VERSION"));
486 }
487}