1use async_trait::async_trait;
7use serde_json::{json, Value};
8use sqlx::{Column, Row, TypeInfo};
9
10use crate::protocol::CallToolResult;
11use crate::tool::{Context, Tool};
12
13type SqliteColumnRow = (i64, String, String, i64, Option<String>, i64);
15
16pub struct DatabaseSchema;
19
20#[async_trait]
21impl Tool for DatabaseSchema {
22 fn name(&self) -> &'static str {
23 "database-schema"
24 }
25 fn description(&self) -> &'static str {
26 "Dump the live database schema: every table and its columns, types, and nullability. Reads from information_schema (Postgres/MySQL) or sqlite_master (SQLite)."
27 }
28
29 async fn call(&self, ctx: &Context, _args: Value) -> CallToolResult {
30 let pool = ctx.container.driver_pool();
31 match pool {
32 cast_core::Pool::Postgres(p) => {
33 let rows: Result<Vec<(String, String, String, String)>, _> = sqlx::query_as(
34 "SELECT table_name::TEXT, column_name::TEXT, data_type::TEXT, is_nullable::TEXT
35 FROM information_schema.columns
36 WHERE table_schema = 'public'
37 ORDER BY table_name, ordinal_position",
38 )
39 .fetch_all(&p)
40 .await;
41 pg_my_to_result(rows, "postgres")
42 }
43 cast_core::Pool::MySql(p) => {
44 let rows: Result<Vec<(String, String, String, String)>, _> = sqlx::query_as(
45 "SELECT table_name, column_name, column_type, is_nullable
46 FROM information_schema.columns
47 WHERE table_schema = DATABASE()
48 ORDER BY table_name, ordinal_position",
49 )
50 .fetch_all(&p)
51 .await;
52 pg_my_to_result(rows, "mysql")
53 }
54 cast_core::Pool::Sqlite(p) => sqlite_schema(&p).await,
55 }
56 }
57}
58
59fn pg_my_to_result(
60 rows: Result<Vec<(String, String, String, String)>, sqlx::Error>,
61 driver: &str,
62) -> CallToolResult {
63 let rows = match rows {
64 Ok(r) => r,
65 Err(e) => return CallToolResult::error(format!("schema query failed: {e}")),
66 };
67 let mut by_table: indexmap::IndexMap<String, Vec<Value>> = indexmap::IndexMap::new();
68 for (table, col, ty, nullable) in rows {
69 by_table.entry(table).or_default().push(json!({
70 "name": col,
71 "type": ty,
72 "nullable": nullable.eq_ignore_ascii_case("yes"),
73 }));
74 }
75 CallToolResult::json(&json!({
76 "driver": driver,
77 "tables": by_table,
78 }))
79}
80
81async fn sqlite_schema(p: &sqlx::SqlitePool) -> CallToolResult {
82 let tables: Result<Vec<(String,)>, _> = sqlx::query_as(
83 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name",
84 )
85 .fetch_all(p)
86 .await;
87 let tables = match tables {
88 Ok(t) => t,
89 Err(e) => return CallToolResult::error(format!("schema query failed: {e}")),
90 };
91 let mut by_table = serde_json::Map::new();
92 for (name,) in tables {
93 let cols: Result<Vec<SqliteColumnRow>, _> =
94 sqlx::query_as(&format!("PRAGMA table_info({name})"))
95 .fetch_all(p)
96 .await;
97 if let Ok(rows) = cols {
98 let columns: Vec<Value> = rows
99 .into_iter()
100 .map(|(cid, col, ty, notnull, dflt, pk)| {
101 json!({
102 "cid": cid,
103 "name": col,
104 "type": ty,
105 "nullable": notnull == 0,
106 "default": dflt,
107 "pk": pk != 0,
108 })
109 })
110 .collect();
111 by_table.insert(name, Value::Array(columns));
112 }
113 }
114 CallToolResult::json(&json!({
115 "driver": "sqlite",
116 "tables": by_table,
117 }))
118}
119
120pub struct DatabaseQuery;
123
124#[async_trait]
125impl Tool for DatabaseQuery {
126 fn name(&self) -> &'static str {
127 "database-query"
128 }
129 fn description(&self) -> &'static str {
130 "Run a read-only SQL query and return rows as JSON. Only SELECT, WITH, EXPLAIN, SHOW, and PRAGMA are accepted. Optional `limit` caps the result count (default 100, max 1000)."
131 }
132 fn input_schema(&self) -> Value {
133 json!({
134 "type": "object",
135 "required": ["sql"],
136 "properties": {
137 "sql": { "type": "string", "description": "Read-only SQL statement." },
138 "limit": { "type": "integer", "description": "Max rows to return.", "default": 100, "minimum": 1, "maximum": 1000 }
139 }
140 })
141 }
142
143 async fn call(&self, ctx: &Context, args: Value) -> CallToolResult {
144 let sql = match args.get("sql").and_then(|v| v.as_str()) {
145 Some(s) if !s.trim().is_empty() => s.trim().to_string(),
146 _ => return CallToolResult::error("`sql` is required"),
147 };
148 if !is_readonly(&sql) {
149 return CallToolResult::error(
150 "rejected: only SELECT, WITH, EXPLAIN, SHOW, and PRAGMA statements are allowed by database-query",
151 );
152 }
153 let limit = args
154 .get("limit")
155 .and_then(|v| v.as_u64())
156 .unwrap_or(100)
157 .clamp(1, 1000) as usize;
158
159 let pool = ctx.container.driver_pool();
160 let result = match pool {
161 cast_core::Pool::Postgres(p) => run_postgres(&sql, &p, limit).await,
162 cast_core::Pool::MySql(p) => run_mysql(&sql, &p, limit).await,
163 cast_core::Pool::Sqlite(p) => run_sqlite(&sql, &p, limit).await,
164 };
165 match result {
166 Ok((rows, truncated)) => CallToolResult::json(&json!({
167 "rows": rows,
168 "row_count": rows.len(),
169 "truncated": truncated,
170 })),
171 Err(e) => CallToolResult::error(e),
172 }
173 }
174}
175
176fn is_readonly(sql: &str) -> bool {
177 let s = sql.trim_start();
178 let lower = s.to_ascii_lowercase();
179 for p in ["select", "with", "explain", "show", "pragma"] {
180 if lower.starts_with(p) {
181 return true;
182 }
183 }
184 false
185}
186
187async fn run_postgres(
188 sql: &str,
189 pool: &sqlx::PgPool,
190 limit: usize,
191) -> Result<(Vec<Value>, bool), String> {
192 let rows = sqlx::query(sql)
193 .fetch_all(pool)
194 .await
195 .map_err(|e| format!("query error: {e}"))?;
196 let truncated = rows.len() > limit;
197 let take_n = rows.len().min(limit);
198 let mut out = Vec::with_capacity(take_n);
199 for row in rows.iter().take(take_n) {
200 let mut obj = serde_json::Map::new();
201 for (i, col) in row.columns().iter().enumerate() {
202 let key = col.name().to_string();
203 let value = pg_value(row, i, col.type_info().name());
204 obj.insert(key, value);
205 }
206 out.push(Value::Object(obj));
207 }
208 Ok((out, truncated))
209}
210
211fn pg_value(row: &sqlx::postgres::PgRow, idx: usize, ty: &str) -> Value {
212 if let Ok(Some(v)) = row.try_get::<Option<i64>, _>(idx) {
213 return json!(v);
214 }
215 if let Ok(Some(v)) = row.try_get::<Option<i32>, _>(idx) {
216 return json!(v);
217 }
218 if let Ok(Some(v)) = row.try_get::<Option<f64>, _>(idx) {
219 return json!(v);
220 }
221 if let Ok(Some(v)) = row.try_get::<Option<bool>, _>(idx) {
222 return json!(v);
223 }
224 if let Ok(Some(v)) = row.try_get::<Option<String>, _>(idx) {
225 return json!(v);
226 }
227 if let Ok(Some(v)) = row.try_get::<Option<serde_json::Value>, _>(idx) {
228 return v;
229 }
230 if let Ok(None::<String>) = row.try_get::<Option<String>, _>(idx) {
231 return Value::Null;
232 }
233 json!({ "_unknown_type": ty })
234}
235
236async fn run_mysql(
237 sql: &str,
238 pool: &sqlx::MySqlPool,
239 limit: usize,
240) -> Result<(Vec<Value>, bool), String> {
241 let rows = sqlx::query(sql)
242 .fetch_all(pool)
243 .await
244 .map_err(|e| format!("query error: {e}"))?;
245 let truncated = rows.len() > limit;
246 let take_n = rows.len().min(limit);
247 let mut out = Vec::with_capacity(take_n);
248 for row in rows.iter().take(take_n) {
249 let mut obj = serde_json::Map::new();
250 for (i, col) in row.columns().iter().enumerate() {
251 let key = col.name().to_string();
252 obj.insert(key, mysql_value(row, i, col.type_info().name()));
253 }
254 out.push(Value::Object(obj));
255 }
256 Ok((out, truncated))
257}
258
259fn mysql_value(row: &sqlx::mysql::MySqlRow, idx: usize, ty: &str) -> Value {
260 if let Ok(Some(v)) = row.try_get::<Option<i64>, _>(idx) {
261 return json!(v);
262 }
263 if let Ok(Some(v)) = row.try_get::<Option<f64>, _>(idx) {
264 return json!(v);
265 }
266 if let Ok(Some(v)) = row.try_get::<Option<bool>, _>(idx) {
267 return json!(v);
268 }
269 if let Ok(Some(v)) = row.try_get::<Option<String>, _>(idx) {
270 return json!(v);
271 }
272 if let Ok(None::<String>) = row.try_get::<Option<String>, _>(idx) {
273 return Value::Null;
274 }
275 json!({ "_unknown_type": ty })
276}
277
278async fn run_sqlite(
279 sql: &str,
280 pool: &sqlx::SqlitePool,
281 limit: usize,
282) -> Result<(Vec<Value>, bool), String> {
283 let rows = sqlx::query(sql)
284 .fetch_all(pool)
285 .await
286 .map_err(|e| format!("query error: {e}"))?;
287 let truncated = rows.len() > limit;
288 let take_n = rows.len().min(limit);
289 let mut out = Vec::with_capacity(take_n);
290 for row in rows.iter().take(take_n) {
291 let mut obj = serde_json::Map::new();
292 for (i, col) in row.columns().iter().enumerate() {
293 let key = col.name().to_string();
294 obj.insert(key, sqlite_value(row, i, col.type_info().name()));
295 }
296 out.push(Value::Object(obj));
297 }
298 Ok((out, truncated))
299}
300
301fn sqlite_value(row: &sqlx::sqlite::SqliteRow, idx: usize, ty: &str) -> Value {
302 if let Ok(Some(v)) = row.try_get::<Option<i64>, _>(idx) {
303 return json!(v);
304 }
305 if let Ok(Some(v)) = row.try_get::<Option<f64>, _>(idx) {
306 return json!(v);
307 }
308 if let Ok(Some(v)) = row.try_get::<Option<bool>, _>(idx) {
309 return json!(v);
310 }
311 if let Ok(Some(v)) = row.try_get::<Option<String>, _>(idx) {
312 return json!(v);
313 }
314 if let Ok(None::<String>) = row.try_get::<Option<String>, _>(idx) {
315 return Value::Null;
316 }
317 json!({ "_unknown_type": ty })
318}