1use std::collections::{HashMap, HashSet};
24use std::ops::ControlFlow;
25
26use sqlparser::ast::{
27 DescribeAlias, Expr, Ident, ObjectName, ObjectNamePart, Query, Statement, Visit, VisitMut,
28 Visitor, VisitorMut,
29};
30use sqlparser::dialect::GenericDialect;
31use sqlparser::parser::Parser;
32
33use crate::errors::AppError;
34
35const DENIED_FUNCTIONS: &[&str] = &[
40 "read_text",
41 "read_blob",
42 "read_csv",
43 "read_csv_auto",
44 "read_parquet",
45 "parquet_scan",
46 "read_json",
47 "read_json_auto",
48 "read_json_objects",
49 "read_ndjson",
50 "read_ndjson_auto",
51 "read_ndjson_objects",
52 "sniff_csv",
53 "glob",
54];
55
56#[derive(Debug)]
58pub struct ValidatedSql {
59 pub sql: String,
61 pub datasets: Vec<String>,
64}
65
66pub fn validate(
77 sql: &str,
78 allowed: &HashSet<String>,
79 max_datasets: usize,
80) -> Result<ValidatedSql, AppError> {
81 let trimmed = sql.trim().trim_end_matches(';').trim();
82 if trimmed.is_empty() {
83 return Err(AppError::InvalidValue("sql must not be empty".into()));
84 }
85
86 let statements = Parser::parse_sql(&GenericDialect {}, trimmed)
87 .map_err(|e| AppError::InvalidValue(format!("could not parse SQL: {e}")))?;
88 if statements.len() != 1 {
89 return Err(AppError::InvalidValue(
90 "exactly one SQL statement is allowed".into(),
91 ));
92 }
93 let stmt = &statements[0];
94 match stmt {
99 Statement::Query(_) => {}
100 Statement::ExplainTable {
101 describe_alias: DescribeAlias::Describe | DescribeAlias::Desc,
102 ..
103 } => {}
104 _ => {
105 return Err(AppError::InvalidValue(
106 "only read-only SELECT and DESCRIBE statements are allowed".into(),
107 ));
108 }
109 }
110
111 let mut checker = ScopeCheck {
112 allowed,
113 cte_names: HashSet::new(),
114 referenced: HashSet::new(),
115 violation: None,
116 };
117 let _ = stmt.visit(&mut checker);
118 if let Some(err) = checker.violation {
119 return Err(AppError::InvalidValue(err));
120 }
121
122 let mut datasets: Vec<String> = checker.referenced.into_iter().collect();
123 datasets.sort();
124 if datasets.len() > max_datasets {
125 return Err(AppError::InvalidValue(format!(
126 "this endpoint allows at most {max_datasets} dataset(s) per query; \
127 the statement references {}",
128 datasets.len()
129 )));
130 }
131
132 Ok(ValidatedSql {
133 sql: trimmed.to_string(),
134 datasets,
135 })
136}
137
138pub fn is_describe(sql: &str) -> bool {
146 let trimmed = sql.trim().trim_end_matches(';').trim();
147 matches!(
148 Parser::parse_sql(&GenericDialect {}, trimmed).as_deref(),
149 Ok([Statement::ExplainTable {
150 describe_alias: DescribeAlias::Describe | DescribeAlias::Desc,
151 ..
152 }])
153 )
154}
155
156pub fn canonicalize_identifiers(
174 sql: &str,
175 tables: &HashMap<String, String>,
176 columns: &HashMap<String, String>,
177) -> String {
178 let mut statements = match Parser::parse_sql(&GenericDialect {}, sql) {
179 Ok(s) if s.len() == 1 => s,
180 _ => return sql.to_string(),
181 };
182 let mut canon = Canonicalizer { tables, columns };
183 let _ = VisitMut::visit(&mut statements[0], &mut canon);
184 statements[0].to_string()
185}
186
187struct Canonicalizer<'a> {
188 tables: &'a HashMap<String, String>,
189 columns: &'a HashMap<String, String>,
190}
191
192impl Canonicalizer<'_> {
193 fn rewrite(ident: &mut Ident, map: &HashMap<String, String>) {
197 if let Some(canonical) = map.get(&ident.value.to_lowercase()) {
198 ident.value = canonical.clone();
199 ident.quote_style = Some('"');
200 }
201 }
202}
203
204impl VisitorMut for Canonicalizer<'_> {
205 type Break = ();
206
207 fn pre_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
208 for part in relation.0.iter_mut() {
209 if let ObjectNamePart::Identifier(ident) = part {
210 Self::rewrite(ident, self.tables);
211 }
212 }
213 ControlFlow::Continue(())
214 }
215
216 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
217 match expr {
218 Expr::Identifier(ident) => Self::rewrite(ident, self.columns),
220 Expr::CompoundIdentifier(idents) => {
224 if let Some((column, qualifiers)) = idents.split_last_mut() {
225 Self::rewrite(column, self.columns);
226 for qualifier in qualifiers {
227 Self::rewrite(qualifier, self.tables);
228 }
229 }
230 }
231 _ => {}
232 }
233 ControlFlow::Continue(())
234 }
235}
236
237struct ScopeCheck<'a> {
238 allowed: &'a HashSet<String>,
239 cte_names: HashSet<String>,
240 referenced: HashSet<String>,
241 violation: Option<String>,
242}
243
244impl Visitor for ScopeCheck<'_> {
245 type Break = ();
246
247 fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
248 if let Some(with) = &query.with {
253 for cte in &with.cte_tables {
254 self.cte_names.insert(cte.alias.name.value.to_lowercase());
255 }
256 }
257 ControlFlow::Continue(())
258 }
259
260 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
261 let ident = relation
262 .0
263 .last()
264 .and_then(|p| p.as_ident())
265 .map(|i| i.value.to_lowercase())
266 .unwrap_or_default();
267
268 if self.cte_names.contains(&ident) {
269 return ControlFlow::Continue(());
270 }
271 if let Some(name) = self.allowed.get(&ident) {
272 self.referenced.insert(name.clone());
273 return ControlFlow::Continue(());
274 }
275 self.violation = Some(format!(
276 "table '{ident}' is not a registered dataset accessible from the SQL endpoint"
277 ));
278 ControlFlow::Break(())
279 }
280
281 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
282 if let Expr::Function(func) = expr {
283 let fname = func
284 .name
285 .0
286 .last()
287 .and_then(|p| p.as_ident())
288 .map(|i| i.value.to_lowercase())
289 .unwrap_or_default();
290 if DENIED_FUNCTIONS.contains(&fname.as_str()) {
291 self.violation =
292 Some(format!("function '{fname}' is not allowed in the SQL endpoint"));
293 return ControlFlow::Break(());
294 }
295 }
296 ControlFlow::Continue(())
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 fn allowed(names: &[&str]) -> HashSet<String> {
305 names.iter().map(|s| s.to_lowercase()).collect()
306 }
307
308 #[test]
309 fn accepts_single_dataset_select() {
310 let v = validate("SELECT a, b FROM events WHERE a > 1", &allowed(&["events"]), 1).unwrap();
311 assert_eq!(v.datasets, vec!["events".to_string()]);
312 }
313
314 #[test]
315 fn case_insensitive_table_match() {
316 let v = validate("SELECT * FROM Events", &allowed(&["events"]), 1).unwrap();
317 assert_eq!(v.datasets, vec!["events".to_string()]);
318 }
319
320 #[test]
321 fn strips_trailing_semicolon() {
322 let v = validate("SELECT 1 FROM events;", &allowed(&["events"]), 1).unwrap();
323 assert_eq!(v.sql, "SELECT 1 FROM events");
324 }
325
326 #[test]
327 fn allows_cte_over_single_dataset() {
328 let sql = "WITH t AS (SELECT * FROM events) SELECT count(*) FROM t";
329 let v = validate(sql, &allowed(&["events"]), 1).unwrap();
330 assert_eq!(v.datasets, vec!["events".to_string()]);
331 }
332
333 #[test]
334 fn allows_tableless_select() {
335 let v = validate("SELECT 1 + 1", &allowed(&["events"]), 1).unwrap();
336 assert!(v.datasets.is_empty());
337 }
338
339 #[test]
340 fn rejects_unknown_table() {
341 let err = validate("SELECT * FROM secrets", &allowed(&["events"]), 1).unwrap_err();
342 assert!(matches!(err, AppError::InvalidValue(_)));
343 }
344
345 #[test]
346 fn rejects_second_dataset_join() {
347 let err = validate(
348 "SELECT * FROM events e JOIN other o ON e.id = o.id",
349 &allowed(&["events", "other"]),
350 1,
351 )
352 .unwrap_err();
353 assert!(matches!(err, AppError::InvalidValue(_)));
354 }
355
356 #[test]
357 fn allows_two_datasets_when_limit_raised() {
358 let v = validate(
359 "SELECT * FROM events e JOIN other o ON e.id = o.id",
360 &allowed(&["events", "other"]),
361 2,
362 )
363 .unwrap();
364 assert_eq!(v.datasets.len(), 2);
365 }
366
367 #[test]
368 fn rejects_non_select() {
369 let err = validate("DELETE FROM events", &allowed(&["events"]), 1).unwrap_err();
370 assert!(matches!(err, AppError::InvalidValue(_)));
371 }
372
373 #[test]
374 fn accepts_describe_table() {
375 let v = validate("DESCRIBE events", &allowed(&["events"]), 1).unwrap();
376 assert_eq!(v.datasets, vec!["events".to_string()]);
377 assert!(is_describe(&v.sql));
378 }
379
380 #[test]
381 fn accepts_desc_table_case_insensitive() {
382 let v = validate("DESC Events", &allowed(&["events"]), 1).unwrap();
383 assert_eq!(v.datasets, vec!["events".to_string()]);
384 assert!(is_describe(&v.sql));
385 }
386
387 #[test]
388 fn describe_rejects_unknown_table() {
389 let err = validate("DESCRIBE secrets", &allowed(&["events"]), 1).unwrap_err();
390 assert!(matches!(err, AppError::InvalidValue(_)));
391 }
392
393 #[test]
394 fn is_describe_false_for_select() {
395 assert!(!is_describe("SELECT * FROM events"));
396 assert!(!is_describe("SELECT 1"));
397 }
398
399 #[test]
400 fn rejects_multiple_statements() {
401 let err = validate("SELECT 1 FROM events; SELECT 2 FROM events", &allowed(&["events"]), 1)
402 .unwrap_err();
403 assert!(matches!(err, AppError::InvalidValue(_)));
404 }
405
406 #[test]
407 fn rejects_file_table_function() {
408 let err = validate("SELECT * FROM read_parquet('/etc/passwd')", &allowed(&["events"]), 1)
409 .unwrap_err();
410 assert!(matches!(err, AppError::InvalidValue(_)));
411 }
412
413 #[test]
414 fn rejects_file_scalar_function() {
415 let err = validate(
416 "SELECT read_text('/etc/passwd') FROM events",
417 &allowed(&["events"]),
418 1,
419 )
420 .unwrap_err();
421 assert!(matches!(err, AppError::InvalidValue(_)));
422 }
423
424 #[test]
425 fn rejects_empty_sql() {
426 let err = validate(" ", &allowed(&["events"]), 1).unwrap_err();
427 assert!(matches!(err, AppError::InvalidValue(_)));
428 }
429
430 fn maps(
431 tables: &[(&str, &str)],
432 columns: &[(&str, &str)],
433 ) -> (HashMap<String, String>, HashMap<String, String>) {
434 let t = tables
435 .iter()
436 .map(|(k, v)| (k.to_string(), v.to_string()))
437 .collect();
438 let c = columns
439 .iter()
440 .map(|(k, v)| (k.to_string(), v.to_string()))
441 .collect();
442 (t, c)
443 }
444
445 #[test]
446 fn canonicalizes_mixed_case_column_and_table() {
447 let (t, c) = maps(
448 &[("accidents", "accidents")],
449 &[("state", "State"), ("id", "ID")],
450 );
451 let out = canonicalize_identifiers(
452 "SELECT state, COUNT(*) AS n FROM Accidents GROUP BY STATE ORDER BY n DESC",
453 &t,
454 &c,
455 );
456 assert!(out.contains("\"State\""), "got: {out}");
459 assert!(out.contains("FROM \"accidents\""), "got: {out}");
460 assert!(out.contains("AS n"), "got: {out}");
461 assert!(!out.contains("\"n\""), "alias must not be quoted: {out}");
462 }
463
464 #[test]
465 fn canonicalizes_qualified_column() {
466 let (t, c) = maps(&[("accidents", "accidents")], &[("state", "State")]);
467 let out = canonicalize_identifiers("SELECT a.state FROM accidents AS a", &t, &c);
468 assert!(out.contains("a.\"State\""), "got: {out}");
471 }
472
473 #[test]
474 fn leaves_unknown_identifiers_untouched() {
475 let (t, c) = maps(&[("events", "events")], &[("id", "id")]);
476 let out = canonicalize_identifiers("SELECT foo, bar FROM events", &t, &c);
477 assert!(out.contains("foo"), "got: {out}");
478 assert!(out.contains("bar"), "got: {out}");
479 assert!(!out.contains("\"foo\""), "got: {out}");
480 }
481
482 #[test]
483 fn returns_input_unchanged_on_parse_error() {
484 let (t, c) = maps(&[], &[]);
485 let garbage = "SELECT FROM WHERE";
486 assert_eq!(canonicalize_identifiers(garbage, &t, &c), garbage);
487 }
488}