Skip to main content

datapress_core/
sql.rs

1//! Shared safety gate for the raw-SQL endpoint (`POST /api/v1/sql`).
2//!
3//! Raw SQL is a much larger attack surface than the structured `/query`
4//! endpoint, so every statement is parsed and validated *before* it is
5//! handed to a backend engine. The same gate runs for DuckDB and
6//! DataFusion, giving both backends identical safety semantics — and
7//! keeping the "which tables may this query touch?" policy in one place.
8//!
9//! Guarantees enforced by [`validate`]:
10//! - exactly one statement, and it is a read-only `SELECT` / `WITH … SELECT`,
11//! - every referenced table is a registered dataset — no file-reading
12//!   table functions (`read_parquet`, `read_csv`, …), no unknown tables,
13//! - no file-reading scalar functions (`read_text`, `read_blob`, …),
14//! - at most `max_datasets` distinct datasets are referenced. Phase 1
15//!   passes `1`, enforcing the single-dataset rule; raising this bound is
16//!   all that's needed to allow cross-dataset joins later.
17//!
18//! CTE-defined names are tracked per query scope and excluded from the
19//! dataset allowlist check, so `WITH t AS (SELECT … FROM events) SELECT …`
20//! is accepted (it still only touches `events`).
21
22use std::collections::{HashMap, HashSet};
23use std::ops::ControlFlow;
24
25use sqlparser::ast::{
26    Expr, Ident, ObjectName, ObjectNamePart, Query, Statement, Visit, VisitMut, Visitor,
27    VisitorMut,
28};
29use sqlparser::dialect::GenericDialect;
30use sqlparser::parser::Parser;
31
32use crate::errors::AppError;
33
34/// File-reading / external-access functions that must never run through
35/// the SQL endpoint, in either table or scalar position. Table-position
36/// functions are already blocked by the relation allowlist; this list
37/// closes the scalar-position gap (e.g. `SELECT read_text('/etc/passwd')`).
38const DENIED_FUNCTIONS: &[&str] = &[
39    "read_text",
40    "read_blob",
41    "read_csv",
42    "read_csv_auto",
43    "read_parquet",
44    "parquet_scan",
45    "read_json",
46    "read_json_auto",
47    "read_json_objects",
48    "read_ndjson",
49    "read_ndjson_auto",
50    "read_ndjson_objects",
51    "sniff_csv",
52    "glob",
53];
54
55/// A validated, ready-to-execute SQL query.
56#[derive(Debug)]
57pub struct ValidatedSql {
58    /// The trimmed, semicolon-free SQL string, safe to wrap and execute.
59    pub sql: String,
60    /// The distinct dataset names the query references (lowercased). Empty
61    /// for table-less queries such as `SELECT 1`.
62    pub datasets: Vec<String>,
63}
64
65/// Validate `sql` for the raw-SQL endpoint.
66///
67/// `allowed` is the set of registered dataset names, **lowercased** by the
68/// caller (matching is case-insensitive). `max_datasets` caps how many
69/// distinct datasets a single statement may touch (phase 1 = `1`).
70///
71/// On success returns the cleaned SQL ready to be wrapped in an outer
72/// `LIMIT` and executed by the backend.
73pub fn validate(
74    sql: &str,
75    allowed: &HashSet<String>,
76    max_datasets: usize,
77) -> Result<ValidatedSql, AppError> {
78    let trimmed = sql.trim().trim_end_matches(';').trim();
79    if trimmed.is_empty() {
80        return Err(AppError::InvalidValue("sql must not be empty".into()));
81    }
82
83    let statements = Parser::parse_sql(&GenericDialect {}, trimmed)
84        .map_err(|e| AppError::InvalidValue(format!("could not parse SQL: {e}")))?;
85    if statements.len() != 1 {
86        return Err(AppError::InvalidValue(
87            "exactly one SQL statement is allowed".into(),
88        ));
89    }
90    let stmt = &statements[0];
91    if !matches!(stmt, Statement::Query(_)) {
92        return Err(AppError::InvalidValue(
93            "only read-only SELECT queries are allowed".into(),
94        ));
95    }
96
97    let mut checker = ScopeCheck {
98        allowed,
99        cte_names: HashSet::new(),
100        referenced: HashSet::new(),
101        violation: None,
102    };
103    let _ = stmt.visit(&mut checker);
104    if let Some(err) = checker.violation {
105        return Err(AppError::InvalidValue(err));
106    }
107
108    let mut datasets: Vec<String> = checker.referenced.into_iter().collect();
109    datasets.sort();
110    if datasets.len() > max_datasets {
111        return Err(AppError::InvalidValue(format!(
112            "this endpoint allows at most {max_datasets} dataset(s) per query; \
113             the statement references {}",
114            datasets.len()
115        )));
116    }
117
118    Ok(ValidatedSql {
119        sql: trimmed.to_string(),
120        datasets,
121    })
122}
123
124/// Rewrite references to registered tables and their columns so they
125/// match **case-insensitively**, the way DuckDB does.
126///
127/// DataFusion lowercases unquoted identifiers by default, so a query like
128/// `SELECT State FROM accidents` is looked up as `state` and fails against
129/// a case-sensitive Parquet column literally named `State`. Rather than
130/// disable normalization globally (which would also make *aliases* and
131/// *CTE names* case-sensitive), we rewrite only the identifiers that name
132/// a known dataset or column into their **canonical casing, quoted**.
133/// Quoted identifiers bypass the engine's lowercasing and match the stored
134/// name exactly, while every other identifier (aliases, CTE names, …) is
135/// left untouched so the engine's normal case-insensitive handling still
136/// applies.
137///
138/// `tables` and `columns` map a **lowercased** name to its canonical
139/// spelling. On any parse failure the input is returned unchanged so the
140/// backend can surface a meaningful error.
141pub fn canonicalize_identifiers(
142    sql: &str,
143    tables: &HashMap<String, String>,
144    columns: &HashMap<String, String>,
145) -> String {
146    let mut statements = match Parser::parse_sql(&GenericDialect {}, sql) {
147        Ok(s) if s.len() == 1 => s,
148        _ => return sql.to_string(),
149    };
150    let mut canon = Canonicalizer { tables, columns };
151    let _ = VisitMut::visit(&mut statements[0], &mut canon);
152    statements[0].to_string()
153}
154
155struct Canonicalizer<'a> {
156    tables: &'a HashMap<String, String>,
157    columns: &'a HashMap<String, String>,
158}
159
160impl Canonicalizer<'_> {
161    /// If `ident` (case-folded) names something in `map`, replace it with
162    /// the canonical spelling and force double-quoting so the engine keeps
163    /// the exact case.
164    fn rewrite(ident: &mut Ident, map: &HashMap<String, String>) {
165        if let Some(canonical) = map.get(&ident.value.to_lowercase()) {
166            ident.value = canonical.clone();
167            ident.quote_style = Some('"');
168        }
169    }
170}
171
172impl VisitorMut for Canonicalizer<'_> {
173    type Break = ();
174
175    fn pre_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
176        for part in relation.0.iter_mut() {
177            if let ObjectNamePart::Identifier(ident) = part {
178                Self::rewrite(ident, self.tables);
179            }
180        }
181        ControlFlow::Continue(())
182    }
183
184    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
185        match expr {
186            // Bare column reference: `State`.
187            Expr::Identifier(ident) => Self::rewrite(ident, self.columns),
188            // Qualified reference: `accidents.State` / `a.State`. The last
189            // part is the column; earlier parts qualify it with a table or
190            // alias (only real table names are rewritten).
191            Expr::CompoundIdentifier(idents) => {
192                if let Some((column, qualifiers)) = idents.split_last_mut() {
193                    Self::rewrite(column, self.columns);
194                    for qualifier in qualifiers {
195                        Self::rewrite(qualifier, self.tables);
196                    }
197                }
198            }
199            _ => {}
200        }
201        ControlFlow::Continue(())
202    }
203}
204
205struct ScopeCheck<'a> {
206    allowed: &'a HashSet<String>,
207    cte_names: HashSet<String>,
208    referenced: HashSet<String>,
209    violation: Option<String>,
210}
211
212impl Visitor for ScopeCheck<'_> {
213    type Break = ();
214
215    fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
216        // Record CTE names *before* visiting the query body so references
217        // to them inside the body are recognised and not mistaken for
218        // unknown tables. Nested `WITH` clauses are handled the same way
219        // as the visitor descends into subqueries.
220        if let Some(with) = &query.with {
221            for cte in &with.cte_tables {
222                self.cte_names.insert(cte.alias.name.value.to_lowercase());
223            }
224        }
225        ControlFlow::Continue(())
226    }
227
228    fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
229        let ident = relation
230            .0
231            .last()
232            .and_then(|p| p.as_ident())
233            .map(|i| i.value.to_lowercase())
234            .unwrap_or_default();
235
236        if self.cte_names.contains(&ident) {
237            return ControlFlow::Continue(());
238        }
239        if let Some(name) = self.allowed.get(&ident) {
240            self.referenced.insert(name.clone());
241            return ControlFlow::Continue(());
242        }
243        self.violation = Some(format!(
244            "table '{ident}' is not a registered dataset accessible from the SQL endpoint"
245        ));
246        ControlFlow::Break(())
247    }
248
249    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
250        if let Expr::Function(func) = expr {
251            let fname = func
252                .name
253                .0
254                .last()
255                .and_then(|p| p.as_ident())
256                .map(|i| i.value.to_lowercase())
257                .unwrap_or_default();
258            if DENIED_FUNCTIONS.contains(&fname.as_str()) {
259                self.violation =
260                    Some(format!("function '{fname}' is not allowed in the SQL endpoint"));
261                return ControlFlow::Break(());
262            }
263        }
264        ControlFlow::Continue(())
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    fn allowed(names: &[&str]) -> HashSet<String> {
273        names.iter().map(|s| s.to_lowercase()).collect()
274    }
275
276    #[test]
277    fn accepts_single_dataset_select() {
278        let v = validate("SELECT a, b FROM events WHERE a > 1", &allowed(&["events"]), 1).unwrap();
279        assert_eq!(v.datasets, vec!["events".to_string()]);
280    }
281
282    #[test]
283    fn case_insensitive_table_match() {
284        let v = validate("SELECT * FROM Events", &allowed(&["events"]), 1).unwrap();
285        assert_eq!(v.datasets, vec!["events".to_string()]);
286    }
287
288    #[test]
289    fn strips_trailing_semicolon() {
290        let v = validate("SELECT 1 FROM events;", &allowed(&["events"]), 1).unwrap();
291        assert_eq!(v.sql, "SELECT 1 FROM events");
292    }
293
294    #[test]
295    fn allows_cte_over_single_dataset() {
296        let sql = "WITH t AS (SELECT * FROM events) SELECT count(*) FROM t";
297        let v = validate(sql, &allowed(&["events"]), 1).unwrap();
298        assert_eq!(v.datasets, vec!["events".to_string()]);
299    }
300
301    #[test]
302    fn allows_tableless_select() {
303        let v = validate("SELECT 1 + 1", &allowed(&["events"]), 1).unwrap();
304        assert!(v.datasets.is_empty());
305    }
306
307    #[test]
308    fn rejects_unknown_table() {
309        let err = validate("SELECT * FROM secrets", &allowed(&["events"]), 1).unwrap_err();
310        assert!(matches!(err, AppError::InvalidValue(_)));
311    }
312
313    #[test]
314    fn rejects_second_dataset_join() {
315        let err = validate(
316            "SELECT * FROM events e JOIN other o ON e.id = o.id",
317            &allowed(&["events", "other"]),
318            1,
319        )
320        .unwrap_err();
321        assert!(matches!(err, AppError::InvalidValue(_)));
322    }
323
324    #[test]
325    fn allows_two_datasets_when_limit_raised() {
326        let v = validate(
327            "SELECT * FROM events e JOIN other o ON e.id = o.id",
328            &allowed(&["events", "other"]),
329            2,
330        )
331        .unwrap();
332        assert_eq!(v.datasets.len(), 2);
333    }
334
335    #[test]
336    fn rejects_non_select() {
337        let err = validate("DELETE FROM events", &allowed(&["events"]), 1).unwrap_err();
338        assert!(matches!(err, AppError::InvalidValue(_)));
339    }
340
341    #[test]
342    fn rejects_multiple_statements() {
343        let err = validate("SELECT 1 FROM events; SELECT 2 FROM events", &allowed(&["events"]), 1)
344            .unwrap_err();
345        assert!(matches!(err, AppError::InvalidValue(_)));
346    }
347
348    #[test]
349    fn rejects_file_table_function() {
350        let err = validate("SELECT * FROM read_parquet('/etc/passwd')", &allowed(&["events"]), 1)
351            .unwrap_err();
352        assert!(matches!(err, AppError::InvalidValue(_)));
353    }
354
355    #[test]
356    fn rejects_file_scalar_function() {
357        let err = validate(
358            "SELECT read_text('/etc/passwd') FROM events",
359            &allowed(&["events"]),
360            1,
361        )
362        .unwrap_err();
363        assert!(matches!(err, AppError::InvalidValue(_)));
364    }
365
366    #[test]
367    fn rejects_empty_sql() {
368        let err = validate("   ", &allowed(&["events"]), 1).unwrap_err();
369        assert!(matches!(err, AppError::InvalidValue(_)));
370    }
371
372    fn maps(
373        tables: &[(&str, &str)],
374        columns: &[(&str, &str)],
375    ) -> (HashMap<String, String>, HashMap<String, String>) {
376        let t = tables
377            .iter()
378            .map(|(k, v)| (k.to_string(), v.to_string()))
379            .collect();
380        let c = columns
381            .iter()
382            .map(|(k, v)| (k.to_string(), v.to_string()))
383            .collect();
384        (t, c)
385    }
386
387    #[test]
388    fn canonicalizes_mixed_case_column_and_table() {
389        let (t, c) = maps(
390            &[("accidents", "accidents")],
391            &[("state", "State"), ("id", "ID")],
392        );
393        let out = canonicalize_identifiers(
394            "SELECT state, COUNT(*) AS n FROM Accidents GROUP BY STATE ORDER BY n DESC",
395            &t,
396            &c,
397        );
398        // Column refs become quoted canonical names; the table name is
399        // quoted canonical; the alias `n` is left untouched.
400        assert!(out.contains("\"State\""), "got: {out}");
401        assert!(out.contains("FROM \"accidents\""), "got: {out}");
402        assert!(out.contains("AS n"), "got: {out}");
403        assert!(!out.contains("\"n\""), "alias must not be quoted: {out}");
404    }
405
406    #[test]
407    fn canonicalizes_qualified_column() {
408        let (t, c) = maps(&[("accidents", "accidents")], &[("state", "State")]);
409        let out = canonicalize_identifiers("SELECT a.state FROM accidents AS a", &t, &c);
410        // The column part is canonicalized; the table alias `a` is not a
411        // registered table, so it is left alone.
412        assert!(out.contains("a.\"State\""), "got: {out}");
413    }
414
415    #[test]
416    fn leaves_unknown_identifiers_untouched() {
417        let (t, c) = maps(&[("events", "events")], &[("id", "id")]);
418        let out = canonicalize_identifiers("SELECT foo, bar FROM events", &t, &c);
419        assert!(out.contains("foo"), "got: {out}");
420        assert!(out.contains("bar"), "got: {out}");
421        assert!(!out.contains("\"foo\""), "got: {out}");
422    }
423
424    #[test]
425    fn returns_input_unchanged_on_parse_error() {
426        let (t, c) = maps(&[], &[]);
427        let garbage = "SELECT FROM WHERE";
428        assert_eq!(canonicalize_identifiers(garbage, &t, &c), garbage);
429    }
430}