1use async_trait::async_trait;
7use serde_json::{json, Value};
8use sqlx::{Column, Row, TypeInfo};
9
10use crate::protocol::CallToolResult;
11use crate::tool::{Context, Tool};
12
13pub struct DatabaseSchema;
16
17#[async_trait]
18impl Tool for DatabaseSchema {
19 fn name(&self) -> &'static str {
20 "database-schema"
21 }
22 fn description(&self) -> &'static str {
23 "Dump the live database schema: every table and its columns, types, and nullability. Reads from information_schema (Postgres/MySQL) or sqlite_master (SQLite)."
24 }
25
26 async fn call(&self, ctx: &Context, _args: Value) -> CallToolResult {
27 let pool = ctx.container.driver_pool();
28 match pool {
29 cast_core::Pool::Postgres(p) => {
30 let rows: Result<Vec<(String, String, String, String)>, _> = sqlx::query_as(
31 "SELECT table_name::TEXT, column_name::TEXT, data_type::TEXT, is_nullable::TEXT
32 FROM information_schema.columns
33 WHERE table_schema = 'public'
34 ORDER BY table_name, ordinal_position",
35 )
36 .fetch_all(&p)
37 .await;
38 pg_my_to_result(rows, "postgres")
39 }
40 cast_core::Pool::MySql(p) => {
41 let rows: Result<Vec<(String, String, String, String)>, _> = sqlx::query_as(
42 "SELECT table_name, column_name, column_type, is_nullable
43 FROM information_schema.columns
44 WHERE table_schema = DATABASE()
45 ORDER BY table_name, ordinal_position",
46 )
47 .fetch_all(&p)
48 .await;
49 pg_my_to_result(rows, "mysql")
50 }
51 cast_core::Pool::Sqlite(p) => sqlite_schema(&p).await,
52 }
53 }
54}
55
56fn pg_my_to_result(
57 rows: Result<Vec<(String, String, String, String)>, sqlx::Error>,
58 driver: &str,
59) -> CallToolResult {
60 let rows = match rows {
61 Ok(r) => r,
62 Err(e) => return CallToolResult::error(format!("schema query failed: {e}")),
63 };
64 let mut by_table: indexmap::IndexMap<String, Vec<Value>> = indexmap::IndexMap::new();
65 for (table, col, ty, nullable) in rows {
66 by_table.entry(table).or_default().push(json!({
67 "name": col,
68 "type": ty,
69 "nullable": nullable.eq_ignore_ascii_case("yes"),
70 }));
71 }
72 CallToolResult::json(&json!({
73 "driver": driver,
74 "tables": by_table,
75 }))
76}
77
78async fn sqlite_schema(p: &sqlx::SqlitePool) -> CallToolResult {
79 let tables: Result<Vec<(String,)>, _> = sqlx::query_as(
80 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name",
81 )
82 .fetch_all(p)
83 .await;
84 let tables = match tables {
85 Ok(t) => t,
86 Err(e) => return CallToolResult::error(format!("schema query failed: {e}")),
87 };
88 let mut by_table = serde_json::Map::new();
89 for (name,) in tables {
90 let cols: Result<Vec<(i64, String, String, i64, Option<String>, i64)>, _> =
91 sqlx::query_as(&format!("PRAGMA table_info({name})"))
92 .fetch_all(p)
93 .await;
94 if let Ok(rows) = cols {
95 let columns: Vec<Value> = rows
96 .into_iter()
97 .map(|(cid, col, ty, notnull, dflt, pk)| {
98 json!({
99 "cid": cid,
100 "name": col,
101 "type": ty,
102 "nullable": notnull == 0,
103 "default": dflt,
104 "pk": pk != 0,
105 })
106 })
107 .collect();
108 by_table.insert(name, Value::Array(columns));
109 }
110 }
111 CallToolResult::json(&json!({
112 "driver": "sqlite",
113 "tables": by_table,
114 }))
115}
116
117pub struct DatabaseQuery;
120
121#[async_trait]
122impl Tool for DatabaseQuery {
123 fn name(&self) -> &'static str {
124 "database-query"
125 }
126 fn description(&self) -> &'static str {
127 "Run a read-only SQL query and return rows as JSON. Only SELECT, WITH, EXPLAIN, SHOW, and PRAGMA are accepted. Optional `limit` caps the result count (default 100, max 1000)."
128 }
129 fn input_schema(&self) -> Value {
130 json!({
131 "type": "object",
132 "required": ["sql"],
133 "properties": {
134 "sql": { "type": "string", "description": "Read-only SQL statement." },
135 "limit": { "type": "integer", "description": "Max rows to return.", "default": 100, "minimum": 1, "maximum": 1000 }
136 }
137 })
138 }
139
140 async fn call(&self, ctx: &Context, args: Value) -> CallToolResult {
141 let sql = match args.get("sql").and_then(|v| v.as_str()) {
142 Some(s) if !s.trim().is_empty() => s.trim().to_string(),
143 _ => return CallToolResult::error("`sql` is required"),
144 };
145 if !is_readonly(&sql) {
146 return CallToolResult::error(
147 "rejected: only SELECT, WITH, EXPLAIN, SHOW, and PRAGMA statements are allowed by database-query",
148 );
149 }
150 let limit = args
151 .get("limit")
152 .and_then(|v| v.as_u64())
153 .unwrap_or(100)
154 .min(1000)
155 .max(1) as usize;
156
157 let pool = ctx.container.driver_pool();
158 let result = match pool {
159 cast_core::Pool::Postgres(p) => run_postgres(&sql, &p, limit).await,
160 cast_core::Pool::MySql(p) => run_mysql(&sql, &p, limit).await,
161 cast_core::Pool::Sqlite(p) => run_sqlite(&sql, &p, limit).await,
162 };
163 match result {
164 Ok((rows, truncated)) => CallToolResult::json(&json!({
165 "rows": rows,
166 "row_count": rows.len(),
167 "truncated": truncated,
168 })),
169 Err(e) => CallToolResult::error(e),
170 }
171 }
172}
173
174fn is_readonly(sql: &str) -> bool {
175 let s = sql.trim_start();
176 let lower = s.to_ascii_lowercase();
177 for p in ["select", "with", "explain", "show", "pragma"] {
178 if lower.starts_with(p) {
179 return true;
180 }
181 }
182 false
183}
184
185async fn run_postgres(
186 sql: &str,
187 pool: &sqlx::PgPool,
188 limit: usize,
189) -> Result<(Vec<Value>, bool), String> {
190 let rows = sqlx::query(sql)
191 .fetch_all(pool)
192 .await
193 .map_err(|e| format!("query error: {e}"))?;
194 let truncated = rows.len() > limit;
195 let take_n = rows.len().min(limit);
196 let mut out = Vec::with_capacity(take_n);
197 for row in rows.iter().take(take_n) {
198 let mut obj = serde_json::Map::new();
199 for (i, col) in row.columns().iter().enumerate() {
200 let key = col.name().to_string();
201 let value = pg_value(row, i, col.type_info().name());
202 obj.insert(key, value);
203 }
204 out.push(Value::Object(obj));
205 }
206 Ok((out, truncated))
207}
208
209fn pg_value(row: &sqlx::postgres::PgRow, idx: usize, ty: &str) -> Value {
210 if let Ok(Some(v)) = row.try_get::<Option<i64>, _>(idx) {
211 return json!(v);
212 }
213 if let Ok(Some(v)) = row.try_get::<Option<i32>, _>(idx) {
214 return json!(v);
215 }
216 if let Ok(Some(v)) = row.try_get::<Option<f64>, _>(idx) {
217 return json!(v);
218 }
219 if let Ok(Some(v)) = row.try_get::<Option<bool>, _>(idx) {
220 return json!(v);
221 }
222 if let Ok(Some(v)) = row.try_get::<Option<String>, _>(idx) {
223 return json!(v);
224 }
225 if let Ok(Some(v)) = row.try_get::<Option<serde_json::Value>, _>(idx) {
226 return v;
227 }
228 if let Ok(None::<String>) = row.try_get::<Option<String>, _>(idx) {
229 return Value::Null;
230 }
231 json!({ "_unknown_type": ty })
232}
233
234async fn run_mysql(
235 sql: &str,
236 pool: &sqlx::MySqlPool,
237 limit: usize,
238) -> Result<(Vec<Value>, bool), String> {
239 let rows = sqlx::query(sql)
240 .fetch_all(pool)
241 .await
242 .map_err(|e| format!("query error: {e}"))?;
243 let truncated = rows.len() > limit;
244 let take_n = rows.len().min(limit);
245 let mut out = Vec::with_capacity(take_n);
246 for row in rows.iter().take(take_n) {
247 let mut obj = serde_json::Map::new();
248 for (i, col) in row.columns().iter().enumerate() {
249 let key = col.name().to_string();
250 obj.insert(key, mysql_value(row, i, col.type_info().name()));
251 }
252 out.push(Value::Object(obj));
253 }
254 Ok((out, truncated))
255}
256
257fn mysql_value(row: &sqlx::mysql::MySqlRow, idx: usize, ty: &str) -> Value {
258 if let Ok(Some(v)) = row.try_get::<Option<i64>, _>(idx) {
259 return json!(v);
260 }
261 if let Ok(Some(v)) = row.try_get::<Option<f64>, _>(idx) {
262 return json!(v);
263 }
264 if let Ok(Some(v)) = row.try_get::<Option<bool>, _>(idx) {
265 return json!(v);
266 }
267 if let Ok(Some(v)) = row.try_get::<Option<String>, _>(idx) {
268 return json!(v);
269 }
270 if let Ok(None::<String>) = row.try_get::<Option<String>, _>(idx) {
271 return Value::Null;
272 }
273 json!({ "_unknown_type": ty })
274}
275
276async fn run_sqlite(
277 sql: &str,
278 pool: &sqlx::SqlitePool,
279 limit: usize,
280) -> Result<(Vec<Value>, bool), String> {
281 let rows = sqlx::query(sql)
282 .fetch_all(pool)
283 .await
284 .map_err(|e| format!("query error: {e}"))?;
285 let truncated = rows.len() > limit;
286 let take_n = rows.len().min(limit);
287 let mut out = Vec::with_capacity(take_n);
288 for row in rows.iter().take(take_n) {
289 let mut obj = serde_json::Map::new();
290 for (i, col) in row.columns().iter().enumerate() {
291 let key = col.name().to_string();
292 obj.insert(key, sqlite_value(row, i, col.type_info().name()));
293 }
294 out.push(Value::Object(obj));
295 }
296 Ok((out, truncated))
297}
298
299fn sqlite_value(row: &sqlx::sqlite::SqliteRow, idx: usize, ty: &str) -> Value {
300 if let Ok(Some(v)) = row.try_get::<Option<i64>, _>(idx) {
301 return json!(v);
302 }
303 if let Ok(Some(v)) = row.try_get::<Option<f64>, _>(idx) {
304 return json!(v);
305 }
306 if let Ok(Some(v)) = row.try_get::<Option<bool>, _>(idx) {
307 return json!(v);
308 }
309 if let Ok(Some(v)) = row.try_get::<Option<String>, _>(idx) {
310 return json!(v);
311 }
312 if let Ok(None::<String>) = row.try_get::<Option<String>, _>(idx) {
313 return Value::Null;
314 }
315 json!({ "_unknown_type": ty })
316}