1use std::sync::Arc;
16use std::time::Duration;
17
18use serde_json::{json, Value};
19use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
20use tokio::net::TcpListener;
21
22use crate::agent_contract::{self, AgentContract};
23use crate::backend::client::QueryResult;
24use crate::backend::types::TextValue;
25use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, TlsMode};
26use crate::config::McpConfig;
27use crate::{ProxyError, Result};
28
29const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
30
31pub struct McpServer {
33 config: McpConfig,
34 contract: Option<AgentContract>,
35}
36
37impl McpServer {
38 pub fn new(config: McpConfig, contract: Option<AgentContract>) -> Self {
39 Self { config, contract }
40 }
41
42 pub async fn run(self) -> Result<()> {
44 let listener = TcpListener::bind(&self.config.listen_address)
45 .await
46 .map_err(|e| {
47 ProxyError::Network(format!("MCP bind {}: {}", self.config.listen_address, e))
48 })?;
49 tracing::info!(addr = %self.config.listen_address, read_only = self.config.read_only,
50 contract = ?self.contract.as_ref().map(|c| &c.id), "MCP agent gateway listening");
51 let cfg = Arc::new(self.config);
52 let contract = Arc::new(self.contract);
53 loop {
54 let (stream, peer) = match listener.accept().await {
55 Ok(x) => x,
56 Err(e) => {
57 tracing::warn!("MCP accept error: {}", e);
58 continue;
59 }
60 };
61 let cfg = cfg.clone();
62 let contract = contract.clone();
63 tokio::spawn(async move {
64 if let Err(e) = Self::handle_connection(stream, cfg, contract).await {
65 tracing::debug!(%peer, "MCP connection error: {}", e);
66 }
67 });
68 }
69 }
70
71 async fn handle_connection(
72 mut stream: tokio::net::TcpStream,
73 cfg: Arc<McpConfig>,
74 contract: Arc<Option<AgentContract>>,
75 ) -> Result<()> {
76 let (reader, mut writer) = stream.split();
77 let mut reader = BufReader::new(reader);
78 let mut line = String::new();
79 let mut content_length = 0usize;
80 use tokio::io::AsyncBufReadExt;
82 let mut first = true;
83 loop {
84 line.clear();
85 let n = reader
86 .read_line(&mut line)
87 .await
88 .map_err(|e| ProxyError::Network(format!("MCP read: {}", e)))?;
89 if n == 0 || line == "\r\n" {
90 break;
91 }
92 if first {
93 first = false; } else if line.to_ascii_lowercase().starts_with("content-length:") {
95 if let Some(v) = line.split(':').nth(1) {
96 content_length = v.trim().parse().unwrap_or(0);
97 }
98 }
99 }
100 let body = if content_length > 0 {
101 let mut buf = vec![0u8; content_length];
102 reader
103 .read_exact(&mut buf)
104 .await
105 .map_err(|e| ProxyError::Network(format!("MCP body read: {}", e)))?;
106 String::from_utf8_lossy(&buf).to_string()
107 } else {
108 String::new()
109 };
110
111 let response = Self::dispatch(&body, &cfg, (*contract).as_ref()).await;
112 match response {
113 Some(v) => {
114 let payload = serde_json::to_string(&v).unwrap_or_else(|_| "{}".to_string());
115 Self::write_http(&mut writer, 200, "application/json", payload.as_bytes()).await
116 }
117 None => Self::write_http(&mut writer, 202, "application/json", b"").await,
119 }
120 }
121
122 async fn dispatch(
124 body: &str,
125 cfg: &McpConfig,
126 contract: Option<&AgentContract>,
127 ) -> Option<Value> {
128 let req: Value = match serde_json::from_str(body) {
129 Ok(v) => v,
130 Err(e) => {
131 return Some(rpc_error(
132 Value::Null,
133 -32700,
134 &format!("parse error: {}", e),
135 ))
136 }
137 };
138 let id = req.get("id").cloned().unwrap_or(Value::Null);
139 let method = req.get("method").and_then(|m| m.as_str()).unwrap_or("");
140 let params = req.get("params").cloned().unwrap_or(json!({}));
141
142 match method {
143 "initialize" => Some(rpc_ok(
144 id,
145 json!({
146 "protocolVersion": MCP_PROTOCOL_VERSION,
147 "serverInfo": { "name": "heliosproxy-mcp", "version": crate::VERSION },
148 "capabilities": { "tools": { "listChanged": false } }
149 }),
150 )),
151 "notifications/initialized" | "notifications/cancelled" => None,
153 "ping" => Some(rpc_ok(id, json!({}))),
154 "tools/list" => Some(rpc_ok(id, json!({ "tools": Self::tool_defs(cfg) }))),
155 "tools/call" => Some(Self::handle_tool_call(id, ¶ms, cfg, contract).await),
156 other => Some(rpc_error(
157 id,
158 -32601,
159 &format!("method not found: {}", other),
160 )),
161 }
162 }
163
164 fn tool_defs(cfg: &McpConfig) -> Value {
165 let query_desc = if cfg.read_only {
166 "Run a read-only SQL query and return rows. Writes/DDL are refused."
167 } else {
168 "Run a SQL query and return rows (or the command tag for writes)."
169 };
170 json!([
171 {
172 "name": "query",
173 "description": query_desc,
174 "inputSchema": {
175 "type": "object",
176 "properties": { "sql": { "type": "string", "description": "SQL to execute" } },
177 "required": ["sql"]
178 }
179 },
180 {
181 "name": "list_tables",
182 "description": "List user tables (schema.table) in the connected database.",
183 "inputSchema": { "type": "object", "properties": {} }
184 },
185 {
186 "name": "explain",
187 "description": "Return the query plan for a SQL statement (EXPLAIN).",
188 "inputSchema": {
189 "type": "object",
190 "properties": { "sql": { "type": "string" } },
191 "required": ["sql"]
192 }
193 }
194 ])
195 }
196
197 async fn handle_tool_call(
198 id: Value,
199 params: &Value,
200 cfg: &McpConfig,
201 contract: Option<&AgentContract>,
202 ) -> Value {
203 let name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
204 let args = params.get("arguments").cloned().unwrap_or(json!({}));
205
206 let result: std::result::Result<String, String> = match name {
207 "query" => {
208 let sql = args
209 .get("sql")
210 .and_then(|s| s.as_str())
211 .unwrap_or("")
212 .trim();
213 if sql.is_empty() {
214 Err("missing 'sql'".to_string())
215 } else {
216 match Self::check_policy(cfg, contract, sql) {
217 Err(hint) => Err(hint),
218 Ok(()) => Self::run_sql(cfg, sql).await.map(|r| format_result(&r)),
219 }
220 }
221 }
222 "list_tables" => {
223 let sql = "SELECT table_schema, table_name FROM information_schema.tables \
224 WHERE table_schema NOT IN ('pg_catalog','information_schema') \
225 ORDER BY table_schema, table_name";
226 Self::run_sql(cfg, sql).await.map(|r| format_result(&r))
227 }
228 "explain" => {
229 let sql = args
230 .get("sql")
231 .and_then(|s| s.as_str())
232 .unwrap_or("")
233 .trim();
234 if sql.is_empty() {
235 Err("missing 'sql'".to_string())
236 } else {
237 match Self::check_policy(cfg, contract, sql) {
238 Err(hint) => Err(hint),
239 Ok(()) => Self::run_sql(cfg, &format!("EXPLAIN {}", sql))
240 .await
241 .map(|r| format_result(&r)),
242 }
243 }
244 }
245 other => Err(format!("unknown tool: {}", other)),
246 };
247
248 match result {
249 Ok(text) => {
250 tracing::info!(tool = %name, "MCP tool call ok");
251 rpc_ok(
252 id,
253 json!({ "content": [{ "type": "text", "text": text }], "isError": false }),
254 )
255 }
256 Err(e) => {
257 tracing::info!(tool = %name, error = %e, "MCP tool call error");
258 rpc_ok(
261 id,
262 json!({ "content": [{ "type": "text", "text": e }], "isError": true }),
263 )
264 }
265 }
266 }
267
268 fn check_policy(
272 cfg: &McpConfig,
273 contract: Option<&AgentContract>,
274 sql: &str,
275 ) -> std::result::Result<(), String> {
276 if let Some(c) = contract {
277 agent_contract::validate(sql, c).map_err(|v| v.to_json())
278 } else if cfg.read_only && is_write_sql(sql) {
279 Err("write/DDL refused: the MCP gateway is read-only".to_string())
280 } else {
281 Ok(())
282 }
283 }
284
285 async fn run_sql(cfg: &McpConfig, sql: &str) -> std::result::Result<QueryResult, String> {
287 let bcfg = BackendConfig {
288 host: cfg.backend_host.clone(),
289 port: cfg.backend_port,
290 user: cfg.backend_user.clone(),
291 password: cfg.backend_password.clone(),
292 database: cfg.backend_database.clone(),
293 application_name: Some("heliosproxy-mcp".to_string()),
294 tls_mode: TlsMode::Disable,
295 connect_timeout: Duration::from_secs(5),
296 query_timeout: Duration::from_secs(30),
297 tls_config: default_client_config(),
298 };
299 let mut client = BackendClient::connect(&bcfg)
300 .await
301 .map_err(|e| format!("backend connect: {}", e))?;
302 let res = client.simple_query(sql).await.map_err(|e| format!("{}", e));
303 client.close().await;
304 res
305 }
306
307 async fn write_http(
308 writer: &mut tokio::net::tcp::WriteHalf<'_>,
309 status: u16,
310 content_type: &str,
311 body: &[u8],
312 ) -> Result<()> {
313 let head = format!(
314 "HTTP/1.1 {} {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
315 status,
316 if status == 200 { "OK" } else { "Accepted" },
317 content_type,
318 body.len()
319 );
320 writer
321 .write_all(head.as_bytes())
322 .await
323 .map_err(|e| ProxyError::Network(format!("MCP write: {}", e)))?;
324 if !body.is_empty() {
325 writer
326 .write_all(body)
327 .await
328 .map_err(|e| ProxyError::Network(format!("MCP write: {}", e)))?;
329 }
330 Ok(())
331 }
332}
333
334fn rpc_ok(id: Value, result: Value) -> Value {
335 json!({ "jsonrpc": "2.0", "id": id, "result": result })
336}
337
338fn rpc_error(id: Value, code: i32, message: &str) -> Value {
339 json!({ "jsonrpc": "2.0", "id": id, "error": { "code": code, "message": message } })
340}
341
342fn format_result(r: &QueryResult) -> String {
344 if r.columns.is_empty() {
345 return r.command_tag.clone();
346 }
347 let header: Vec<&str> = r.columns.iter().map(|c| c.name.as_str()).collect();
348 let mut out = String::new();
349 out.push_str(&header.join(" | "));
350 out.push('\n');
351 for row in &r.rows {
352 let cells: Vec<String> = row
353 .iter()
354 .map(|v| match v {
355 TextValue::Null => "NULL".to_string(),
356 TextValue::Text(s) => s.clone(),
357 })
358 .collect();
359 out.push_str(&cells.join(" | "));
360 out.push('\n');
361 }
362 out.push_str(&format!("({} rows)", r.rows.len()));
363 out
364}
365
366fn is_write_sql(sql: &str) -> bool {
368 use crate::protocol::starts_with_ci;
369 let s = sql.trim_start();
370 for kw in [
371 "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER", "TRUNCATE", "GRANT", "REVOKE",
372 "COPY", "MERGE", "CALL", "DO", "VACUUM", "REINDEX", "CLUSTER", "LOCK", "COMMENT", "SET",
373 ] {
374 if starts_with_ci(s, kw) {
375 return true;
376 }
377 }
378 false
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn read_only_guardrail() {
387 assert!(is_write_sql("INSERT INTO t VALUES (1)"));
388 assert!(is_write_sql(" drop table t"));
389 assert!(is_write_sql("CREATE TABLE t(x int)"));
390 assert!(!is_write_sql("SELECT * FROM t"));
391 assert!(!is_write_sql(" with x as (select 1) select * from x"));
392 }
393
394 #[tokio::test]
395 async fn initialize_and_tools_list() {
396 let cfg = McpConfig::default();
397 let init = McpServer::dispatch(
398 r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#,
399 &cfg,
400 None,
401 )
402 .await
403 .unwrap();
404 assert_eq!(init["result"]["protocolVersion"], MCP_PROTOCOL_VERSION);
405 assert_eq!(init["result"]["serverInfo"]["name"], "heliosproxy-mcp");
406
407 let tools = McpServer::dispatch(
408 r#"{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}"#,
409 &cfg,
410 None,
411 )
412 .await
413 .unwrap();
414 let names: Vec<&str> = tools["result"]["tools"]
415 .as_array()
416 .unwrap()
417 .iter()
418 .map(|t| t["name"].as_str().unwrap())
419 .collect();
420 assert!(names.contains(&"query"));
421 assert!(names.contains(&"list_tables"));
422 assert!(names.contains(&"explain"));
423 }
424
425 #[tokio::test]
426 async fn notification_has_no_response() {
427 let cfg = McpConfig::default();
428 let r = McpServer::dispatch(
429 r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#,
430 &cfg,
431 None,
432 )
433 .await;
434 assert!(r.is_none());
435 }
436
437 #[tokio::test]
438 async fn read_only_blocks_write_tool_call() {
439 let cfg = McpConfig::default(); let r = McpServer::dispatch(
441 r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"query","arguments":{"sql":"DELETE FROM t"}}}"#,
442 &cfg,
443 None,
444 )
445 .await
446 .unwrap();
447 assert_eq!(r["result"]["isError"], true);
448 assert!(r["result"]["content"][0]["text"]
449 .as_str()
450 .unwrap()
451 .contains("read-only"));
452 }
453}