1use std::sync::Arc;
2
3use datafusion::sql::sqlparser::ast::Statement;
4use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
5use datafusion::sql::sqlparser::parser::Parser;
6use datafusion::sql::sqlparser::parser::ParserError;
7use datafusion::sql::sqlparser::tokenizer::Token;
8use datafusion::sql::sqlparser::tokenizer::TokenWithSpan;
9
10use super::rules::AliasDuplicatedProjectionRewrite;
11use super::rules::CurrentUserVariableToSessionUserFunctionCall;
12use super::rules::FixArrayLiteral;
13use super::rules::FixCollate;
14use super::rules::FixVersionColumnName;
15use super::rules::PrependUnqualifiedPgTableName;
16use super::rules::RemoveQualifier;
17use super::rules::RemoveSubqueryFromProjection;
18use super::rules::RemoveUnsupportedTypes;
19use super::rules::ResolveUnqualifiedIdentifer;
20use super::rules::RewriteArrayAnyAllOperation;
21use super::rules::SqlStatementRewriteRule;
22
23const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[
24 (
26"SELECT s_p.nspname AS parentschema,
27 t_p.relname AS parenttable,
28 unnest((
29 select
30 array_agg(attname ORDER BY i)
31 from
32 (select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x
33 JOIN pg_catalog.pg_attribute c USING(attnum)
34 WHERE c.attrelid = fk.confrelid
35 )) AS parentcolumn,
36 s_c.nspname AS childschema,
37 t_c.relname AS childtable,
38 unnest((
39 select
40 array_agg(attname ORDER BY i)
41 from
42 (select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x
43 JOIN pg_catalog.pg_attribute c USING(attnum)
44 WHERE c.attrelid = fk.conrelid
45 )) AS childcolumn
46 FROM pg_catalog.pg_constraint fk
47 JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid
48 JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace
49 JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid
50 JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace
51 WHERE fk.contype = 'f'",
52"SELECT
53 NULL::TEXT AS parentschema,
54 NULL::TEXT AS parenttable,
55 NULL::TEXT AS parentcolumn,
56 NULL::TEXT AS childschema,
57 NULL::TEXT AS childtable,
58 NULL::TEXT AS childcolumn
59 WHERE false"),
60
61 (
63"SELECT n.nspname schema_name,
64 t.typname type_name
65 FROM pg_catalog.pg_type t
66 INNER JOIN pg_catalog.pg_namespace n
67 ON n.oid = t.typnamespace
68 WHERE ( t.typrelid = 0 -- non-composite types
69 OR ( -- composite type, but not a table
70 SELECT c.relkind = 'c'
71 FROM pg_catalog.pg_class c
72 WHERE c.oid = t.typrelid
73 )
74 )
75 AND NOT EXISTS( -- ignore array types
76 SELECT 1
77 FROM pg_catalog.pg_type el
78 WHERE el.oid = t.typelem AND el.typarray = t.oid
79 )
80 AND n.nspname <> 'pg_catalog'
81 AND n.nspname <> 'information_schema'
82 ORDER BY 1, 2;",
83"SELECT NULL::TEXT AS schema_name, NULL::TEXT AS type_name WHERE false"
84 ),
85
86(
88"SELECT pol.polname, pol.polpermissive,
89 CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END,
90 pg_catalog.pg_get_expr(pol.polqual, pol.polrelid),
91 pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid),
92 CASE pol.polcmd
93 WHEN 'r' THEN 'SELECT'
94 WHEN 'a' THEN 'INSERT'
95 WHEN 'w' THEN 'UPDATE'
96 WHEN 'd' THEN 'DELETE'
97 END AS cmd
98 FROM pg_catalog.pg_policy pol
99 WHERE pol.polrelid = $1 ORDER BY 1;",
100"SELECT
101 NULL::TEXT AS polname,
102 NULL::TEXT AS polpermissive,
103 NULL::TEXT AS array_to_string,
104 NULL::TEXT AS pg_get_expr_1,
105 NULL::TEXT AS pg_get_expr_2,
106 NULL::TEXT AS cmd
107 WHERE false"
108 ),
109
110 (
111"SELECT oid, stxrelid::pg_catalog.regclass, stxnamespace::pg_catalog.regnamespace::pg_catalog.text AS nsp, stxname,
112 pg_catalog.pg_get_statisticsobjdef_columns(oid) AS columns,
113 'd' = any(stxkind) AS ndist_enabled,
114 'f' = any(stxkind) AS deps_enabled,
115 'm' = any(stxkind) AS mcv_enabled,
116 stxstattarget
117 FROM pg_catalog.pg_statistic_ext
118 WHERE stxrelid = $1
119 ORDER BY nsp, stxname;",
120"SELECT
121 NULL::INT AS oid,
122 NULL::TEXT AS stxrelid,
123 NULL::TEXT AS nsp,
124 NULL::TEXT AS stxname,
125 NULL::TEXT AS columns,
126 NULL::BOOLEAN AS ndist_enabled,
127 NULL::BOOLEAN AS deps_enabled,
128 NULL::BOOLEAN AS mcv_enabled,
129 NULL::TEXT AS stxstattarget
130 WHERE false"
131 ),
132
133 (
134"SELECT pubname
135 , NULL
136 , NULL
137 FROM pg_catalog.pg_publication p
138 JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid
139 JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid
140 WHERE pc.oid = $1 and pg_catalog.pg_relation_is_publishable($1)
141 UNION
142 SELECT pubname
143 , pg_get_expr(pr.prqual, c.oid)
144 , (CASE WHEN pr.prattrs IS NOT NULL THEN
145 (SELECT string_agg(attname, ', ')
146 FROM pg_catalog.generate_series(0, pg_catalog.array_upper(pr.prattrs::pg_catalog.int2[], 1)) s,
147 pg_catalog.pg_attribute
148 WHERE attrelid = pr.prrelid AND attnum = prattrs[s])
149 ELSE NULL END) FROM pg_catalog.pg_publication p
150 JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid
151 JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid
152 WHERE pr.prrelid = $1
153 UNION
154 SELECT pubname
155 , NULL
156 , NULL
157 FROM pg_catalog.pg_publication p
158 WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable($1)
159 ORDER BY 1;",
160"SELECT
161 NULL::TEXT AS pubname,
162 NULL::TEXT AS _1,
163 NULL::TEXT AS _2
164 WHERE false"
165 ),
166
167 (r#"SELECT
169 CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END
170 FROM
171 generate_series(
172 array_lower(string_to_array(current_setting('search_path'),','),1),
173 array_upper(string_to_array(current_setting('search_path'),','),1)
174 ) as i,
175 string_to_array(current_setting('search_path'),',') s"#,
176"''")
177];
178
179#[derive(Debug)]
185pub struct PostgresCompatibilityParser {
186 blacklist: Vec<(Vec<Token>, Vec<Token>)>,
187 rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
188}
189
190impl Default for PostgresCompatibilityParser {
191 fn default() -> Self {
192 Self::new()
193 }
194}
195
196impl PostgresCompatibilityParser {
197 pub fn new() -> Self {
198 let mut mapping = Vec::with_capacity(BLACKLIST_SQL_MAPPING.len());
199
200 for (sql_from, sql_to) in BLACKLIST_SQL_MAPPING {
201 mapping.push((
202 Parser::new(&PostgreSqlDialect {})
203 .try_with_sql(sql_from)
204 .unwrap()
205 .into_tokens()
206 .into_iter()
207 .map(|t| t.token)
208 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
209 .collect(),
210 Parser::new(&PostgreSqlDialect {})
211 .try_with_sql(sql_to)
212 .unwrap()
213 .into_tokens()
214 .into_iter()
215 .map(|t| t.token)
216 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
217 .collect(),
218 ));
219 }
220
221 Self {
222 blacklist: mapping,
223 rewrite_rules: vec![
224 Arc::new(AliasDuplicatedProjectionRewrite),
227 Arc::new(ResolveUnqualifiedIdentifer),
228 Arc::new(RewriteArrayAnyAllOperation),
229 Arc::new(PrependUnqualifiedPgTableName),
230 Arc::new(RemoveQualifier),
231 Arc::new(RemoveUnsupportedTypes::new()),
232 Arc::new(FixArrayLiteral),
233 Arc::new(CurrentUserVariableToSessionUserFunctionCall),
234 Arc::new(FixCollate),
235 Arc::new(RemoveSubqueryFromProjection),
236 Arc::new(FixVersionColumnName),
237 ],
238 }
239 }
240
241 fn maybe_replace_tokens(&self, input: &str) -> Result<Vec<Token>, ParserError> {
243 let parser = Parser::new(&PostgreSqlDialect {});
244 let tokens = parser.try_with_sql(input)?.into_tokens();
245
246 let filtered_tokens: Vec<Token> = tokens
249 .iter()
250 .map(|t| t.token.clone())
251 .filter(|t| !matches!(t, Token::Whitespace(_)))
252 .collect();
253
254 if filtered_tokens.is_empty() {
256 return Ok(Vec::new());
257 }
258
259 let mut result = Vec::new();
261 let mut i = 0;
262
263 while i < filtered_tokens.len() {
264 if matches!(&filtered_tokens[i], Token::SemiColon) {
266 result.push(filtered_tokens[i].clone());
267 i += 1;
268 continue;
269 }
270
271 let mut matched = false;
273 for (pattern, replacement) in &self.blacklist {
274 if pattern.is_empty() {
275 continue;
276 }
277
278 let mut j = 0;
280 let mut pattern_idx = 0;
281 while i + j < filtered_tokens.len() && pattern_idx < pattern.len() {
282 if matches!(&filtered_tokens[i + j], Token::SemiColon) {
284 j += 1;
285 continue;
286 }
287
288 match &pattern[pattern_idx] {
289 Token::Placeholder(_) => {
290 pattern_idx += 1;
292 j += 1;
293 }
294 _ => {
295 if filtered_tokens[i + j] != pattern[pattern_idx] {
296 break;
297 }
298 pattern_idx += 1;
299 j += 1;
300 }
301 }
302 }
303
304 if pattern_idx == pattern.len() {
306 result.extend(replacement.iter().cloned());
308 i += j;
310 matched = true;
311 break;
312 }
313 }
314
315 if !matched {
316 result.push(filtered_tokens[i].clone());
318 i += 1;
319 }
320 }
321
322 Ok(result)
323 }
324
325 fn parse_tokens(&self, tokens: Vec<Token>) -> Result<Vec<Statement>, ParserError> {
326 let parser = Parser::new(&PostgreSqlDialect {});
327 let tokens_with_spans: Vec<TokenWithSpan> = tokens
329 .into_iter()
330 .map(|token| TokenWithSpan {
331 token,
332 span: datafusion::sql::sqlparser::tokenizer::Span::empty(),
333 })
334 .collect();
335 parser
336 .with_tokens_with_locations(tokens_with_spans)
337 .parse_statements()
338 }
339
340 pub fn parse(&self, input: &str) -> Result<Vec<Statement>, ParserError> {
341 let tokens = self.maybe_replace_tokens(input)?;
342 let statements = self.parse_tokens(tokens)?;
343
344 let statements: Vec<_> = statements.into_iter().map(|s| self.rewrite(s)).collect();
345 Ok(statements)
346 }
347
348 pub fn rewrite(&self, mut s: Statement) -> Statement {
349 for rule in &self.rewrite_rules {
350 s = rule.rewrite(s);
351 }
352
353 s
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_full_match() {
363 let sql = "SELECT pol.polname, pol.polpermissive,
364 CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END,
365 pg_catalog.pg_get_expr(pol.polqual, pol.polrelid),
366 pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid),
367 CASE pol.polcmd
368 WHEN 'r' THEN 'SELECT'
369 WHEN 'a' THEN 'INSERT'
370 WHEN 'w' THEN 'UPDATE'
371 WHEN 'd' THEN 'DELETE'
372 END AS cmd
373 FROM pg_catalog.pg_policy pol
374 WHERE pol.polrelid = '16384' ORDER BY 1;";
375
376 let parser = PostgresCompatibilityParser::new();
377 let actual_tokens = parser
378 .maybe_replace_tokens(sql)
379 .expect("failed to parse sql")
380 .into_iter()
381 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
382 .collect::<Vec<_>>();
383
384 let expected_sql = r#"SELECT
385 NULL::TEXT AS polname,
386 NULL::TEXT AS polpermissive,
387 NULL::TEXT AS array_to_string,
388 NULL::TEXT AS pg_get_expr_1,
389 NULL::TEXT AS pg_get_expr_2,
390 NULL::TEXT AS cmd
391 WHERE false"#;
392
393 let expected_tokens = Parser::new(&PostgreSqlDialect {})
394 .try_with_sql(expected_sql)
395 .unwrap()
396 .into_tokens()
397 .into_iter()
398 .map(|t| t.token)
399 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
400 .collect::<Vec<_>>();
401
402 assert_eq!(actual_tokens, expected_tokens);
403
404 let sql = "SELECT n.nspname schema_name,
405 t.typname type_name
406 FROM pg_catalog.pg_type t
407 INNER JOIN pg_catalog.pg_namespace n
408 ON n.oid = t.typnamespace
409 WHERE ( t.typrelid = 0 -- non-composite types
410 OR ( -- composite type, but not a table
411 SELECT c.relkind = 'c'
412 FROM pg_catalog.pg_class c
413 WHERE c.oid = t.typrelid
414 )
415 )
416 AND NOT EXISTS( -- ignore array types
417 SELECT 1
418 FROM pg_catalog.pg_type el
419 WHERE el.oid = t.typelem AND el.typarray = t.oid
420 )
421 AND n.nspname <> 'pg_catalog'
422 AND n.nspname <> 'information_schema'
423 ORDER BY 1, 2";
424
425 let parser = PostgresCompatibilityParser::new();
426
427 let actual_tokens = parser
428 .maybe_replace_tokens(sql)
429 .expect("failed to parse sql")
430 .into_iter()
431 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
432 .collect::<Vec<_>>();
433
434 let expected_sql =
435 r#"SELECT NULL::TEXT AS schema_name, NULL::TEXT AS type_name WHERE false"#;
436
437 let expected_tokens = Parser::new(&PostgreSqlDialect {})
438 .try_with_sql(expected_sql)
439 .unwrap()
440 .into_tokens()
441 .into_iter()
442 .map(|t| t.token)
443 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
444 .collect::<Vec<_>>();
445
446 assert_eq!(actual_tokens, expected_tokens);
447
448 let sql = "SELECT pubname
449 , NULL
450 , NULL
451 FROM pg_catalog.pg_publication p
452 JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid
453 JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid
454 WHERE pc.oid ='16384' and pg_catalog.pg_relation_is_publishable('16384')
455 UNION
456 SELECT pubname
457 , pg_get_expr(pr.prqual, c.oid)
458 , (CASE WHEN pr.prattrs IS NOT NULL THEN
459 (SELECT string_agg(attname, ', ')
460 FROM pg_catalog.generate_series(0, pg_catalog.array_upper(pr.prattrs::pg_catalog.int2[], 1)) s,
461 pg_catalog.pg_attribute
462 WHERE attrelid = pr.prrelid AND attnum = prattrs[s])
463 ELSE NULL END) FROM pg_catalog.pg_publication p
464 JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid
465 JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid
466 WHERE pr.prrelid = '16384'
467 UNION
468 SELECT pubname
469 , NULL
470 , NULL
471 FROM pg_catalog.pg_publication p
472 WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable('16384')
473 ORDER BY 1;";
474
475 let parser = PostgresCompatibilityParser::new();
476
477 let actual_tokens = parser
478 .maybe_replace_tokens(sql)
479 .expect("failed to parse sql")
480 .into_iter()
481 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
482 .collect::<Vec<_>>();
483
484 let expected_sql = r#"SELECT
485 NULL::TEXT AS pubname,
486 NULL::TEXT AS _1,
487 NULL::TEXT AS _2
488 WHERE false"#;
489
490 let expected_tokens = Parser::new(&PostgreSqlDialect {})
491 .try_with_sql(expected_sql)
492 .unwrap()
493 .into_tokens()
494 .into_iter()
495 .map(|t| t.token)
496 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
497 .collect::<Vec<_>>();
498
499 assert_eq!(actual_tokens, expected_tokens);
500 }
501
502 #[test]
503 fn test_empty_query() {
504 let parser = PostgresCompatibilityParser::new();
505 let result = parser.parse(" ").expect("failed to parse sql");
506 assert!(result.is_empty());
507
508 let result = parser.parse("").expect("failed to parse sql");
509 assert!(result.is_empty());
510
511 let result = parser.parse(";").expect("failed to parse sql");
512 assert!(result.is_empty());
513 }
514
515 #[test]
516 fn test_partial_match() {
517 let parser = PostgresCompatibilityParser::new();
518
519 let sql = r#"SELECT
522 CASE WHEN
523 quote_ident(table_schema) IN (
524 SELECT
525 CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END
526 FROM
527 generate_series(
528 array_lower(string_to_array(current_setting('search_path'),','),1),
529 array_upper(string_to_array(current_setting('search_path'),','),1)
530 ) as i,
531 string_to_array(current_setting('search_path'),',') s
532 )
533 THEN quote_ident(table_name)
534 ELSE quote_ident(table_schema) || '.' || quote_ident(table_name)
535 END AS "table"
536 FROM information_schema.tables
537 WHERE quote_ident(table_schema) NOT IN ('information_schema',
538 'pg_catalog',
539 '_timescaledb_cache',
540 '_timescaledb_catalog',
541 '_timescaledb_internal',
542 '_timescaledb_config',
543 'timescaledb_information',
544 'timescaledb_experimental')
545 ORDER BY CASE WHEN
546 quote_ident(table_schema) IN (
547 SELECT
548 CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END
549 FROM
550 generate_series(
551 array_lower(string_to_array(current_setting('search_path'),','),1),
552 array_upper(string_to_array(current_setting('search_path'),','),1)
553 ) as i,
554 string_to_array(current_setting('search_path'),',') s
555 ) THEN 0 ELSE 1 END, 1"#;
556
557 let tokens = parser
558 .maybe_replace_tokens(sql)
559 .expect("failed to parse sql");
560 assert!(tokens.len() > 0);
562
563 let expected_sql = r#"SELECT
564 CASE WHEN
565 quote_ident(table_schema) IN (
566 '')
567 THEN quote_ident(table_name)
568 ELSE quote_ident(table_schema) || '.' || quote_ident(table_name)
569 END AS "table"
570 FROM information_schema.tables
571 WHERE quote_ident(table_schema) NOT IN ('information_schema',
572 'pg_catalog',
573 '_timescaledb_cache',
574 '_timescaledb_catalog',
575 '_timescaledb_internal',
576 '_timescaledb_config',
577 'timescaledb_information',
578 'timescaledb_experimental')
579 ORDER BY CASE WHEN
580 quote_ident(table_schema) IN (
581 ''
582 ) THEN 0 ELSE 1 END, 1"#;
583
584 let expected_tokens = Parser::new(&PostgreSqlDialect {})
585 .try_with_sql(expected_sql)
586 .unwrap()
587 .into_tokens();
588
589 let actual_tokens: Vec<_> = tokens
591 .iter()
592 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
593 .collect();
594 let expected_token_values: Vec<_> = expected_tokens
595 .iter()
596 .map(|t| &t.token)
597 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
598 .collect();
599
600 assert_eq!(actual_tokens, expected_token_values);
601 }
602}