Skip to main content

heliosdb_proxy/
mcp.rs

1//! MCP (Model Context Protocol) agent gateway.
2//!
3//! When `[mcp] enabled = true`, the proxy exposes a native MCP server so AI
4//! agents call structured, policy-gated tools (`query`, `list_tables`,
5//! `explain`) instead of opening raw SQL connections. This is the AI-data-
6//! plane differentiator: every tool call goes through one auditable surface
7//! and a read-only-by-default guardrail, and runs over the proxy's backend
8//! PG-wire client so it is backend-agnostic (PostgreSQL or HeliosDB-Nano).
9//!
10//! Transport: JSON-RPC 2.0 over HTTP POST (the simplest MCP transport; an
11//! SSE/Streamable-HTTP upgrade is a follow-on). Methods implemented:
12//! `initialize`, `notifications/initialized`, `ping`, `tools/list`,
13//! `tools/call`.
14
15use std::sync::Arc;
16use std::time::Duration;
17
18use serde_json::{json, Value};
19use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
20use tokio::net::TcpListener;
21
22use crate::agent_contract::{self, AgentContract};
23use crate::backend::client::QueryResult;
24use crate::backend::types::TextValue;
25use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, TlsMode};
26use crate::config::McpConfig;
27use crate::{ProxyError, Result};
28
29const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
30
31/// The MCP gateway server.
32pub struct McpServer {
33    config: McpConfig,
34    contract: Option<AgentContract>,
35}
36
37impl McpServer {
38    pub fn new(config: McpConfig, contract: Option<AgentContract>) -> Self {
39        Self { config, contract }
40    }
41
42    /// Bind and serve the MCP HTTP endpoint until the task is dropped.
43    pub async fn run(self) -> Result<()> {
44        let listener = TcpListener::bind(&self.config.listen_address)
45            .await
46            .map_err(|e| {
47                ProxyError::Network(format!("MCP bind {}: {}", self.config.listen_address, e))
48            })?;
49        tracing::info!(addr = %self.config.listen_address, read_only = self.config.read_only,
50            contract = ?self.contract.as_ref().map(|c| &c.id), "MCP agent gateway listening");
51        let cfg = Arc::new(self.config);
52        let contract = Arc::new(self.contract);
53        loop {
54            let (stream, peer) = match listener.accept().await {
55                Ok(x) => x,
56                Err(e) => {
57                    tracing::warn!("MCP accept error: {}", e);
58                    continue;
59                }
60            };
61            let cfg = cfg.clone();
62            let contract = contract.clone();
63            tokio::spawn(async move {
64                if let Err(e) = Self::handle_connection(stream, cfg, contract).await {
65                    tracing::debug!(%peer, "MCP connection error: {}", e);
66                }
67            });
68        }
69    }
70
71    async fn handle_connection(
72        mut stream: tokio::net::TcpStream,
73        cfg: Arc<McpConfig>,
74        contract: Arc<Option<AgentContract>>,
75    ) -> Result<()> {
76        let (reader, mut writer) = stream.split();
77        let mut reader = BufReader::new(reader);
78        let mut line = String::new();
79        let mut content_length = 0usize;
80        // Read request line + headers.
81        use tokio::io::AsyncBufReadExt;
82        let mut first = true;
83        loop {
84            line.clear();
85            let n = reader
86                .read_line(&mut line)
87                .await
88                .map_err(|e| ProxyError::Network(format!("MCP read: {}", e)))?;
89            if n == 0 || line == "\r\n" {
90                break;
91            }
92            if first {
93                first = false; // request line; we accept any method/path
94            } else if line.to_ascii_lowercase().starts_with("content-length:") {
95                if let Some(v) = line.split(':').nth(1) {
96                    content_length = v.trim().parse().unwrap_or(0);
97                }
98            }
99        }
100        let body = if content_length > 0 {
101            let mut buf = vec![0u8; content_length];
102            reader
103                .read_exact(&mut buf)
104                .await
105                .map_err(|e| ProxyError::Network(format!("MCP body read: {}", e)))?;
106            String::from_utf8_lossy(&buf).to_string()
107        } else {
108            String::new()
109        };
110
111        let response = Self::dispatch(&body, &cfg, (*contract).as_ref()).await;
112        match response {
113            Some(v) => {
114                let payload = serde_json::to_string(&v).unwrap_or_else(|_| "{}".to_string());
115                Self::write_http(&mut writer, 200, "application/json", payload.as_bytes()).await
116            }
117            // Notifications get a bare 202 with no JSON-RPC body.
118            None => Self::write_http(&mut writer, 202, "application/json", b"").await,
119        }
120    }
121
122    /// Dispatch one JSON-RPC request. Returns `None` for notifications.
123    async fn dispatch(
124        body: &str,
125        cfg: &McpConfig,
126        contract: Option<&AgentContract>,
127    ) -> Option<Value> {
128        let req: Value = match serde_json::from_str(body) {
129            Ok(v) => v,
130            Err(e) => {
131                return Some(rpc_error(
132                    Value::Null,
133                    -32700,
134                    &format!("parse error: {}", e),
135                ))
136            }
137        };
138        let id = req.get("id").cloned().unwrap_or(Value::Null);
139        let method = req.get("method").and_then(|m| m.as_str()).unwrap_or("");
140        let params = req.get("params").cloned().unwrap_or(json!({}));
141
142        match method {
143            "initialize" => Some(rpc_ok(
144                id,
145                json!({
146                    "protocolVersion": MCP_PROTOCOL_VERSION,
147                    "serverInfo": { "name": "heliosproxy-mcp", "version": crate::VERSION },
148                    "capabilities": { "tools": { "listChanged": false } }
149                }),
150            )),
151            // Notifications (no id) — no response.
152            "notifications/initialized" | "notifications/cancelled" => None,
153            "ping" => Some(rpc_ok(id, json!({}))),
154            "tools/list" => Some(rpc_ok(id, json!({ "tools": Self::tool_defs(cfg) }))),
155            "tools/call" => Some(Self::handle_tool_call(id, &params, cfg, contract).await),
156            other => Some(rpc_error(
157                id,
158                -32601,
159                &format!("method not found: {}", other),
160            )),
161        }
162    }
163
164    fn tool_defs(cfg: &McpConfig) -> Value {
165        let query_desc = if cfg.read_only {
166            "Run a read-only SQL query and return rows. Writes/DDL are refused."
167        } else {
168            "Run a SQL query and return rows (or the command tag for writes)."
169        };
170        json!([
171            {
172                "name": "query",
173                "description": query_desc,
174                "inputSchema": {
175                    "type": "object",
176                    "properties": { "sql": { "type": "string", "description": "SQL to execute" } },
177                    "required": ["sql"]
178                }
179            },
180            {
181                "name": "list_tables",
182                "description": "List user tables (schema.table) in the connected database.",
183                "inputSchema": { "type": "object", "properties": {} }
184            },
185            {
186                "name": "explain",
187                "description": "Return the query plan for a SQL statement (EXPLAIN).",
188                "inputSchema": {
189                    "type": "object",
190                    "properties": { "sql": { "type": "string" } },
191                    "required": ["sql"]
192                }
193            }
194        ])
195    }
196
197    async fn handle_tool_call(
198        id: Value,
199        params: &Value,
200        cfg: &McpConfig,
201        contract: Option<&AgentContract>,
202    ) -> Value {
203        let name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
204        let args = params.get("arguments").cloned().unwrap_or(json!({}));
205
206        let result: std::result::Result<String, String> = match name {
207            "query" => {
208                let sql = args
209                    .get("sql")
210                    .and_then(|s| s.as_str())
211                    .unwrap_or("")
212                    .trim();
213                if sql.is_empty() {
214                    Err("missing 'sql'".to_string())
215                } else {
216                    match Self::check_policy(cfg, contract, sql) {
217                        Err(hint) => Err(hint),
218                        Ok(()) => Self::run_sql(cfg, sql).await.map(|r| format_result(&r)),
219                    }
220                }
221            }
222            "list_tables" => {
223                let sql = "SELECT table_schema, table_name FROM information_schema.tables \
224                           WHERE table_schema NOT IN ('pg_catalog','information_schema') \
225                           ORDER BY table_schema, table_name";
226                Self::run_sql(cfg, sql).await.map(|r| format_result(&r))
227            }
228            "explain" => {
229                let sql = args
230                    .get("sql")
231                    .and_then(|s| s.as_str())
232                    .unwrap_or("")
233                    .trim();
234                if sql.is_empty() {
235                    Err("missing 'sql'".to_string())
236                } else {
237                    match Self::check_policy(cfg, contract, sql) {
238                        Err(hint) => Err(hint),
239                        Ok(()) => Self::run_sql(cfg, &format!("EXPLAIN {}", sql))
240                            .await
241                            .map(|r| format_result(&r)),
242                    }
243                }
244            }
245            other => Err(format!("unknown tool: {}", other)),
246        };
247
248        match result {
249            Ok(text) => {
250                tracing::info!(tool = %name, "MCP tool call ok");
251                rpc_ok(
252                    id,
253                    json!({ "content": [{ "type": "text", "text": text }], "isError": false }),
254                )
255            }
256            Err(e) => {
257                tracing::info!(tool = %name, error = %e, "MCP tool call error");
258                // Tool errors are reported in-band (isError) per MCP, not as a
259                // protocol error, so the agent can read + self-correct.
260                rpc_ok(
261                    id,
262                    json!({ "content": [{ "type": "text", "text": e }], "isError": true }),
263                )
264            }
265        }
266    }
267
268    /// Gate a SQL statement: when an agent contract is configured, validate
269    /// against it and return a structured JSON repair hint on violation;
270    /// otherwise apply the plain read-only guardrail.
271    fn check_policy(
272        cfg: &McpConfig,
273        contract: Option<&AgentContract>,
274        sql: &str,
275    ) -> std::result::Result<(), String> {
276        if let Some(c) = contract {
277            agent_contract::validate(sql, c).map_err(|v| v.to_json())
278        } else if cfg.read_only && is_write_sql(sql) {
279            Err("write/DDL refused: the MCP gateway is read-only".to_string())
280        } else {
281            Ok(())
282        }
283    }
284
285    /// Connect to the configured backend, run one statement, return rows.
286    async fn run_sql(cfg: &McpConfig, sql: &str) -> std::result::Result<QueryResult, String> {
287        let bcfg = BackendConfig {
288            host: cfg.backend_host.clone(),
289            port: cfg.backend_port,
290            user: cfg.backend_user.clone(),
291            password: cfg.backend_password.clone(),
292            database: cfg.backend_database.clone(),
293            application_name: Some("heliosproxy-mcp".to_string()),
294            tls_mode: TlsMode::Disable,
295            connect_timeout: Duration::from_secs(5),
296            query_timeout: Duration::from_secs(30),
297            tls_config: default_client_config(),
298        };
299        let mut client = BackendClient::connect(&bcfg)
300            .await
301            .map_err(|e| format!("backend connect: {}", e))?;
302        let res = client.simple_query(sql).await.map_err(|e| format!("{}", e));
303        client.close().await;
304        res
305    }
306
307    async fn write_http(
308        writer: &mut tokio::net::tcp::WriteHalf<'_>,
309        status: u16,
310        content_type: &str,
311        body: &[u8],
312    ) -> Result<()> {
313        let head = format!(
314            "HTTP/1.1 {} {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
315            status,
316            if status == 200 { "OK" } else { "Accepted" },
317            content_type,
318            body.len()
319        );
320        writer
321            .write_all(head.as_bytes())
322            .await
323            .map_err(|e| ProxyError::Network(format!("MCP write: {}", e)))?;
324        if !body.is_empty() {
325            writer
326                .write_all(body)
327                .await
328                .map_err(|e| ProxyError::Network(format!("MCP write: {}", e)))?;
329        }
330        Ok(())
331    }
332}
333
334fn rpc_ok(id: Value, result: Value) -> Value {
335    json!({ "jsonrpc": "2.0", "id": id, "result": result })
336}
337
338fn rpc_error(id: Value, code: i32, message: &str) -> Value {
339    json!({ "jsonrpc": "2.0", "id": id, "error": { "code": code, "message": message } })
340}
341
342/// Render a QueryResult as a compact text table for the agent.
343fn format_result(r: &QueryResult) -> String {
344    if r.columns.is_empty() {
345        return r.command_tag.clone();
346    }
347    let header: Vec<&str> = r.columns.iter().map(|c| c.name.as_str()).collect();
348    let mut out = String::new();
349    out.push_str(&header.join(" | "));
350    out.push('\n');
351    for row in &r.rows {
352        let cells: Vec<String> = row
353            .iter()
354            .map(|v| match v {
355                TextValue::Null => "NULL".to_string(),
356                TextValue::Text(s) => s.clone(),
357            })
358            .collect();
359        out.push_str(&cells.join(" | "));
360        out.push('\n');
361    }
362    out.push_str(&format!("({} rows)", r.rows.len()));
363    out
364}
365
366/// First-keyword write/DDL detection (read-only guardrail).
367fn is_write_sql(sql: &str) -> bool {
368    use crate::protocol::starts_with_ci;
369    let s = sql.trim_start();
370    for kw in [
371        "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER", "TRUNCATE", "GRANT", "REVOKE",
372        "COPY", "MERGE", "CALL", "DO", "VACUUM", "REINDEX", "CLUSTER", "LOCK", "COMMENT", "SET",
373    ] {
374        if starts_with_ci(s, kw) {
375            return true;
376        }
377    }
378    false
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn read_only_guardrail() {
387        assert!(is_write_sql("INSERT INTO t VALUES (1)"));
388        assert!(is_write_sql("  drop table t"));
389        assert!(is_write_sql("CREATE TABLE t(x int)"));
390        assert!(!is_write_sql("SELECT * FROM t"));
391        assert!(!is_write_sql("  with x as (select 1) select * from x"));
392    }
393
394    #[tokio::test]
395    async fn initialize_and_tools_list() {
396        let cfg = McpConfig::default();
397        let init = McpServer::dispatch(
398            r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#,
399            &cfg,
400            None,
401        )
402        .await
403        .unwrap();
404        assert_eq!(init["result"]["protocolVersion"], MCP_PROTOCOL_VERSION);
405        assert_eq!(init["result"]["serverInfo"]["name"], "heliosproxy-mcp");
406
407        let tools = McpServer::dispatch(
408            r#"{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}"#,
409            &cfg,
410            None,
411        )
412        .await
413        .unwrap();
414        let names: Vec<&str> = tools["result"]["tools"]
415            .as_array()
416            .unwrap()
417            .iter()
418            .map(|t| t["name"].as_str().unwrap())
419            .collect();
420        assert!(names.contains(&"query"));
421        assert!(names.contains(&"list_tables"));
422        assert!(names.contains(&"explain"));
423    }
424
425    #[tokio::test]
426    async fn notification_has_no_response() {
427        let cfg = McpConfig::default();
428        let r = McpServer::dispatch(
429            r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#,
430            &cfg,
431            None,
432        )
433        .await;
434        assert!(r.is_none());
435    }
436
437    #[tokio::test]
438    async fn read_only_blocks_write_tool_call() {
439        let cfg = McpConfig::default(); // read_only = true
440        let r = McpServer::dispatch(
441            r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"query","arguments":{"sql":"DELETE FROM t"}}}"#,
442            &cfg,
443            None,
444        )
445        .await
446        .unwrap();
447        assert_eq!(r["result"]["isError"], true);
448        assert!(r["result"]["content"][0]["text"]
449            .as_str()
450            .unwrap()
451            .contains("read-only"));
452    }
453}