1use tokio::net::{TcpListener, TcpStream};
2use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
3use serde_json::{json, Value};
4use tracing::{error, warn};
5use std::sync::Arc;
6
7use crate::config::Config;
8use crate::errors::{MCPError, Result as MCPResult};
9use crate::metrics;
10use crate::pool::ConnectionPool;
11use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
12use crate::actions;
13use once_cell::sync::Lazy;
14
15static TOOLS_LIST: Lazy<Value> = Lazy::new(|| {
16 let tools_json = include_str!("../tools.json");
17 let tools: Vec<Value> = serde_json::from_str(tools_json)
18 .expect("Failed to parse tools.json");
19 json!({ "tools": tools })
20});
21
22const BUFFER_CAPACITY: usize = 16384;
23const SOCKET_BUFFER_SIZE: libc::c_int = 4 * 1024; const NEWLINE: &[u8] = b"\n";
25
26#[inline]
27#[cold]
28fn parse_error(msg: String) -> JsonRpcResponse {
29 let mcp_error = MCPError::ParseError(msg);
30 JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
31}
32
33#[inline]
34fn parse_request(line: &str) -> Result<JsonRpcRequest, String> {
35 let trimmed = line.trim();
36 if trimmed.is_empty() {
37 return Err("Empty request".to_string());
38 }
39 serde_json::from_str::<JsonRpcRequest>(trimmed)
40 .map_err(|e| e.to_string())
41}
42
43pub struct MCPServer {
44 config: Config,
45 pool: Arc<ConnectionPool>,
46}
47
48impl MCPServer {
49 pub fn new(config: Config, pool: Arc<ConnectionPool>) -> Self {
50 Self { config, pool }
51 }
52
53 pub async fn run_stdio(&self) -> MCPResult<()> {
55 let stdin = tokio::io::stdin();
56 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
57 let mut stdout = tokio::io::stdout();
58 let mut line = String::with_capacity(512);
59 let mut response_buf = Vec::with_capacity(65536);
60
61 loop {
62 line.clear();
63 match reader.read_line(&mut line).await {
64 Ok(0) => break,
65 Ok(_) => {
66 process_one_line(&line, &self.pool, &self.config, &mut response_buf, &mut stdout).await?;
67 }
68 Err(e) => {
69 error!("IO error: {}", e);
70 break;
71 }
72 }
73 }
74 Ok(())
75 }
76
77 pub async fn run(&self) -> MCPResult<()> {
78 let addr = format!("{}:{}", self.config.server.host, self.config.server.port);
79 let listener = TcpListener::bind(&addr).await?;
80
81 tracing::info!("MCP server listening on {}", addr);
82
83 loop {
84 let (socket, peer_addr) = listener.accept().await?;
85
86 if let Err(e) = socket.set_nodelay(true) {
87 warn!("Failed to set TCP_NODELAY: {}", e);
88 }
89 use std::os::unix::io::AsRawFd;
91 let raw = socket.as_raw_fd();
92 let on: libc::c_int = 1;
93 unsafe {
94 libc::setsockopt(raw, libc::SOL_SOCKET, libc::SO_KEEPALIVE, &on as *const _ as *const libc::c_void, std::mem::size_of_val(&on) as libc::socklen_t);
95 libc::setsockopt(raw, libc::SOL_SOCKET, libc::SO_RCVBUF, &SOCKET_BUFFER_SIZE as *const _ as *const libc::c_void, std::mem::size_of_val(&SOCKET_BUFFER_SIZE) as libc::socklen_t);
96 libc::setsockopt(raw, libc::SOL_SOCKET, libc::SO_SNDBUF, &SOCKET_BUFFER_SIZE as *const _ as *const libc::c_void, std::mem::size_of_val(&SOCKET_BUFFER_SIZE) as libc::socklen_t);
97 }
98
99 let pool = Arc::clone(&self.pool);
100 let config = self.config.clone();
101
102 tokio::spawn(async move {
103 if let Err(e) = handle_client(socket, pool, config).await {
104 error!("Client {} error: {}", peer_addr, e);
105 }
106 });
107 }
108 }
109}
110
111#[inline(never)]
112async fn handle_client(socket: TcpStream, pool: Arc<ConnectionPool>, config: Config) -> MCPResult<()> {
113 let (reader, mut writer) = socket.into_split();
114 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, reader);
115 let mut line = String::with_capacity(512);
116 let mut response_buf = Vec::with_capacity(65536);
117
118 loop {
119 line.clear();
120 match reader.read_line(&mut line).await {
121 Ok(0) => break,
122 Ok(_) => {
123 process_one_line(&line, &pool, &config, &mut response_buf, &mut writer).await?;
124 }
125 Err(e) => {
126 error!("IO error: {}", e);
127 break;
128 }
129 }
130 }
131
132 Ok(())
133}
134
135#[inline]
137async fn process_one_line<W: AsyncWriteExt + Unpin>(
138 line: &str,
139 pool: &Arc<ConnectionPool>,
140 config: &Config,
141 response_buf: &mut Vec<u8>,
142 writer: &mut W,
143) -> MCPResult<()> {
144 metrics::inc_requests();
145
146 let response = match parse_request(line) {
147 Ok(req) => match process_request(&req, pool, config).await {
148 Ok(result) => JsonRpcResponse::success(req.id, result),
149 Err(e) => {
150 metrics::inc_errors();
151 JsonRpcResponse::error(req.id, e.error_code(), e.to_string())
152 }
153 },
154 Err(e) => {
155 metrics::inc_errors();
156 parse_error(e)
157 }
158 };
159
160 response_buf.clear();
161 serde_json::to_writer(&mut *response_buf, &response)?;
162 response_buf.extend_from_slice(NEWLINE);
163
164 writer.write_all(response_buf).await?;
165 writer.flush().await?;
166 Ok(())
167}
168
169#[inline]
170async fn process_request(
171 req: &JsonRpcRequest,
172 pool: &Arc<ConnectionPool>,
173 config: &Config,
174) -> MCPResult<Value> {
175 match req.method.as_str() {
176 "initialize" => handle_initialize(req),
177 "tools/list" => handle_tools_list(),
178 "tools/call" => handle_tools_call(req, pool, config).await,
179 _ => Err(MCPError::MethodNotFound(req.method.clone())),
180 }
181}
182
183#[inline]
184fn handle_initialize(_req: &JsonRpcRequest) -> MCPResult<Value> {
185 Ok(json!({
186 "protocolVersion": "2024-11-05",
187 "capabilities": {
188 "tools": {
189 "listChanged": false
190 },
191 "resources": {
192 "subscribe": false,
193 "listChanged": false
194 },
195 "prompts": {
196 "listChanged": false
197 }
198 },
199 "serverInfo": {
200 "name": "mcp-postgres",
201 "version": env!("CARGO_PKG_VERSION")
202 }
203 }))
204}
205
206#[inline]
207fn handle_tools_list() -> MCPResult<Value> {
208 Ok((*TOOLS_LIST).clone())
209}
210
211async fn handle_tools_call(
212 req: &JsonRpcRequest,
213 pool: &Arc<ConnectionPool>,
214 config: &Config,
215) -> MCPResult<Value> {
216 let tool_name = req
217 .params
218 .as_ref()
219 .and_then(|p| p.get("name").and_then(|v| v.as_str()))
220 .ok_or_else(|| MCPError::InvalidParams("Missing 'name' parameter".into()))?;
221
222 let tool_args = req.params.as_ref().and_then(|p| p.get("arguments").cloned());
223
224 let write_tools: &[&str] = &[
226 "execute_insert", "execute_update", "execute_delete",
227 "batch_insert", "batch_update", "batch_delete", "batch_insert_copy",
228 "vacuum_analyze", "analyze_table", "reindex_table",
229 "reset_statistics", "kill_connection",
230 "begin_transaction", "commit_transaction", "rollback_transaction",
231 ];
232
233 if config.server.access_mode == crate::config::AccessMode::Restricted
234 && write_tools.contains(&tool_name)
235 {
236 return Err(MCPError::InvalidParams(format!(
237 "Operation '{tool_name}' is not allowed in restricted (read-only) mode"
238 )));
239 }
240
241 let no_db_tools: &[&str] = &["list_tables", "list_schemas", "show_constraints"];
243 if !no_db_tools.contains(&tool_name) {
244 let tool_exists = matches!(tool_name,
246 "describe_table" | "list_indexes" | "execute_query" | "execute_insert"
247 | "execute_update" | "execute_delete" | "explain_query"
248 | "batch_insert" | "batch_update" | "batch_delete" | "batch_insert_copy"
249 | "get_table_stats" | "get_index_stats" | "show_database_size"
250 | "show_table_size" | "get_cache_hit_ratio"
251 | "list_connections" | "kill_connection" | "show_current_user"
252 | "show_running_queries" | "show_connection_summary"
253 | "vacuum_analyze" | "analyze_table" | "reindex_table"
254 | "get_pg_stat_statements" | "reset_statistics"
255 | "list_users" | "list_user_privileges" | "list_role_memberships"
256 | "list_database_privileges" | "show_session_info"
257 | "show_all_settings" | "get_setting" | "show_memory_settings"
258 | "show_performance_settings" | "show_log_settings"
259 | "show_replication_status" | "list_replication_slots"
260 | "list_standby_servers" | "show_wal_info" | "show_base_backup_progress"
261 | "show_active_transactions" | "show_locks" | "show_waiting_locks"
262 | "begin_transaction" | "commit_transaction" | "rollback_transaction"
263 | "show_transaction_isolation" | "show_deadlocks"
264 | "show_autocommit_status" | "show_transaction_timeout"
265 | "analyze_db_health" | "list_unused_indexes" | "list_duplicate_indexes"
266 | "show_vacuum_progress" | "get_object_details"
267 );
268 if !tool_exists {
269 return Err(method_not_found(tool_name));
270 }
271 }
272
273 let client = pool.acquire().await?;
275
276 let result = match tool_name {
277 "list_tables" => actions::schema::list_tables(&client, &tool_args).await,
279 "describe_table" => actions::schema::describe_table(&client, &tool_args).await,
280 "list_indexes" => actions::schema::list_indexes(&client, &tool_args).await,
281 "list_schemas" => actions::schema::list_schemas(&client, &tool_args).await,
282 "show_constraints" => actions::schema::show_constraints(&client, &tool_args).await,
283 "execute_query" => actions::query::execute_query(&client, &tool_args).await,
285 "execute_insert" => actions::query::execute_insert(&client, &tool_args).await,
286 "execute_update" => actions::query::execute_update(&client, &tool_args).await,
287 "execute_delete" => actions::query::execute_delete(&client, &tool_args).await,
288 "explain_query" => actions::query::explain_query(&client, &tool_args).await,
289 "batch_insert" => actions::batch::batch_insert(&client, &tool_args).await,
291 "batch_update" => actions::batch::batch_update(&client, &tool_args).await,
292 "batch_delete" => actions::batch::batch_delete(&client, &tool_args).await,
293 "batch_insert_copy" => actions::batch::batch_insert_copy(&client, &tool_args).await,
294 "get_table_stats" => actions::monitoring::get_table_stats(&client, &tool_args).await,
296 "get_index_stats" => actions::monitoring::get_index_stats(&client, &tool_args).await,
297 "show_database_size" => actions::monitoring::show_database_size(&client, &tool_args).await,
298 "show_table_size" => actions::monitoring::show_table_size(&client, &tool_args).await,
299 "get_cache_hit_ratio" => actions::monitoring::get_cache_hit_ratio(&client, &tool_args).await,
300 "list_connections" => actions::connections::list_connections(&client, &tool_args).await,
302 "kill_connection" => actions::connections::kill_connection(&client, &tool_args).await,
303 "show_current_user" => actions::connections::show_current_user(&client, &tool_args).await,
304 "show_running_queries" => actions::connections::show_running_queries(&client, &tool_args).await,
305 "show_connection_summary" => actions::connections::show_connection_summary(&client, &tool_args).await,
306 "vacuum_analyze" => actions::maintenance::vacuum_analyze(&client, &tool_args).await,
308 "analyze_table" => actions::maintenance::analyze_table(&client, &tool_args).await,
309 "reindex_table" => actions::maintenance::reindex_table(&client, &tool_args).await,
310 "get_pg_stat_statements" => actions::maintenance::get_pg_stat_statements(&client, &tool_args).await,
311 "reset_statistics" => actions::maintenance::reset_statistics(&client, &tool_args).await,
312 "list_users" => actions::security::list_users(&client, &tool_args).await,
314 "list_user_privileges" => actions::security::list_user_privileges(&client, &tool_args).await,
315 "list_role_memberships" => actions::security::list_role_memberships(&client, &tool_args).await,
316 "list_database_privileges" => actions::security::list_database_privileges(&client, &tool_args).await,
317 "show_session_info" => actions::security::show_session_info(&client, &tool_args).await,
318 "show_all_settings" => actions::config::show_all_settings(&client, &tool_args).await,
320 "get_setting" => actions::config::get_setting(&client, &tool_args).await,
321 "show_memory_settings" => actions::config::show_memory_settings(&client, &tool_args).await,
322 "show_performance_settings" => actions::config::show_performance_settings(&client, &tool_args).await,
323 "show_log_settings" => actions::config::show_log_settings(&client, &tool_args).await,
324 "show_replication_status" => actions::replication::show_replication_status(&client, &tool_args).await,
326 "list_replication_slots" => actions::replication::list_replication_slots(&client, &tool_args).await,
327 "list_standby_servers" => actions::replication::list_standby_servers(&client, &tool_args).await,
328 "show_wal_info" => actions::replication::show_wal_info(&client, &tool_args).await,
329 "show_base_backup_progress" => actions::replication::show_base_backup_progress(&client, &tool_args).await,
330 "show_active_transactions" => actions::transactions::show_active_transactions(&client, &tool_args).await,
332 "show_locks" => actions::transactions::show_locks(&client, &tool_args).await,
333 "show_waiting_locks" => actions::transactions::show_waiting_locks(&client, &tool_args).await,
334 "begin_transaction" => actions::transactions::begin_transaction(&client, &tool_args).await,
335 "commit_transaction" => actions::transactions::commit_transaction(&client, &tool_args).await,
336 "rollback_transaction" => actions::transactions::rollback_transaction(&client, &tool_args).await,
337 "show_transaction_isolation" => actions::transactions::show_transaction_isolation(&client, &tool_args).await,
338 "show_deadlocks" => actions::transactions::show_deadlocks(&client, &tool_args).await,
339 "show_autocommit_status" => actions::transactions::show_autocommit_status(&client, &tool_args).await,
340 "show_transaction_timeout" => actions::transactions::show_transaction_timeout(&client, &tool_args).await,
341 "analyze_db_health" => actions::health::analyze_db_health(&client, &tool_args).await,
343 "list_unused_indexes" => actions::health::list_unused_indexes(&client, &tool_args).await,
344 "list_duplicate_indexes" => actions::health::list_duplicate_indexes(&client, &tool_args).await,
345 "show_vacuum_progress" => actions::health::show_vacuum_progress(&client, &tool_args).await,
346 "get_object_details" => actions::schema::get_object_details(&client, &tool_args).await,
348 tool => Err(method_not_found(tool)),
349 };
350
351 pool.release(client);
352 result
353}
354
355#[cold]
356fn method_not_found(name: &str) -> MCPError {
357 MCPError::MethodNotFound(name.to_string())
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_parse_valid_request() {
366 let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
367 let req = parse_request(line).unwrap();
368 assert_eq!(req.method, "initialize");
369 assert_eq!(req.id, Some(Value::Number(1.into())));
370 }
371
372 #[test]
373 fn test_parse_request_with_trailing_newline() {
374 let line = r#"{"jsonrpc":"2.0","method":"tools/list","id":2}"#;
375 let req = parse_request(line).unwrap();
376 assert_eq!(req.method, "tools/list");
377 }
378
379 #[test]
380 fn test_parse_request_with_whitespace() {
381 let line = " {\"jsonrpc\":\"2.0\",\"method\":\"ping\",\"id\":3} ";
382 let req = parse_request(line).unwrap();
383 assert_eq!(req.method, "ping");
384 }
385
386 #[test]
387 fn test_parse_empty_request() {
388 let err = parse_request("").unwrap_err();
389 assert_eq!(err, "Empty request");
390 }
391
392 #[test]
393 fn test_parse_whitespace_only() {
394 let err = parse_request(" \n ").unwrap_err();
395 assert_eq!(err, "Empty request");
396 }
397
398 #[test]
399 fn test_parse_invalid_json() {
400 let err = parse_request("{invalid}").unwrap_err();
401 assert!(!err.is_empty(), "Invalid JSON should produce an error message");
402 }
403
404 #[test]
405 fn test_parse_missing_method() {
406 let err = parse_request(r#"{"jsonrpc":"2.0","id":1}"#).unwrap_err();
407 assert!(err.contains("method"));
408 }
409
410 #[test]
411 fn test_parse_wrong_version() {
412 let req = parse_request(r#"{"jsonrpc":"1.0","method":"init","id":1}"#).unwrap();
413 assert_eq!(req.jsonrpc, "1.0");
414 }
415
416 #[test]
417 fn test_method_not_found_error() {
418 let err = method_not_found("test_tool");
419 assert_eq!(err.error_code(), -32601);
420 assert!(err.to_string().contains("test_tool"));
421 }
422
423 #[test]
424 fn test_tools_list_static() {
425 let list = &*TOOLS_LIST;
426 let tools = list.get("tools").and_then(|v| v.as_array());
427 assert!(tools.is_some(), "TOOLS_LIST should contain a tools array");
428 assert!(!tools.unwrap().is_empty(), "Tools list should not be empty");
429 }
430
431 #[test]
432 fn test_process_request_method_dispatch() {
433 let _req = JsonRpcRequest {
436 jsonrpc: "2.0".to_string(),
437 method: "nonexistent".to_string(),
438 params: None,
439 id: Some(Value::Number(1.into())),
440 };
441 }
444
445 #[test]
446 fn test_handle_initialize_response() {
447 let req = JsonRpcRequest {
448 jsonrpc: "2.0".to_string(),
449 method: "initialize".to_string(),
450 params: None,
451 id: Some(Value::Number(1.into())),
452 };
453 let result = handle_initialize(&req).unwrap();
454 assert_eq!(result["protocolVersion"], "2024-11-05");
455 assert!(result["capabilities"]["tools"]["listChanged"].is_boolean());
456 assert_eq!(result["serverInfo"]["version"], env!("CARGO_PKG_VERSION"));
457 }
458}