1use std::collections::{HashMap, HashSet};
24use std::ops::ControlFlow;
25
26use sqlparser::ast::{
27 DescribeAlias, Expr, Ident, ObjectName, ObjectNamePart, Query, SelectItem, SetExpr, Statement,
28 TableFactor, Visit, VisitMut, Visitor, VisitorMut,
29};
30use sqlparser::dialect::GenericDialect;
31use sqlparser::parser::Parser;
32
33use crate::errors::AppError;
34use crate::schema::DatasetSchema;
35
36const DENIED_FUNCTIONS: &[&str] = &[
41 "read_text",
42 "read_blob",
43 "read_csv",
44 "read_csv_auto",
45 "read_parquet",
46 "parquet_scan",
47 "read_json",
48 "read_json_auto",
49 "read_json_objects",
50 "read_ndjson",
51 "read_ndjson_auto",
52 "read_ndjson_objects",
53 "sniff_csv",
54 "glob",
55];
56
57#[derive(Debug)]
59pub struct ValidatedSql {
60 pub sql: String,
62 pub datasets: Vec<String>,
65}
66
67pub fn validate(
78 sql: &str,
79 allowed: &HashSet<String>,
80 max_datasets: usize,
81) -> Result<ValidatedSql, AppError> {
82 let trimmed = sql.trim().trim_end_matches(';').trim();
83 if trimmed.is_empty() {
84 return Err(AppError::InvalidValue("sql must not be empty".into()));
85 }
86
87 let statements = Parser::parse_sql(&GenericDialect {}, trimmed)
88 .map_err(|e| AppError::InvalidValue(format!("could not parse SQL: {e}")))?;
89 if statements.len() != 1 {
90 return Err(AppError::InvalidValue(
91 "exactly one SQL statement is allowed".into(),
92 ));
93 }
94 let stmt = &statements[0];
95 match stmt {
100 Statement::Query(_) => {}
101 Statement::ExplainTable {
102 describe_alias: DescribeAlias::Describe | DescribeAlias::Desc,
103 ..
104 } => {}
105 _ => {
106 return Err(AppError::InvalidValue(
107 "only read-only SELECT and DESCRIBE statements are allowed".into(),
108 ));
109 }
110 }
111
112 let mut checker = ScopeCheck {
113 allowed,
114 cte_names: HashSet::new(),
115 referenced: HashSet::new(),
116 violation: None,
117 };
118 let _ = stmt.visit(&mut checker);
119 if let Some(err) = checker.violation {
120 return Err(AppError::InvalidValue(err));
121 }
122
123 let mut datasets: Vec<String> = checker.referenced.into_iter().collect();
124 datasets.sort();
125 if datasets.len() > max_datasets {
126 return Err(AppError::InvalidValue(format!(
127 "this endpoint allows at most {max_datasets} dataset(s) per query; \
128 the statement references {}",
129 datasets.len()
130 )));
131 }
132
133 Ok(ValidatedSql {
134 sql: trimmed.to_string(),
135 datasets,
136 })
137}
138
139pub fn is_describe(sql: &str) -> bool {
147 let trimmed = sql.trim().trim_end_matches(';').trim();
148 matches!(
149 Parser::parse_sql(&GenericDialect {}, trimmed).as_deref(),
150 Ok([Statement::ExplainTable {
151 describe_alias: DescribeAlias::Describe | DescribeAlias::Desc,
152 ..
153 }])
154 )
155}
156
157pub fn canonicalize_identifiers(
175 sql: &str,
176 tables: &HashMap<String, String>,
177 columns: &HashMap<String, String>,
178) -> String {
179 let mut statements = match Parser::parse_sql(&GenericDialect {}, sql) {
180 Ok(s) if s.len() == 1 => s,
181 _ => return sql.to_string(),
182 };
183 let mut canon = Canonicalizer { tables, columns };
184 let _ = VisitMut::visit(&mut statements[0], &mut canon);
185 statements[0].to_string()
186}
187
188struct Canonicalizer<'a> {
189 tables: &'a HashMap<String, String>,
190 columns: &'a HashMap<String, String>,
191}
192
193impl Canonicalizer<'_> {
194 fn rewrite(ident: &mut Ident, map: &HashMap<String, String>) {
198 if let Some(canonical) = map.get(&ident.value.to_lowercase()) {
199 ident.value = canonical.clone();
200 ident.quote_style = Some('"');
201 }
202 }
203}
204
205impl VisitorMut for Canonicalizer<'_> {
206 type Break = ();
207
208 fn pre_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
209 for part in relation.0.iter_mut() {
210 if let ObjectNamePart::Identifier(ident) = part {
211 Self::rewrite(ident, self.tables);
212 }
213 }
214 ControlFlow::Continue(())
215 }
216
217 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
218 match expr {
219 Expr::Identifier(ident) => Self::rewrite(ident, self.columns),
221 Expr::CompoundIdentifier(idents) => {
225 if let Some((column, qualifiers)) = idents.split_last_mut() {
226 Self::rewrite(column, self.columns);
227 for qualifier in qualifiers {
228 Self::rewrite(qualifier, self.tables);
229 }
230 }
231 }
232 _ => {}
233 }
234 ControlFlow::Continue(())
235 }
236}
237
238struct ScopeCheck<'a> {
239 allowed: &'a HashSet<String>,
240 cte_names: HashSet<String>,
241 referenced: HashSet<String>,
242 violation: Option<String>,
243}
244
245impl Visitor for ScopeCheck<'_> {
246 type Break = ();
247
248 fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
249 if let Some(with) = &query.with {
254 for cte in &with.cte_tables {
255 self.cte_names.insert(cte.alias.name.value.to_lowercase());
256 }
257 }
258 ControlFlow::Continue(())
259 }
260
261 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
262 let ident = relation
263 .0
264 .last()
265 .and_then(|p| p.as_ident())
266 .map(|i| i.value.to_lowercase())
267 .unwrap_or_default();
268
269 if self.cte_names.contains(&ident) {
270 return ControlFlow::Continue(());
271 }
272 if let Some(name) = self.allowed.get(&ident) {
273 self.referenced.insert(name.clone());
274 return ControlFlow::Continue(());
275 }
276 self.violation = Some(format!(
277 "table '{ident}' is not a registered dataset accessible from the SQL endpoint"
278 ));
279 ControlFlow::Break(())
280 }
281
282 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
283 if let Expr::Function(func) = expr {
284 let fname = func
285 .name
286 .0
287 .last()
288 .and_then(|p| p.as_ident())
289 .map(|i| i.value.to_lowercase())
290 .unwrap_or_default();
291 if DENIED_FUNCTIONS.contains(&fname.as_str()) {
292 self.violation =
293 Some(format!("function '{fname}' is not allowed in the SQL endpoint"));
294 return ControlFlow::Break(());
295 }
296 }
297 ControlFlow::Continue(())
298 }
299}
300
301pub fn enforce_column_access(sql: &str, schema: &DatasetSchema) -> Result<(), AppError> {
317 if !schema.has_column_filters() {
318 return Ok(());
319 }
320 let trimmed = sql.trim().trim_end_matches(';').trim();
321 let statements = match Parser::parse_sql(&GenericDialect {}, trimmed) {
322 Ok(s) if s.len() == 1 => s,
323 _ => return Ok(()),
324 };
325 let stmt = &statements[0];
326
327 if schema.projection_filter.is_active() && statement_has_wildcard(stmt) {
328 return Err(AppError::Forbidden(format!(
329 "SELECT * is not allowed on dataset '{}' because it hides columns; \
330 list the columns explicitly",
331 schema.name
332 )));
333 }
334
335 let mut refs = ColumnRefCollector {
336 columns: HashSet::new(),
337 };
338 let _ = stmt.visit(&mut refs);
339 for lc in &refs.columns {
340 let Some(col) = schema.by_name.get(lc).map(|&i| &schema.columns[i]) else {
343 continue;
344 };
345 if !schema.projection_filter.allows(&col.name) {
346 return Err(AppError::UnknownColumn(col.name.clone()));
347 }
348 if !schema.predicate_filter.allows(&col.name) {
349 return Err(AppError::Forbidden(format!(
350 "column '{}' may not be used on the SQL endpoint for dataset '{}'",
351 col.name, schema.name
352 )));
353 }
354 }
355 Ok(())
356}
357
358struct ColumnRefCollector {
362 columns: HashSet<String>,
363}
364
365impl Visitor for ColumnRefCollector {
366 type Break = ();
367
368 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
369 match expr {
370 Expr::Identifier(ident) => {
371 self.columns.insert(ident.value.to_lowercase());
372 }
373 Expr::CompoundIdentifier(idents) => {
374 if let Some(last) = idents.last() {
375 self.columns.insert(last.value.to_lowercase());
376 }
377 }
378 _ => {}
379 }
380 ControlFlow::Continue(())
381 }
382}
383
384fn statement_has_wildcard(stmt: &Statement) -> bool {
388 match stmt {
389 Statement::Query(q) => query_has_wildcard(q),
390 _ => false,
391 }
392}
393
394fn query_has_wildcard(query: &Query) -> bool {
395 if let Some(with) = &query.with
396 && with
397 .cte_tables
398 .iter()
399 .any(|cte| query_has_wildcard(&cte.query))
400 {
401 return true;
402 }
403 set_expr_has_wildcard(&query.body)
404}
405
406fn set_expr_has_wildcard(set: &SetExpr) -> bool {
407 match set {
408 SetExpr::Select(select) => {
409 let proj_wildcard = select.projection.iter().any(|item| {
410 matches!(
411 item,
412 SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _)
413 )
414 });
415 proj_wildcard
416 || select
417 .from
418 .iter()
419 .any(|twj| table_factor_has_wildcard(&twj.relation))
420 }
421 SetExpr::Query(q) => query_has_wildcard(q),
422 SetExpr::SetOperation { left, right, .. } => {
423 set_expr_has_wildcard(left) || set_expr_has_wildcard(right)
424 }
425 _ => false,
426 }
427}
428
429fn table_factor_has_wildcard(factor: &TableFactor) -> bool {
430 match factor {
431 TableFactor::Derived { subquery, .. } => query_has_wildcard(subquery),
432 TableFactor::NestedJoin {
433 table_with_joins, ..
434 } => table_factor_has_wildcard(&table_with_joins.relation),
435 _ => false,
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 fn allowed(names: &[&str]) -> HashSet<String> {
444 names.iter().map(|s| s.to_lowercase()).collect()
445 }
446
447 #[test]
448 fn accepts_single_dataset_select() {
449 let v = validate("SELECT a, b FROM events WHERE a > 1", &allowed(&["events"]), 1).unwrap();
450 assert_eq!(v.datasets, vec!["events".to_string()]);
451 }
452
453 #[test]
454 fn case_insensitive_table_match() {
455 let v = validate("SELECT * FROM Events", &allowed(&["events"]), 1).unwrap();
456 assert_eq!(v.datasets, vec!["events".to_string()]);
457 }
458
459 #[test]
460 fn strips_trailing_semicolon() {
461 let v = validate("SELECT 1 FROM events;", &allowed(&["events"]), 1).unwrap();
462 assert_eq!(v.sql, "SELECT 1 FROM events");
463 }
464
465 #[test]
466 fn allows_cte_over_single_dataset() {
467 let sql = "WITH t AS (SELECT * FROM events) SELECT count(*) FROM t";
468 let v = validate(sql, &allowed(&["events"]), 1).unwrap();
469 assert_eq!(v.datasets, vec!["events".to_string()]);
470 }
471
472 #[test]
473 fn allows_tableless_select() {
474 let v = validate("SELECT 1 + 1", &allowed(&["events"]), 1).unwrap();
475 assert!(v.datasets.is_empty());
476 }
477
478 #[test]
479 fn rejects_unknown_table() {
480 let err = validate("SELECT * FROM secrets", &allowed(&["events"]), 1).unwrap_err();
481 assert!(matches!(err, AppError::InvalidValue(_)));
482 }
483
484 #[test]
485 fn rejects_second_dataset_join() {
486 let err = validate(
487 "SELECT * FROM events e JOIN other o ON e.id = o.id",
488 &allowed(&["events", "other"]),
489 1,
490 )
491 .unwrap_err();
492 assert!(matches!(err, AppError::InvalidValue(_)));
493 }
494
495 #[test]
496 fn allows_two_datasets_when_limit_raised() {
497 let v = validate(
498 "SELECT * FROM events e JOIN other o ON e.id = o.id",
499 &allowed(&["events", "other"]),
500 2,
501 )
502 .unwrap();
503 assert_eq!(v.datasets.len(), 2);
504 }
505
506 #[test]
507 fn rejects_non_select() {
508 let err = validate("DELETE FROM events", &allowed(&["events"]), 1).unwrap_err();
509 assert!(matches!(err, AppError::InvalidValue(_)));
510 }
511
512 #[test]
513 fn accepts_describe_table() {
514 let v = validate("DESCRIBE events", &allowed(&["events"]), 1).unwrap();
515 assert_eq!(v.datasets, vec!["events".to_string()]);
516 assert!(is_describe(&v.sql));
517 }
518
519 #[test]
520 fn accepts_desc_table_case_insensitive() {
521 let v = validate("DESC Events", &allowed(&["events"]), 1).unwrap();
522 assert_eq!(v.datasets, vec!["events".to_string()]);
523 assert!(is_describe(&v.sql));
524 }
525
526 #[test]
527 fn describe_rejects_unknown_table() {
528 let err = validate("DESCRIBE secrets", &allowed(&["events"]), 1).unwrap_err();
529 assert!(matches!(err, AppError::InvalidValue(_)));
530 }
531
532 #[test]
533 fn is_describe_false_for_select() {
534 assert!(!is_describe("SELECT * FROM events"));
535 assert!(!is_describe("SELECT 1"));
536 }
537
538 #[test]
539 fn rejects_multiple_statements() {
540 let err = validate("SELECT 1 FROM events; SELECT 2 FROM events", &allowed(&["events"]), 1)
541 .unwrap_err();
542 assert!(matches!(err, AppError::InvalidValue(_)));
543 }
544
545 #[test]
546 fn rejects_file_table_function() {
547 let err = validate("SELECT * FROM read_parquet('/etc/passwd')", &allowed(&["events"]), 1)
548 .unwrap_err();
549 assert!(matches!(err, AppError::InvalidValue(_)));
550 }
551
552 #[test]
553 fn rejects_file_scalar_function() {
554 let err = validate(
555 "SELECT read_text('/etc/passwd') FROM events",
556 &allowed(&["events"]),
557 1,
558 )
559 .unwrap_err();
560 assert!(matches!(err, AppError::InvalidValue(_)));
561 }
562
563 #[test]
564 fn rejects_empty_sql() {
565 let err = validate(" ", &allowed(&["events"]), 1).unwrap_err();
566 assert!(matches!(err, AppError::InvalidValue(_)));
567 }
568
569 fn maps(
570 tables: &[(&str, &str)],
571 columns: &[(&str, &str)],
572 ) -> (HashMap<String, String>, HashMap<String, String>) {
573 let t = tables
574 .iter()
575 .map(|(k, v)| (k.to_string(), v.to_string()))
576 .collect();
577 let c = columns
578 .iter()
579 .map(|(k, v)| (k.to_string(), v.to_string()))
580 .collect();
581 (t, c)
582 }
583
584 #[test]
585 fn canonicalizes_mixed_case_column_and_table() {
586 let (t, c) = maps(
587 &[("accidents", "accidents")],
588 &[("state", "State"), ("id", "ID")],
589 );
590 let out = canonicalize_identifiers(
591 "SELECT state, COUNT(*) AS n FROM Accidents GROUP BY STATE ORDER BY n DESC",
592 &t,
593 &c,
594 );
595 assert!(out.contains("\"State\""), "got: {out}");
598 assert!(out.contains("FROM \"accidents\""), "got: {out}");
599 assert!(out.contains("AS n"), "got: {out}");
600 assert!(!out.contains("\"n\""), "alias must not be quoted: {out}");
601 }
602
603 #[test]
604 fn canonicalizes_qualified_column() {
605 let (t, c) = maps(&[("accidents", "accidents")], &[("state", "State")]);
606 let out = canonicalize_identifiers("SELECT a.state FROM accidents AS a", &t, &c);
607 assert!(out.contains("a.\"State\""), "got: {out}");
610 }
611
612 #[test]
613 fn leaves_unknown_identifiers_untouched() {
614 let (t, c) = maps(&[("events", "events")], &[("id", "id")]);
615 let out = canonicalize_identifiers("SELECT foo, bar FROM events", &t, &c);
616 assert!(out.contains("foo"), "got: {out}");
617 assert!(out.contains("bar"), "got: {out}");
618 assert!(!out.contains("\"foo\""), "got: {out}");
619 }
620
621 #[test]
622 fn returns_input_unchanged_on_parse_error() {
623 let (t, c) = maps(&[], &[]);
624 let garbage = "SELECT FROM WHERE";
625 assert_eq!(canonicalize_identifiers(garbage, &t, &c), garbage);
626 }
627
628 fn filtered_schema(pred_excl: &[&str], proj_excl: &[&str]) -> DatasetSchema {
631 use crate::schema::{ColumnInfo, LogicalType};
632 let col = |name: &str| ColumnInfo {
633 name: name.into(),
634 logical: LogicalType::Int,
635 sql_type: "BIGINT".into(),
636 nullable: true,
637 };
638 let excl = |cols: &[&str]| crate::config::ColumnFilter {
639 include: vec![],
640 exclude: cols.iter().map(|s| s.to_string()).collect(),
641 };
642 DatasetSchema::new("events", vec![col("id"), col("email"), col("ts")])
643 .with_filters(excl(pred_excl), excl(proj_excl))
644 .unwrap()
645 }
646
647 #[test]
648 fn access_noop_without_filters() {
649 use crate::schema::{ColumnInfo, LogicalType};
650 let sch = DatasetSchema::new(
651 "events",
652 vec![ColumnInfo {
653 name: "id".into(),
654 logical: LogicalType::Int,
655 sql_type: "BIGINT".into(),
656 nullable: false,
657 }],
658 );
659 assert!(enforce_column_access("SELECT * FROM events", &sch).is_ok());
660 }
661
662 #[test]
663 fn access_rejects_wildcard_when_projection_hides() {
664 let sch = filtered_schema(&[], &["email"]);
665 let err = enforce_column_access("SELECT * FROM events", &sch).unwrap_err();
666 assert!(matches!(err, AppError::Forbidden(_)));
667 }
668
669 #[test]
670 fn access_allows_wildcard_when_only_predicate_filter() {
671 let sch = filtered_schema(&["email"], &[]);
672 assert!(enforce_column_access("SELECT * FROM events", &sch).is_ok());
673 }
674
675 #[test]
676 fn access_rejects_hidden_column_reference() {
677 let sch = filtered_schema(&[], &["email"]);
678 let err = enforce_column_access("SELECT id, email FROM events", &sch).unwrap_err();
679 assert!(matches!(err, AppError::UnknownColumn(_)));
680 }
681
682 #[test]
683 fn access_rejects_predicate_restricted_column_reference() {
684 let sch = filtered_schema(&["email"], &[]);
685 let err =
686 enforce_column_access("SELECT id FROM events WHERE email = 'x'", &sch).unwrap_err();
687 assert!(matches!(err, AppError::Forbidden(_)));
688 }
689
690 #[test]
691 fn access_allows_visible_columns() {
692 let sch = filtered_schema(&[], &["email"]);
693 assert!(enforce_column_access("SELECT id, ts FROM events WHERE id > 1", &sch).is_ok());
694 }
695
696 #[test]
697 fn access_ignores_aliases_and_functions() {
698 let sch = filtered_schema(&[], &["email"]);
699 assert!(
701 enforce_column_access("SELECT count(id) AS total FROM events", &sch).is_ok()
702 );
703 }
704
705 #[test]
706 fn access_matches_qualified_column() {
707 let sch = filtered_schema(&[], &["email"]);
708 let err =
709 enforce_column_access("SELECT e.email FROM events e", &sch).unwrap_err();
710 assert!(matches!(err, AppError::UnknownColumn(_)));
711 }
712}