Skip to main content

mcp_postgres/
server.rs

1use tokio::net::{TcpListener, TcpStream};
2use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
3use serde_json::{json, Value};
4use tracing::{error, warn};
5use std::sync::Arc;
6
7use crate::config::Config;
8use crate::errors::{MCPError, Result as MCPResult};
9use crate::metrics;
10use crate::pool::ConnectionPool;
11use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
12use crate::actions;
13use once_cell::sync::Lazy;
14
15static TOOLS_LIST: Lazy<Value> = Lazy::new(|| {
16    let tools_json = include_str!("../tools.json");
17    let tools: Vec<Value> = serde_json::from_str(tools_json)
18        .expect("Failed to parse tools.json");
19    json!({ "tools": tools })
20});
21
22const BUFFER_CAPACITY: usize = 4096;
23const NEWLINE: &[u8] = b"\n";
24
25#[inline]
26#[cold]
27fn parse_error(msg: String) -> JsonRpcResponse {
28    let mcp_error = MCPError::ParseError(msg);
29    JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
30}
31
32#[inline]
33fn parse_request(line: &str) -> Result<JsonRpcRequest, String> {
34    let trimmed = line.trim();
35    if trimmed.is_empty() {
36        return Err("Empty request".to_string());
37    }
38    serde_json::from_str::<JsonRpcRequest>(trimmed)
39        .map_err(|e| e.to_string())
40}
41
42pub struct MCPServer {
43    config: Config,
44    pool: Arc<ConnectionPool>,
45}
46
47impl MCPServer {
48    pub fn new(config: Config, pool: Arc<ConnectionPool>) -> Self {
49        Self { config, pool }
50    }
51
52    /// Run in stdio mode for MCP compatibility (Claude Desktop, etc.)
53    pub async fn run_stdio(&self) -> MCPResult<()> {
54        let stdin = tokio::io::stdin();
55        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
56        let mut stdout = tokio::io::stdout();
57        let mut line = String::with_capacity(512);
58        let mut response_buf = Vec::with_capacity(65536);
59
60        loop {
61            line.clear();
62            match reader.read_line(&mut line).await {
63                Ok(0) => break,
64                Ok(_) => {
65                    process_one_line(&line, &self.pool, &self.config, &mut response_buf, &mut stdout).await?;
66                }
67                Err(e) => {
68                    error!("IO error: {}", e);
69                    break;
70                }
71            }
72        }
73        Ok(())
74    }
75
76    pub async fn run(&self) -> MCPResult<()> {
77        let addr = format!("{}:{}", self.config.server.host, self.config.server.port);
78        let listener = TcpListener::bind(&addr).await?;
79
80        tracing::info!("MCP server listening on {}", addr);
81
82        loop {
83            let (socket, peer_addr) = listener.accept().await?;
84
85            if let Err(e) = socket.set_nodelay(true) {
86                warn!("Failed to set TCP_NODELAY: {}", e);
87            }
88
89            let pool = Arc::clone(&self.pool);
90            let config = self.config.clone();
91
92            tokio::spawn(async move {
93                if let Err(e) = handle_client(socket, pool, config).await {
94                    error!("Client {} error: {}", peer_addr, e);
95                }
96            });
97        }
98    }
99}
100
101#[inline(never)]
102async fn handle_client(socket: TcpStream, pool: Arc<ConnectionPool>, config: Config) -> MCPResult<()> {
103    let (reader, mut writer) = socket.into_split();
104    let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, reader);
105    let mut line = String::with_capacity(512);
106    let mut response_buf = Vec::with_capacity(65536);
107
108    loop {
109        line.clear();
110        match reader.read_line(&mut line).await {
111            Ok(0) => break,
112            Ok(_) => {
113                process_one_line(&line, &pool, &config, &mut response_buf, &mut writer).await?;
114            }
115            Err(e) => {
116                error!("IO error: {}", e);
117                break;
118            }
119        }
120    }
121
122    Ok(())
123}
124
125/// Core per-line processing shared by TCP and stdio transports.
126/// For notifications (JSON-RPC messages without `id`), no response is sent.
127#[inline]
128async fn process_one_line<W: AsyncWriteExt + Unpin>(
129    line: &str,
130    pool: &Arc<ConnectionPool>,
131    config: &Config,
132    response_buf: &mut Vec<u8>,
133    writer: &mut W,
134) -> MCPResult<()> {
135    metrics::inc_requests();
136
137    let (response, is_notification) = match parse_request(line) {
138        Ok(req) => {
139            let is_notif = req.id.is_none();
140            match process_request(&req, pool, config).await {
141                Ok(result) => (JsonRpcResponse::success(req.id, result), is_notif),
142                Err(e) => {
143                    metrics::inc_errors();
144                    (JsonRpcResponse::error(req.id, e.error_code(), e.to_string()), is_notif)
145                }
146            }
147        }
148        Err(e) => {
149            metrics::inc_errors();
150            (parse_error(e), false)
151        }
152    };
153
154    // JSON-RPC notifications (no `id`) do not expect a response
155    if is_notification {
156        return Ok(());
157    }
158
159    response_buf.clear();
160    serde_json::to_writer(&mut *response_buf, &response)?;
161    response_buf.extend_from_slice(NEWLINE);
162
163    writer.write_all(response_buf).await?;
164    writer.flush().await?;
165    Ok(())
166}
167
168/// Process a JSON-RPC request (used by both TCP and HTTP transports)
169#[inline]
170pub async fn process_request(
171    req: &JsonRpcRequest,
172    pool: &Arc<ConnectionPool>,
173    config: &Config,
174) -> MCPResult<Value> {
175    match req.method.as_str() {
176        "initialize" => handle_initialize(req),
177        "tools/list" => handle_tools_list(),
178        "tools/call" => handle_tools_call(req, pool, config).await,
179        "ping" => handle_ping(),
180        method if method.starts_with("notifications/") => handle_notification(method),
181        _ => Err(MCPError::MethodNotFound(req.method.clone())),
182    }
183}
184
185/// Handle JSON-RPC ping (respond with empty success)
186#[inline]
187fn handle_ping() -> MCPResult<Value> {
188    Ok(Value::Null)
189}
190
191/// Handle MCP notifications (silently accepted, no response needed per JSON-RPC spec)
192#[inline]
193fn handle_notification(method: &str) -> MCPResult<Value> {
194    tracing::trace!("Received notification: {method}");
195    Ok(Value::Null)
196}
197
198/// Public wrapper for HTTP handlers - returns complete JSON-RPC response
199pub async fn process_request_http(
200    req: &JsonRpcRequest,
201    pool: &Arc<ConnectionPool>,
202    config: &Config,
203) -> JsonRpcResponse {
204    metrics::inc_requests();
205
206    match process_request(req, pool, config).await {
207        Ok(result) => JsonRpcResponse::success(req.id.clone(), result),
208        Err(e) => {
209            metrics::inc_errors();
210            JsonRpcResponse::error(req.id.clone(), e.error_code(), e.to_string())
211        }
212    }
213}
214
215#[inline]
216fn handle_initialize(_req: &JsonRpcRequest) -> MCPResult<Value> {
217    Ok(json!({
218        "protocolVersion": "2024-11-05",
219        "capabilities": {
220            "tools": {
221                "listChanged": false
222            },
223            "resources": {
224                "subscribe": false,
225                "listChanged": false
226            },
227            "prompts": {
228                "listChanged": false
229            }
230        },
231        "serverInfo": {
232            "name": "mcp-postgres",
233            "version": env!("CARGO_PKG_VERSION")
234        }
235    }))
236}
237
238#[inline]
239fn handle_tools_list() -> MCPResult<Value> {
240    Ok((*TOOLS_LIST).clone())
241}
242
243async fn handle_tools_call(
244    req: &JsonRpcRequest,
245    pool: &Arc<ConnectionPool>,
246    config: &Config,
247) -> MCPResult<Value> {
248    let tool_name = req
249        .params
250        .as_ref()
251        .and_then(|p| p.get("name").and_then(|v| v.as_str()))
252        .ok_or_else(|| MCPError::InvalidParams("Missing 'name' parameter".into()))?;
253
254    let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
255
256    // Restricted mode check + unknown tool check BEFORE pool acquire
257    let write_tools: &[&str] = &[
258        "execute_insert", "execute_update", "execute_delete",
259        "async_execute_insert", "async_execute_update", "async_execute_delete",
260        "async_batch_insert", "async_batch_update", "async_batch_delete", "async_batch_insert_copy",
261        "create_table", "drop_table", "create_view", "drop_view", "alter_view", "create_schema", "drop_schema", "create_sequence", "drop_sequence", "alter_index", "backup_table", "create_index", "drop_index", "create_partition", "drop_partition",
262        "vacuum_analyze", "analyze_table", "reindex_table",
263        "reset_statistics", "truncate_table",
264    ];
265
266    if config.server.access_mode == crate::config::AccessMode::Restricted
267        && write_tools.contains(&tool_name)
268    {
269        return Err(MCPError::InvalidParams(format!(
270            "Operation '{tool_name}' is not allowed in restricted (read-only) mode"
271        )));
272    }
273
274    // Fast-path simple tools that don't need a DB connection
275    let no_db_tools: &[&str] = &["list_tables", "list_schemas", "show_constraints"];
276    if !no_db_tools.contains(&tool_name) {
277        // Verify tool exists before acquiring a connection
278        let tool_exists = matches!(tool_name,
279            "describe_table" | "list_triggers" | "list_indexes" | "execute_query" | "execute_insert"
280            | "execute_update" | "execute_delete" | "explain_query"
281            | "async_execute_insert" | "async_execute_update" | "async_execute_delete"
282            | "async_batch_insert" | "async_batch_update" | "async_batch_delete" | "async_batch_insert_copy"
283            | "create_table" | "drop_table" | "create_view" | "drop_view" | "alter_view" | "create_schema" | "drop_schema" | "create_sequence" | "drop_sequence" | "alter_index" | "backup_table" | "create_index" | "list_partitions" | "drop_index" | "create_partition" | "drop_partition"
284            | "get_table_stats" | "get_index_stats" | "show_database_size"
285            | "show_table_size" | "get_cache_hit_ratio"
286            | "list_connections" | "show_current_user"
287            | "show_running_queries" | "show_connection_summary"
288            | "vacuum_analyze" | "analyze_table" | "reindex_table"
289            | "get_pg_stat_statements" | "reset_statistics" | "truncate_table"
290            | "list_users" | "list_user_privileges" | "list_role_memberships"
291            | "list_database_privileges" | "show_session_info"
292            | "show_all_settings" | "get_setting" | "show_memory_settings"
293            | "show_performance_settings" | "show_log_settings"
294            | "show_replication_status" | "list_replication_slots"
295            | "list_standby_servers" | "show_wal_info" | "show_base_backup_progress"
296            | "show_active_transactions" | "show_locks" | "show_waiting_locks"
297            | "show_transaction_isolation" | "show_deadlocks"
298            | "show_autocommit_status" | "show_transaction_timeout"
299            | "analyze_db_health" | "list_unused_indexes" | "list_duplicate_indexes"
300            | "show_vacuum_progress" | "get_object_details"
301        );
302        if !tool_exists {
303            return Err(method_not_found(tool_name));
304        }
305    }
306
307    // Acquire pool connection only for known tools
308    let client = pool.acquire().await?;
309
310    let result = match tool_name {
311        // Schema actions
312        "list_tables" => actions::schema::list_tables(&client, &tool_args).await,
313        "describe_table" => actions::schema::describe_table(&client, &tool_args).await,
314        "list_indexes" => actions::schema::list_indexes(&client, &tool_args).await,
315        "list_schemas" => actions::schema::list_schemas(&client, &tool_args).await,
316        "show_constraints" => actions::schema::show_constraints(&client, &tool_args).await,
317        "list_triggers" => actions::schema::list_triggers(&client, &tool_args).await,
318        "create_table" => actions::schema::create_table(&client, &tool_args).await,
319        "drop_table" => actions::schema::drop_table(&client, &tool_args).await,
320        "create_view" => actions::schema::create_view(&client, &tool_args).await,
321        "drop_view" => actions::schema::drop_view(&client, &tool_args).await,
322        "alter_view" => actions::schema::alter_view(&client, &tool_args).await,
323        "create_schema" => actions::schema::create_schema(&client, &tool_args).await,
324        "drop_schema" => actions::schema::drop_schema(&client, &tool_args).await,
325        "create_sequence" => actions::schema::create_sequence(&client, &tool_args).await,
326        "drop_sequence" => actions::schema::drop_sequence(&client, &tool_args).await,
327        "alter_index" => actions::schema::alter_index(&client, &tool_args).await,
328        "list_partitions" => actions::schema::list_partitions(&client, &tool_args).await,
329        "backup_table" => actions::schema::backup_table(&client, &tool_args).await,
330        "create_index" => actions::schema::create_index(&client, &tool_args).await,
331        "drop_index" => actions::schema::drop_index(&client, &tool_args).await,
332        "create_partition" => actions::schema::create_partition(&client, &tool_args).await,
333        "drop_partition" => actions::schema::drop_partition(&client, &tool_args).await,
334        // Query actions
335        "execute_query" => actions::query::execute_query(&client, &tool_args).await,
336        "execute_insert" => actions::query::execute_insert(&client, &tool_args).await,
337        "execute_update" => actions::query::execute_update(&client, &tool_args).await,
338        "execute_delete" => actions::query::execute_delete(&client, &tool_args).await,
339        "async_execute_insert" => actions::query::async_execute_insert(&client, &tool_args).await,
340        "async_execute_update" => actions::query::async_execute_update(&client, &tool_args).await,
341        "async_execute_delete" => actions::query::async_execute_delete(&client, &tool_args).await,
342        "explain_query" => actions::query::explain_query(&client, &tool_args).await,
343        // Batch operations
344        "async_batch_insert" => actions::batch::async_batch_insert(&client, &tool_args).await,
345        "async_batch_update" => actions::batch::async_batch_update(&client, &tool_args).await,
346        "async_batch_delete" => actions::batch::async_batch_delete(&client, &tool_args).await,
347        "async_batch_insert_copy" => actions::batch::async_batch_insert_copy(&client, &tool_args).await,
348        // Monitoring actions
349        "get_table_stats" => actions::monitoring::get_table_stats(&client, &tool_args).await,
350        "get_index_stats" => actions::monitoring::get_index_stats(&client, &tool_args).await,
351        "show_database_size" => actions::monitoring::show_database_size(&client, &tool_args).await,
352        "show_table_size" => actions::monitoring::show_table_size(&client, &tool_args).await,
353        "get_cache_hit_ratio" => actions::monitoring::get_cache_hit_ratio(&client, &tool_args).await,
354        // Connection actions
355        "list_connections" => actions::connections::list_connections(&client, &tool_args).await,
356        "show_current_user" => actions::connections::show_current_user(&client, &tool_args).await,
357        "show_running_queries" => actions::connections::show_running_queries(&client, &tool_args).await,
358        "show_connection_summary" => actions::connections::show_connection_summary(&client, &tool_args).await,
359        // Maintenance actions
360        "vacuum_analyze" => actions::maintenance::vacuum_analyze(&client, &tool_args).await,
361        "analyze_table" => actions::maintenance::analyze_table(&client, &tool_args).await,
362        "reindex_table" => actions::maintenance::reindex_table(&client, &tool_args).await,
363        "get_pg_stat_statements" => actions::maintenance::get_pg_stat_statements(&client, &tool_args).await,
364        "reset_statistics" => actions::maintenance::reset_statistics(&client, &tool_args).await,
365        "truncate_table" => actions::maintenance::truncate_table(&client, &tool_args).await,
366        // Security actions
367        "list_users" => actions::security::list_users(&client, &tool_args).await,
368        "list_user_privileges" => actions::security::list_user_privileges(&client, &tool_args).await,
369        "list_role_memberships" => actions::security::list_role_memberships(&client, &tool_args).await,
370        "list_database_privileges" => actions::security::list_database_privileges(&client, &tool_args).await,
371        "show_session_info" => actions::security::show_session_info(&client, &tool_args).await,
372        // Config actions
373        "show_all_settings" => actions::config::show_all_settings(&client, &tool_args).await,
374        "get_setting" => actions::config::get_setting(&client, &tool_args).await,
375        "show_memory_settings" => actions::config::show_memory_settings(&client, &tool_args).await,
376        "show_performance_settings" => actions::config::show_performance_settings(&client, &tool_args).await,
377        "show_log_settings" => actions::config::show_log_settings(&client, &tool_args).await,
378        // Replication actions
379        "show_replication_status" => actions::replication::show_replication_status(&client, &tool_args).await,
380        "list_replication_slots" => actions::replication::list_replication_slots(&client, &tool_args).await,
381        "list_standby_servers" => actions::replication::list_standby_servers(&client, &tool_args).await,
382        "show_wal_info" => actions::replication::show_wal_info(&client, &tool_args).await,
383        "show_base_backup_progress" => actions::replication::show_base_backup_progress(&client, &tool_args).await,
384        // Transaction actions
385        "show_active_transactions" => actions::transactions::show_active_transactions(&client, &tool_args).await,
386        "show_locks" => actions::transactions::show_locks(&client, &tool_args).await,
387        "show_waiting_locks" => actions::transactions::show_waiting_locks(&client, &tool_args).await,
388        "show_transaction_isolation" => actions::transactions::show_transaction_isolation(&client, &tool_args).await,
389        "show_deadlocks" => actions::transactions::show_deadlocks(&client, &tool_args).await,
390        "show_autocommit_status" => actions::transactions::show_autocommit_status(&client, &tool_args).await,
391        "show_transaction_timeout" => actions::transactions::show_transaction_timeout(&client, &tool_args).await,
392        // Health actions
393        "analyze_db_health" => actions::health::analyze_db_health(&client, &tool_args).await,
394        "list_unused_indexes" => actions::health::list_unused_indexes(&client, &tool_args).await,
395        "list_duplicate_indexes" => actions::health::list_duplicate_indexes(&client, &tool_args).await,
396        "show_vacuum_progress" => actions::health::show_vacuum_progress(&client, &tool_args).await,
397        // Enhanced schema
398        "get_object_details" => actions::schema::get_object_details(&client, &tool_args).await,
399        tool => Err(method_not_found(tool)),
400    };
401
402    if let Err(ref e) = result {
403        error!("Tool '{}' error: {:?}", tool_name, e);
404    }
405    pool.release(client);
406    result
407}
408
409#[cold]
410fn method_not_found(name: &str) -> MCPError {
411    MCPError::MethodNotFound(name.to_string())
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    #[test]
419    fn test_parse_valid_request() {
420        let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
421        let req = parse_request(line).unwrap();
422        assert_eq!(req.method, "initialize");
423        assert_eq!(req.id, Some(Value::Number(1.into())));
424    }
425
426    #[test]
427    fn test_parse_request_with_trailing_newline() {
428        let line = r#"{"jsonrpc":"2.0","method":"tools/list","id":2}"#;
429        let req = parse_request(line).unwrap();
430        assert_eq!(req.method, "tools/list");
431    }
432
433    #[test]
434    fn test_parse_request_with_whitespace() {
435        let line = "  {\"jsonrpc\":\"2.0\",\"method\":\"ping\",\"id\":3}  ";
436        let req = parse_request(line).unwrap();
437        assert_eq!(req.method, "ping");
438    }
439
440    #[test]
441    fn test_parse_empty_request() {
442        let err = parse_request("").unwrap_err();
443        assert_eq!(err, "Empty request");
444    }
445
446    #[test]
447    fn test_parse_whitespace_only() {
448        let err = parse_request("   \n  ").unwrap_err();
449        assert_eq!(err, "Empty request");
450    }
451
452    #[test]
453    fn test_parse_invalid_json() {
454        let err = parse_request("{invalid}").unwrap_err();
455        assert!(!err.is_empty(), "Invalid JSON should produce an error message");
456    }
457
458    #[test]
459    fn test_parse_missing_method() {
460        let err = parse_request(r#"{"jsonrpc":"2.0","id":1}"#).unwrap_err();
461        assert!(err.contains("method"));
462    }
463
464    #[test]
465    fn test_parse_wrong_version() {
466        let req = parse_request(r#"{"jsonrpc":"1.0","method":"init","id":1}"#).unwrap();
467        assert_eq!(req.jsonrpc, "1.0");
468    }
469
470    #[test]
471    fn test_method_not_found_error() {
472        let err = method_not_found("test_tool");
473        assert_eq!(err.error_code(), -32601);
474        assert!(err.to_string().contains("test_tool"));
475    }
476
477    #[test]
478    fn test_tools_list_static() {
479        let list = &*TOOLS_LIST;
480        let tools = list.get("tools").and_then(|v| v.as_array());
481        assert!(tools.is_some(), "TOOLS_LIST should contain a tools array");
482        assert!(!tools.unwrap().is_empty(), "Tools list should not be empty");
483    }
484
485    #[test]
486    fn test_process_request_method_dispatch() {
487        // Verify that process_request handles the dispatch correctly
488        // by testing the match on method strings — this is a compilation/coverage test
489        let _req = JsonRpcRequest {
490            jsonrpc: "2.0".to_string(),
491            method: "nonexistent".to_string(),
492            params: None,
493            id: Some(Value::Number(1.into())),
494        };
495        // We can't run process_request without a pool, but we can verify the fallback path
496        // acts as expected through separate unit tests on the dispatch logic
497    }
498
499    #[test]
500    fn test_handle_initialize_response() {
501        let req = JsonRpcRequest {
502            jsonrpc: "2.0".to_string(),
503            method: "initialize".to_string(),
504            params: None,
505            id: Some(Value::Number(1.into())),
506        };
507        let result = handle_initialize(&req).unwrap();
508        assert_eq!(result["protocolVersion"], "2024-11-05");
509        assert!(result["capabilities"]["tools"]["listChanged"].is_boolean());
510        assert_eq!(result["serverInfo"]["version"], env!("CARGO_PKG_VERSION"));
511    }
512
513    /// Enforce Phase 1.5: no bare `SET ` outside transaction blocks.
514    /// Every session-level SET must use `SET LOCAL` inside a `BEGIN`/`COMMIT` pair.
515    /// This grep-based test fails compilation if any violation exists in `src/actions/`.
516    #[test]
517    fn test_no_bare_set_outside_transaction() {
518        let source_files = &[
519            include_str!("../src/actions/query.rs"),
520            include_str!("../src/actions/batch.rs"),
521        ];
522        for (idx, source) in source_files.iter().enumerate() {
523            for (line_no, line) in source.lines().enumerate() {
524                let trimmed = line.trim();
525                // Skip comments, UPDATE SET, string literals
526                if trimmed.starts_with("//") || trimmed.starts_with("/*") || trimmed.starts_with("*") {
527                    continue;
528                }
529                if trimmed.contains("UPDATE ") && trimmed.contains("SET ") {
530                    continue;
531                }
532                if trimmed.contains("SET LOCAL") {
533                    continue;
534                }
535                // Check for bare client.execute("SET ...") outside txn
536                if trimmed.contains("client.execute(\"SET ") && !trimmed.contains("SET LOCAL") {
537                    let names = ["query.rs", "batch.rs"];
538                    panic!(
539                        "Phase 1.5 violation: bare `SET` (not SET LOCAL) found in {}:{} — \
540                         use BEGIN + SET LOCAL + COMMIT pattern to avoid session leakage.\n\
541                         Line: {}",
542                        names[idx], line_no + 1, trimmed
543                    );
544                }
545            }
546        }
547    }
548}