Skip to main content

mcp_postgres/
server.rs

1use tokio::net::{TcpListener, TcpStream};
2use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
3use serde_json::{json, Value};
4use tracing::{error, warn};
5use std::sync::Arc;
6
7use crate::config::Config;
8use crate::errors::{MCPError, Result as MCPResult};
9use crate::metrics;
10use crate::pool::ConnectionPool;
11use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
12use crate::actions;
13use once_cell::sync::Lazy;
14
15static TOOLS_LIST: Lazy<Value> = Lazy::new(|| {
16    let tools_json = include_str!("../tools.json");
17    let tools: Vec<Value> = serde_json::from_str(tools_json)
18        .expect("Failed to parse tools.json");
19    json!({ "tools": tools })
20});
21
22const BUFFER_CAPACITY: usize = 16384;
23const NEWLINE: &[u8] = b"\n";
24
25#[inline]
26#[cold]
27fn parse_error(msg: String) -> JsonRpcResponse {
28    let mcp_error = MCPError::ParseError(msg);
29    JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
30}
31
32#[inline]
33fn parse_request(line: &str) -> Result<JsonRpcRequest, String> {
34    let trimmed = line.trim();
35    if trimmed.is_empty() {
36        return Err("Empty request".to_string());
37    }
38    serde_json::from_str::<JsonRpcRequest>(trimmed)
39        .map_err(|e| e.to_string())
40}
41
42pub struct MCPServer {
43    config: Config,
44    pool: Arc<ConnectionPool>,
45}
46
47impl MCPServer {
48    pub fn new(config: Config, pool: Arc<ConnectionPool>) -> Self {
49        Self { config, pool }
50    }
51
52    /// Run in stdio mode for MCP compatibility (Claude Desktop, etc.)
53    pub async fn run_stdio(&self) -> MCPResult<()> {
54        let stdin = tokio::io::stdin();
55        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
56        let mut stdout = tokio::io::stdout();
57        let mut line = String::with_capacity(512);
58        let mut response_buf = Vec::with_capacity(65536);
59
60        loop {
61            line.clear();
62            match reader.read_line(&mut line).await {
63                Ok(0) => break,
64                Ok(_) => {
65                    process_one_line(&line, &self.pool, &self.config, &mut response_buf, &mut stdout).await?;
66                }
67                Err(e) => {
68                    error!("IO error: {}", e);
69                    break;
70                }
71            }
72        }
73        Ok(())
74    }
75
76    pub async fn run(&self) -> MCPResult<()> {
77        let addr = format!("{}:{}", self.config.server.host, self.config.server.port);
78        let listener = TcpListener::bind(&addr).await?;
79
80        tracing::info!("MCP server listening on {}", addr);
81
82        loop {
83            let (socket, peer_addr) = listener.accept().await?;
84
85            if let Err(e) = socket.set_nodelay(true) {
86                warn!("Failed to set TCP_NODELAY: {}", e);
87            }
88            // Apply TCP socket options via raw fd (SO_KEEPALIVE, optimized buffer sizes)
89            use std::os::unix::io::AsRawFd;
90            let raw = socket.as_raw_fd();
91            let on: libc::c_int = 1;
92            unsafe {
93                libc::setsockopt(raw, libc::SOL_SOCKET, libc::SO_KEEPALIVE, &on as *const _ as *const libc::c_void, std::mem::size_of_val(&on) as libc::socklen_t);
94                // Optimized buffer sizes for JSON-RPC: 256KB instead of 4MB
95                // 4MB was excessive for typical JSON requests/responses
96                let buf_size: libc::c_int = 256 * 1024;
97                libc::setsockopt(raw, libc::SOL_SOCKET, libc::SO_RCVBUF, &buf_size as *const _ as *const libc::c_void, std::mem::size_of_val(&buf_size) as libc::socklen_t);
98                libc::setsockopt(raw, libc::SOL_SOCKET, libc::SO_SNDBUF, &buf_size as *const _ as *const libc::c_void, std::mem::size_of_val(&buf_size) as libc::socklen_t);
99            }
100
101            let pool = Arc::clone(&self.pool);
102            let config = self.config.clone();
103
104            tokio::spawn(async move {
105                if let Err(e) = handle_client(socket, pool, config).await {
106                    error!("Client {} error: {}", peer_addr, e);
107                }
108            });
109        }
110    }
111}
112
113#[inline(never)]
114async fn handle_client(socket: TcpStream, pool: Arc<ConnectionPool>, config: Config) -> MCPResult<()> {
115    let (reader, mut writer) = socket.into_split();
116    let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, reader);
117    let mut line = String::with_capacity(512);
118    let mut response_buf = Vec::with_capacity(65536);
119
120    loop {
121        line.clear();
122        match reader.read_line(&mut line).await {
123            Ok(0) => break,
124            Ok(_) => {
125                process_one_line(&line, &pool, &config, &mut response_buf, &mut writer).await?;
126            }
127            Err(e) => {
128                error!("IO error: {}", e);
129                break;
130            }
131        }
132    }
133
134    Ok(())
135}
136
137/// Core per-line processing shared by TCP and stdio transports.
138#[inline]
139async fn process_one_line<W: AsyncWriteExt + Unpin>(
140    line: &str,
141    pool: &Arc<ConnectionPool>,
142    config: &Config,
143    response_buf: &mut Vec<u8>,
144    writer: &mut W,
145) -> MCPResult<()> {
146    metrics::inc_requests();
147
148    let response = match parse_request(line) {
149        Ok(req) => match process_request(&req, pool, config).await {
150            Ok(result) => JsonRpcResponse::success(req.id, result),
151            Err(e) => {
152                metrics::inc_errors();
153                JsonRpcResponse::error(req.id, e.error_code(), e.to_string())
154            }
155        },
156        Err(e) => {
157            metrics::inc_errors();
158            parse_error(e)
159        }
160    };
161
162    response_buf.clear();
163    serde_json::to_writer(&mut *response_buf, &response)?;
164    response_buf.extend_from_slice(NEWLINE);
165
166    writer.write_all(response_buf).await?;
167    writer.flush().await?;
168    Ok(())
169}
170
171#[inline]
172async fn process_request(
173    req: &JsonRpcRequest,
174    pool: &Arc<ConnectionPool>,
175    config: &Config,
176) -> MCPResult<Value> {
177    match req.method.as_str() {
178        "initialize" => handle_initialize(req),
179        "tools/list" => handle_tools_list(),
180        "tools/call" => handle_tools_call(req, pool, config).await,
181        _ => Err(MCPError::MethodNotFound(req.method.clone())),
182    }
183}
184
185#[inline]
186fn handle_initialize(_req: &JsonRpcRequest) -> MCPResult<Value> {
187    Ok(json!({
188        "protocolVersion": "2024-11-05",
189        "capabilities": {
190            "tools": {
191                "listChanged": false
192            },
193            "resources": {
194                "subscribe": false,
195                "listChanged": false
196            },
197            "prompts": {
198                "listChanged": false
199            }
200        },
201        "serverInfo": {
202            "name": "mcp-postgres",
203            "version": env!("CARGO_PKG_VERSION")
204        }
205    }))
206}
207
208#[inline]
209fn handle_tools_list() -> MCPResult<Value> {
210    Ok((*TOOLS_LIST).clone())
211}
212
213async fn handle_tools_call(
214    req: &JsonRpcRequest,
215    pool: &Arc<ConnectionPool>,
216    config: &Config,
217) -> MCPResult<Value> {
218    let tool_name = req
219        .params
220        .as_ref()
221        .and_then(|p| p.get("name").and_then(|v| v.as_str()))
222        .ok_or_else(|| MCPError::InvalidParams("Missing 'name' parameter".into()))?;
223
224    let tool_args = req.params.as_ref().and_then(|p| p.get("arguments").cloned());
225
226    // Restricted mode check + unknown tool check BEFORE pool acquire
227    let write_tools: &[&str] = &[
228        "execute_insert", "execute_update", "execute_delete",
229        "batch_insert", "batch_update", "batch_delete", "batch_insert_copy",
230        "vacuum_analyze", "analyze_table", "reindex_table",
231        "reset_statistics", "kill_connection",
232        "begin_transaction", "commit_transaction", "rollback_transaction",
233    ];
234
235    if config.server.access_mode == crate::config::AccessMode::Restricted
236        && write_tools.contains(&tool_name)
237    {
238        return Err(MCPError::InvalidParams(format!(
239            "Operation '{tool_name}' is not allowed in restricted (read-only) mode"
240        )));
241    }
242
243    // Fast-path simple tools that don't need a DB connection
244    let no_db_tools: &[&str] = &["list_tables", "list_schemas", "show_constraints"];
245    if !no_db_tools.contains(&tool_name) {
246        // Verify tool exists before acquiring a connection
247        let tool_exists = matches!(tool_name,
248            "describe_table" | "list_indexes" | "execute_query" | "execute_insert"
249            | "execute_update" | "execute_delete" | "explain_query"
250            | "batch_insert" | "batch_update" | "batch_delete" | "batch_insert_copy"
251            | "get_table_stats" | "get_index_stats" | "show_database_size"
252            | "show_table_size" | "get_cache_hit_ratio"
253            | "list_connections" | "kill_connection" | "show_current_user"
254            | "show_running_queries" | "show_connection_summary"
255            | "vacuum_analyze" | "analyze_table" | "reindex_table"
256            | "get_pg_stat_statements" | "reset_statistics"
257            | "list_users" | "list_user_privileges" | "list_role_memberships"
258            | "list_database_privileges" | "show_session_info"
259            | "show_all_settings" | "get_setting" | "show_memory_settings"
260            | "show_performance_settings" | "show_log_settings"
261            | "show_replication_status" | "list_replication_slots"
262            | "list_standby_servers" | "show_wal_info" | "show_base_backup_progress"
263            | "show_active_transactions" | "show_locks" | "show_waiting_locks"
264            | "begin_transaction" | "commit_transaction" | "rollback_transaction"
265            | "show_transaction_isolation" | "show_deadlocks"
266            | "show_autocommit_status" | "show_transaction_timeout"
267            | "analyze_db_health" | "list_unused_indexes" | "list_duplicate_indexes"
268            | "show_vacuum_progress" | "get_object_details"
269        );
270        if !tool_exists {
271            return Err(method_not_found(tool_name));
272        }
273    }
274
275    // Acquire pool connection only for known tools
276    let client = pool.acquire().await?;
277
278    let result = match tool_name {
279        // Schema actions
280        "list_tables" => actions::schema::list_tables(&client, &tool_args).await,
281        "describe_table" => actions::schema::describe_table(&client, &tool_args).await,
282        "list_indexes" => actions::schema::list_indexes(&client, &tool_args).await,
283        "list_schemas" => actions::schema::list_schemas(&client, &tool_args).await,
284        "show_constraints" => actions::schema::show_constraints(&client, &tool_args).await,
285        // Query actions
286        "execute_query" => actions::query::execute_query(&client, &tool_args).await,
287        "execute_insert" => actions::query::execute_insert(&client, &tool_args).await,
288        "execute_update" => actions::query::execute_update(&client, &tool_args).await,
289        "execute_delete" => actions::query::execute_delete(&client, &tool_args).await,
290        "explain_query" => actions::query::explain_query(&client, &tool_args).await,
291        // Batch operations
292        "batch_insert" => actions::batch::batch_insert(&client, &tool_args).await,
293        "batch_update" => actions::batch::batch_update(&client, &tool_args).await,
294        "batch_delete" => actions::batch::batch_delete(&client, &tool_args).await,
295        "batch_insert_copy" => actions::batch::batch_insert_copy(&client, &tool_args).await,
296        // Monitoring actions
297        "get_table_stats" => actions::monitoring::get_table_stats(&client, &tool_args).await,
298        "get_index_stats" => actions::monitoring::get_index_stats(&client, &tool_args).await,
299        "show_database_size" => actions::monitoring::show_database_size(&client, &tool_args).await,
300        "show_table_size" => actions::monitoring::show_table_size(&client, &tool_args).await,
301        "get_cache_hit_ratio" => actions::monitoring::get_cache_hit_ratio(&client, &tool_args).await,
302        // Connection actions
303        "list_connections" => actions::connections::list_connections(&client, &tool_args).await,
304        "kill_connection" => actions::connections::kill_connection(&client, &tool_args).await,
305        "show_current_user" => actions::connections::show_current_user(&client, &tool_args).await,
306        "show_running_queries" => actions::connections::show_running_queries(&client, &tool_args).await,
307        "show_connection_summary" => actions::connections::show_connection_summary(&client, &tool_args).await,
308        // Maintenance actions
309        "vacuum_analyze" => actions::maintenance::vacuum_analyze(&client, &tool_args).await,
310        "analyze_table" => actions::maintenance::analyze_table(&client, &tool_args).await,
311        "reindex_table" => actions::maintenance::reindex_table(&client, &tool_args).await,
312        "get_pg_stat_statements" => actions::maintenance::get_pg_stat_statements(&client, &tool_args).await,
313        "reset_statistics" => actions::maintenance::reset_statistics(&client, &tool_args).await,
314        // Security actions
315        "list_users" => actions::security::list_users(&client, &tool_args).await,
316        "list_user_privileges" => actions::security::list_user_privileges(&client, &tool_args).await,
317        "list_role_memberships" => actions::security::list_role_memberships(&client, &tool_args).await,
318        "list_database_privileges" => actions::security::list_database_privileges(&client, &tool_args).await,
319        "show_session_info" => actions::security::show_session_info(&client, &tool_args).await,
320        // Config actions
321        "show_all_settings" => actions::config::show_all_settings(&client, &tool_args).await,
322        "get_setting" => actions::config::get_setting(&client, &tool_args).await,
323        "show_memory_settings" => actions::config::show_memory_settings(&client, &tool_args).await,
324        "show_performance_settings" => actions::config::show_performance_settings(&client, &tool_args).await,
325        "show_log_settings" => actions::config::show_log_settings(&client, &tool_args).await,
326        // Replication actions
327        "show_replication_status" => actions::replication::show_replication_status(&client, &tool_args).await,
328        "list_replication_slots" => actions::replication::list_replication_slots(&client, &tool_args).await,
329        "list_standby_servers" => actions::replication::list_standby_servers(&client, &tool_args).await,
330        "show_wal_info" => actions::replication::show_wal_info(&client, &tool_args).await,
331        "show_base_backup_progress" => actions::replication::show_base_backup_progress(&client, &tool_args).await,
332        // Transaction actions
333        "show_active_transactions" => actions::transactions::show_active_transactions(&client, &tool_args).await,
334        "show_locks" => actions::transactions::show_locks(&client, &tool_args).await,
335        "show_waiting_locks" => actions::transactions::show_waiting_locks(&client, &tool_args).await,
336        "begin_transaction" => actions::transactions::begin_transaction(&client, &tool_args).await,
337        "commit_transaction" => actions::transactions::commit_transaction(&client, &tool_args).await,
338        "rollback_transaction" => actions::transactions::rollback_transaction(&client, &tool_args).await,
339        "show_transaction_isolation" => actions::transactions::show_transaction_isolation(&client, &tool_args).await,
340        "show_deadlocks" => actions::transactions::show_deadlocks(&client, &tool_args).await,
341        "show_autocommit_status" => actions::transactions::show_autocommit_status(&client, &tool_args).await,
342        "show_transaction_timeout" => actions::transactions::show_transaction_timeout(&client, &tool_args).await,
343        // Health actions
344        "analyze_db_health" => actions::health::analyze_db_health(&client, &tool_args).await,
345        "list_unused_indexes" => actions::health::list_unused_indexes(&client, &tool_args).await,
346        "list_duplicate_indexes" => actions::health::list_duplicate_indexes(&client, &tool_args).await,
347        "show_vacuum_progress" => actions::health::show_vacuum_progress(&client, &tool_args).await,
348        // Enhanced schema
349        "get_object_details" => actions::schema::get_object_details(&client, &tool_args).await,
350        tool => Err(method_not_found(tool)),
351    };
352
353    pool.release(client);
354    result
355}
356
357#[cold]
358fn method_not_found(name: &str) -> MCPError {
359    MCPError::MethodNotFound(name.to_string())
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_parse_valid_request() {
368        let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
369        let req = parse_request(line).unwrap();
370        assert_eq!(req.method, "initialize");
371        assert_eq!(req.id, Some(Value::Number(1.into())));
372    }
373
374    #[test]
375    fn test_parse_request_with_trailing_newline() {
376        let line = r#"{"jsonrpc":"2.0","method":"tools/list","id":2}"#;
377        let req = parse_request(line).unwrap();
378        assert_eq!(req.method, "tools/list");
379    }
380
381    #[test]
382    fn test_parse_request_with_whitespace() {
383        let line = "  {\"jsonrpc\":\"2.0\",\"method\":\"ping\",\"id\":3}  ";
384        let req = parse_request(line).unwrap();
385        assert_eq!(req.method, "ping");
386    }
387
388    #[test]
389    fn test_parse_empty_request() {
390        let err = parse_request("").unwrap_err();
391        assert_eq!(err, "Empty request");
392    }
393
394    #[test]
395    fn test_parse_whitespace_only() {
396        let err = parse_request("   \n  ").unwrap_err();
397        assert_eq!(err, "Empty request");
398    }
399
400    #[test]
401    fn test_parse_invalid_json() {
402        let err = parse_request("{invalid}").unwrap_err();
403        assert!(!err.is_empty(), "Invalid JSON should produce an error message");
404    }
405
406    #[test]
407    fn test_parse_missing_method() {
408        let err = parse_request(r#"{"jsonrpc":"2.0","id":1}"#).unwrap_err();
409        assert!(err.contains("method"));
410    }
411
412    #[test]
413    fn test_parse_wrong_version() {
414        let req = parse_request(r#"{"jsonrpc":"1.0","method":"init","id":1}"#).unwrap();
415        assert_eq!(req.jsonrpc, "1.0");
416    }
417
418    #[test]
419    fn test_method_not_found_error() {
420        let err = method_not_found("test_tool");
421        assert_eq!(err.error_code(), -32601);
422        assert!(err.to_string().contains("test_tool"));
423    }
424
425    #[test]
426    fn test_tools_list_static() {
427        let list = &*TOOLS_LIST;
428        let tools = list.get("tools").and_then(|v| v.as_array());
429        assert!(tools.is_some(), "TOOLS_LIST should contain a tools array");
430        assert!(!tools.unwrap().is_empty(), "Tools list should not be empty");
431    }
432
433    #[test]
434    fn test_process_request_method_dispatch() {
435        // Verify that process_request handles the dispatch correctly
436        // by testing the match on method strings — this is a compilation/coverage test
437        let _req = JsonRpcRequest {
438            jsonrpc: "2.0".to_string(),
439            method: "nonexistent".to_string(),
440            params: None,
441            id: Some(Value::Number(1.into())),
442        };
443        // We can't run process_request without a pool, but we can verify the fallback path
444        // acts as expected through separate unit tests on the dispatch logic
445    }
446
447    #[test]
448    fn test_handle_initialize_response() {
449        let req = JsonRpcRequest {
450            jsonrpc: "2.0".to_string(),
451            method: "initialize".to_string(),
452            params: None,
453            id: Some(Value::Number(1.into())),
454        };
455        let result = handle_initialize(&req).unwrap();
456        assert_eq!(result["protocolVersion"], "2024-11-05");
457        assert!(result["capabilities"]["tools"]["listChanged"].is_boolean());
458        assert_eq!(result["serverInfo"]["version"], env!("CARGO_PKG_VERSION"));
459    }
460}