Skip to main content

mcp_sql/db/
dialect.rs

1use serde_json::Value;
2use sqlx::{AnyPool, Row};
3
4use crate::db::convert::row_to_json;
5use crate::db::DbBackend;
6use crate::error::McpSqlError;
7
8/// List tables with approximate row counts.
9pub async fn list_tables(pool: &AnyPool, backend: DbBackend) -> Result<Vec<Value>, McpSqlError> {
10    let sql = match backend {
11        DbBackend::Postgres => {
12            "SELECT schemaname || '.' || tablename AS table_name, \
13                    COALESCE(n_live_tup, 0) AS row_count \
14             FROM pg_tables \
15             LEFT JOIN pg_stat_user_tables ON tablename = relname AND schemaname = pg_stat_user_tables.schemaname \
16             WHERE pg_tables.schemaname NOT IN ('pg_catalog', 'information_schema') \
17             ORDER BY table_name"
18        }
19        DbBackend::Sqlite => {
20            "SELECT name AS table_name, 0 AS row_count \
21             FROM sqlite_master \
22             WHERE type = 'table' AND name NOT LIKE 'sqlite_%' \
23             ORDER BY name"
24        }
25        DbBackend::Mysql => {
26            "SELECT table_name, table_rows AS row_count \
27             FROM information_schema.tables \
28             WHERE table_schema = DATABASE() \
29             ORDER BY table_name"
30        }
31    };
32
33    let rows = sqlx::query(sql).fetch_all(pool).await?;
34    Ok(rows.iter().map(row_to_json).collect())
35}
36
37/// Describe a table's columns.
38pub async fn describe_table(
39    pool: &AnyPool,
40    backend: DbBackend,
41    table: &str,
42) -> Result<Vec<Value>, McpSqlError> {
43    match backend {
44        DbBackend::Postgres => describe_table_postgres(pool, table).await,
45        DbBackend::Sqlite => describe_table_sqlite(pool, table).await,
46        DbBackend::Mysql => describe_table_mysql(pool, table).await,
47    }
48}
49
50async fn describe_table_postgres(pool: &AnyPool, table: &str) -> Result<Vec<Value>, McpSqlError> {
51    // Handle schema.table format
52    let (schema, tbl) = if let Some((s, t)) = table.split_once('.') {
53        (s, t)
54    } else {
55        ("public", table)
56    };
57
58    let sql = "SELECT c.column_name AS name, c.data_type AS type, \
59               c.is_nullable AS nullable, c.column_default AS default_value, \
60               CASE WHEN tc.constraint_type = 'PRIMARY KEY' THEN 'YES' ELSE 'NO' END AS primary_key \
61               FROM information_schema.columns c \
62               LEFT JOIN information_schema.key_column_usage kcu \
63                 ON c.table_schema = kcu.table_schema \
64                 AND c.table_name = kcu.table_name \
65                 AND c.column_name = kcu.column_name \
66               LEFT JOIN information_schema.table_constraints tc \
67                 ON kcu.constraint_name = tc.constraint_name \
68                 AND kcu.table_schema = tc.table_schema \
69                 AND tc.constraint_type = 'PRIMARY KEY' \
70               WHERE c.table_schema = $1 AND c.table_name = $2 \
71               ORDER BY c.ordinal_position";
72
73    let rows = sqlx::query(sql)
74        .bind(schema)
75        .bind(tbl)
76        .fetch_all(pool)
77        .await?;
78
79    if rows.is_empty() {
80        return Err(McpSqlError::Other(format!("Table '{table}' not found")));
81    }
82
83    // Fetch FK info
84    let fk_sql = "SELECT kcu.column_name, ccu.table_schema || '.' || ccu.table_name || '.' || ccu.column_name AS references_col \
85                   FROM information_schema.key_column_usage kcu \
86                   JOIN information_schema.referential_constraints rc \
87                     ON kcu.constraint_name = rc.constraint_name AND kcu.constraint_schema = rc.constraint_schema \
88                   JOIN information_schema.constraint_column_usage ccu \
89                     ON rc.unique_constraint_name = ccu.constraint_name AND rc.unique_constraint_schema = ccu.constraint_schema \
90                   WHERE kcu.table_schema = $1 AND kcu.table_name = $2";
91
92    let fk_rows = sqlx::query(fk_sql)
93        .bind(schema)
94        .bind(tbl)
95        .fetch_all(pool)
96        .await
97        .unwrap_or_default();
98
99    let fk_map: std::collections::HashMap<String, String> = fk_rows
100        .iter()
101        .filter_map(|r| {
102            let col: String = r.try_get("column_name").ok()?;
103            let refs: String = r.try_get("references_col").ok()?;
104            Some((col, refs))
105        })
106        .collect();
107
108    let mut result: Vec<Value> = rows.iter().map(row_to_json).collect();
109    for col in &mut result {
110        if let Value::Object(map) = col {
111            let col_name = map.get("name").and_then(|v| v.as_str()).unwrap_or("");
112            let fk = fk_map.get(col_name).map(|s| Value::String(s.clone())).unwrap_or(Value::Null);
113            map.insert("foreign_key".to_string(), fk);
114        }
115    }
116
117    Ok(result)
118}
119
120async fn describe_table_sqlite(pool: &AnyPool, table: &str) -> Result<Vec<Value>, McpSqlError> {
121    // SQLite PRAGMA doesn't support parameterized queries, so we validate the table name
122    let safe_table = sanitize_identifier(table)?;
123    let sql = format!("PRAGMA table_info(\"{}\")", safe_table);
124    let rows = sqlx::query(&sql).fetch_all(pool).await?;
125
126    if rows.is_empty() {
127        return Err(McpSqlError::Other(format!("Table '{table}' not found")));
128    }
129
130    // Fetch FK info via PRAGMA foreign_key_list
131    let fk_sql = format!("PRAGMA foreign_key_list(\"{}\")", safe_table);
132    let fk_rows = sqlx::query(&fk_sql).fetch_all(pool).await.unwrap_or_default();
133
134    let fk_map: std::collections::HashMap<String, String> = fk_rows
135        .iter()
136        .filter_map(|r| {
137            let from: String = r.try_get("from").ok()?;
138            let ref_table: String = r.try_get("table").ok()?;
139            let ref_col: String = r.try_get("to").ok()?;
140            Some((from, format!("{ref_table}.{ref_col}")))
141        })
142        .collect();
143
144    let mut result = Vec::new();
145    for row in &rows {
146        let name: String = row.try_get("name").unwrap_or_default();
147        let col_type: String = row.try_get("type").unwrap_or_default();
148        let notnull: i32 = row.try_get("notnull").unwrap_or(0);
149        let dflt_value: Option<String> = row.try_get("dflt_value").ok();
150        let pk: i32 = row.try_get("pk").unwrap_or(0);
151        let fk = fk_map.get(&name).cloned();
152
153        result.push(serde_json::json!({
154            "name": name,
155            "type": col_type,
156            "nullable": if notnull == 0 { "YES" } else { "NO" },
157            "default_value": dflt_value,
158            "primary_key": if pk > 0 { "YES" } else { "NO" },
159            "foreign_key": fk,
160        }));
161    }
162
163    Ok(result)
164}
165
166async fn describe_table_mysql(pool: &AnyPool, table: &str) -> Result<Vec<Value>, McpSqlError> {
167    let sql = "SELECT column_name AS name, column_type AS type, \
168               is_nullable AS nullable, column_default AS default_value, \
169               CASE WHEN column_key = 'PRI' THEN 'YES' ELSE 'NO' END AS primary_key \
170               FROM information_schema.columns \
171               WHERE table_schema = DATABASE() AND table_name = ? \
172               ORDER BY ordinal_position";
173
174    let rows = sqlx::query(sql).bind(table).fetch_all(pool).await?;
175
176    if rows.is_empty() {
177        return Err(McpSqlError::Other(format!("Table '{table}' not found")));
178    }
179
180    // Fetch FK info
181    let fk_sql = "SELECT column_name, CONCAT(referenced_table_name, '.', referenced_column_name) AS references_col \
182                   FROM information_schema.key_column_usage \
183                   WHERE table_schema = DATABASE() AND table_name = ? AND referenced_table_name IS NOT NULL";
184
185    let fk_rows = sqlx::query(fk_sql).bind(table).fetch_all(pool).await.unwrap_or_default();
186
187    let fk_map: std::collections::HashMap<String, String> = fk_rows
188        .iter()
189        .filter_map(|r| {
190            let col: String = r.try_get("column_name").ok()?;
191            let refs: String = r.try_get("references_col").ok()?;
192            Some((col, refs))
193        })
194        .collect();
195
196    let mut result: Vec<Value> = rows.iter().map(row_to_json).collect();
197    for col in &mut result {
198        if let Value::Object(map) = col {
199            let col_name = map.get("name").and_then(|v| v.as_str()).unwrap_or("");
200            let fk = fk_map.get(col_name).map(|s| Value::String(s.clone())).unwrap_or(Value::Null);
201            map.insert("foreign_key".to_string(), fk);
202        }
203    }
204
205    Ok(result)
206}
207
208/// Sample N rows from a table.
209pub async fn sample_data(
210    pool: &AnyPool,
211    backend: DbBackend,
212    table: &str,
213    limit: u32,
214) -> Result<Vec<Value>, McpSqlError> {
215    let safe_table = sanitize_identifier(table)?;
216    let sql = match backend {
217        DbBackend::Postgres => format!(
218            "SELECT * FROM \"{}\" TABLESAMPLE BERNOULLI (100) LIMIT {}",
219            safe_table, limit
220        ),
221        DbBackend::Sqlite => format!("SELECT * FROM \"{}\" LIMIT {}", safe_table, limit),
222        DbBackend::Mysql => format!(
223            "SELECT * FROM `{}` ORDER BY RAND() LIMIT {}",
224            safe_table, limit
225        ),
226    };
227
228    let rows = sqlx::query(&sql).fetch_all(pool).await?;
229    Ok(rows.iter().map(row_to_json).collect())
230}
231
232/// Get the correct EXPLAIN prefix for each backend.
233pub fn explain_prefix(backend: DbBackend) -> &'static str {
234    match backend {
235        DbBackend::Postgres => "EXPLAIN (FORMAT TEXT) ",
236        DbBackend::Sqlite => "EXPLAIN QUERY PLAN ",
237        DbBackend::Mysql => "EXPLAIN ",
238    }
239}
240
241/// Validate and sanitize a SQL identifier to prevent injection.
242fn sanitize_identifier(name: &str) -> Result<String, McpSqlError> {
243    // Allow alphanumeric, underscore, dot (for schema.table), and hyphen
244    if name
245        .chars()
246        .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == '-')
247        && !name.is_empty()
248    {
249        Ok(name.to_string())
250    } else {
251        Err(McpSqlError::InvalidSql(format!(
252            "Invalid identifier: '{name}'"
253        )))
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn test_sanitize_identifier() {
263        assert!(sanitize_identifier("users").is_ok());
264        assert!(sanitize_identifier("public.users").is_ok());
265        assert!(sanitize_identifier("my_table").is_ok());
266        assert!(sanitize_identifier("my-table").is_ok());
267        assert!(sanitize_identifier("").is_err());
268        assert!(sanitize_identifier("users; DROP TABLE users").is_err());
269        assert!(sanitize_identifier("users\"").is_err());
270    }
271
272    #[test]
273    fn test_explain_prefix() {
274        assert_eq!(explain_prefix(DbBackend::Postgres), "EXPLAIN (FORMAT TEXT) ");
275        assert_eq!(explain_prefix(DbBackend::Sqlite), "EXPLAIN QUERY PLAN ");
276        assert_eq!(explain_prefix(DbBackend::Mysql), "EXPLAIN ");
277    }
278}