Skip to main content

boost/tools/
database.rs

1//! Database introspection tools: `database-schema` and `database-query`.
2//!
3//! Both are strictly read-only. `database-query` rejects any statement that
4//! doesn't start with SELECT/WITH/EXPLAIN/SHOW/PRAGMA (case-insensitive).
5
6use async_trait::async_trait;
7use serde_json::{json, Value};
8use sqlx::{Column, Row, TypeInfo};
9
10use crate::protocol::CallToolResult;
11use crate::tool::{Context, Tool};
12
13/// Row schema for `PRAGMA table_info(...)`: (cid, name, type, notnull, dflt_value, pk).
14type SqliteColumnRow = (i64, String, String, i64, Option<String>, i64);
15
16// ─── database-schema ────────────────────────────────────────────────────────
17
18pub struct DatabaseSchema;
19
20#[async_trait]
21impl Tool for DatabaseSchema {
22    fn name(&self) -> &'static str {
23        "database-schema"
24    }
25    fn description(&self) -> &'static str {
26        "Dump the live database schema: every table and its columns, types, and nullability. Reads from information_schema (Postgres/MySQL) or sqlite_master (SQLite)."
27    }
28
29    async fn call(&self, ctx: &Context, _args: Value) -> CallToolResult {
30        let pool = ctx.container.driver_pool();
31        match pool {
32            cast_core::Pool::Postgres(p) => {
33                let rows: Result<Vec<(String, String, String, String)>, _> = sqlx::query_as(
34                    "SELECT table_name::TEXT, column_name::TEXT, data_type::TEXT, is_nullable::TEXT
35                       FROM information_schema.columns
36                      WHERE table_schema = 'public'
37                      ORDER BY table_name, ordinal_position",
38                )
39                .fetch_all(&p)
40                .await;
41                pg_my_to_result(rows, "postgres")
42            }
43            cast_core::Pool::MySql(p) => {
44                let rows: Result<Vec<(String, String, String, String)>, _> = sqlx::query_as(
45                    "SELECT table_name, column_name, column_type, is_nullable
46                       FROM information_schema.columns
47                      WHERE table_schema = DATABASE()
48                      ORDER BY table_name, ordinal_position",
49                )
50                .fetch_all(&p)
51                .await;
52                pg_my_to_result(rows, "mysql")
53            }
54            cast_core::Pool::Sqlite(p) => sqlite_schema(&p).await,
55        }
56    }
57}
58
59fn pg_my_to_result(
60    rows: Result<Vec<(String, String, String, String)>, sqlx::Error>,
61    driver: &str,
62) -> CallToolResult {
63    let rows = match rows {
64        Ok(r) => r,
65        Err(e) => return CallToolResult::error(format!("schema query failed: {e}")),
66    };
67    let mut by_table: indexmap::IndexMap<String, Vec<Value>> = indexmap::IndexMap::new();
68    for (table, col, ty, nullable) in rows {
69        by_table.entry(table).or_default().push(json!({
70            "name": col,
71            "type": ty,
72            "nullable": nullable.eq_ignore_ascii_case("yes"),
73        }));
74    }
75    CallToolResult::json(&json!({
76        "driver": driver,
77        "tables": by_table,
78    }))
79}
80
81async fn sqlite_schema(p: &sqlx::SqlitePool) -> CallToolResult {
82    let tables: Result<Vec<(String,)>, _> = sqlx::query_as(
83        "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name",
84    )
85    .fetch_all(p)
86    .await;
87    let tables = match tables {
88        Ok(t) => t,
89        Err(e) => return CallToolResult::error(format!("schema query failed: {e}")),
90    };
91    let mut by_table = serde_json::Map::new();
92    for (name,) in tables {
93        let cols: Result<Vec<SqliteColumnRow>, _> =
94            sqlx::query_as(&format!("PRAGMA table_info({name})"))
95                .fetch_all(p)
96                .await;
97        if let Ok(rows) = cols {
98            let columns: Vec<Value> = rows
99                .into_iter()
100                .map(|(cid, col, ty, notnull, dflt, pk)| {
101                    json!({
102                        "cid": cid,
103                        "name": col,
104                        "type": ty,
105                        "nullable": notnull == 0,
106                        "default": dflt,
107                        "pk": pk != 0,
108                    })
109                })
110                .collect();
111            by_table.insert(name, Value::Array(columns));
112        }
113    }
114    CallToolResult::json(&json!({
115        "driver": "sqlite",
116        "tables": by_table,
117    }))
118}
119
120// ─── database-query ─────────────────────────────────────────────────────────
121
122pub struct DatabaseQuery;
123
124#[async_trait]
125impl Tool for DatabaseQuery {
126    fn name(&self) -> &'static str {
127        "database-query"
128    }
129    fn description(&self) -> &'static str {
130        "Run a read-only SQL query and return rows as JSON. Only SELECT, WITH, EXPLAIN, SHOW, and PRAGMA are accepted. Optional `limit` caps the result count (default 100, max 1000)."
131    }
132    fn input_schema(&self) -> Value {
133        json!({
134            "type": "object",
135            "required": ["sql"],
136            "properties": {
137                "sql": { "type": "string", "description": "Read-only SQL statement." },
138                "limit": { "type": "integer", "description": "Max rows to return.", "default": 100, "minimum": 1, "maximum": 1000 }
139            }
140        })
141    }
142
143    async fn call(&self, ctx: &Context, args: Value) -> CallToolResult {
144        let sql = match args.get("sql").and_then(|v| v.as_str()) {
145            Some(s) if !s.trim().is_empty() => s.trim().to_string(),
146            _ => return CallToolResult::error("`sql` is required"),
147        };
148        if !is_readonly(&sql) {
149            return CallToolResult::error(
150                "rejected: only SELECT, WITH, EXPLAIN, SHOW, and PRAGMA statements are allowed by database-query",
151            );
152        }
153        let limit = args
154            .get("limit")
155            .and_then(|v| v.as_u64())
156            .unwrap_or(100)
157            .clamp(1, 1000) as usize;
158
159        let pool = ctx.container.driver_pool();
160        let result = match pool {
161            cast_core::Pool::Postgres(p) => run_postgres(&sql, &p, limit).await,
162            cast_core::Pool::MySql(p) => run_mysql(&sql, &p, limit).await,
163            cast_core::Pool::Sqlite(p) => run_sqlite(&sql, &p, limit).await,
164        };
165        match result {
166            Ok((rows, truncated)) => CallToolResult::json(&json!({
167                "rows": rows,
168                "row_count": rows.len(),
169                "truncated": truncated,
170            })),
171            Err(e) => CallToolResult::error(e),
172        }
173    }
174}
175
176fn is_readonly(sql: &str) -> bool {
177    let s = sql.trim_start();
178    let lower = s.to_ascii_lowercase();
179    for p in ["select", "with", "explain", "show", "pragma"] {
180        if lower.starts_with(p) {
181            return true;
182        }
183    }
184    false
185}
186
187async fn run_postgres(
188    sql: &str,
189    pool: &sqlx::PgPool,
190    limit: usize,
191) -> Result<(Vec<Value>, bool), String> {
192    let rows = sqlx::query(sql)
193        .fetch_all(pool)
194        .await
195        .map_err(|e| format!("query error: {e}"))?;
196    let truncated = rows.len() > limit;
197    let take_n = rows.len().min(limit);
198    let mut out = Vec::with_capacity(take_n);
199    for row in rows.iter().take(take_n) {
200        let mut obj = serde_json::Map::new();
201        for (i, col) in row.columns().iter().enumerate() {
202            let key = col.name().to_string();
203            let value = pg_value(row, i, col.type_info().name());
204            obj.insert(key, value);
205        }
206        out.push(Value::Object(obj));
207    }
208    Ok((out, truncated))
209}
210
211fn pg_value(row: &sqlx::postgres::PgRow, idx: usize, ty: &str) -> Value {
212    if let Ok(Some(v)) = row.try_get::<Option<i64>, _>(idx) {
213        return json!(v);
214    }
215    if let Ok(Some(v)) = row.try_get::<Option<i32>, _>(idx) {
216        return json!(v);
217    }
218    if let Ok(Some(v)) = row.try_get::<Option<f64>, _>(idx) {
219        return json!(v);
220    }
221    if let Ok(Some(v)) = row.try_get::<Option<bool>, _>(idx) {
222        return json!(v);
223    }
224    if let Ok(Some(v)) = row.try_get::<Option<String>, _>(idx) {
225        return json!(v);
226    }
227    if let Ok(Some(v)) = row.try_get::<Option<serde_json::Value>, _>(idx) {
228        return v;
229    }
230    if let Ok(None::<String>) = row.try_get::<Option<String>, _>(idx) {
231        return Value::Null;
232    }
233    json!({ "_unknown_type": ty })
234}
235
236async fn run_mysql(
237    sql: &str,
238    pool: &sqlx::MySqlPool,
239    limit: usize,
240) -> Result<(Vec<Value>, bool), String> {
241    let rows = sqlx::query(sql)
242        .fetch_all(pool)
243        .await
244        .map_err(|e| format!("query error: {e}"))?;
245    let truncated = rows.len() > limit;
246    let take_n = rows.len().min(limit);
247    let mut out = Vec::with_capacity(take_n);
248    for row in rows.iter().take(take_n) {
249        let mut obj = serde_json::Map::new();
250        for (i, col) in row.columns().iter().enumerate() {
251            let key = col.name().to_string();
252            obj.insert(key, mysql_value(row, i, col.type_info().name()));
253        }
254        out.push(Value::Object(obj));
255    }
256    Ok((out, truncated))
257}
258
259fn mysql_value(row: &sqlx::mysql::MySqlRow, idx: usize, ty: &str) -> Value {
260    if let Ok(Some(v)) = row.try_get::<Option<i64>, _>(idx) {
261        return json!(v);
262    }
263    if let Ok(Some(v)) = row.try_get::<Option<f64>, _>(idx) {
264        return json!(v);
265    }
266    if let Ok(Some(v)) = row.try_get::<Option<bool>, _>(idx) {
267        return json!(v);
268    }
269    if let Ok(Some(v)) = row.try_get::<Option<String>, _>(idx) {
270        return json!(v);
271    }
272    if let Ok(None::<String>) = row.try_get::<Option<String>, _>(idx) {
273        return Value::Null;
274    }
275    json!({ "_unknown_type": ty })
276}
277
278async fn run_sqlite(
279    sql: &str,
280    pool: &sqlx::SqlitePool,
281    limit: usize,
282) -> Result<(Vec<Value>, bool), String> {
283    let rows = sqlx::query(sql)
284        .fetch_all(pool)
285        .await
286        .map_err(|e| format!("query error: {e}"))?;
287    let truncated = rows.len() > limit;
288    let take_n = rows.len().min(limit);
289    let mut out = Vec::with_capacity(take_n);
290    for row in rows.iter().take(take_n) {
291        let mut obj = serde_json::Map::new();
292        for (i, col) in row.columns().iter().enumerate() {
293            let key = col.name().to_string();
294            obj.insert(key, sqlite_value(row, i, col.type_info().name()));
295        }
296        out.push(Value::Object(obj));
297    }
298    Ok((out, truncated))
299}
300
301fn sqlite_value(row: &sqlx::sqlite::SqliteRow, idx: usize, ty: &str) -> Value {
302    if let Ok(Some(v)) = row.try_get::<Option<i64>, _>(idx) {
303        return json!(v);
304    }
305    if let Ok(Some(v)) = row.try_get::<Option<f64>, _>(idx) {
306        return json!(v);
307    }
308    if let Ok(Some(v)) = row.try_get::<Option<bool>, _>(idx) {
309        return json!(v);
310    }
311    if let Ok(Some(v)) = row.try_get::<Option<String>, _>(idx) {
312        return json!(v);
313    }
314    if let Ok(None::<String>) = row.try_get::<Option<String>, _>(idx) {
315        return Value::Null;
316    }
317    json!({ "_unknown_type": ty })
318}