Skip to main content

citadel_sql/
dialect.rs

1use sqlparser::ast::{BinaryOperator, Expr as SpExpr, Statement as SpStatement};
2use sqlparser::dialect::{Dialect, GenericDialect, PostgreSqlDialect, Precedence};
3use sqlparser::parser::{Parser, ParserError};
4use sqlparser::tokenizer::Token;
5
6/// PG-superset dialect: delegates every method to `PostgreSqlDialect` and adds the
7/// `@?_tz` / `@@_tz` infix operators (tokenized as `[AtQuestion|AtAt, Word("_tz")]`,
8/// stitched via `parse_infix` into `BinaryOperator::Custom`).
9#[derive(Debug)]
10pub struct CitadelDialect {
11    inner: PostgreSqlDialect,
12}
13
14impl CitadelDialect {
15    pub fn new() -> Self {
16        Self {
17            inner: PostgreSqlDialect {},
18        }
19    }
20}
21
22impl Default for CitadelDialect {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl Dialect for CitadelDialect {
29    fn dialect(&self) -> std::any::TypeId {
30        self.inner.dialect()
31    }
32
33    fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
34        self.inner.identifier_quote_style(identifier)
35    }
36    fn is_identifier_start(&self, ch: char) -> bool {
37        self.inner.is_identifier_start(ch)
38    }
39    fn is_identifier_part(&self, ch: char) -> bool {
40        self.inner.is_identifier_part(ch)
41    }
42    fn is_delimited_identifier_start(&self, ch: char) -> bool {
43        self.inner.is_delimited_identifier_start(ch)
44    }
45
46    fn is_custom_operator_part(&self, ch: char) -> bool {
47        self.inner.is_custom_operator_part(ch)
48    }
49    fn supports_unicode_string_literal(&self) -> bool {
50        self.inner.supports_unicode_string_literal()
51    }
52    fn supports_string_literal_backslash_escape(&self) -> bool {
53        self.inner.supports_string_literal_backslash_escape()
54    }
55    fn supports_string_escape_constant(&self) -> bool {
56        self.inner.supports_string_escape_constant()
57    }
58    fn supports_numeric_literal_underscores(&self) -> bool {
59        self.inner.supports_numeric_literal_underscores()
60    }
61    fn supports_nested_comments(&self) -> bool {
62        self.inner.supports_nested_comments()
63    }
64    fn supports_factorial_operator(&self) -> bool {
65        self.inner.supports_factorial_operator()
66    }
67    fn supports_bitwise_shift_operators(&self) -> bool {
68        self.inner.supports_bitwise_shift_operators()
69    }
70    fn supports_geometric_types(&self) -> bool {
71        self.inner.supports_geometric_types()
72    }
73
74    fn get_next_precedence(&self, parser: &Parser) -> Option<Result<u8, ParserError>> {
75        self.inner.get_next_precedence(parser)
76    }
77    fn prec_value(&self, prec: Precedence) -> u8 {
78        self.inner.prec_value(prec)
79    }
80
81    fn supports_filter_during_aggregation(&self) -> bool {
82        self.inner.supports_filter_during_aggregation()
83    }
84    fn supports_within_after_array_aggregation(&self) -> bool {
85        self.inner.supports_within_after_array_aggregation()
86    }
87    fn supports_group_by_expr(&self) -> bool {
88        self.inner.supports_group_by_expr()
89    }
90    fn supports_named_fn_args_with_eq_operator(&self) -> bool {
91        self.inner.supports_named_fn_args_with_eq_operator()
92    }
93    fn supports_named_fn_args_with_assignment_operator(&self) -> bool {
94        self.inner.supports_named_fn_args_with_assignment_operator()
95    }
96    fn supports_named_fn_args_with_rarrow_operator(&self) -> bool {
97        self.inner.supports_named_fn_args_with_rarrow_operator()
98    }
99    fn supports_named_fn_args_with_colon_operator(&self) -> bool {
100        self.inner.supports_named_fn_args_with_colon_operator()
101    }
102    fn supports_named_fn_args_with_expr_name(&self) -> bool {
103        self.inner.supports_named_fn_args_with_expr_name()
104    }
105    fn supports_window_function_null_treatment_arg(&self) -> bool {
106        self.inner.supports_window_function_null_treatment_arg()
107    }
108    fn supports_dictionary_syntax(&self) -> bool {
109        self.inner.supports_dictionary_syntax()
110    }
111    fn supports_lambda_functions(&self) -> bool {
112        self.inner.supports_lambda_functions()
113    }
114
115    fn supports_in_empty_list(&self) -> bool {
116        self.inner.supports_in_empty_list()
117    }
118    fn supports_start_transaction_modifier(&self) -> bool {
119        self.inner.supports_start_transaction_modifier()
120    }
121    fn supports_parenthesized_set_variables(&self) -> bool {
122        self.inner.supports_parenthesized_set_variables()
123    }
124    fn supports_select_wildcard_except(&self) -> bool {
125        self.inner.supports_select_wildcard_except()
126    }
127    fn supports_empty_projections(&self) -> bool {
128        self.inner.supports_empty_projections()
129    }
130    fn convert_type_before_value(&self) -> bool {
131        self.inner.convert_type_before_value()
132    }
133    fn supports_triple_quoted_string(&self) -> bool {
134        self.inner.supports_triple_quoted_string()
135    }
136    fn supports_array_typedef_with_brackets(&self) -> bool {
137        self.inner.supports_array_typedef_with_brackets()
138    }
139    fn supports_create_index_with_clause(&self) -> bool {
140        self.inner.supports_create_index_with_clause()
141    }
142    fn supports_explain_with_utility_options(&self) -> bool {
143        self.inner.supports_explain_with_utility_options()
144    }
145    fn supports_listen_notify(&self) -> bool {
146        self.inner.supports_listen_notify()
147    }
148    fn supports_comment_on(&self) -> bool {
149        self.inner.supports_comment_on()
150    }
151    fn supports_load_extension(&self) -> bool {
152        self.inner.supports_load_extension()
153    }
154    fn supports_set_names(&self) -> bool {
155        self.inner.supports_set_names()
156    }
157    fn supports_alter_column_type_using(&self) -> bool {
158        self.inner.supports_alter_column_type_using()
159    }
160    fn supports_notnull_operator(&self) -> bool {
161        self.inner.supports_notnull_operator()
162    }
163    fn supports_interval_options(&self) -> bool {
164        self.inner.supports_interval_options()
165    }
166    fn allow_extract_custom(&self) -> bool {
167        self.inner.allow_extract_custom()
168    }
169    fn allow_extract_single_quotes(&self) -> bool {
170        self.inner.allow_extract_single_quotes()
171    }
172
173    /// Without this hook sqlparser sees `@?` followed by identifier `_tz` and parse-errors.
174    fn parse_infix(
175        &self,
176        parser: &mut Parser,
177        expr: &SpExpr,
178        precedence: u8,
179    ) -> Option<Result<SpExpr, ParserError>> {
180        let next = parser.peek_token().token;
181        let custom_op = match next {
182            Token::AtQuestion => "@?_tz",
183            Token::AtAt => "@@_tz",
184            _ => return self.inner.parse_infix(parser, expr, precedence),
185        };
186        let after = parser.peek_nth_token(1).token;
187        let Token::Word(w) = after else {
188            return self.inner.parse_infix(parser, expr, precedence);
189        };
190        if !w.value.eq_ignore_ascii_case("_tz") {
191            return self.inner.parse_infix(parser, expr, precedence);
192        }
193        parser.advance_token();
194        parser.advance_token();
195        let right = match parser.parse_subexpr(precedence) {
196            Ok(r) => r,
197            Err(e) => return Some(Err(e)),
198        };
199        Some(Ok(SpExpr::BinaryOp {
200            left: Box::new(expr.clone()),
201            op: BinaryOperator::Custom(custom_op.to_string()),
202            right: Box::new(right),
203        }))
204    }
205}
206
207/// Falls back to `GenericDialect` only for SQLite-style quirks PG rejects (e.g. `TRIM(s, c)`).
208pub fn parse_statements(sql: &str) -> Result<Vec<SpStatement>, ParserError> {
209    parse_with_fallback(sql, Parser::parse_sql)
210}
211
212pub fn parse_expr(sql: &str) -> Result<SpExpr, ParserError> {
213    parse_with_fallback(sql, |dialect, sql| {
214        Parser::new(dialect).try_with_sql(sql)?.parse_expr()
215    })
216}
217
218fn parse_with_fallback<T, F>(sql: &str, parse_fn: F) -> Result<T, ParserError>
219where
220    F: Fn(&dyn Dialect, &str) -> Result<T, ParserError>,
221{
222    let citadel = CitadelDialect::new();
223    match parse_fn(&citadel, sql) {
224        Ok(r) => Ok(r),
225        Err(pg_err) => {
226            let generic = GenericDialect {};
227            parse_fn(&generic, sql).map_err(|_| pg_err)
228        }
229    }
230}
231
232#[cfg(test)]
233#[path = "dialect_tests.rs"]
234mod tests;