Skip to main content

rivet/
mcp.rs

1//! MCP (Model Context Protocol) server — read-only DB introspection tools.
2//!
3//! The public entry point is [`run_stdio`]; it is invoked by the `rivet-mcp`
4//! binary in `src/bin/rivet-mcp.rs`. See that file for client integration
5//! examples (Claude Desktop, Claude Code).
6//!
7//! The server speaks JSON-RPC 2.0 over stdin/stdout (one object per line).
8//! All tools are read-only — no writes, no DDL.
9
10use serde_json::{Value, json};
11use std::io::{BufRead, Write};
12
13use crate::config::{TlsConfig, TlsMode};
14
15// ─── Public entry point ────────────────────────────────────────────────────
16
17/// Run the MCP server loop on stdin/stdout until EOF.
18///
19/// `pg_url` and `mysql_url` are used for tool calls against each database.
20/// Either (or both) may be `None`; tools for a missing database return an error.
21pub fn run_stdio(pg_url: Option<&str>, mysql_url: Option<&str>) -> anyhow::Result<()> {
22    let stdin = std::io::stdin();
23    let mut stdout = std::io::stdout();
24
25    for line in stdin.lock().lines() {
26        let line = line?;
27        if line.trim().is_empty() {
28            continue;
29        }
30        let msg: Value = match serde_json::from_str(&line) {
31            Ok(v) => v,
32            Err(_) => continue,
33        };
34
35        // Notifications have no `id` and require no response.
36        let id = msg.get("id").cloned();
37        let Some(id) = id else { continue };
38
39        let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
40        let envelope = match dispatch(method, &msg, pg_url, mysql_url) {
41            Ok(result) => json!({ "jsonrpc": "2.0", "id": id, "result": result }),
42            // The MCP client is prompt-injectable (it is an LLM); route every
43            // error through the shared redaction chokepoint so a connect error
44            // carrying a `scheme://user:password@host` URL never reaches it in
45            // cleartext (CWE-209).
46            Err(e) => json!({
47                "jsonrpc": "2.0",
48                "id": id,
49                "error": { "code": -32_000, "message": crate::redact::redact_error(&e) }
50            }),
51        };
52
53        writeln!(stdout, "{}", serde_json::to_string(&envelope)?)?;
54        stdout.flush()?;
55    }
56    Ok(())
57}
58
59// ─── Dispatcher ────────────────────────────────────────────────────────────
60
61fn dispatch(
62    method: &str,
63    msg: &Value,
64    pg_url: Option<&str>,
65    mysql_url: Option<&str>,
66) -> anyhow::Result<Value> {
67    match method {
68        "initialize" => Ok(json!({
69            "protocolVersion": "2024-11-05",
70            "capabilities": { "tools": {} },
71            "serverInfo": {
72                "name": "rivet-mcp",
73                "version": env!("CARGO_PKG_VERSION")
74            }
75        })),
76
77        "tools/list" => Ok(json!({ "tools": tools_list() })),
78
79        "tools/call" => {
80            let params = msg
81                .get("params")
82                .ok_or_else(|| anyhow::anyhow!("missing params"))?;
83            let name = params
84                .get("name")
85                .and_then(|n| n.as_str())
86                .ok_or_else(|| anyhow::anyhow!("missing tool name"))?;
87            let args = params.get("arguments").unwrap_or(&Value::Null);
88            // Per MCP spec, tool execution errors are text content, not JSON-RPC errors.
89            Ok(match call_tool(name, args, pg_url, mysql_url) {
90                Ok(v) => v,
91                // Redact before the error reaches the (prompt-injectable) client.
92                Err(e) => json!({
93                    "content": [{ "type": "text", "text": format!("error: {}", crate::redact::redact_error(&e)) }],
94                    "isError": true
95                }),
96            })
97        }
98
99        _ => Err(anyhow::anyhow!("unknown method: {method}")),
100    }
101}
102
103// ─── Tool registry ─────────────────────────────────────────────────────────
104
105fn tools_list() -> Value {
106    json!([
107        {
108            "name": "pg_active_sessions",
109            "description": "Show non-idle Postgres sessions: pid, state, wait event, query snippet, user, application. Useful to spot blocked or long-running queries during an export.",
110            "inputSchema": { "type": "object", "properties": {}, "required": [] }
111        },
112        {
113            "name": "pg_checkpoint_pressure",
114            "description": "Show pg_stat_bgwriter counters: checkpoints_timed, checkpoints_req (write-pressure indicator), write/sync times, and buffer stats. Rivet adaptive mode reacts to checkpoints_req delta.",
115            "inputSchema": { "type": "object", "properties": {}, "required": [] }
116        },
117        {
118            "name": "pg_table_stats",
119            "description": "Top 20 Postgres tables by live row count: n_live_tup, n_dead_tup, dead ratio, last vacuum/analyze timestamps.",
120            "inputSchema": {
121                "type": "object",
122                "properties": {
123                    "schema": {
124                        "type": "string",
125                        "description": "Restrict to a specific schema (default: all user schemas)"
126                    }
127                },
128                "required": []
129            }
130        },
131        {
132            "name": "pg_locks",
133            "description": "Show relation-level Postgres locks: pid, relation, mode, granted. Useful to diagnose lock contention during an export.",
134            "inputSchema": { "type": "object", "properties": {}, "required": [] }
135        },
136        {
137            "name": "pg_top_queries_by_io",
138            "description": "Top 10 queries by total I/O wait time from pg_stat_statements. Requires the pg_stat_statements extension; returns a clear error if unavailable.",
139            "inputSchema": { "type": "object", "properties": {}, "required": [] }
140        },
141        {
142            "name": "mysql_processlist",
143            "description": "Show MySQL SHOW PROCESSLIST: id, user, db, command, time, state, query snippet.",
144            "inputSchema": { "type": "object", "properties": {}, "required": [] }
145        },
146        {
147            "name": "mysql_key_metrics",
148            "description": "Key MySQL global status counters: Innodb_log_waits, Threads_running, Queries, Slow_queries, Innodb_row_lock_waits, Connections.",
149            "inputSchema": { "type": "object", "properties": {}, "required": [] }
150        },
151        {
152            "name": "mysql_table_stats",
153            "description": "Top 20 MySQL InnoDB tables by row count from information_schema.TABLES.",
154            "inputSchema": {
155                "type": "object",
156                "properties": {
157                    "schema": {
158                        "type": "string",
159                        "description": "Restrict to a specific schema/database (default: all non-system schemas)"
160                    }
161                },
162                "required": []
163            }
164        },
165        {
166            "name": "pgbouncer_pools",
167            "description": "Show pgBouncer pool stats (SHOW POOLS) via the pgBouncer admin interface. Requires PGBOUNCER_ADMIN_URL env var (e.g. postgresql://pgbouncer@127.0.0.1:6432/pgbouncer).",
168            "inputSchema": { "type": "object", "properties": {}, "required": [] }
169        },
170        {
171            "name": "pgbouncer_stats",
172            "description": "Show pgBouncer per-database stats (SHOW STATS). Requires PGBOUNCER_ADMIN_URL env var.",
173            "inputSchema": { "type": "object", "properties": {}, "required": [] }
174        }
175    ])
176}
177
178// ─── Tool dispatch ──────────────────────────────────────────────────────────
179
180fn call_tool(
181    name: &str,
182    args: &Value,
183    pg_url: Option<&str>,
184    mysql_url: Option<&str>,
185) -> anyhow::Result<Value> {
186    match name {
187        "pg_active_sessions" => text(pg_active_sessions(require_pg(pg_url)?)),
188        "pg_checkpoint_pressure" => text(pg_checkpoint_pressure(require_pg(pg_url)?)),
189        "pg_table_stats" => {
190            let schema = args.get("schema").and_then(|v| v.as_str());
191            text(pg_table_stats(require_pg(pg_url)?, schema))
192        }
193        "pg_locks" => text(pg_locks(require_pg(pg_url)?)),
194        "pg_top_queries_by_io" => text(pg_top_queries_by_io(require_pg(pg_url)?)),
195        "mysql_processlist" => text(mysql_processlist(require_mysql(mysql_url)?)),
196        "mysql_key_metrics" => text(mysql_key_metrics(require_mysql(mysql_url)?)),
197        "mysql_table_stats" => {
198            let schema = args.get("schema").and_then(|v| v.as_str());
199            text(mysql_table_stats(require_mysql(mysql_url)?, schema))
200        }
201        "pgbouncer_pools" => text(pgbouncer_query("SHOW POOLS")),
202        "pgbouncer_stats" => text(pgbouncer_query("SHOW STATS")),
203        other => Err(anyhow::anyhow!("unknown tool: {other}")),
204    }
205}
206
207fn require_pg(url: Option<&str>) -> anyhow::Result<&str> {
208    url.ok_or_else(|| {
209        anyhow::anyhow!("no Postgres URL configured — pass --pg-url or set DATABASE_URL")
210    })
211}
212
213fn require_mysql(url: Option<&str>) -> anyhow::Result<&str> {
214    url.ok_or_else(|| {
215        anyhow::anyhow!("no MySQL URL configured — pass --mysql-url or set DATABASE_URL")
216    })
217}
218
219fn text(result: anyhow::Result<String>) -> anyhow::Result<Value> {
220    // On the error branch the body is a stringified driver/connect error that
221    // may embed a `scheme://user:password@host` URL; redact before it reaches
222    // the prompt-injectable client (CWE-209). The Ok branch is tool output
223    // (already-formatted DB rows), which carries no credential material.
224    let body = result.unwrap_or_else(|e| format!("error: {}", crate::redact::redact_error(&e)));
225    Ok(json!({ "content": [{ "type": "text", "text": body }] }))
226}
227
228// ─── Postgres tools ────────────────────────────────────────────────────────
229
230/// Derive a [`TlsConfig`] from a connection URL's `sslmode` query parameter so
231/// the MCP diagnostics tools honor transport security (CWE-319). The MCP server
232/// runs before/without any YAML `tls:` block (it takes raw URLs on the command
233/// line), so the URL's `sslmode` is the only policy signal — the same source
234/// `rivet init` and the state backend use.
235///
236/// `require` / `verify-ca` / `verify-full` map to the enforced mode; everything
237/// else (missing, `disable`, `prefer`, `allow`, unrecognized) returns `None`
238/// (plaintext), which keeps loopback/local-dev URLs working. Last occurrence
239/// wins, matching libpq. Returning `None` for a remote plaintext URL is what
240/// lets the shared connect seams refuse it via `require_tls_or_loopback`.
241fn tls_config_from_url(url: &str) -> Option<TlsConfig> {
242    let (_, query) = url.split_once('?')?;
243    let mut mode = None;
244    for pair in query.split('&') {
245        let (key, value) = pair.split_once('=').unwrap_or((pair, ""));
246        if key != "sslmode" {
247            continue;
248        }
249        mode = match value {
250            "require" => Some(TlsMode::Require),
251            "verify-ca" => Some(TlsMode::VerifyCa),
252            "verify-full" => Some(TlsMode::VerifyFull),
253            _ => None,
254        };
255    }
256    mode.map(|mode| TlsConfig {
257        mode,
258        ..TlsConfig::default()
259    })
260}
261
262fn pg_connect(url: &str) -> anyhow::Result<postgres::Client> {
263    // Route through the shared TLS-aware seam (same path as doctor/check/init)
264    // so `sslmode=require|verify-ca|verify-full` is honored and remote plaintext
265    // is refused before any dial.
266    let tls = tls_config_from_url(url);
267    crate::source::postgres::connect_client(url, tls.as_ref())
268}
269
270/// Convert a Postgres row cell to a displayable string.
271fn pg_val(row: &postgres::Row, idx: usize) -> String {
272    // Try common types in priority order; most pg_stat* columns are int8/float8/text.
273    if let Ok(v) = row.try_get::<_, Option<String>>(idx) {
274        return v.unwrap_or_else(|| "NULL".into());
275    }
276    if let Ok(v) = row.try_get::<_, Option<i64>>(idx) {
277        return v.map(|n| n.to_string()).unwrap_or_else(|| "NULL".into());
278    }
279    if let Ok(v) = row.try_get::<_, Option<i32>>(idx) {
280        return v.map(|n| n.to_string()).unwrap_or_else(|| "NULL".into());
281    }
282    if let Ok(v) = row.try_get::<_, Option<f64>>(idx) {
283        return v
284            .map(|n| format!("{n:.2}"))
285            .unwrap_or_else(|| "NULL".into());
286    }
287    if let Ok(v) = row.try_get::<_, Option<bool>>(idx) {
288        return v.map(|b| b.to_string()).unwrap_or_else(|| "NULL".into());
289    }
290    if let Ok(v) = row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(idx) {
291        return v
292            .map(|t| t.format("%Y-%m-%d %H:%M:%S").to_string())
293            .unwrap_or_else(|| "NULL".into());
294    }
295    if let Ok(v) = row.try_get::<_, Option<chrono::NaiveDateTime>>(idx) {
296        return v
297            .map(|t| t.format("%Y-%m-%d %H:%M:%S").to_string())
298            .unwrap_or_else(|| "NULL".into());
299    }
300    "?".into()
301}
302
303fn pg_rows_to_table(rows: &[postgres::Row]) -> String {
304    if rows.is_empty() {
305        return "(no rows)".into();
306    }
307    let headers: Vec<String> = rows[0]
308        .columns()
309        .iter()
310        .map(|c| c.name().to_string())
311        .collect();
312    let data: Vec<Vec<String>> = rows
313        .iter()
314        .map(|row| (0..headers.len()).map(|i| pg_val(row, i)).collect())
315        .collect();
316    ascii_table(&headers, &data)
317}
318
319fn pg_active_sessions(url: &str) -> anyhow::Result<String> {
320    let mut client = pg_connect(url)?;
321    let rows = client.query(
322        "SELECT pid::text, state, COALESCE(wait_event_type,'') AS wait_type,
323                COALESCE(wait_event,'') AS wait_event,
324                LEFT(COALESCE(query,''),100) AS query_snippet,
325                usename, application_name
326         FROM pg_stat_activity
327         WHERE state IS DISTINCT FROM 'idle'
328         ORDER BY state, pid",
329        &[],
330    )?;
331    Ok(format!(
332        "Active sessions ({})\n\n{}",
333        rows.len(),
334        pg_rows_to_table(&rows)
335    ))
336}
337
338fn pg_checkpoint_pressure(url: &str) -> anyhow::Result<String> {
339    let mut client = pg_connect(url)?;
340    let rows = client.query(
341        "SELECT checkpoints_timed, checkpoints_req,
342                ROUND(checkpoint_write_time) AS write_ms,
343                ROUND(checkpoint_sync_time) AS sync_ms,
344                buffers_checkpoint, buffers_clean, buffers_backend,
345                maxwritten_clean
346         FROM pg_stat_bgwriter",
347        &[],
348    )?;
349    Ok(format!("pg_stat_bgwriter\n\n{}", pg_rows_to_table(&rows)))
350}
351
352/// SQL for `pg_table_stats`. The schema value only selects the variant — the
353/// caller binds it as `$1`, so client-supplied input never lands in the SQL
354/// text (the tool's read-only promise depends on this).
355fn pg_table_stats_sql(schema: Option<&str>) -> &'static str {
356    match schema {
357        Some(_) => {
358            "SELECT schemaname, relname AS tablename, n_live_tup, n_dead_tup,
359                    (n_dead_tup * 100 / NULLIF(n_live_tup + n_dead_tup, 0)) AS dead_pct,
360                    COALESCE(to_char(last_vacuum, 'YYYY-MM-DD HH24:MI'), '-') AS last_vacuum,
361                    COALESCE(to_char(last_analyze, 'YYYY-MM-DD HH24:MI'), '-') AS last_analyze
362             FROM pg_stat_user_tables
363             WHERE schemaname = $1::text
364             ORDER BY n_live_tup DESC
365             LIMIT 20"
366        }
367        None => {
368            "SELECT schemaname, relname AS tablename, n_live_tup, n_dead_tup,
369                    (n_dead_tup * 100 / NULLIF(n_live_tup + n_dead_tup, 0)) AS dead_pct,
370                    COALESCE(to_char(last_vacuum, 'YYYY-MM-DD HH24:MI'), '-') AS last_vacuum,
371                    COALESCE(to_char(last_analyze, 'YYYY-MM-DD HH24:MI'), '-') AS last_analyze
372             FROM pg_stat_user_tables
373             WHERE schemaname NOT IN ('pg_catalog','information_schema','pg_toast')
374             ORDER BY n_live_tup DESC
375             LIMIT 20"
376        }
377    }
378}
379
380fn pg_table_stats(url: &str, schema: Option<&str>) -> anyhow::Result<String> {
381    let mut client = pg_connect(url)?;
382    let sql = pg_table_stats_sql(schema);
383    let rows = match schema {
384        Some(s) => client.query(sql, &[&s])?,
385        None => client.query(sql, &[])?,
386    };
387    Ok(format!(
388        "Table stats (top 20)\n\n{}",
389        pg_rows_to_table(&rows)
390    ))
391}
392
393fn pg_locks(url: &str) -> anyhow::Result<String> {
394    let mut client = pg_connect(url)?;
395    let rows = client.query(
396        "SELECT l.pid::text, c.relname AS relation, l.mode, l.granted::text
397         FROM pg_locks l
398         LEFT JOIN pg_class c ON c.oid = l.relation
399         WHERE l.relation IS NOT NULL
400         ORDER BY l.granted, l.pid",
401        &[],
402    )?;
403    if rows.is_empty() {
404        return Ok("No relation-level locks held.".into());
405    }
406    Ok(format!(
407        "Relation locks ({})\n\n{}",
408        rows.len(),
409        pg_rows_to_table(&rows)
410    ))
411}
412
413fn pg_top_queries_by_io(url: &str) -> anyhow::Result<String> {
414    let mut client = pg_connect(url)?;
415    // Check that pg_stat_statements is available.
416    let available: bool = client
417        .query_one(
418            "SELECT COUNT(*) > 0 FROM pg_extension WHERE extname = 'pg_stat_statements'",
419            &[],
420        )
421        .ok()
422        .and_then(|r| r.try_get::<_, bool>(0).ok())
423        .unwrap_or(false);
424    if !available {
425        return Ok("pg_stat_statements extension is not installed. \
426             Run: CREATE EXTENSION IF NOT EXISTS pg_stat_statements;"
427            .into());
428    }
429    let rows = client.query(
430        "SELECT LEFT(query, 80) AS query, calls,
431                ROUND(blk_read_time + blk_write_time) AS io_ms,
432                ROUND(total_exec_time) AS total_exec_ms
433         FROM pg_stat_statements
434         ORDER BY blk_read_time + blk_write_time DESC
435         LIMIT 10",
436        &[],
437    )?;
438    Ok(format!(
439        "Top 10 queries by I/O time\n\n{}",
440        pg_rows_to_table(&rows)
441    ))
442}
443
444// ─── MySQL tools ───────────────────────────────────────────────────────────
445
446fn mysql_pool(url: &str) -> anyhow::Result<mysql::Pool> {
447    // Route through the shared TLS-aware seam so `sslmode` from the URL enables
448    // `ssl_opts` (previously the MCP pool built Opts with no TLS — CWE-319) and
449    // remote plaintext is refused before any dial. The seam pins lean pool opts
450    // (no eager pre-connection) for these short-lived diagnostics calls.
451    let tls = tls_config_from_url(url);
452    crate::source::mysql::connect_pool(url, tls.as_ref())
453}
454
455fn mysql_rows_to_table(rows: &[Vec<String>], headers: &[String]) -> String {
456    if rows.is_empty() {
457        return "(no rows)".into();
458    }
459    ascii_table(headers, rows)
460}
461
462fn mysql_processlist(url: &str) -> anyhow::Result<String> {
463    use mysql::prelude::*;
464    let pool = mysql_pool(url)?;
465    let mut conn = pool.get_conn()?;
466    let mut result = conn.exec_iter("SHOW PROCESSLIST", ())?;
467    let cols: Vec<String> = result
468        .columns()
469        .as_ref()
470        .iter()
471        .map(|c| c.name_str().to_string())
472        .collect();
473    let row_set = result
474        .iter()
475        .ok_or_else(|| anyhow::anyhow!("no result set"))?;
476    let rows: Vec<Vec<String>> = row_set
477        .filter_map(|r| r.ok())
478        .map(|row| {
479            (0..cols.len())
480                .map(|i| match row.as_ref(i) {
481                    Some(mysql::Value::Bytes(b)) => String::from_utf8_lossy(b).into_owned(),
482                    Some(mysql::Value::Int(n)) => n.to_string(),
483                    Some(mysql::Value::UInt(n)) => n.to_string(),
484                    Some(mysql::Value::NULL) | None => "NULL".into(),
485                    _ => "?".into(),
486                })
487                .collect()
488        })
489        .collect();
490    Ok(format!(
491        "SHOW PROCESSLIST ({})\n\n{}",
492        rows.len(),
493        mysql_rows_to_table(&rows, &cols)
494    ))
495}
496
497fn mysql_key_metrics(url: &str) -> anyhow::Result<String> {
498    use mysql::prelude::*;
499    let pool = mysql_pool(url)?;
500    let mut conn = pool.get_conn()?;
501    let metrics = [
502        "Innodb_log_waits",
503        "Innodb_row_lock_waits",
504        "Innodb_row_lock_time_avg",
505        "Threads_running",
506        "Threads_connected",
507        "Queries",
508        "Slow_queries",
509        "Connections",
510        "Aborted_connects",
511    ];
512    let in_clause = metrics
513        .iter()
514        .map(|m| format!("'{m}'"))
515        .collect::<Vec<_>>()
516        .join(",");
517    let sql = format!(
518        "SELECT variable_name, variable_value \
519         FROM information_schema.global_status \
520         WHERE variable_name IN ({in_clause})"
521    );
522    let rows: Vec<(String, String)> = conn.query(sql)?;
523    if rows.is_empty() {
524        return Ok("(no metrics returned)".into());
525    }
526    let headers = vec!["metric".to_string(), "value".to_string()];
527    let data: Vec<Vec<String>> = rows.into_iter().map(|(k, v)| vec![k, v]).collect();
528    Ok(format!(
529        "MySQL key metrics\n\n{}",
530        ascii_table(&headers, &data)
531    ))
532}
533
534/// SQL for `mysql_table_stats`. The schema value only selects the variant —
535/// the caller binds it as `?`, so client-supplied input never lands in the
536/// SQL text (the tool's read-only promise depends on this).
537fn mysql_table_stats_sql(schema: Option<&str>) -> &'static str {
538    match schema {
539        Some(_) => {
540            "SELECT table_schema, table_name, table_rows, \
541                    data_length, index_length, engine \
542             FROM information_schema.TABLES \
543             WHERE table_type = 'BASE TABLE' AND table_schema = ? \
544             ORDER BY table_rows DESC \
545             LIMIT 20"
546        }
547        None => {
548            "SELECT table_schema, table_name, table_rows, \
549                    data_length, index_length, engine \
550             FROM information_schema.TABLES \
551             WHERE table_type = 'BASE TABLE' \
552               AND table_schema NOT IN ('information_schema','performance_schema','mysql','sys') \
553             ORDER BY table_rows DESC \
554             LIMIT 20"
555        }
556    }
557}
558
559fn mysql_table_stats(url: &str, schema: Option<&str>) -> anyhow::Result<String> {
560    use mysql::prelude::*;
561    let pool = mysql_pool(url)?;
562    let mut conn = pool.get_conn()?;
563    let sql = mysql_table_stats_sql(schema);
564    let mut result = match schema {
565        Some(s) => conn.exec_iter(sql, (s,))?,
566        None => conn.exec_iter(sql, ())?,
567    };
568    let cols: Vec<String> = result
569        .columns()
570        .as_ref()
571        .iter()
572        .map(|c| c.name_str().to_string())
573        .collect();
574    let row_set = result
575        .iter()
576        .ok_or_else(|| anyhow::anyhow!("no result set"))?;
577    let rows: Vec<Vec<String>> = row_set
578        .filter_map(|r| r.ok())
579        .map(|row| {
580            (0..cols.len())
581                .map(|i| match row.as_ref(i) {
582                    Some(mysql::Value::Bytes(b)) => String::from_utf8_lossy(b).into_owned(),
583                    Some(mysql::Value::Int(n)) => n.to_string(),
584                    Some(mysql::Value::UInt(n)) => n.to_string(),
585                    Some(mysql::Value::NULL) | None => "NULL".into(),
586                    _ => "?".into(),
587                })
588                .collect()
589        })
590        .collect();
591    Ok(format!(
592        "Table stats (top 20)\n\n{}",
593        mysql_rows_to_table(&rows, &cols)
594    ))
595}
596
597// ─── pgBouncer tools ───────────────────────────────────────────────────────
598
599fn pgbouncer_query(sql: &str) -> anyhow::Result<String> {
600    let admin_url = std::env::var("PGBOUNCER_ADMIN_URL").map_err(|_| {
601        anyhow::anyhow!(
602            "PGBOUNCER_ADMIN_URL not set. \
603             Example: postgresql://pgbouncer@127.0.0.1:6432/pgbouncer"
604        )
605    })?;
606    let mut client = pg_connect(&admin_url)?;
607    let rows = client.query(sql, &[])?;
608    Ok(pg_rows_to_table(&rows))
609}
610
611// ─── ASCII table formatter ─────────────────────────────────────────────────
612
613fn ascii_table(headers: &[impl AsRef<str>], rows: &[Vec<String>]) -> String {
614    let ncols = headers.len();
615    let mut widths: Vec<usize> = headers.iter().map(|h| h.as_ref().len()).collect();
616    for row in rows {
617        for (i, cell) in row.iter().enumerate() {
618            if i < ncols {
619                widths[i] = widths[i].max(cell.len());
620            }
621        }
622    }
623
624    let fmt_row = |cells: &[String]| -> String {
625        cells
626            .iter()
627            .enumerate()
628            .map(|(i, c)| format!("{:<width$}", c, width = widths.get(i).copied().unwrap_or(0)))
629            .collect::<Vec<_>>()
630            .join(" | ")
631    };
632
633    let header: Vec<String> = headers.iter().map(|h| h.as_ref().to_string()).collect();
634    let separator = widths
635        .iter()
636        .map(|w| "-".repeat(*w))
637        .collect::<Vec<_>>()
638        .join("-+-");
639    let body = rows
640        .iter()
641        .map(|r| fmt_row(r))
642        .collect::<Vec<_>>()
643        .join("\n");
644
645    if body.is_empty() {
646        format!("{}\n{}", fmt_row(&header), separator)
647    } else {
648        format!("{}\n{}\n{}", fmt_row(&header), separator, body)
649    }
650}
651
652#[cfg(test)]
653mod tests {
654    use super::*;
655
656    #[test]
657    fn ascii_table_widens_columns_to_longest_cell() {
658        let headers = ["pid", "state"];
659        let rows = vec![
660            vec!["1".into(), "active".into()],
661            vec!["10000".into(), "idle".into()],
662        ];
663        let out = ascii_table(&headers, &rows);
664        let lines: Vec<&str> = out.lines().collect();
665        assert_eq!(lines.len(), 4, "header + separator + 2 rows");
666        // pid column widened to 5 (len of "10000"); state column to 6 (len of "active").
667        // Separator uses "-+-" as the cross — adds an extra dash on each side of '+'.
668        assert_eq!(lines[0], "pid   | state ");
669        assert_eq!(lines[1], "------+-------");
670        assert_eq!(lines[2], "1     | active");
671        assert_eq!(lines[3], "10000 | idle  ");
672    }
673
674    #[test]
675    fn ascii_table_renders_header_only_when_no_rows() {
676        let headers = ["col_a", "col_b"];
677        let out = ascii_table(&headers, &[]);
678        // No body line — header + separator only.
679        assert_eq!(out, "col_a | col_b\n------+------");
680    }
681
682    // Hostile `schema` values an MCP client (an LLM) could supply. Both stay
683    // inside a single SELECT, so single-statement defaults do not block them —
684    // only bind parameters do.
685    const HOSTILE_PG: &str = "x' UNION SELECT usename, passwd, 0, 0, 0, '-', '-' FROM pg_shadow --";
686    const HOSTILE_MYSQL: &str =
687        "x' UNION SELECT user, authentication_string, 0, 0, 0, 'x' FROM mysql.user -- ";
688
689    #[test]
690    fn pg_table_stats_sql_binds_schema_instead_of_interpolating() {
691        let sql = pg_table_stats_sql(Some(HOSTILE_PG));
692        assert!(
693            sql.contains("schemaname = $1"),
694            "schema filter must use a bind placeholder, got: {sql}"
695        );
696        assert!(
697            !sql.contains(HOSTILE_PG) && !sql.contains("UNION"),
698            "client input must never land in the SQL text, got: {sql}"
699        );
700        // The SQL text is identical regardless of the schema value.
701        assert_eq!(sql, pg_table_stats_sql(Some("public")));
702    }
703
704    #[test]
705    fn pg_table_stats_sql_no_schema_is_static_with_no_placeholder() {
706        let sql = pg_table_stats_sql(None);
707        assert!(sql.contains("schemaname NOT IN"));
708        assert!(!sql.contains("$1"), "fallback takes no bind params: {sql}");
709    }
710
711    #[test]
712    fn mysql_table_stats_sql_binds_schema_instead_of_interpolating() {
713        let sql = mysql_table_stats_sql(Some(HOSTILE_MYSQL));
714        assert!(
715            sql.contains("table_schema = ?"),
716            "schema filter must use a bind placeholder, got: {sql}"
717        );
718        assert!(
719            !sql.contains(HOSTILE_MYSQL) && !sql.contains("UNION"),
720            "client input must never land in the SQL text, got: {sql}"
721        );
722        assert_eq!(sql, mysql_table_stats_sql(Some("appdb")));
723    }
724
725    #[test]
726    fn mysql_table_stats_sql_no_schema_is_static_with_no_placeholder() {
727        let sql = mysql_table_stats_sql(None);
728        assert!(sql.contains("table_schema NOT IN"));
729        assert!(!sql.contains('?'), "fallback takes no bind params: {sql}");
730    }
731
732    // ── V10/V18: sslmode → TlsConfig derivation ──────────────────────────────
733
734    #[test]
735    fn tls_config_from_url_enforces_when_sslmode_requested() {
736        for (url, want) in [
737            (
738                "postgresql://u:p@db.prod:5432/d?sslmode=require",
739                TlsMode::Require,
740            ),
741            (
742                "postgresql://u:p@db.prod/d?sslmode=verify-ca",
743                TlsMode::VerifyCa,
744            ),
745            (
746                "mysql://u:p@db.prod:3306/d?sslmode=verify-full",
747                TlsMode::VerifyFull,
748            ),
749        ] {
750            let cfg = tls_config_from_url(url)
751                .unwrap_or_else(|| panic!("expected enforced TLS for {url}"));
752            assert_eq!(cfg.mode, want, "url {url}");
753            assert!(cfg.mode.is_enforced(), "url {url} must enforce TLS");
754        }
755    }
756
757    #[test]
758    fn tls_config_from_url_none_for_plaintext_or_missing() {
759        // Missing, disable, prefer/allow, unrecognized, or empty → None (plaintext),
760        // which on a remote host is what makes the shared seam refuse the dial.
761        for url in [
762            "postgresql://u:p@localhost/d",
763            "mysql://u:p@127.0.0.1:3306/d",
764            "postgresql://u:p@db/d?sslmode=disable",
765            "postgresql://u:p@db/d?sslmode=prefer",
766            "postgresql://u:p@db/d?sslmode=allow",
767            "postgresql://u:p@db/d?sslmode=REQUIRE",
768            "postgresql://u:p@db/d?sslmode=garbage",
769            "postgresql://u:p@db/d?sslmode",
770            "postgresql://u:p@db/d?sslmode=",
771        ] {
772            assert!(tls_config_from_url(url).is_none(), "url {url} must be None");
773        }
774    }
775
776    #[test]
777    fn tls_config_from_url_exact_key_and_last_occurrence_wins() {
778        // `xsslmode` is a different parameter; the exact `sslmode` key matters.
779        assert!(tls_config_from_url("postgresql://u:p@db/d?xsslmode=require").is_none());
780        // Last occurrence wins (matches libpq), even mid-query.
781        let cfg = tls_config_from_url(
782            "postgresql://u:p@db/d?connect_timeout=10&sslmode=require&application_name=x",
783        )
784        .expect("enforced");
785        assert_eq!(cfg.mode, TlsMode::Require);
786        assert!(
787            tls_config_from_url("postgresql://u:p@db/d?sslmode=require&sslmode=disable").is_none()
788        );
789    }
790
791    // ── V11: MCP error emission must pass through the redaction chokepoint ─────
792
793    #[test]
794    fn sec_mcp_error_is_redacted() {
795        // A connect error commonly stringifies the URL it failed to reach,
796        // including `user:password@host`. The MCP client is an LLM (prompt-
797        // injectable), so the tool-output `text()` path must redact before the
798        // password reaches it.
799        let err = anyhow::anyhow!(
800            "could not connect to postgresql://rivet:s3cret@db.prod:5432/orders: timeout"
801        );
802        let value = text(Err(err)).expect("text() always returns Ok envelope");
803        let body = value["content"][0]["text"]
804            .as_str()
805            .expect("text content present");
806        assert!(
807            !body.contains("s3cret"),
808            "password must be redacted in MCP error output: {body}"
809        );
810        assert!(
811            body.contains("postgresql://REDACTED@db.prod:5432/orders"),
812            "host/path retained, userinfo redacted: {body}"
813        );
814    }
815
816    #[test]
817    fn ascii_table_handles_unicode_byte_width() {
818        // String::len() returns byte count, not grapheme count. Cyrillic is 2 bytes/char.
819        // Documenting current behavior so a future width-aware change is intentional.
820        let headers = ["x"];
821        let rows = vec![vec!["ы".into()]]; // 2 bytes
822        let out = ascii_table(&headers, &rows);
823        // Header width = max(1, 2) = 2 bytes — header padded to 2.
824        assert!(out.contains("x "), "header padded to byte-width 2");
825    }
826}