1use std::collections::{HashMap, HashSet};
23use std::ops::ControlFlow;
24
25use sqlparser::ast::{
26 Expr, Ident, ObjectName, ObjectNamePart, Query, Statement, Visit, VisitMut, Visitor,
27 VisitorMut,
28};
29use sqlparser::dialect::GenericDialect;
30use sqlparser::parser::Parser;
31
32use crate::errors::AppError;
33
34const DENIED_FUNCTIONS: &[&str] = &[
39 "read_text",
40 "read_blob",
41 "read_csv",
42 "read_csv_auto",
43 "read_parquet",
44 "parquet_scan",
45 "read_json",
46 "read_json_auto",
47 "read_json_objects",
48 "read_ndjson",
49 "read_ndjson_auto",
50 "read_ndjson_objects",
51 "sniff_csv",
52 "glob",
53];
54
55#[derive(Debug)]
57pub struct ValidatedSql {
58 pub sql: String,
60 pub datasets: Vec<String>,
63}
64
65pub fn validate(
74 sql: &str,
75 allowed: &HashSet<String>,
76 max_datasets: usize,
77) -> Result<ValidatedSql, AppError> {
78 let trimmed = sql.trim().trim_end_matches(';').trim();
79 if trimmed.is_empty() {
80 return Err(AppError::InvalidValue("sql must not be empty".into()));
81 }
82
83 let statements = Parser::parse_sql(&GenericDialect {}, trimmed)
84 .map_err(|e| AppError::InvalidValue(format!("could not parse SQL: {e}")))?;
85 if statements.len() != 1 {
86 return Err(AppError::InvalidValue(
87 "exactly one SQL statement is allowed".into(),
88 ));
89 }
90 let stmt = &statements[0];
91 if !matches!(stmt, Statement::Query(_)) {
92 return Err(AppError::InvalidValue(
93 "only read-only SELECT queries are allowed".into(),
94 ));
95 }
96
97 let mut checker = ScopeCheck {
98 allowed,
99 cte_names: HashSet::new(),
100 referenced: HashSet::new(),
101 violation: None,
102 };
103 let _ = stmt.visit(&mut checker);
104 if let Some(err) = checker.violation {
105 return Err(AppError::InvalidValue(err));
106 }
107
108 let mut datasets: Vec<String> = checker.referenced.into_iter().collect();
109 datasets.sort();
110 if datasets.len() > max_datasets {
111 return Err(AppError::InvalidValue(format!(
112 "this endpoint allows at most {max_datasets} dataset(s) per query; \
113 the statement references {}",
114 datasets.len()
115 )));
116 }
117
118 Ok(ValidatedSql {
119 sql: trimmed.to_string(),
120 datasets,
121 })
122}
123
124pub fn canonicalize_identifiers(
142 sql: &str,
143 tables: &HashMap<String, String>,
144 columns: &HashMap<String, String>,
145) -> String {
146 let mut statements = match Parser::parse_sql(&GenericDialect {}, sql) {
147 Ok(s) if s.len() == 1 => s,
148 _ => return sql.to_string(),
149 };
150 let mut canon = Canonicalizer { tables, columns };
151 let _ = VisitMut::visit(&mut statements[0], &mut canon);
152 statements[0].to_string()
153}
154
155struct Canonicalizer<'a> {
156 tables: &'a HashMap<String, String>,
157 columns: &'a HashMap<String, String>,
158}
159
160impl Canonicalizer<'_> {
161 fn rewrite(ident: &mut Ident, map: &HashMap<String, String>) {
165 if let Some(canonical) = map.get(&ident.value.to_lowercase()) {
166 ident.value = canonical.clone();
167 ident.quote_style = Some('"');
168 }
169 }
170}
171
172impl VisitorMut for Canonicalizer<'_> {
173 type Break = ();
174
175 fn pre_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
176 for part in relation.0.iter_mut() {
177 if let ObjectNamePart::Identifier(ident) = part {
178 Self::rewrite(ident, self.tables);
179 }
180 }
181 ControlFlow::Continue(())
182 }
183
184 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
185 match expr {
186 Expr::Identifier(ident) => Self::rewrite(ident, self.columns),
188 Expr::CompoundIdentifier(idents) => {
192 if let Some((column, qualifiers)) = idents.split_last_mut() {
193 Self::rewrite(column, self.columns);
194 for qualifier in qualifiers {
195 Self::rewrite(qualifier, self.tables);
196 }
197 }
198 }
199 _ => {}
200 }
201 ControlFlow::Continue(())
202 }
203}
204
205struct ScopeCheck<'a> {
206 allowed: &'a HashSet<String>,
207 cte_names: HashSet<String>,
208 referenced: HashSet<String>,
209 violation: Option<String>,
210}
211
212impl Visitor for ScopeCheck<'_> {
213 type Break = ();
214
215 fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
216 if let Some(with) = &query.with {
221 for cte in &with.cte_tables {
222 self.cte_names.insert(cte.alias.name.value.to_lowercase());
223 }
224 }
225 ControlFlow::Continue(())
226 }
227
228 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
229 let ident = relation
230 .0
231 .last()
232 .and_then(|p| p.as_ident())
233 .map(|i| i.value.to_lowercase())
234 .unwrap_or_default();
235
236 if self.cte_names.contains(&ident) {
237 return ControlFlow::Continue(());
238 }
239 if let Some(name) = self.allowed.get(&ident) {
240 self.referenced.insert(name.clone());
241 return ControlFlow::Continue(());
242 }
243 self.violation = Some(format!(
244 "table '{ident}' is not a registered dataset accessible from the SQL endpoint"
245 ));
246 ControlFlow::Break(())
247 }
248
249 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
250 if let Expr::Function(func) = expr {
251 let fname = func
252 .name
253 .0
254 .last()
255 .and_then(|p| p.as_ident())
256 .map(|i| i.value.to_lowercase())
257 .unwrap_or_default();
258 if DENIED_FUNCTIONS.contains(&fname.as_str()) {
259 self.violation =
260 Some(format!("function '{fname}' is not allowed in the SQL endpoint"));
261 return ControlFlow::Break(());
262 }
263 }
264 ControlFlow::Continue(())
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 fn allowed(names: &[&str]) -> HashSet<String> {
273 names.iter().map(|s| s.to_lowercase()).collect()
274 }
275
276 #[test]
277 fn accepts_single_dataset_select() {
278 let v = validate("SELECT a, b FROM events WHERE a > 1", &allowed(&["events"]), 1).unwrap();
279 assert_eq!(v.datasets, vec!["events".to_string()]);
280 }
281
282 #[test]
283 fn case_insensitive_table_match() {
284 let v = validate("SELECT * FROM Events", &allowed(&["events"]), 1).unwrap();
285 assert_eq!(v.datasets, vec!["events".to_string()]);
286 }
287
288 #[test]
289 fn strips_trailing_semicolon() {
290 let v = validate("SELECT 1 FROM events;", &allowed(&["events"]), 1).unwrap();
291 assert_eq!(v.sql, "SELECT 1 FROM events");
292 }
293
294 #[test]
295 fn allows_cte_over_single_dataset() {
296 let sql = "WITH t AS (SELECT * FROM events) SELECT count(*) FROM t";
297 let v = validate(sql, &allowed(&["events"]), 1).unwrap();
298 assert_eq!(v.datasets, vec!["events".to_string()]);
299 }
300
301 #[test]
302 fn allows_tableless_select() {
303 let v = validate("SELECT 1 + 1", &allowed(&["events"]), 1).unwrap();
304 assert!(v.datasets.is_empty());
305 }
306
307 #[test]
308 fn rejects_unknown_table() {
309 let err = validate("SELECT * FROM secrets", &allowed(&["events"]), 1).unwrap_err();
310 assert!(matches!(err, AppError::InvalidValue(_)));
311 }
312
313 #[test]
314 fn rejects_second_dataset_join() {
315 let err = validate(
316 "SELECT * FROM events e JOIN other o ON e.id = o.id",
317 &allowed(&["events", "other"]),
318 1,
319 )
320 .unwrap_err();
321 assert!(matches!(err, AppError::InvalidValue(_)));
322 }
323
324 #[test]
325 fn allows_two_datasets_when_limit_raised() {
326 let v = validate(
327 "SELECT * FROM events e JOIN other o ON e.id = o.id",
328 &allowed(&["events", "other"]),
329 2,
330 )
331 .unwrap();
332 assert_eq!(v.datasets.len(), 2);
333 }
334
335 #[test]
336 fn rejects_non_select() {
337 let err = validate("DELETE FROM events", &allowed(&["events"]), 1).unwrap_err();
338 assert!(matches!(err, AppError::InvalidValue(_)));
339 }
340
341 #[test]
342 fn rejects_multiple_statements() {
343 let err = validate("SELECT 1 FROM events; SELECT 2 FROM events", &allowed(&["events"]), 1)
344 .unwrap_err();
345 assert!(matches!(err, AppError::InvalidValue(_)));
346 }
347
348 #[test]
349 fn rejects_file_table_function() {
350 let err = validate("SELECT * FROM read_parquet('/etc/passwd')", &allowed(&["events"]), 1)
351 .unwrap_err();
352 assert!(matches!(err, AppError::InvalidValue(_)));
353 }
354
355 #[test]
356 fn rejects_file_scalar_function() {
357 let err = validate(
358 "SELECT read_text('/etc/passwd') FROM events",
359 &allowed(&["events"]),
360 1,
361 )
362 .unwrap_err();
363 assert!(matches!(err, AppError::InvalidValue(_)));
364 }
365
366 #[test]
367 fn rejects_empty_sql() {
368 let err = validate(" ", &allowed(&["events"]), 1).unwrap_err();
369 assert!(matches!(err, AppError::InvalidValue(_)));
370 }
371
372 fn maps(
373 tables: &[(&str, &str)],
374 columns: &[(&str, &str)],
375 ) -> (HashMap<String, String>, HashMap<String, String>) {
376 let t = tables
377 .iter()
378 .map(|(k, v)| (k.to_string(), v.to_string()))
379 .collect();
380 let c = columns
381 .iter()
382 .map(|(k, v)| (k.to_string(), v.to_string()))
383 .collect();
384 (t, c)
385 }
386
387 #[test]
388 fn canonicalizes_mixed_case_column_and_table() {
389 let (t, c) = maps(
390 &[("accidents", "accidents")],
391 &[("state", "State"), ("id", "ID")],
392 );
393 let out = canonicalize_identifiers(
394 "SELECT state, COUNT(*) AS n FROM Accidents GROUP BY STATE ORDER BY n DESC",
395 &t,
396 &c,
397 );
398 assert!(out.contains("\"State\""), "got: {out}");
401 assert!(out.contains("FROM \"accidents\""), "got: {out}");
402 assert!(out.contains("AS n"), "got: {out}");
403 assert!(!out.contains("\"n\""), "alias must not be quoted: {out}");
404 }
405
406 #[test]
407 fn canonicalizes_qualified_column() {
408 let (t, c) = maps(&[("accidents", "accidents")], &[("state", "State")]);
409 let out = canonicalize_identifiers("SELECT a.state FROM accidents AS a", &t, &c);
410 assert!(out.contains("a.\"State\""), "got: {out}");
413 }
414
415 #[test]
416 fn leaves_unknown_identifiers_untouched() {
417 let (t, c) = maps(&[("events", "events")], &[("id", "id")]);
418 let out = canonicalize_identifiers("SELECT foo, bar FROM events", &t, &c);
419 assert!(out.contains("foo"), "got: {out}");
420 assert!(out.contains("bar"), "got: {out}");
421 assert!(!out.contains("\"foo\""), "got: {out}");
422 }
423
424 #[test]
425 fn returns_input_unchanged_on_parse_error() {
426 let (t, c) = maps(&[], &[]);
427 let garbage = "SELECT FROM WHERE";
428 assert_eq!(canonicalize_identifiers(garbage, &t, &c), garbage);
429 }
430}