Skip to main content

rivet/
mcp.rs

1//! MCP (Model Context Protocol) server — read-only DB introspection tools.
2//!
3//! The public entry point is [`run_stdio`]; it is invoked by the `rivet-mcp`
4//! binary in `src/bin/rivet-mcp.rs`. See that file for client integration
5//! examples (Claude Desktop, Claude Code).
6//!
7//! The server speaks JSON-RPC 2.0 over stdin/stdout (one object per line).
8//! All tools are read-only — no writes, no DDL.
9
10use serde_json::{Value, json};
11use std::io::{BufRead, Write};
12
13// ─── Public entry point ────────────────────────────────────────────────────
14
15/// Run the MCP server loop on stdin/stdout until EOF.
16///
17/// `pg_url` and `mysql_url` are used for tool calls against each database.
18/// Either (or both) may be `None`; tools for a missing database return an error.
19pub fn run_stdio(pg_url: Option<&str>, mysql_url: Option<&str>) -> anyhow::Result<()> {
20    let stdin = std::io::stdin();
21    let mut stdout = std::io::stdout();
22
23    for line in stdin.lock().lines() {
24        let line = line?;
25        if line.trim().is_empty() {
26            continue;
27        }
28        let msg: Value = match serde_json::from_str(&line) {
29            Ok(v) => v,
30            Err(_) => continue,
31        };
32
33        // Notifications have no `id` and require no response.
34        let id = msg.get("id").cloned();
35        let Some(id) = id else { continue };
36
37        let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
38        let envelope = match dispatch(method, &msg, pg_url, mysql_url) {
39            Ok(result) => json!({ "jsonrpc": "2.0", "id": id, "result": result }),
40            Err(e) => json!({
41                "jsonrpc": "2.0",
42                "id": id,
43                "error": { "code": -32_000, "message": e.to_string() }
44            }),
45        };
46
47        writeln!(stdout, "{}", serde_json::to_string(&envelope)?)?;
48        stdout.flush()?;
49    }
50    Ok(())
51}
52
53// ─── Dispatcher ────────────────────────────────────────────────────────────
54
55fn dispatch(
56    method: &str,
57    msg: &Value,
58    pg_url: Option<&str>,
59    mysql_url: Option<&str>,
60) -> anyhow::Result<Value> {
61    match method {
62        "initialize" => Ok(json!({
63            "protocolVersion": "2024-11-05",
64            "capabilities": { "tools": {} },
65            "serverInfo": {
66                "name": "rivet-mcp",
67                "version": env!("CARGO_PKG_VERSION")
68            }
69        })),
70
71        "tools/list" => Ok(json!({ "tools": tools_list() })),
72
73        "tools/call" => {
74            let params = msg
75                .get("params")
76                .ok_or_else(|| anyhow::anyhow!("missing params"))?;
77            let name = params
78                .get("name")
79                .and_then(|n| n.as_str())
80                .ok_or_else(|| anyhow::anyhow!("missing tool name"))?;
81            let args = params.get("arguments").unwrap_or(&Value::Null);
82            // Per MCP spec, tool execution errors are text content, not JSON-RPC errors.
83            Ok(match call_tool(name, args, pg_url, mysql_url) {
84                Ok(v) => v,
85                Err(e) => json!({
86                    "content": [{ "type": "text", "text": format!("error: {e}") }],
87                    "isError": true
88                }),
89            })
90        }
91
92        _ => Err(anyhow::anyhow!("unknown method: {method}")),
93    }
94}
95
96// ─── Tool registry ─────────────────────────────────────────────────────────
97
98fn tools_list() -> Value {
99    json!([
100        {
101            "name": "pg_active_sessions",
102            "description": "Show non-idle Postgres sessions: pid, state, wait event, query snippet, user, application. Useful to spot blocked or long-running queries during an export.",
103            "inputSchema": { "type": "object", "properties": {}, "required": [] }
104        },
105        {
106            "name": "pg_checkpoint_pressure",
107            "description": "Show pg_stat_bgwriter counters: checkpoints_timed, checkpoints_req (write-pressure indicator), write/sync times, and buffer stats. Rivet adaptive mode reacts to checkpoints_req delta.",
108            "inputSchema": { "type": "object", "properties": {}, "required": [] }
109        },
110        {
111            "name": "pg_table_stats",
112            "description": "Top 20 Postgres tables by live row count: n_live_tup, n_dead_tup, dead ratio, last vacuum/analyze timestamps.",
113            "inputSchema": {
114                "type": "object",
115                "properties": {
116                    "schema": {
117                        "type": "string",
118                        "description": "Restrict to a specific schema (default: all user schemas)"
119                    }
120                },
121                "required": []
122            }
123        },
124        {
125            "name": "pg_locks",
126            "description": "Show relation-level Postgres locks: pid, relation, mode, granted. Useful to diagnose lock contention during an export.",
127            "inputSchema": { "type": "object", "properties": {}, "required": [] }
128        },
129        {
130            "name": "pg_top_queries_by_io",
131            "description": "Top 10 queries by total I/O wait time from pg_stat_statements. Requires the pg_stat_statements extension; returns a clear error if unavailable.",
132            "inputSchema": { "type": "object", "properties": {}, "required": [] }
133        },
134        {
135            "name": "mysql_processlist",
136            "description": "Show MySQL SHOW PROCESSLIST: id, user, db, command, time, state, query snippet.",
137            "inputSchema": { "type": "object", "properties": {}, "required": [] }
138        },
139        {
140            "name": "mysql_key_metrics",
141            "description": "Key MySQL global status counters: Innodb_log_waits, Threads_running, Queries, Slow_queries, Innodb_row_lock_waits, Connections.",
142            "inputSchema": { "type": "object", "properties": {}, "required": [] }
143        },
144        {
145            "name": "mysql_table_stats",
146            "description": "Top 20 MySQL InnoDB tables by row count from information_schema.TABLES.",
147            "inputSchema": {
148                "type": "object",
149                "properties": {
150                    "schema": {
151                        "type": "string",
152                        "description": "Restrict to a specific schema/database (default: all non-system schemas)"
153                    }
154                },
155                "required": []
156            }
157        },
158        {
159            "name": "pgbouncer_pools",
160            "description": "Show pgBouncer pool stats (SHOW POOLS) via the pgBouncer admin interface. Requires PGBOUNCER_ADMIN_URL env var (e.g. postgresql://pgbouncer@127.0.0.1:6432/pgbouncer).",
161            "inputSchema": { "type": "object", "properties": {}, "required": [] }
162        },
163        {
164            "name": "pgbouncer_stats",
165            "description": "Show pgBouncer per-database stats (SHOW STATS). Requires PGBOUNCER_ADMIN_URL env var.",
166            "inputSchema": { "type": "object", "properties": {}, "required": [] }
167        }
168    ])
169}
170
171// ─── Tool dispatch ──────────────────────────────────────────────────────────
172
173fn call_tool(
174    name: &str,
175    args: &Value,
176    pg_url: Option<&str>,
177    mysql_url: Option<&str>,
178) -> anyhow::Result<Value> {
179    match name {
180        "pg_active_sessions" => text(pg_active_sessions(require_pg(pg_url)?)),
181        "pg_checkpoint_pressure" => text(pg_checkpoint_pressure(require_pg(pg_url)?)),
182        "pg_table_stats" => {
183            let schema = args.get("schema").and_then(|v| v.as_str());
184            text(pg_table_stats(require_pg(pg_url)?, schema))
185        }
186        "pg_locks" => text(pg_locks(require_pg(pg_url)?)),
187        "pg_top_queries_by_io" => text(pg_top_queries_by_io(require_pg(pg_url)?)),
188        "mysql_processlist" => text(mysql_processlist(require_mysql(mysql_url)?)),
189        "mysql_key_metrics" => text(mysql_key_metrics(require_mysql(mysql_url)?)),
190        "mysql_table_stats" => {
191            let schema = args.get("schema").and_then(|v| v.as_str());
192            text(mysql_table_stats(require_mysql(mysql_url)?, schema))
193        }
194        "pgbouncer_pools" => text(pgbouncer_query("SHOW POOLS")),
195        "pgbouncer_stats" => text(pgbouncer_query("SHOW STATS")),
196        other => Err(anyhow::anyhow!("unknown tool: {other}")),
197    }
198}
199
200fn require_pg(url: Option<&str>) -> anyhow::Result<&str> {
201    url.ok_or_else(|| {
202        anyhow::anyhow!("no Postgres URL configured — pass --pg-url or set DATABASE_URL")
203    })
204}
205
206fn require_mysql(url: Option<&str>) -> anyhow::Result<&str> {
207    url.ok_or_else(|| {
208        anyhow::anyhow!("no MySQL URL configured — pass --mysql-url or set DATABASE_URL")
209    })
210}
211
212fn text(result: anyhow::Result<String>) -> anyhow::Result<Value> {
213    let body = result.unwrap_or_else(|e| format!("error: {e}"));
214    Ok(json!({ "content": [{ "type": "text", "text": body }] }))
215}
216
217// ─── Postgres tools ────────────────────────────────────────────────────────
218
219fn pg_connect(url: &str) -> anyhow::Result<postgres::Client> {
220    use postgres::NoTls;
221    Ok(postgres::Client::connect(url, NoTls)?)
222}
223
224/// Convert a Postgres row cell to a displayable string.
225fn pg_val(row: &postgres::Row, idx: usize) -> String {
226    // Try common types in priority order; most pg_stat* columns are int8/float8/text.
227    if let Ok(v) = row.try_get::<_, Option<String>>(idx) {
228        return v.unwrap_or_else(|| "NULL".into());
229    }
230    if let Ok(v) = row.try_get::<_, Option<i64>>(idx) {
231        return v.map(|n| n.to_string()).unwrap_or_else(|| "NULL".into());
232    }
233    if let Ok(v) = row.try_get::<_, Option<i32>>(idx) {
234        return v.map(|n| n.to_string()).unwrap_or_else(|| "NULL".into());
235    }
236    if let Ok(v) = row.try_get::<_, Option<f64>>(idx) {
237        return v
238            .map(|n| format!("{n:.2}"))
239            .unwrap_or_else(|| "NULL".into());
240    }
241    if let Ok(v) = row.try_get::<_, Option<bool>>(idx) {
242        return v.map(|b| b.to_string()).unwrap_or_else(|| "NULL".into());
243    }
244    if let Ok(v) = row.try_get::<_, Option<chrono::DateTime<chrono::Utc>>>(idx) {
245        return v
246            .map(|t| t.format("%Y-%m-%d %H:%M:%S").to_string())
247            .unwrap_or_else(|| "NULL".into());
248    }
249    if let Ok(v) = row.try_get::<_, Option<chrono::NaiveDateTime>>(idx) {
250        return v
251            .map(|t| t.format("%Y-%m-%d %H:%M:%S").to_string())
252            .unwrap_or_else(|| "NULL".into());
253    }
254    "?".into()
255}
256
257fn pg_rows_to_table(rows: &[postgres::Row]) -> String {
258    if rows.is_empty() {
259        return "(no rows)".into();
260    }
261    let headers: Vec<String> = rows[0]
262        .columns()
263        .iter()
264        .map(|c| c.name().to_string())
265        .collect();
266    let data: Vec<Vec<String>> = rows
267        .iter()
268        .map(|row| (0..headers.len()).map(|i| pg_val(row, i)).collect())
269        .collect();
270    ascii_table(&headers, &data)
271}
272
273fn pg_active_sessions(url: &str) -> anyhow::Result<String> {
274    let mut client = pg_connect(url)?;
275    let rows = client.query(
276        "SELECT pid::text, state, COALESCE(wait_event_type,'') AS wait_type,
277                COALESCE(wait_event,'') AS wait_event,
278                LEFT(COALESCE(query,''),100) AS query_snippet,
279                usename, application_name
280         FROM pg_stat_activity
281         WHERE state IS DISTINCT FROM 'idle'
282         ORDER BY state, pid",
283        &[],
284    )?;
285    Ok(format!(
286        "Active sessions ({})\n\n{}",
287        rows.len(),
288        pg_rows_to_table(&rows)
289    ))
290}
291
292fn pg_checkpoint_pressure(url: &str) -> anyhow::Result<String> {
293    let mut client = pg_connect(url)?;
294    let rows = client.query(
295        "SELECT checkpoints_timed, checkpoints_req,
296                ROUND(checkpoint_write_time) AS write_ms,
297                ROUND(checkpoint_sync_time) AS sync_ms,
298                buffers_checkpoint, buffers_clean, buffers_backend,
299                maxwritten_clean
300         FROM pg_stat_bgwriter",
301        &[],
302    )?;
303    Ok(format!("pg_stat_bgwriter\n\n{}", pg_rows_to_table(&rows)))
304}
305
306fn pg_table_stats(url: &str, schema: Option<&str>) -> anyhow::Result<String> {
307    let mut client = pg_connect(url)?;
308    let schema_filter = match schema {
309        Some(s) => format!("AND schemaname = '{s}'"),
310        None => "AND schemaname NOT IN ('pg_catalog','information_schema','pg_toast')".into(),
311    };
312    let sql = format!(
313        "SELECT schemaname, relname AS tablename, n_live_tup, n_dead_tup,
314                (n_dead_tup * 100 / NULLIF(n_live_tup + n_dead_tup, 0)) AS dead_pct,
315                COALESCE(to_char(last_vacuum, 'YYYY-MM-DD HH24:MI'), '-') AS last_vacuum,
316                COALESCE(to_char(last_analyze, 'YYYY-MM-DD HH24:MI'), '-') AS last_analyze
317         FROM pg_stat_user_tables
318         WHERE TRUE {schema_filter}
319         ORDER BY n_live_tup DESC
320         LIMIT 20"
321    );
322    let rows = client.query(&sql, &[])?;
323    Ok(format!(
324        "Table stats (top 20)\n\n{}",
325        pg_rows_to_table(&rows)
326    ))
327}
328
329fn pg_locks(url: &str) -> anyhow::Result<String> {
330    let mut client = pg_connect(url)?;
331    let rows = client.query(
332        "SELECT l.pid::text, c.relname AS relation, l.mode, l.granted::text
333         FROM pg_locks l
334         LEFT JOIN pg_class c ON c.oid = l.relation
335         WHERE l.relation IS NOT NULL
336         ORDER BY l.granted, l.pid",
337        &[],
338    )?;
339    if rows.is_empty() {
340        return Ok("No relation-level locks held.".into());
341    }
342    Ok(format!(
343        "Relation locks ({})\n\n{}",
344        rows.len(),
345        pg_rows_to_table(&rows)
346    ))
347}
348
349fn pg_top_queries_by_io(url: &str) -> anyhow::Result<String> {
350    let mut client = pg_connect(url)?;
351    // Check that pg_stat_statements is available.
352    let available: bool = client
353        .query_one(
354            "SELECT COUNT(*) > 0 FROM pg_extension WHERE extname = 'pg_stat_statements'",
355            &[],
356        )
357        .ok()
358        .and_then(|r| r.try_get::<_, bool>(0).ok())
359        .unwrap_or(false);
360    if !available {
361        return Ok("pg_stat_statements extension is not installed. \
362             Run: CREATE EXTENSION IF NOT EXISTS pg_stat_statements;"
363            .into());
364    }
365    let rows = client.query(
366        "SELECT LEFT(query, 80) AS query, calls,
367                ROUND(blk_read_time + blk_write_time) AS io_ms,
368                ROUND(total_exec_time) AS total_exec_ms
369         FROM pg_stat_statements
370         ORDER BY blk_read_time + blk_write_time DESC
371         LIMIT 10",
372        &[],
373    )?;
374    Ok(format!(
375        "Top 10 queries by I/O time\n\n{}",
376        pg_rows_to_table(&rows)
377    ))
378}
379
380// ─── MySQL tools ───────────────────────────────────────────────────────────
381
382fn mysql_pool(url: &str) -> anyhow::Result<mysql::Pool> {
383    use mysql::{Opts, OptsBuilder, PoolConstraints, PoolOpts};
384    let opts = Opts::from(
385        OptsBuilder::from_opts(Opts::from_url(url)?).pool_opts(
386            PoolOpts::default()
387                .with_constraints(PoolConstraints::new(1, 1).expect("valid pool constraints")),
388        ),
389    );
390    Ok(mysql::Pool::new(opts)?)
391}
392
393fn mysql_rows_to_table(rows: &[Vec<String>], headers: &[String]) -> String {
394    if rows.is_empty() {
395        return "(no rows)".into();
396    }
397    ascii_table(headers, rows)
398}
399
400fn mysql_processlist(url: &str) -> anyhow::Result<String> {
401    use mysql::prelude::*;
402    let pool = mysql_pool(url)?;
403    let mut conn = pool.get_conn()?;
404    let mut result = conn.exec_iter("SHOW PROCESSLIST", ())?;
405    let cols: Vec<String> = result
406        .columns()
407        .as_ref()
408        .iter()
409        .map(|c| c.name_str().to_string())
410        .collect();
411    let row_set = result
412        .iter()
413        .ok_or_else(|| anyhow::anyhow!("no result set"))?;
414    let rows: Vec<Vec<String>> = row_set
415        .filter_map(|r| r.ok())
416        .map(|row| {
417            (0..cols.len())
418                .map(|i| match row.as_ref(i) {
419                    Some(mysql::Value::Bytes(b)) => String::from_utf8_lossy(b).into_owned(),
420                    Some(mysql::Value::Int(n)) => n.to_string(),
421                    Some(mysql::Value::UInt(n)) => n.to_string(),
422                    Some(mysql::Value::NULL) | None => "NULL".into(),
423                    _ => "?".into(),
424                })
425                .collect()
426        })
427        .collect();
428    Ok(format!(
429        "SHOW PROCESSLIST ({})\n\n{}",
430        rows.len(),
431        mysql_rows_to_table(&rows, &cols)
432    ))
433}
434
435fn mysql_key_metrics(url: &str) -> anyhow::Result<String> {
436    use mysql::prelude::*;
437    let pool = mysql_pool(url)?;
438    let mut conn = pool.get_conn()?;
439    let metrics = [
440        "Innodb_log_waits",
441        "Innodb_row_lock_waits",
442        "Innodb_row_lock_time_avg",
443        "Threads_running",
444        "Threads_connected",
445        "Queries",
446        "Slow_queries",
447        "Connections",
448        "Aborted_connects",
449    ];
450    let in_clause = metrics
451        .iter()
452        .map(|m| format!("'{m}'"))
453        .collect::<Vec<_>>()
454        .join(",");
455    let sql = format!(
456        "SELECT variable_name, variable_value \
457         FROM information_schema.global_status \
458         WHERE variable_name IN ({in_clause})"
459    );
460    let rows: Vec<(String, String)> = conn.query(sql)?;
461    if rows.is_empty() {
462        return Ok("(no metrics returned)".into());
463    }
464    let headers = vec!["metric".to_string(), "value".to_string()];
465    let data: Vec<Vec<String>> = rows.into_iter().map(|(k, v)| vec![k, v]).collect();
466    Ok(format!(
467        "MySQL key metrics\n\n{}",
468        ascii_table(&headers, &data)
469    ))
470}
471
472fn mysql_table_stats(url: &str, schema: Option<&str>) -> anyhow::Result<String> {
473    use mysql::prelude::*;
474    let pool = mysql_pool(url)?;
475    let mut conn = pool.get_conn()?;
476    let schema_filter = match schema {
477        Some(s) => format!("AND table_schema = '{s}'"),
478        None => "AND table_schema NOT IN ('information_schema','performance_schema','mysql','sys')"
479            .into(),
480    };
481    let sql = format!(
482        "SELECT table_schema, table_name, table_rows, \
483                data_length, index_length, engine \
484         FROM information_schema.TABLES \
485         WHERE table_type = 'BASE TABLE' {schema_filter} \
486         ORDER BY table_rows DESC \
487         LIMIT 20"
488    );
489    let mut result = conn.exec_iter(&sql, ())?;
490    let cols: Vec<String> = result
491        .columns()
492        .as_ref()
493        .iter()
494        .map(|c| c.name_str().to_string())
495        .collect();
496    let row_set = result
497        .iter()
498        .ok_or_else(|| anyhow::anyhow!("no result set"))?;
499    let rows: Vec<Vec<String>> = row_set
500        .filter_map(|r| r.ok())
501        .map(|row| {
502            (0..cols.len())
503                .map(|i| match row.as_ref(i) {
504                    Some(mysql::Value::Bytes(b)) => String::from_utf8_lossy(b).into_owned(),
505                    Some(mysql::Value::Int(n)) => n.to_string(),
506                    Some(mysql::Value::UInt(n)) => n.to_string(),
507                    Some(mysql::Value::NULL) | None => "NULL".into(),
508                    _ => "?".into(),
509                })
510                .collect()
511        })
512        .collect();
513    Ok(format!(
514        "Table stats (top 20)\n\n{}",
515        mysql_rows_to_table(&rows, &cols)
516    ))
517}
518
519// ─── pgBouncer tools ───────────────────────────────────────────────────────
520
521fn pgbouncer_query(sql: &str) -> anyhow::Result<String> {
522    let admin_url = std::env::var("PGBOUNCER_ADMIN_URL").map_err(|_| {
523        anyhow::anyhow!(
524            "PGBOUNCER_ADMIN_URL not set. \
525             Example: postgresql://pgbouncer@127.0.0.1:6432/pgbouncer"
526        )
527    })?;
528    use postgres::NoTls;
529    let mut client = postgres::Client::connect(&admin_url, NoTls)?;
530    let rows = client.query(sql, &[])?;
531    Ok(pg_rows_to_table(&rows))
532}
533
534// ─── ASCII table formatter ─────────────────────────────────────────────────
535
536fn ascii_table(headers: &[impl AsRef<str>], rows: &[Vec<String>]) -> String {
537    let ncols = headers.len();
538    let mut widths: Vec<usize> = headers.iter().map(|h| h.as_ref().len()).collect();
539    for row in rows {
540        for (i, cell) in row.iter().enumerate() {
541            if i < ncols {
542                widths[i] = widths[i].max(cell.len());
543            }
544        }
545    }
546
547    let fmt_row = |cells: &[String]| -> String {
548        cells
549            .iter()
550            .enumerate()
551            .map(|(i, c)| format!("{:<width$}", c, width = widths.get(i).copied().unwrap_or(0)))
552            .collect::<Vec<_>>()
553            .join(" | ")
554    };
555
556    let header: Vec<String> = headers.iter().map(|h| h.as_ref().to_string()).collect();
557    let separator = widths
558        .iter()
559        .map(|w| "-".repeat(*w))
560        .collect::<Vec<_>>()
561        .join("-+-");
562    let body = rows
563        .iter()
564        .map(|r| fmt_row(r))
565        .collect::<Vec<_>>()
566        .join("\n");
567
568    if body.is_empty() {
569        format!("{}\n{}", fmt_row(&header), separator)
570    } else {
571        format!("{}\n{}\n{}", fmt_row(&header), separator, body)
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578
579    #[test]
580    fn ascii_table_widens_columns_to_longest_cell() {
581        let headers = ["pid", "state"];
582        let rows = vec![
583            vec!["1".into(), "active".into()],
584            vec!["10000".into(), "idle".into()],
585        ];
586        let out = ascii_table(&headers, &rows);
587        let lines: Vec<&str> = out.lines().collect();
588        assert_eq!(lines.len(), 4, "header + separator + 2 rows");
589        // pid column widened to 5 (len of "10000"); state column to 6 (len of "active").
590        // Separator uses "-+-" as the cross — adds an extra dash on each side of '+'.
591        assert_eq!(lines[0], "pid   | state ");
592        assert_eq!(lines[1], "------+-------");
593        assert_eq!(lines[2], "1     | active");
594        assert_eq!(lines[3], "10000 | idle  ");
595    }
596
597    #[test]
598    fn ascii_table_renders_header_only_when_no_rows() {
599        let headers = ["col_a", "col_b"];
600        let out = ascii_table(&headers, &[]);
601        // No body line — header + separator only.
602        assert_eq!(out, "col_a | col_b\n------+------");
603    }
604
605    #[test]
606    fn ascii_table_handles_unicode_byte_width() {
607        // String::len() returns byte count, not grapheme count. Cyrillic is 2 bytes/char.
608        // Documenting current behavior so a future width-aware change is intentional.
609        let headers = ["x"];
610        let rows = vec![vec!["ы".into()]]; // 2 bytes
611        let out = ascii_table(&headers, &rows);
612        // Header width = max(1, 2) = 2 bytes — header padded to 2.
613        assert!(out.contains("x "), "header padded to byte-width 2");
614    }
615}