Skip to main content

datapress_core/
sql.rs

1//! Shared safety gate for the raw-SQL endpoint (`POST /api/v1/sql`).
2//!
3//! Raw SQL is a much larger attack surface than the structured `/query`
4//! endpoint, so every statement is parsed and validated *before* it is
5//! handed to a backend engine. The same gate runs for DuckDB and
6//! DataFusion, giving both backends identical safety semantics — and
7//! keeping the "which tables may this query touch?" policy in one place.
8//!
9//! Guarantees enforced by [`validate`]:
10//! - exactly one statement, and it is a read-only `SELECT` / `WITH … SELECT`
11//!   or a `DESCRIBE` / `DESC <table>` schema lookup,
12//! - every referenced table is a registered dataset — no file-reading
13//!   table functions (`read_parquet`, `read_csv`, …), no unknown tables,
14//! - no file-reading scalar functions (`read_text`, `read_blob`, …),
15//! - at most `max_datasets` distinct datasets are referenced. Phase 1
16//!   passes `1`, enforcing the single-dataset rule; raising this bound is
17//!   all that's needed to allow cross-dataset joins later.
18//!
19//! CTE-defined names are tracked per query scope and excluded from the
20//! dataset allowlist check, so `WITH t AS (SELECT … FROM events) SELECT …`
21//! is accepted (it still only touches `events`).
22
23use std::collections::{HashMap, HashSet};
24use std::ops::ControlFlow;
25
26use sqlparser::ast::{
27    DescribeAlias, Expr, Ident, ObjectName, ObjectNamePart, Query, Statement, Visit, VisitMut,
28    Visitor, VisitorMut,
29};
30use sqlparser::dialect::GenericDialect;
31use sqlparser::parser::Parser;
32
33use crate::errors::AppError;
34
35/// File-reading / external-access functions that must never run through
36/// the SQL endpoint, in either table or scalar position. Table-position
37/// functions are already blocked by the relation allowlist; this list
38/// closes the scalar-position gap (e.g. `SELECT read_text('/etc/passwd')`).
39const DENIED_FUNCTIONS: &[&str] = &[
40    "read_text",
41    "read_blob",
42    "read_csv",
43    "read_csv_auto",
44    "read_parquet",
45    "parquet_scan",
46    "read_json",
47    "read_json_auto",
48    "read_json_objects",
49    "read_ndjson",
50    "read_ndjson_auto",
51    "read_ndjson_objects",
52    "sniff_csv",
53    "glob",
54];
55
56/// A validated, ready-to-execute SQL query.
57#[derive(Debug)]
58pub struct ValidatedSql {
59    /// The trimmed, semicolon-free SQL string, safe to wrap and execute.
60    pub sql: String,
61    /// The distinct dataset names the query references (lowercased). Empty
62    /// for table-less queries such as `SELECT 1`.
63    pub datasets: Vec<String>,
64}
65
66/// Validate `sql` for the raw-SQL endpoint.
67///
68/// Accepts a single read-only `SELECT` / `WITH … SELECT` or a `DESCRIBE` /
69/// `DESC <table>` statement. `allowed` is the set of registered dataset
70/// caller (matching is case-insensitive). `max_datasets` caps how many
71/// distinct datasets a single statement may touch (phase 1 = `1`).
72///
73/// On success returns the cleaned SQL ready to be wrapped in an outer
74/// `LIMIT` and executed by the backend (DESCRIBE is run as-is; see
75/// [`is_describe`]).
76pub fn validate(
77    sql: &str,
78    allowed: &HashSet<String>,
79    max_datasets: usize,
80) -> Result<ValidatedSql, AppError> {
81    let trimmed = sql.trim().trim_end_matches(';').trim();
82    if trimmed.is_empty() {
83        return Err(AppError::InvalidValue("sql must not be empty".into()));
84    }
85
86    let statements = Parser::parse_sql(&GenericDialect {}, trimmed)
87        .map_err(|e| AppError::InvalidValue(format!("could not parse SQL: {e}")))?;
88    if statements.len() != 1 {
89        return Err(AppError::InvalidValue(
90            "exactly one SQL statement is allowed".into(),
91        ));
92    }
93    let stmt = &statements[0];
94    // Read-only statements only: a `SELECT` / `WITH … SELECT`, or a
95    // `DESCRIBE`/`DESC <table>` schema lookup. `DESCRIBE` still flows through
96    // the visitor below, so its target table is subject to the same dataset
97    // allowlist as a query. The `EXPLAIN` alias is deliberately excluded.
98    match stmt {
99        Statement::Query(_) => {}
100        Statement::ExplainTable {
101            describe_alias: DescribeAlias::Describe | DescribeAlias::Desc,
102            ..
103        } => {}
104        _ => {
105            return Err(AppError::InvalidValue(
106                "only read-only SELECT and DESCRIBE statements are allowed".into(),
107            ));
108        }
109    }
110
111    let mut checker = ScopeCheck {
112        allowed,
113        cte_names: HashSet::new(),
114        referenced: HashSet::new(),
115        violation: None,
116    };
117    let _ = stmt.visit(&mut checker);
118    if let Some(err) = checker.violation {
119        return Err(AppError::InvalidValue(err));
120    }
121
122    let mut datasets: Vec<String> = checker.referenced.into_iter().collect();
123    datasets.sort();
124    if datasets.len() > max_datasets {
125        return Err(AppError::InvalidValue(format!(
126            "this endpoint allows at most {max_datasets} dataset(s) per query; \
127             the statement references {}",
128            datasets.len()
129        )));
130    }
131
132    Ok(ValidatedSql {
133        sql: trimmed.to_string(),
134        datasets,
135    })
136}
137
138/// Returns `true` if `sql` is a single `DESCRIBE` / `DESC <table>` statement.
139///
140/// `DESCRIBE` yields a schema listing rather than a row stream and cannot be
141/// nested inside an outer `SELECT … LIMIT` subquery on every backend (notably
142/// DataFusion), so callers run it directly instead of through the bounded
143/// wrapper. Returns `false` for anything that does not parse to exactly one
144/// `DESCRIBE`/`DESC` statement.
145pub fn is_describe(sql: &str) -> bool {
146    let trimmed = sql.trim().trim_end_matches(';').trim();
147    matches!(
148        Parser::parse_sql(&GenericDialect {}, trimmed).as_deref(),
149        Ok([Statement::ExplainTable {
150            describe_alias: DescribeAlias::Describe | DescribeAlias::Desc,
151            ..
152        }])
153    )
154}
155
156/// Rewrite references to registered tables and their columns so they
157/// match **case-insensitively**, the way DuckDB does.
158///
159/// DataFusion lowercases unquoted identifiers by default, so a query like
160/// `SELECT State FROM accidents` is looked up as `state` and fails against
161/// a case-sensitive Parquet column literally named `State`. Rather than
162/// disable normalization globally (which would also make *aliases* and
163/// *CTE names* case-sensitive), we rewrite only the identifiers that name
164/// a known dataset or column into their **canonical casing, quoted**.
165/// Quoted identifiers bypass the engine's lowercasing and match the stored
166/// name exactly, while every other identifier (aliases, CTE names, …) is
167/// left untouched so the engine's normal case-insensitive handling still
168/// applies.
169///
170/// `tables` and `columns` map a **lowercased** name to its canonical
171/// spelling. On any parse failure the input is returned unchanged so the
172/// backend can surface a meaningful error.
173pub fn canonicalize_identifiers(
174    sql: &str,
175    tables: &HashMap<String, String>,
176    columns: &HashMap<String, String>,
177) -> String {
178    let mut statements = match Parser::parse_sql(&GenericDialect {}, sql) {
179        Ok(s) if s.len() == 1 => s,
180        _ => return sql.to_string(),
181    };
182    let mut canon = Canonicalizer { tables, columns };
183    let _ = VisitMut::visit(&mut statements[0], &mut canon);
184    statements[0].to_string()
185}
186
187struct Canonicalizer<'a> {
188    tables: &'a HashMap<String, String>,
189    columns: &'a HashMap<String, String>,
190}
191
192impl Canonicalizer<'_> {
193    /// If `ident` (case-folded) names something in `map`, replace it with
194    /// the canonical spelling and force double-quoting so the engine keeps
195    /// the exact case.
196    fn rewrite(ident: &mut Ident, map: &HashMap<String, String>) {
197        if let Some(canonical) = map.get(&ident.value.to_lowercase()) {
198            ident.value = canonical.clone();
199            ident.quote_style = Some('"');
200        }
201    }
202}
203
204impl VisitorMut for Canonicalizer<'_> {
205    type Break = ();
206
207    fn pre_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
208        for part in relation.0.iter_mut() {
209            if let ObjectNamePart::Identifier(ident) = part {
210                Self::rewrite(ident, self.tables);
211            }
212        }
213        ControlFlow::Continue(())
214    }
215
216    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
217        match expr {
218            // Bare column reference: `State`.
219            Expr::Identifier(ident) => Self::rewrite(ident, self.columns),
220            // Qualified reference: `accidents.State` / `a.State`. The last
221            // part is the column; earlier parts qualify it with a table or
222            // alias (only real table names are rewritten).
223            Expr::CompoundIdentifier(idents) => {
224                if let Some((column, qualifiers)) = idents.split_last_mut() {
225                    Self::rewrite(column, self.columns);
226                    for qualifier in qualifiers {
227                        Self::rewrite(qualifier, self.tables);
228                    }
229                }
230            }
231            _ => {}
232        }
233        ControlFlow::Continue(())
234    }
235}
236
237struct ScopeCheck<'a> {
238    allowed: &'a HashSet<String>,
239    cte_names: HashSet<String>,
240    referenced: HashSet<String>,
241    violation: Option<String>,
242}
243
244impl Visitor for ScopeCheck<'_> {
245    type Break = ();
246
247    fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
248        // Record CTE names *before* visiting the query body so references
249        // to them inside the body are recognised and not mistaken for
250        // unknown tables. Nested `WITH` clauses are handled the same way
251        // as the visitor descends into subqueries.
252        if let Some(with) = &query.with {
253            for cte in &with.cte_tables {
254                self.cte_names.insert(cte.alias.name.value.to_lowercase());
255            }
256        }
257        ControlFlow::Continue(())
258    }
259
260    fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
261        let ident = relation
262            .0
263            .last()
264            .and_then(|p| p.as_ident())
265            .map(|i| i.value.to_lowercase())
266            .unwrap_or_default();
267
268        if self.cte_names.contains(&ident) {
269            return ControlFlow::Continue(());
270        }
271        if let Some(name) = self.allowed.get(&ident) {
272            self.referenced.insert(name.clone());
273            return ControlFlow::Continue(());
274        }
275        self.violation = Some(format!(
276            "table '{ident}' is not a registered dataset accessible from the SQL endpoint"
277        ));
278        ControlFlow::Break(())
279    }
280
281    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
282        if let Expr::Function(func) = expr {
283            let fname = func
284                .name
285                .0
286                .last()
287                .and_then(|p| p.as_ident())
288                .map(|i| i.value.to_lowercase())
289                .unwrap_or_default();
290            if DENIED_FUNCTIONS.contains(&fname.as_str()) {
291                self.violation =
292                    Some(format!("function '{fname}' is not allowed in the SQL endpoint"));
293                return ControlFlow::Break(());
294            }
295        }
296        ControlFlow::Continue(())
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    fn allowed(names: &[&str]) -> HashSet<String> {
305        names.iter().map(|s| s.to_lowercase()).collect()
306    }
307
308    #[test]
309    fn accepts_single_dataset_select() {
310        let v = validate("SELECT a, b FROM events WHERE a > 1", &allowed(&["events"]), 1).unwrap();
311        assert_eq!(v.datasets, vec!["events".to_string()]);
312    }
313
314    #[test]
315    fn case_insensitive_table_match() {
316        let v = validate("SELECT * FROM Events", &allowed(&["events"]), 1).unwrap();
317        assert_eq!(v.datasets, vec!["events".to_string()]);
318    }
319
320    #[test]
321    fn strips_trailing_semicolon() {
322        let v = validate("SELECT 1 FROM events;", &allowed(&["events"]), 1).unwrap();
323        assert_eq!(v.sql, "SELECT 1 FROM events");
324    }
325
326    #[test]
327    fn allows_cte_over_single_dataset() {
328        let sql = "WITH t AS (SELECT * FROM events) SELECT count(*) FROM t";
329        let v = validate(sql, &allowed(&["events"]), 1).unwrap();
330        assert_eq!(v.datasets, vec!["events".to_string()]);
331    }
332
333    #[test]
334    fn allows_tableless_select() {
335        let v = validate("SELECT 1 + 1", &allowed(&["events"]), 1).unwrap();
336        assert!(v.datasets.is_empty());
337    }
338
339    #[test]
340    fn rejects_unknown_table() {
341        let err = validate("SELECT * FROM secrets", &allowed(&["events"]), 1).unwrap_err();
342        assert!(matches!(err, AppError::InvalidValue(_)));
343    }
344
345    #[test]
346    fn rejects_second_dataset_join() {
347        let err = validate(
348            "SELECT * FROM events e JOIN other o ON e.id = o.id",
349            &allowed(&["events", "other"]),
350            1,
351        )
352        .unwrap_err();
353        assert!(matches!(err, AppError::InvalidValue(_)));
354    }
355
356    #[test]
357    fn allows_two_datasets_when_limit_raised() {
358        let v = validate(
359            "SELECT * FROM events e JOIN other o ON e.id = o.id",
360            &allowed(&["events", "other"]),
361            2,
362        )
363        .unwrap();
364        assert_eq!(v.datasets.len(), 2);
365    }
366
367    #[test]
368    fn rejects_non_select() {
369        let err = validate("DELETE FROM events", &allowed(&["events"]), 1).unwrap_err();
370        assert!(matches!(err, AppError::InvalidValue(_)));
371    }
372
373    #[test]
374    fn accepts_describe_table() {
375        let v = validate("DESCRIBE events", &allowed(&["events"]), 1).unwrap();
376        assert_eq!(v.datasets, vec!["events".to_string()]);
377        assert!(is_describe(&v.sql));
378    }
379
380    #[test]
381    fn accepts_desc_table_case_insensitive() {
382        let v = validate("DESC Events", &allowed(&["events"]), 1).unwrap();
383        assert_eq!(v.datasets, vec!["events".to_string()]);
384        assert!(is_describe(&v.sql));
385    }
386
387    #[test]
388    fn describe_rejects_unknown_table() {
389        let err = validate("DESCRIBE secrets", &allowed(&["events"]), 1).unwrap_err();
390        assert!(matches!(err, AppError::InvalidValue(_)));
391    }
392
393    #[test]
394    fn is_describe_false_for_select() {
395        assert!(!is_describe("SELECT * FROM events"));
396        assert!(!is_describe("SELECT 1"));
397    }
398
399    #[test]
400    fn rejects_multiple_statements() {
401        let err = validate("SELECT 1 FROM events; SELECT 2 FROM events", &allowed(&["events"]), 1)
402            .unwrap_err();
403        assert!(matches!(err, AppError::InvalidValue(_)));
404    }
405
406    #[test]
407    fn rejects_file_table_function() {
408        let err = validate("SELECT * FROM read_parquet('/etc/passwd')", &allowed(&["events"]), 1)
409            .unwrap_err();
410        assert!(matches!(err, AppError::InvalidValue(_)));
411    }
412
413    #[test]
414    fn rejects_file_scalar_function() {
415        let err = validate(
416            "SELECT read_text('/etc/passwd') FROM events",
417            &allowed(&["events"]),
418            1,
419        )
420        .unwrap_err();
421        assert!(matches!(err, AppError::InvalidValue(_)));
422    }
423
424    #[test]
425    fn rejects_empty_sql() {
426        let err = validate("   ", &allowed(&["events"]), 1).unwrap_err();
427        assert!(matches!(err, AppError::InvalidValue(_)));
428    }
429
430    fn maps(
431        tables: &[(&str, &str)],
432        columns: &[(&str, &str)],
433    ) -> (HashMap<String, String>, HashMap<String, String>) {
434        let t = tables
435            .iter()
436            .map(|(k, v)| (k.to_string(), v.to_string()))
437            .collect();
438        let c = columns
439            .iter()
440            .map(|(k, v)| (k.to_string(), v.to_string()))
441            .collect();
442        (t, c)
443    }
444
445    #[test]
446    fn canonicalizes_mixed_case_column_and_table() {
447        let (t, c) = maps(
448            &[("accidents", "accidents")],
449            &[("state", "State"), ("id", "ID")],
450        );
451        let out = canonicalize_identifiers(
452            "SELECT state, COUNT(*) AS n FROM Accidents GROUP BY STATE ORDER BY n DESC",
453            &t,
454            &c,
455        );
456        // Column refs become quoted canonical names; the table name is
457        // quoted canonical; the alias `n` is left untouched.
458        assert!(out.contains("\"State\""), "got: {out}");
459        assert!(out.contains("FROM \"accidents\""), "got: {out}");
460        assert!(out.contains("AS n"), "got: {out}");
461        assert!(!out.contains("\"n\""), "alias must not be quoted: {out}");
462    }
463
464    #[test]
465    fn canonicalizes_qualified_column() {
466        let (t, c) = maps(&[("accidents", "accidents")], &[("state", "State")]);
467        let out = canonicalize_identifiers("SELECT a.state FROM accidents AS a", &t, &c);
468        // The column part is canonicalized; the table alias `a` is not a
469        // registered table, so it is left alone.
470        assert!(out.contains("a.\"State\""), "got: {out}");
471    }
472
473    #[test]
474    fn leaves_unknown_identifiers_untouched() {
475        let (t, c) = maps(&[("events", "events")], &[("id", "id")]);
476        let out = canonicalize_identifiers("SELECT foo, bar FROM events", &t, &c);
477        assert!(out.contains("foo"), "got: {out}");
478        assert!(out.contains("bar"), "got: {out}");
479        assert!(!out.contains("\"foo\""), "got: {out}");
480    }
481
482    #[test]
483    fn returns_input_unchanged_on_parse_error() {
484        let (t, c) = maps(&[], &[]);
485        let garbage = "SELECT FROM WHERE";
486        assert_eq!(canonicalize_identifiers(garbage, &t, &c), garbage);
487    }
488}