1use std::sync::Arc;
16use std::time::Duration;
17
18use serde_json::{json, Value};
19use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
20use tokio::net::TcpListener;
21
22use crate::agent_contract::{self, AgentContract};
23use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, TlsMode};
24use crate::backend::client::QueryResult;
25use crate::backend::types::TextValue;
26use crate::config::McpConfig;
27use crate::{ProxyError, Result};
28
29const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
30
31pub struct McpServer {
33 config: McpConfig,
34 contract: Option<AgentContract>,
35}
36
37impl McpServer {
38 pub fn new(config: McpConfig, contract: Option<AgentContract>) -> Self {
39 Self { config, contract }
40 }
41
42 pub async fn run(self) -> Result<()> {
44 let listener = TcpListener::bind(&self.config.listen_address)
45 .await
46 .map_err(|e| ProxyError::Network(format!("MCP bind {}: {}", self.config.listen_address, e)))?;
47 tracing::info!(addr = %self.config.listen_address, read_only = self.config.read_only,
48 contract = ?self.contract.as_ref().map(|c| &c.id), "MCP agent gateway listening");
49 let cfg = Arc::new(self.config);
50 let contract = Arc::new(self.contract);
51 loop {
52 let (stream, peer) = match listener.accept().await {
53 Ok(x) => x,
54 Err(e) => {
55 tracing::warn!("MCP accept error: {}", e);
56 continue;
57 }
58 };
59 let cfg = cfg.clone();
60 let contract = contract.clone();
61 tokio::spawn(async move {
62 if let Err(e) = Self::handle_connection(stream, cfg, contract).await {
63 tracing::debug!(%peer, "MCP connection error: {}", e);
64 }
65 });
66 }
67 }
68
69 async fn handle_connection(
70 mut stream: tokio::net::TcpStream,
71 cfg: Arc<McpConfig>,
72 contract: Arc<Option<AgentContract>>,
73 ) -> Result<()> {
74 let (reader, mut writer) = stream.split();
75 let mut reader = BufReader::new(reader);
76 let mut line = String::new();
77 let mut content_length = 0usize;
78 use tokio::io::AsyncBufReadExt;
80 let mut first = true;
81 loop {
82 line.clear();
83 let n = reader
84 .read_line(&mut line)
85 .await
86 .map_err(|e| ProxyError::Network(format!("MCP read: {}", e)))?;
87 if n == 0 || line == "\r\n" {
88 break;
89 }
90 if first {
91 first = false; } else if line.to_ascii_lowercase().starts_with("content-length:") {
93 if let Some(v) = line.split(':').nth(1) {
94 content_length = v.trim().parse().unwrap_or(0);
95 }
96 }
97 }
98 let body = if content_length > 0 {
99 let mut buf = vec![0u8; content_length];
100 reader
101 .read_exact(&mut buf)
102 .await
103 .map_err(|e| ProxyError::Network(format!("MCP body read: {}", e)))?;
104 String::from_utf8_lossy(&buf).to_string()
105 } else {
106 String::new()
107 };
108
109 let response = Self::dispatch(&body, &cfg, (*contract).as_ref()).await;
110 match response {
111 Some(v) => {
112 let payload = serde_json::to_string(&v).unwrap_or_else(|_| "{}".to_string());
113 Self::write_http(&mut writer, 200, "application/json", payload.as_bytes()).await
114 }
115 None => Self::write_http(&mut writer, 202, "application/json", b"").await,
117 }
118 }
119
120 async fn dispatch(body: &str, cfg: &McpConfig, contract: Option<&AgentContract>) -> Option<Value> {
122 let req: Value = match serde_json::from_str(body) {
123 Ok(v) => v,
124 Err(e) => return Some(rpc_error(Value::Null, -32700, &format!("parse error: {}", e))),
125 };
126 let id = req.get("id").cloned().unwrap_or(Value::Null);
127 let method = req.get("method").and_then(|m| m.as_str()).unwrap_or("");
128 let params = req.get("params").cloned().unwrap_or(json!({}));
129
130 match method {
131 "initialize" => Some(rpc_ok(
132 id,
133 json!({
134 "protocolVersion": MCP_PROTOCOL_VERSION,
135 "serverInfo": { "name": "heliosproxy-mcp", "version": crate::VERSION },
136 "capabilities": { "tools": { "listChanged": false } }
137 }),
138 )),
139 "notifications/initialized" | "notifications/cancelled" => None,
141 "ping" => Some(rpc_ok(id, json!({}))),
142 "tools/list" => Some(rpc_ok(id, json!({ "tools": Self::tool_defs(cfg) }))),
143 "tools/call" => Some(Self::handle_tool_call(id, ¶ms, cfg, contract).await),
144 other => Some(rpc_error(id, -32601, &format!("method not found: {}", other))),
145 }
146 }
147
148 fn tool_defs(cfg: &McpConfig) -> Value {
149 let query_desc = if cfg.read_only {
150 "Run a read-only SQL query and return rows. Writes/DDL are refused."
151 } else {
152 "Run a SQL query and return rows (or the command tag for writes)."
153 };
154 json!([
155 {
156 "name": "query",
157 "description": query_desc,
158 "inputSchema": {
159 "type": "object",
160 "properties": { "sql": { "type": "string", "description": "SQL to execute" } },
161 "required": ["sql"]
162 }
163 },
164 {
165 "name": "list_tables",
166 "description": "List user tables (schema.table) in the connected database.",
167 "inputSchema": { "type": "object", "properties": {} }
168 },
169 {
170 "name": "explain",
171 "description": "Return the query plan for a SQL statement (EXPLAIN).",
172 "inputSchema": {
173 "type": "object",
174 "properties": { "sql": { "type": "string" } },
175 "required": ["sql"]
176 }
177 }
178 ])
179 }
180
181 async fn handle_tool_call(
182 id: Value,
183 params: &Value,
184 cfg: &McpConfig,
185 contract: Option<&AgentContract>,
186 ) -> Value {
187 let name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
188 let args = params.get("arguments").cloned().unwrap_or(json!({}));
189
190 let result: std::result::Result<String, String> = match name {
191 "query" => {
192 let sql = args.get("sql").and_then(|s| s.as_str()).unwrap_or("").trim();
193 if sql.is_empty() {
194 Err("missing 'sql'".to_string())
195 } else {
196 match Self::check_policy(cfg, contract, sql) {
197 Err(hint) => Err(hint),
198 Ok(()) => Self::run_sql(cfg, sql).await.map(|r| format_result(&r)),
199 }
200 }
201 }
202 "list_tables" => {
203 let sql = "SELECT table_schema, table_name FROM information_schema.tables \
204 WHERE table_schema NOT IN ('pg_catalog','information_schema') \
205 ORDER BY table_schema, table_name";
206 Self::run_sql(cfg, sql).await.map(|r| format_result(&r))
207 }
208 "explain" => {
209 let sql = args.get("sql").and_then(|s| s.as_str()).unwrap_or("").trim();
210 if sql.is_empty() {
211 Err("missing 'sql'".to_string())
212 } else {
213 match Self::check_policy(cfg, contract, sql) {
214 Err(hint) => Err(hint),
215 Ok(()) => Self::run_sql(cfg, &format!("EXPLAIN {}", sql))
216 .await
217 .map(|r| format_result(&r)),
218 }
219 }
220 }
221 other => Err(format!("unknown tool: {}", other)),
222 };
223
224 match result {
225 Ok(text) => {
226 tracing::info!(tool = %name, "MCP tool call ok");
227 rpc_ok(
228 id,
229 json!({ "content": [{ "type": "text", "text": text }], "isError": false }),
230 )
231 }
232 Err(e) => {
233 tracing::info!(tool = %name, error = %e, "MCP tool call error");
234 rpc_ok(
237 id,
238 json!({ "content": [{ "type": "text", "text": e }], "isError": true }),
239 )
240 }
241 }
242 }
243
244 fn check_policy(
248 cfg: &McpConfig,
249 contract: Option<&AgentContract>,
250 sql: &str,
251 ) -> std::result::Result<(), String> {
252 if let Some(c) = contract {
253 agent_contract::validate(sql, c).map_err(|v| v.to_json())
254 } else if cfg.read_only && is_write_sql(sql) {
255 Err("write/DDL refused: the MCP gateway is read-only".to_string())
256 } else {
257 Ok(())
258 }
259 }
260
261 async fn run_sql(cfg: &McpConfig, sql: &str) -> std::result::Result<QueryResult, String> {
263 let bcfg = BackendConfig {
264 host: cfg.backend_host.clone(),
265 port: cfg.backend_port,
266 user: cfg.backend_user.clone(),
267 password: cfg.backend_password.clone(),
268 database: cfg.backend_database.clone(),
269 application_name: Some("heliosproxy-mcp".to_string()),
270 tls_mode: TlsMode::Disable,
271 connect_timeout: Duration::from_secs(5),
272 query_timeout: Duration::from_secs(30),
273 tls_config: default_client_config(),
274 };
275 let mut client = BackendClient::connect(&bcfg)
276 .await
277 .map_err(|e| format!("backend connect: {}", e))?;
278 let res = client.simple_query(sql).await.map_err(|e| format!("{}", e));
279 client.close().await;
280 res
281 }
282
283 async fn write_http(
284 writer: &mut tokio::net::tcp::WriteHalf<'_>,
285 status: u16,
286 content_type: &str,
287 body: &[u8],
288 ) -> Result<()> {
289 let head = format!(
290 "HTTP/1.1 {} {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
291 status,
292 if status == 200 { "OK" } else { "Accepted" },
293 content_type,
294 body.len()
295 );
296 writer
297 .write_all(head.as_bytes())
298 .await
299 .map_err(|e| ProxyError::Network(format!("MCP write: {}", e)))?;
300 if !body.is_empty() {
301 writer
302 .write_all(body)
303 .await
304 .map_err(|e| ProxyError::Network(format!("MCP write: {}", e)))?;
305 }
306 Ok(())
307 }
308}
309
310fn rpc_ok(id: Value, result: Value) -> Value {
311 json!({ "jsonrpc": "2.0", "id": id, "result": result })
312}
313
314fn rpc_error(id: Value, code: i32, message: &str) -> Value {
315 json!({ "jsonrpc": "2.0", "id": id, "error": { "code": code, "message": message } })
316}
317
318fn format_result(r: &QueryResult) -> String {
320 if r.columns.is_empty() {
321 return r.command_tag.clone();
322 }
323 let header: Vec<&str> = r.columns.iter().map(|c| c.name.as_str()).collect();
324 let mut out = String::new();
325 out.push_str(&header.join(" | "));
326 out.push('\n');
327 for row in &r.rows {
328 let cells: Vec<String> = row
329 .iter()
330 .map(|v| match v {
331 TextValue::Null => "NULL".to_string(),
332 TextValue::Text(s) => s.clone(),
333 })
334 .collect();
335 out.push_str(&cells.join(" | "));
336 out.push('\n');
337 }
338 out.push_str(&format!("({} rows)", r.rows.len()));
339 out
340}
341
342fn is_write_sql(sql: &str) -> bool {
344 use crate::protocol::starts_with_ci;
345 let s = sql.trim_start();
346 for kw in [
347 "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER", "TRUNCATE", "GRANT", "REVOKE",
348 "COPY", "MERGE", "CALL", "DO", "VACUUM", "REINDEX", "CLUSTER", "LOCK", "COMMENT", "SET",
349 ] {
350 if starts_with_ci(s, kw) {
351 return true;
352 }
353 }
354 false
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn read_only_guardrail() {
363 assert!(is_write_sql("INSERT INTO t VALUES (1)"));
364 assert!(is_write_sql(" drop table t"));
365 assert!(is_write_sql("CREATE TABLE t(x int)"));
366 assert!(!is_write_sql("SELECT * FROM t"));
367 assert!(!is_write_sql(" with x as (select 1) select * from x"));
368 }
369
370 #[tokio::test]
371 async fn initialize_and_tools_list() {
372 let cfg = McpConfig::default();
373 let init = McpServer::dispatch(
374 r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#,
375 &cfg,
376 None,
377 )
378 .await
379 .unwrap();
380 assert_eq!(init["result"]["protocolVersion"], MCP_PROTOCOL_VERSION);
381 assert_eq!(init["result"]["serverInfo"]["name"], "heliosproxy-mcp");
382
383 let tools = McpServer::dispatch(
384 r#"{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}"#,
385 &cfg,
386 None,
387 )
388 .await
389 .unwrap();
390 let names: Vec<&str> = tools["result"]["tools"]
391 .as_array()
392 .unwrap()
393 .iter()
394 .map(|t| t["name"].as_str().unwrap())
395 .collect();
396 assert!(names.contains(&"query"));
397 assert!(names.contains(&"list_tables"));
398 assert!(names.contains(&"explain"));
399 }
400
401 #[tokio::test]
402 async fn notification_has_no_response() {
403 let cfg = McpConfig::default();
404 let r = McpServer::dispatch(
405 r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#,
406 &cfg,
407 None,
408 )
409 .await;
410 assert!(r.is_none());
411 }
412
413 #[tokio::test]
414 async fn read_only_blocks_write_tool_call() {
415 let cfg = McpConfig::default(); let r = McpServer::dispatch(
417 r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"query","arguments":{"sql":"DELETE FROM t"}}}"#,
418 &cfg,
419 None,
420 )
421 .await
422 .unwrap();
423 assert_eq!(r["result"]["isError"], true);
424 assert!(r["result"]["content"][0]["text"]
425 .as_str()
426 .unwrap()
427 .contains("read-only"));
428 }
429}