Skip to main content

heliosdb_proxy/
mcp.rs

1//! MCP (Model Context Protocol) agent gateway.
2//!
3//! When `[mcp] enabled = true`, the proxy exposes a native MCP server so AI
4//! agents call structured, policy-gated tools (`query`, `list_tables`,
5//! `explain`) instead of opening raw SQL connections. This is the AI-data-
6//! plane differentiator: every tool call goes through one auditable surface
7//! and a read-only-by-default guardrail, and runs over the proxy's backend
8//! PG-wire client so it is backend-agnostic (PostgreSQL or HeliosDB-Nano).
9//!
10//! Transport: JSON-RPC 2.0 over HTTP POST (the simplest MCP transport; an
11//! SSE/Streamable-HTTP upgrade is a follow-on). Methods implemented:
12//! `initialize`, `notifications/initialized`, `ping`, `tools/list`,
13//! `tools/call`.
14
15use std::sync::Arc;
16use std::time::Duration;
17
18use serde_json::{json, Value};
19use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
20use tokio::net::TcpListener;
21
22use crate::agent_contract::{self, AgentContract};
23use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, TlsMode};
24use crate::backend::client::QueryResult;
25use crate::backend::types::TextValue;
26use crate::config::McpConfig;
27use crate::{ProxyError, Result};
28
29const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
30
31/// The MCP gateway server.
32pub struct McpServer {
33    config: McpConfig,
34    contract: Option<AgentContract>,
35}
36
37impl McpServer {
38    pub fn new(config: McpConfig, contract: Option<AgentContract>) -> Self {
39        Self { config, contract }
40    }
41
42    /// Bind and serve the MCP HTTP endpoint until the task is dropped.
43    pub async fn run(self) -> Result<()> {
44        let listener = TcpListener::bind(&self.config.listen_address)
45            .await
46            .map_err(|e| ProxyError::Network(format!("MCP bind {}: {}", self.config.listen_address, e)))?;
47        tracing::info!(addr = %self.config.listen_address, read_only = self.config.read_only,
48            contract = ?self.contract.as_ref().map(|c| &c.id), "MCP agent gateway listening");
49        let cfg = Arc::new(self.config);
50        let contract = Arc::new(self.contract);
51        loop {
52            let (stream, peer) = match listener.accept().await {
53                Ok(x) => x,
54                Err(e) => {
55                    tracing::warn!("MCP accept error: {}", e);
56                    continue;
57                }
58            };
59            let cfg = cfg.clone();
60            let contract = contract.clone();
61            tokio::spawn(async move {
62                if let Err(e) = Self::handle_connection(stream, cfg, contract).await {
63                    tracing::debug!(%peer, "MCP connection error: {}", e);
64                }
65            });
66        }
67    }
68
69    async fn handle_connection(
70        mut stream: tokio::net::TcpStream,
71        cfg: Arc<McpConfig>,
72        contract: Arc<Option<AgentContract>>,
73    ) -> Result<()> {
74        let (reader, mut writer) = stream.split();
75        let mut reader = BufReader::new(reader);
76        let mut line = String::new();
77        let mut content_length = 0usize;
78        // Read request line + headers.
79        use tokio::io::AsyncBufReadExt;
80        let mut first = true;
81        loop {
82            line.clear();
83            let n = reader
84                .read_line(&mut line)
85                .await
86                .map_err(|e| ProxyError::Network(format!("MCP read: {}", e)))?;
87            if n == 0 || line == "\r\n" {
88                break;
89            }
90            if first {
91                first = false; // request line; we accept any method/path
92            } else if line.to_ascii_lowercase().starts_with("content-length:") {
93                if let Some(v) = line.split(':').nth(1) {
94                    content_length = v.trim().parse().unwrap_or(0);
95                }
96            }
97        }
98        let body = if content_length > 0 {
99            let mut buf = vec![0u8; content_length];
100            reader
101                .read_exact(&mut buf)
102                .await
103                .map_err(|e| ProxyError::Network(format!("MCP body read: {}", e)))?;
104            String::from_utf8_lossy(&buf).to_string()
105        } else {
106            String::new()
107        };
108
109        let response = Self::dispatch(&body, &cfg, (*contract).as_ref()).await;
110        match response {
111            Some(v) => {
112                let payload = serde_json::to_string(&v).unwrap_or_else(|_| "{}".to_string());
113                Self::write_http(&mut writer, 200, "application/json", payload.as_bytes()).await
114            }
115            // Notifications get a bare 202 with no JSON-RPC body.
116            None => Self::write_http(&mut writer, 202, "application/json", b"").await,
117        }
118    }
119
120    /// Dispatch one JSON-RPC request. Returns `None` for notifications.
121    async fn dispatch(body: &str, cfg: &McpConfig, contract: Option<&AgentContract>) -> Option<Value> {
122        let req: Value = match serde_json::from_str(body) {
123            Ok(v) => v,
124            Err(e) => return Some(rpc_error(Value::Null, -32700, &format!("parse error: {}", e))),
125        };
126        let id = req.get("id").cloned().unwrap_or(Value::Null);
127        let method = req.get("method").and_then(|m| m.as_str()).unwrap_or("");
128        let params = req.get("params").cloned().unwrap_or(json!({}));
129
130        match method {
131            "initialize" => Some(rpc_ok(
132                id,
133                json!({
134                    "protocolVersion": MCP_PROTOCOL_VERSION,
135                    "serverInfo": { "name": "heliosproxy-mcp", "version": crate::VERSION },
136                    "capabilities": { "tools": { "listChanged": false } }
137                }),
138            )),
139            // Notifications (no id) — no response.
140            "notifications/initialized" | "notifications/cancelled" => None,
141            "ping" => Some(rpc_ok(id, json!({}))),
142            "tools/list" => Some(rpc_ok(id, json!({ "tools": Self::tool_defs(cfg) }))),
143            "tools/call" => Some(Self::handle_tool_call(id, &params, cfg, contract).await),
144            other => Some(rpc_error(id, -32601, &format!("method not found: {}", other))),
145        }
146    }
147
148    fn tool_defs(cfg: &McpConfig) -> Value {
149        let query_desc = if cfg.read_only {
150            "Run a read-only SQL query and return rows. Writes/DDL are refused."
151        } else {
152            "Run a SQL query and return rows (or the command tag for writes)."
153        };
154        json!([
155            {
156                "name": "query",
157                "description": query_desc,
158                "inputSchema": {
159                    "type": "object",
160                    "properties": { "sql": { "type": "string", "description": "SQL to execute" } },
161                    "required": ["sql"]
162                }
163            },
164            {
165                "name": "list_tables",
166                "description": "List user tables (schema.table) in the connected database.",
167                "inputSchema": { "type": "object", "properties": {} }
168            },
169            {
170                "name": "explain",
171                "description": "Return the query plan for a SQL statement (EXPLAIN).",
172                "inputSchema": {
173                    "type": "object",
174                    "properties": { "sql": { "type": "string" } },
175                    "required": ["sql"]
176                }
177            }
178        ])
179    }
180
181    async fn handle_tool_call(
182        id: Value,
183        params: &Value,
184        cfg: &McpConfig,
185        contract: Option<&AgentContract>,
186    ) -> Value {
187        let name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
188        let args = params.get("arguments").cloned().unwrap_or(json!({}));
189
190        let result: std::result::Result<String, String> = match name {
191            "query" => {
192                let sql = args.get("sql").and_then(|s| s.as_str()).unwrap_or("").trim();
193                if sql.is_empty() {
194                    Err("missing 'sql'".to_string())
195                } else {
196                    match Self::check_policy(cfg, contract, sql) {
197                        Err(hint) => Err(hint),
198                        Ok(()) => Self::run_sql(cfg, sql).await.map(|r| format_result(&r)),
199                    }
200                }
201            }
202            "list_tables" => {
203                let sql = "SELECT table_schema, table_name FROM information_schema.tables \
204                           WHERE table_schema NOT IN ('pg_catalog','information_schema') \
205                           ORDER BY table_schema, table_name";
206                Self::run_sql(cfg, sql).await.map(|r| format_result(&r))
207            }
208            "explain" => {
209                let sql = args.get("sql").and_then(|s| s.as_str()).unwrap_or("").trim();
210                if sql.is_empty() {
211                    Err("missing 'sql'".to_string())
212                } else {
213                    match Self::check_policy(cfg, contract, sql) {
214                        Err(hint) => Err(hint),
215                        Ok(()) => Self::run_sql(cfg, &format!("EXPLAIN {}", sql))
216                            .await
217                            .map(|r| format_result(&r)),
218                    }
219                }
220            }
221            other => Err(format!("unknown tool: {}", other)),
222        };
223
224        match result {
225            Ok(text) => {
226                tracing::info!(tool = %name, "MCP tool call ok");
227                rpc_ok(
228                    id,
229                    json!({ "content": [{ "type": "text", "text": text }], "isError": false }),
230                )
231            }
232            Err(e) => {
233                tracing::info!(tool = %name, error = %e, "MCP tool call error");
234                // Tool errors are reported in-band (isError) per MCP, not as a
235                // protocol error, so the agent can read + self-correct.
236                rpc_ok(
237                    id,
238                    json!({ "content": [{ "type": "text", "text": e }], "isError": true }),
239                )
240            }
241        }
242    }
243
244    /// Gate a SQL statement: when an agent contract is configured, validate
245    /// against it and return a structured JSON repair hint on violation;
246    /// otherwise apply the plain read-only guardrail.
247    fn check_policy(
248        cfg: &McpConfig,
249        contract: Option<&AgentContract>,
250        sql: &str,
251    ) -> std::result::Result<(), String> {
252        if let Some(c) = contract {
253            agent_contract::validate(sql, c).map_err(|v| v.to_json())
254        } else if cfg.read_only && is_write_sql(sql) {
255            Err("write/DDL refused: the MCP gateway is read-only".to_string())
256        } else {
257            Ok(())
258        }
259    }
260
261    /// Connect to the configured backend, run one statement, return rows.
262    async fn run_sql(cfg: &McpConfig, sql: &str) -> std::result::Result<QueryResult, String> {
263        let bcfg = BackendConfig {
264            host: cfg.backend_host.clone(),
265            port: cfg.backend_port,
266            user: cfg.backend_user.clone(),
267            password: cfg.backend_password.clone(),
268            database: cfg.backend_database.clone(),
269            application_name: Some("heliosproxy-mcp".to_string()),
270            tls_mode: TlsMode::Disable,
271            connect_timeout: Duration::from_secs(5),
272            query_timeout: Duration::from_secs(30),
273            tls_config: default_client_config(),
274        };
275        let mut client = BackendClient::connect(&bcfg)
276            .await
277            .map_err(|e| format!("backend connect: {}", e))?;
278        let res = client.simple_query(sql).await.map_err(|e| format!("{}", e));
279        client.close().await;
280        res
281    }
282
283    async fn write_http(
284        writer: &mut tokio::net::tcp::WriteHalf<'_>,
285        status: u16,
286        content_type: &str,
287        body: &[u8],
288    ) -> Result<()> {
289        let head = format!(
290            "HTTP/1.1 {} {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
291            status,
292            if status == 200 { "OK" } else { "Accepted" },
293            content_type,
294            body.len()
295        );
296        writer
297            .write_all(head.as_bytes())
298            .await
299            .map_err(|e| ProxyError::Network(format!("MCP write: {}", e)))?;
300        if !body.is_empty() {
301            writer
302                .write_all(body)
303                .await
304                .map_err(|e| ProxyError::Network(format!("MCP write: {}", e)))?;
305        }
306        Ok(())
307    }
308}
309
310fn rpc_ok(id: Value, result: Value) -> Value {
311    json!({ "jsonrpc": "2.0", "id": id, "result": result })
312}
313
314fn rpc_error(id: Value, code: i32, message: &str) -> Value {
315    json!({ "jsonrpc": "2.0", "id": id, "error": { "code": code, "message": message } })
316}
317
318/// Render a QueryResult as a compact text table for the agent.
319fn format_result(r: &QueryResult) -> String {
320    if r.columns.is_empty() {
321        return r.command_tag.clone();
322    }
323    let header: Vec<&str> = r.columns.iter().map(|c| c.name.as_str()).collect();
324    let mut out = String::new();
325    out.push_str(&header.join(" | "));
326    out.push('\n');
327    for row in &r.rows {
328        let cells: Vec<String> = row
329            .iter()
330            .map(|v| match v {
331                TextValue::Null => "NULL".to_string(),
332                TextValue::Text(s) => s.clone(),
333            })
334            .collect();
335        out.push_str(&cells.join(" | "));
336        out.push('\n');
337    }
338    out.push_str(&format!("({} rows)", r.rows.len()));
339    out
340}
341
342/// First-keyword write/DDL detection (read-only guardrail).
343fn is_write_sql(sql: &str) -> bool {
344    use crate::protocol::starts_with_ci;
345    let s = sql.trim_start();
346    for kw in [
347        "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER", "TRUNCATE", "GRANT", "REVOKE",
348        "COPY", "MERGE", "CALL", "DO", "VACUUM", "REINDEX", "CLUSTER", "LOCK", "COMMENT", "SET",
349    ] {
350        if starts_with_ci(s, kw) {
351            return true;
352        }
353    }
354    false
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn read_only_guardrail() {
363        assert!(is_write_sql("INSERT INTO t VALUES (1)"));
364        assert!(is_write_sql("  drop table t"));
365        assert!(is_write_sql("CREATE TABLE t(x int)"));
366        assert!(!is_write_sql("SELECT * FROM t"));
367        assert!(!is_write_sql("  with x as (select 1) select * from x"));
368    }
369
370    #[tokio::test]
371    async fn initialize_and_tools_list() {
372        let cfg = McpConfig::default();
373        let init = McpServer::dispatch(
374            r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#,
375            &cfg,
376            None,
377        )
378        .await
379        .unwrap();
380        assert_eq!(init["result"]["protocolVersion"], MCP_PROTOCOL_VERSION);
381        assert_eq!(init["result"]["serverInfo"]["name"], "heliosproxy-mcp");
382
383        let tools = McpServer::dispatch(
384            r#"{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}"#,
385            &cfg,
386            None,
387        )
388        .await
389        .unwrap();
390        let names: Vec<&str> = tools["result"]["tools"]
391            .as_array()
392            .unwrap()
393            .iter()
394            .map(|t| t["name"].as_str().unwrap())
395            .collect();
396        assert!(names.contains(&"query"));
397        assert!(names.contains(&"list_tables"));
398        assert!(names.contains(&"explain"));
399    }
400
401    #[tokio::test]
402    async fn notification_has_no_response() {
403        let cfg = McpConfig::default();
404        let r = McpServer::dispatch(
405            r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#,
406            &cfg,
407            None,
408        )
409        .await;
410        assert!(r.is_none());
411    }
412
413    #[tokio::test]
414    async fn read_only_blocks_write_tool_call() {
415        let cfg = McpConfig::default(); // read_only = true
416        let r = McpServer::dispatch(
417            r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"query","arguments":{"sql":"DELETE FROM t"}}}"#,
418            &cfg,
419            None,
420        )
421        .await
422        .unwrap();
423        assert_eq!(r["result"]["isError"], true);
424        assert!(r["result"]["content"][0]["text"]
425            .as_str()
426            .unwrap()
427            .contains("read-only"));
428    }
429}