Skip to main content

boost/tools/
database.rs

1//! Database introspection tools: `database-schema` and `database-query`.
2//!
3//! Both are strictly read-only. `database-query` rejects any statement that
4//! doesn't start with SELECT/WITH/EXPLAIN/SHOW/PRAGMA (case-insensitive).
5
6use async_trait::async_trait;
7use serde_json::{json, Value};
8use sqlx::{Column, Row, TypeInfo};
9
10use crate::protocol::CallToolResult;
11use crate::tool::{Context, Tool};
12
13// ─── database-schema ────────────────────────────────────────────────────────
14
15pub struct DatabaseSchema;
16
17#[async_trait]
18impl Tool for DatabaseSchema {
19    fn name(&self) -> &'static str {
20        "database-schema"
21    }
22    fn description(&self) -> &'static str {
23        "Dump the live database schema: every table and its columns, types, and nullability. Reads from information_schema (Postgres/MySQL) or sqlite_master (SQLite)."
24    }
25
26    async fn call(&self, ctx: &Context, _args: Value) -> CallToolResult {
27        let pool = ctx.container.driver_pool();
28        match pool {
29            cast_core::Pool::Postgres(p) => {
30                let rows: Result<Vec<(String, String, String, String)>, _> = sqlx::query_as(
31                    "SELECT table_name::TEXT, column_name::TEXT, data_type::TEXT, is_nullable::TEXT
32                       FROM information_schema.columns
33                      WHERE table_schema = 'public'
34                      ORDER BY table_name, ordinal_position",
35                )
36                .fetch_all(&p)
37                .await;
38                pg_my_to_result(rows, "postgres")
39            }
40            cast_core::Pool::MySql(p) => {
41                let rows: Result<Vec<(String, String, String, String)>, _> = sqlx::query_as(
42                    "SELECT table_name, column_name, column_type, is_nullable
43                       FROM information_schema.columns
44                      WHERE table_schema = DATABASE()
45                      ORDER BY table_name, ordinal_position",
46                )
47                .fetch_all(&p)
48                .await;
49                pg_my_to_result(rows, "mysql")
50            }
51            cast_core::Pool::Sqlite(p) => sqlite_schema(&p).await,
52        }
53    }
54}
55
56fn pg_my_to_result(
57    rows: Result<Vec<(String, String, String, String)>, sqlx::Error>,
58    driver: &str,
59) -> CallToolResult {
60    let rows = match rows {
61        Ok(r) => r,
62        Err(e) => return CallToolResult::error(format!("schema query failed: {e}")),
63    };
64    let mut by_table: indexmap::IndexMap<String, Vec<Value>> = indexmap::IndexMap::new();
65    for (table, col, ty, nullable) in rows {
66        by_table.entry(table).or_default().push(json!({
67            "name": col,
68            "type": ty,
69            "nullable": nullable.eq_ignore_ascii_case("yes"),
70        }));
71    }
72    CallToolResult::json(&json!({
73        "driver": driver,
74        "tables": by_table,
75    }))
76}
77
78async fn sqlite_schema(p: &sqlx::SqlitePool) -> CallToolResult {
79    let tables: Result<Vec<(String,)>, _> = sqlx::query_as(
80        "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name",
81    )
82    .fetch_all(p)
83    .await;
84    let tables = match tables {
85        Ok(t) => t,
86        Err(e) => return CallToolResult::error(format!("schema query failed: {e}")),
87    };
88    let mut by_table = serde_json::Map::new();
89    for (name,) in tables {
90        let cols: Result<Vec<(i64, String, String, i64, Option<String>, i64)>, _> =
91            sqlx::query_as(&format!("PRAGMA table_info({name})"))
92                .fetch_all(p)
93                .await;
94        if let Ok(rows) = cols {
95            let columns: Vec<Value> = rows
96                .into_iter()
97                .map(|(cid, col, ty, notnull, dflt, pk)| {
98                    json!({
99                        "cid": cid,
100                        "name": col,
101                        "type": ty,
102                        "nullable": notnull == 0,
103                        "default": dflt,
104                        "pk": pk != 0,
105                    })
106                })
107                .collect();
108            by_table.insert(name, Value::Array(columns));
109        }
110    }
111    CallToolResult::json(&json!({
112        "driver": "sqlite",
113        "tables": by_table,
114    }))
115}
116
117// ─── database-query ─────────────────────────────────────────────────────────
118
119pub struct DatabaseQuery;
120
121#[async_trait]
122impl Tool for DatabaseQuery {
123    fn name(&self) -> &'static str {
124        "database-query"
125    }
126    fn description(&self) -> &'static str {
127        "Run a read-only SQL query and return rows as JSON. Only SELECT, WITH, EXPLAIN, SHOW, and PRAGMA are accepted. Optional `limit` caps the result count (default 100, max 1000)."
128    }
129    fn input_schema(&self) -> Value {
130        json!({
131            "type": "object",
132            "required": ["sql"],
133            "properties": {
134                "sql": { "type": "string", "description": "Read-only SQL statement." },
135                "limit": { "type": "integer", "description": "Max rows to return.", "default": 100, "minimum": 1, "maximum": 1000 }
136            }
137        })
138    }
139
140    async fn call(&self, ctx: &Context, args: Value) -> CallToolResult {
141        let sql = match args.get("sql").and_then(|v| v.as_str()) {
142            Some(s) if !s.trim().is_empty() => s.trim().to_string(),
143            _ => return CallToolResult::error("`sql` is required"),
144        };
145        if !is_readonly(&sql) {
146            return CallToolResult::error(
147                "rejected: only SELECT, WITH, EXPLAIN, SHOW, and PRAGMA statements are allowed by database-query",
148            );
149        }
150        let limit = args
151            .get("limit")
152            .and_then(|v| v.as_u64())
153            .unwrap_or(100)
154            .min(1000)
155            .max(1) as usize;
156
157        let pool = ctx.container.driver_pool();
158        let result = match pool {
159            cast_core::Pool::Postgres(p) => run_postgres(&sql, &p, limit).await,
160            cast_core::Pool::MySql(p) => run_mysql(&sql, &p, limit).await,
161            cast_core::Pool::Sqlite(p) => run_sqlite(&sql, &p, limit).await,
162        };
163        match result {
164            Ok((rows, truncated)) => CallToolResult::json(&json!({
165                "rows": rows,
166                "row_count": rows.len(),
167                "truncated": truncated,
168            })),
169            Err(e) => CallToolResult::error(e),
170        }
171    }
172}
173
174fn is_readonly(sql: &str) -> bool {
175    let s = sql.trim_start();
176    let lower = s.to_ascii_lowercase();
177    for p in ["select", "with", "explain", "show", "pragma"] {
178        if lower.starts_with(p) {
179            return true;
180        }
181    }
182    false
183}
184
185async fn run_postgres(
186    sql: &str,
187    pool: &sqlx::PgPool,
188    limit: usize,
189) -> Result<(Vec<Value>, bool), String> {
190    let rows = sqlx::query(sql)
191        .fetch_all(pool)
192        .await
193        .map_err(|e| format!("query error: {e}"))?;
194    let truncated = rows.len() > limit;
195    let take_n = rows.len().min(limit);
196    let mut out = Vec::with_capacity(take_n);
197    for row in rows.iter().take(take_n) {
198        let mut obj = serde_json::Map::new();
199        for (i, col) in row.columns().iter().enumerate() {
200            let key = col.name().to_string();
201            let value = pg_value(row, i, col.type_info().name());
202            obj.insert(key, value);
203        }
204        out.push(Value::Object(obj));
205    }
206    Ok((out, truncated))
207}
208
209fn pg_value(row: &sqlx::postgres::PgRow, idx: usize, ty: &str) -> Value {
210    if let Ok(Some(v)) = row.try_get::<Option<i64>, _>(idx) {
211        return json!(v);
212    }
213    if let Ok(Some(v)) = row.try_get::<Option<i32>, _>(idx) {
214        return json!(v);
215    }
216    if let Ok(Some(v)) = row.try_get::<Option<f64>, _>(idx) {
217        return json!(v);
218    }
219    if let Ok(Some(v)) = row.try_get::<Option<bool>, _>(idx) {
220        return json!(v);
221    }
222    if let Ok(Some(v)) = row.try_get::<Option<String>, _>(idx) {
223        return json!(v);
224    }
225    if let Ok(Some(v)) = row.try_get::<Option<serde_json::Value>, _>(idx) {
226        return v;
227    }
228    if let Ok(None::<String>) = row.try_get::<Option<String>, _>(idx) {
229        return Value::Null;
230    }
231    json!({ "_unknown_type": ty })
232}
233
234async fn run_mysql(
235    sql: &str,
236    pool: &sqlx::MySqlPool,
237    limit: usize,
238) -> Result<(Vec<Value>, bool), String> {
239    let rows = sqlx::query(sql)
240        .fetch_all(pool)
241        .await
242        .map_err(|e| format!("query error: {e}"))?;
243    let truncated = rows.len() > limit;
244    let take_n = rows.len().min(limit);
245    let mut out = Vec::with_capacity(take_n);
246    for row in rows.iter().take(take_n) {
247        let mut obj = serde_json::Map::new();
248        for (i, col) in row.columns().iter().enumerate() {
249            let key = col.name().to_string();
250            obj.insert(key, mysql_value(row, i, col.type_info().name()));
251        }
252        out.push(Value::Object(obj));
253    }
254    Ok((out, truncated))
255}
256
257fn mysql_value(row: &sqlx::mysql::MySqlRow, idx: usize, ty: &str) -> Value {
258    if let Ok(Some(v)) = row.try_get::<Option<i64>, _>(idx) {
259        return json!(v);
260    }
261    if let Ok(Some(v)) = row.try_get::<Option<f64>, _>(idx) {
262        return json!(v);
263    }
264    if let Ok(Some(v)) = row.try_get::<Option<bool>, _>(idx) {
265        return json!(v);
266    }
267    if let Ok(Some(v)) = row.try_get::<Option<String>, _>(idx) {
268        return json!(v);
269    }
270    if let Ok(None::<String>) = row.try_get::<Option<String>, _>(idx) {
271        return Value::Null;
272    }
273    json!({ "_unknown_type": ty })
274}
275
276async fn run_sqlite(
277    sql: &str,
278    pool: &sqlx::SqlitePool,
279    limit: usize,
280) -> Result<(Vec<Value>, bool), String> {
281    let rows = sqlx::query(sql)
282        .fetch_all(pool)
283        .await
284        .map_err(|e| format!("query error: {e}"))?;
285    let truncated = rows.len() > limit;
286    let take_n = rows.len().min(limit);
287    let mut out = Vec::with_capacity(take_n);
288    for row in rows.iter().take(take_n) {
289        let mut obj = serde_json::Map::new();
290        for (i, col) in row.columns().iter().enumerate() {
291            let key = col.name().to_string();
292            obj.insert(key, sqlite_value(row, i, col.type_info().name()));
293        }
294        out.push(Value::Object(obj));
295    }
296    Ok((out, truncated))
297}
298
299fn sqlite_value(row: &sqlx::sqlite::SqliteRow, idx: usize, ty: &str) -> Value {
300    if let Ok(Some(v)) = row.try_get::<Option<i64>, _>(idx) {
301        return json!(v);
302    }
303    if let Ok(Some(v)) = row.try_get::<Option<f64>, _>(idx) {
304        return json!(v);
305    }
306    if let Ok(Some(v)) = row.try_get::<Option<bool>, _>(idx) {
307        return json!(v);
308    }
309    if let Ok(Some(v)) = row.try_get::<Option<String>, _>(idx) {
310        return json!(v);
311    }
312    if let Ok(None::<String>) = row.try_get::<Option<String>, _>(idx) {
313        return Value::Null;
314    }
315    json!({ "_unknown_type": ty })
316}