1use serde_json::Value;
2use sqlx::{AnyPool, Row};
3
4use crate::db::convert::row_to_json;
5use crate::db::DbBackend;
6use crate::error::McpSqlError;
7
8pub async fn list_tables(pool: &AnyPool, backend: DbBackend) -> Result<Vec<Value>, McpSqlError> {
10 let sql = match backend {
11 DbBackend::Postgres => {
12 "SELECT schemaname || '.' || tablename AS table_name, \
13 COALESCE(n_live_tup, 0) AS row_count \
14 FROM pg_tables \
15 LEFT JOIN pg_stat_user_tables ON tablename = relname AND schemaname = pg_stat_user_tables.schemaname \
16 WHERE pg_tables.schemaname NOT IN ('pg_catalog', 'information_schema') \
17 ORDER BY table_name"
18 }
19 DbBackend::Sqlite => {
20 "SELECT name AS table_name, 0 AS row_count \
21 FROM sqlite_master \
22 WHERE type = 'table' AND name NOT LIKE 'sqlite_%' \
23 ORDER BY name"
24 }
25 DbBackend::Mysql => {
26 "SELECT table_name, table_rows AS row_count \
27 FROM information_schema.tables \
28 WHERE table_schema = DATABASE() \
29 ORDER BY table_name"
30 }
31 };
32
33 let rows = sqlx::query(sql).fetch_all(pool).await?;
34 Ok(rows.iter().map(row_to_json).collect())
35}
36
37pub async fn describe_table(
39 pool: &AnyPool,
40 backend: DbBackend,
41 table: &str,
42) -> Result<Vec<Value>, McpSqlError> {
43 match backend {
44 DbBackend::Postgres => describe_table_postgres(pool, table).await,
45 DbBackend::Sqlite => describe_table_sqlite(pool, table).await,
46 DbBackend::Mysql => describe_table_mysql(pool, table).await,
47 }
48}
49
50async fn describe_table_postgres(pool: &AnyPool, table: &str) -> Result<Vec<Value>, McpSqlError> {
51 let (schema, tbl) = if let Some((s, t)) = table.split_once('.') {
53 (s, t)
54 } else {
55 ("public", table)
56 };
57
58 let sql = "SELECT c.column_name AS name, c.data_type AS type, \
59 c.is_nullable AS nullable, c.column_default AS default_value, \
60 CASE WHEN tc.constraint_type = 'PRIMARY KEY' THEN 'YES' ELSE 'NO' END AS primary_key \
61 FROM information_schema.columns c \
62 LEFT JOIN information_schema.key_column_usage kcu \
63 ON c.table_schema = kcu.table_schema \
64 AND c.table_name = kcu.table_name \
65 AND c.column_name = kcu.column_name \
66 LEFT JOIN information_schema.table_constraints tc \
67 ON kcu.constraint_name = tc.constraint_name \
68 AND kcu.table_schema = tc.table_schema \
69 AND tc.constraint_type = 'PRIMARY KEY' \
70 WHERE c.table_schema = $1 AND c.table_name = $2 \
71 ORDER BY c.ordinal_position";
72
73 let rows = sqlx::query(sql)
74 .bind(schema)
75 .bind(tbl)
76 .fetch_all(pool)
77 .await?;
78
79 if rows.is_empty() {
80 return Err(McpSqlError::Other(format!("Table '{table}' not found")));
81 }
82
83 let fk_sql = "SELECT kcu.column_name, ccu.table_schema || '.' || ccu.table_name || '.' || ccu.column_name AS references_col \
85 FROM information_schema.key_column_usage kcu \
86 JOIN information_schema.referential_constraints rc \
87 ON kcu.constraint_name = rc.constraint_name AND kcu.constraint_schema = rc.constraint_schema \
88 JOIN information_schema.constraint_column_usage ccu \
89 ON rc.unique_constraint_name = ccu.constraint_name AND rc.unique_constraint_schema = ccu.constraint_schema \
90 WHERE kcu.table_schema = $1 AND kcu.table_name = $2";
91
92 let fk_rows = sqlx::query(fk_sql)
93 .bind(schema)
94 .bind(tbl)
95 .fetch_all(pool)
96 .await
97 .unwrap_or_default();
98
99 let fk_map: std::collections::HashMap<String, String> = fk_rows
100 .iter()
101 .filter_map(|r| {
102 let col: String = r.try_get("column_name").ok()?;
103 let refs: String = r.try_get("references_col").ok()?;
104 Some((col, refs))
105 })
106 .collect();
107
108 let mut result: Vec<Value> = rows.iter().map(row_to_json).collect();
109 for col in &mut result {
110 if let Value::Object(map) = col {
111 let col_name = map.get("name").and_then(|v| v.as_str()).unwrap_or("");
112 let fk = fk_map.get(col_name).map(|s| Value::String(s.clone())).unwrap_or(Value::Null);
113 map.insert("foreign_key".to_string(), fk);
114 }
115 }
116
117 Ok(result)
118}
119
120async fn describe_table_sqlite(pool: &AnyPool, table: &str) -> Result<Vec<Value>, McpSqlError> {
121 let safe_table = sanitize_identifier(table)?;
123 let sql = format!("PRAGMA table_info(\"{}\")", safe_table);
124 let rows = sqlx::query(&sql).fetch_all(pool).await?;
125
126 if rows.is_empty() {
127 return Err(McpSqlError::Other(format!("Table '{table}' not found")));
128 }
129
130 let fk_sql = format!("PRAGMA foreign_key_list(\"{}\")", safe_table);
132 let fk_rows = sqlx::query(&fk_sql).fetch_all(pool).await.unwrap_or_default();
133
134 let fk_map: std::collections::HashMap<String, String> = fk_rows
135 .iter()
136 .filter_map(|r| {
137 let from: String = r.try_get("from").ok()?;
138 let ref_table: String = r.try_get("table").ok()?;
139 let ref_col: String = r.try_get("to").ok()?;
140 Some((from, format!("{ref_table}.{ref_col}")))
141 })
142 .collect();
143
144 let mut result = Vec::new();
145 for row in &rows {
146 let name: String = row.try_get("name").unwrap_or_default();
147 let col_type: String = row.try_get("type").unwrap_or_default();
148 let notnull: i32 = row.try_get("notnull").unwrap_or(0);
149 let dflt_value: Option<String> = row.try_get("dflt_value").ok();
150 let pk: i32 = row.try_get("pk").unwrap_or(0);
151 let fk = fk_map.get(&name).cloned();
152
153 result.push(serde_json::json!({
154 "name": name,
155 "type": col_type,
156 "nullable": if notnull == 0 { "YES" } else { "NO" },
157 "default_value": dflt_value,
158 "primary_key": if pk > 0 { "YES" } else { "NO" },
159 "foreign_key": fk,
160 }));
161 }
162
163 Ok(result)
164}
165
166async fn describe_table_mysql(pool: &AnyPool, table: &str) -> Result<Vec<Value>, McpSqlError> {
167 let sql = "SELECT column_name AS name, column_type AS type, \
168 is_nullable AS nullable, column_default AS default_value, \
169 CASE WHEN column_key = 'PRI' THEN 'YES' ELSE 'NO' END AS primary_key \
170 FROM information_schema.columns \
171 WHERE table_schema = DATABASE() AND table_name = ? \
172 ORDER BY ordinal_position";
173
174 let rows = sqlx::query(sql).bind(table).fetch_all(pool).await?;
175
176 if rows.is_empty() {
177 return Err(McpSqlError::Other(format!("Table '{table}' not found")));
178 }
179
180 let fk_sql = "SELECT column_name, CONCAT(referenced_table_name, '.', referenced_column_name) AS references_col \
182 FROM information_schema.key_column_usage \
183 WHERE table_schema = DATABASE() AND table_name = ? AND referenced_table_name IS NOT NULL";
184
185 let fk_rows = sqlx::query(fk_sql).bind(table).fetch_all(pool).await.unwrap_or_default();
186
187 let fk_map: std::collections::HashMap<String, String> = fk_rows
188 .iter()
189 .filter_map(|r| {
190 let col: String = r.try_get("column_name").ok()?;
191 let refs: String = r.try_get("references_col").ok()?;
192 Some((col, refs))
193 })
194 .collect();
195
196 let mut result: Vec<Value> = rows.iter().map(row_to_json).collect();
197 for col in &mut result {
198 if let Value::Object(map) = col {
199 let col_name = map.get("name").and_then(|v| v.as_str()).unwrap_or("");
200 let fk = fk_map.get(col_name).map(|s| Value::String(s.clone())).unwrap_or(Value::Null);
201 map.insert("foreign_key".to_string(), fk);
202 }
203 }
204
205 Ok(result)
206}
207
208pub async fn sample_data(
210 pool: &AnyPool,
211 backend: DbBackend,
212 table: &str,
213 limit: u32,
214) -> Result<Vec<Value>, McpSqlError> {
215 let safe_table = sanitize_identifier(table)?;
216 let sql = match backend {
217 DbBackend::Postgres => format!(
218 "SELECT * FROM \"{}\" TABLESAMPLE BERNOULLI (100) LIMIT {}",
219 safe_table, limit
220 ),
221 DbBackend::Sqlite => format!("SELECT * FROM \"{}\" LIMIT {}", safe_table, limit),
222 DbBackend::Mysql => format!(
223 "SELECT * FROM `{}` ORDER BY RAND() LIMIT {}",
224 safe_table, limit
225 ),
226 };
227
228 let rows = sqlx::query(&sql).fetch_all(pool).await?;
229 Ok(rows.iter().map(row_to_json).collect())
230}
231
232pub fn explain_prefix(backend: DbBackend) -> &'static str {
234 match backend {
235 DbBackend::Postgres => "EXPLAIN (FORMAT TEXT) ",
236 DbBackend::Sqlite => "EXPLAIN QUERY PLAN ",
237 DbBackend::Mysql => "EXPLAIN ",
238 }
239}
240
241fn sanitize_identifier(name: &str) -> Result<String, McpSqlError> {
243 if name
245 .chars()
246 .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == '-')
247 && !name.is_empty()
248 {
249 Ok(name.to_string())
250 } else {
251 Err(McpSqlError::InvalidSql(format!(
252 "Invalid identifier: '{name}'"
253 )))
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_sanitize_identifier() {
263 assert!(sanitize_identifier("users").is_ok());
264 assert!(sanitize_identifier("public.users").is_ok());
265 assert!(sanitize_identifier("my_table").is_ok());
266 assert!(sanitize_identifier("my-table").is_ok());
267 assert!(sanitize_identifier("").is_err());
268 assert!(sanitize_identifier("users; DROP TABLE users").is_err());
269 assert!(sanitize_identifier("users\"").is_err());
270 }
271
272 #[test]
273 fn test_explain_prefix() {
274 assert_eq!(explain_prefix(DbBackend::Postgres), "EXPLAIN (FORMAT TEXT) ");
275 assert_eq!(explain_prefix(DbBackend::Sqlite), "EXPLAIN QUERY PLAN ");
276 assert_eq!(explain_prefix(DbBackend::Mysql), "EXPLAIN ");
277 }
278}