proof_of_sql_parser/
sqlparser.rs

1//! This module exists to adapt the current parser to `sqlparser`.
2use crate::{
3    intermediate_ast::{
4        AliasedResultExpr, BinaryOperator as PoSqlBinaryOperator, Expression, Literal,
5        OrderBy as PoSqlOrderBy, OrderByDirection, SelectResultExpr, SetExpression,
6        TableExpression, UnaryOperator as PoSqlUnaryOperator,
7    },
8    Identifier, ResourceId, SelectStatement,
9};
10use alloc::{
11    boxed::Box,
12    string::{String, ToString},
13    vec,
14};
15use core::fmt::{Display, Write};
16use sqlparser::ast::{
17    BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, GroupByExpr, Ident,
18    ObjectName, Offset, OffsetRows, OrderByExpr, Query, Select, SelectItem, SetExpr, TableFactor,
19    TableWithJoins, TimezoneInfo, UnaryOperator, Value, WildcardAdditionalOptions,
20};
21
22/// Convert a number into a [`Expr`].
23fn number<T>(val: T) -> Expr
24where
25    T: Display,
26{
27    Expr::Value(Value::Number(val.to_string(), false))
28}
29
30/// Convert an [`Identifier`] into a [`Expr`].
31fn id(id: Identifier) -> Expr {
32    Expr::Identifier(id.into())
33}
34
35impl From<Identifier> for Ident {
36    fn from(id: Identifier) -> Self {
37        Ident::new(id.as_str())
38    }
39}
40
41impl From<ResourceId> for ObjectName {
42    fn from(id: ResourceId) -> Self {
43        ObjectName(vec![id.schema().into(), id.object_name().into()])
44    }
45}
46
47impl From<TableExpression> for TableFactor {
48    fn from(table: TableExpression) -> Self {
49        match table {
50            TableExpression::Named { table, schema } => {
51                let object_name = if let Some(schema) = schema {
52                    ObjectName(vec![schema.into(), table.into()])
53                } else {
54                    ObjectName(vec![table.into()])
55                };
56                TableFactor::Table {
57                    name: object_name,
58                    alias: None,
59                    args: None,
60                    with_hints: vec![],
61                    version: None,
62                    partitions: vec![],
63                }
64            }
65        }
66    }
67}
68
69impl From<Literal> for Expr {
70    fn from(literal: Literal) -> Self {
71        match literal {
72            Literal::VarChar(s) => Expr::Value(Value::SingleQuotedString(s)),
73            Literal::VarBinary(bytes) => {
74                // Convert binary data to hex string for SQL representation
75                let hex_string =
76                    bytes
77                        .iter()
78                        .fold(String::with_capacity(bytes.len() * 2), |mut acc, byte| {
79                            write!(&mut acc, "{byte:02x}").unwrap();
80                            acc
81                        });
82                Expr::Value(Value::HexStringLiteral(hex_string))
83            }
84            Literal::BigInt(n) => Expr::Value(Value::Number(n.to_string(), false)),
85            Literal::Int128(n) => Expr::Value(Value::Number(n.to_string(), false)),
86            Literal::Decimal(n) => Expr::Value(Value::Number(n.to_string(), false)),
87            Literal::Boolean(b) => Expr::Value(Value::Boolean(b)),
88            Literal::Timestamp(timestamp) => {
89                // We currently exclusively store timestamps in UTC.
90                Expr::TypedString {
91                    data_type: DataType::Timestamp(
92                        Some(timestamp.timeunit().into()),
93                        TimezoneInfo::None,
94                    ),
95                    value: timestamp.timestamp().to_string(),
96                }
97            }
98        }
99    }
100}
101
102impl From<PoSqlBinaryOperator> for BinaryOperator {
103    fn from(op: PoSqlBinaryOperator) -> Self {
104        match op {
105            PoSqlBinaryOperator::And => BinaryOperator::And,
106            PoSqlBinaryOperator::Or => BinaryOperator::Or,
107            PoSqlBinaryOperator::Equal => BinaryOperator::Eq,
108            PoSqlBinaryOperator::LessThan => BinaryOperator::Lt,
109            PoSqlBinaryOperator::GreaterThan => BinaryOperator::Gt,
110            PoSqlBinaryOperator::Add => BinaryOperator::Plus,
111            PoSqlBinaryOperator::Subtract => BinaryOperator::Minus,
112            PoSqlBinaryOperator::Multiply => BinaryOperator::Multiply,
113            PoSqlBinaryOperator::Division => BinaryOperator::Divide,
114        }
115    }
116}
117
118impl From<PoSqlUnaryOperator> for UnaryOperator {
119    fn from(op: PoSqlUnaryOperator) -> Self {
120        match op {
121            PoSqlUnaryOperator::Not => UnaryOperator::Not,
122        }
123    }
124}
125
126impl From<PoSqlOrderBy> for OrderByExpr {
127    fn from(order_by: PoSqlOrderBy) -> Self {
128        let asc = match order_by.direction {
129            OrderByDirection::Asc => Some(true),
130            OrderByDirection::Desc => Some(false),
131        };
132        OrderByExpr {
133            expr: id(order_by.expr),
134            asc,
135            nulls_first: None,
136        }
137    }
138}
139
140impl From<Expression> for Expr {
141    fn from(expr: Expression) -> Self {
142        match expr {
143            Expression::Literal(literal) => literal.into(),
144            Expression::Column(identifier) => id(identifier),
145            Expression::Unary { op, expr } => Expr::UnaryOp {
146                op: op.into(),
147                expr: Box::new((*expr).into()),
148            },
149            Expression::Binary { op, left, right } => Expr::BinaryOp {
150                left: Box::new((*left).into()),
151                op: op.into(),
152                right: Box::new((*right).into()),
153            },
154            Expression::Wildcard => Expr::Wildcard,
155            Expression::Aggregation { op, expr } => Expr::Function(Function {
156                name: ObjectName(vec![Ident::new(op.to_string())]),
157                args: vec![FunctionArg::Unnamed((*expr).into())],
158                filter: None,
159                null_treatment: None,
160                over: None,
161                distinct: false,
162                special: false,
163                order_by: vec![],
164            }),
165        }
166    }
167}
168
169// Note that sqlparser singles out `Wildcard` as a separate case, so we have to handle it separately.
170impl From<Expression> for FunctionArgExpr {
171    fn from(expr: Expression) -> Self {
172        match expr {
173            Expression::Wildcard => FunctionArgExpr::Wildcard,
174            _ => FunctionArgExpr::Expr(expr.into()),
175        }
176    }
177}
178
179impl From<SelectResultExpr> for SelectItem {
180    fn from(select: SelectResultExpr) -> Self {
181        match select {
182            SelectResultExpr::ALL => SelectItem::Wildcard(WildcardAdditionalOptions {
183                opt_exclude: None,
184                opt_except: None,
185                opt_rename: None,
186                opt_replace: None,
187            }),
188            SelectResultExpr::AliasedResultExpr(AliasedResultExpr { expr, alias }) => {
189                SelectItem::ExprWithAlias {
190                    expr: (*expr).into(),
191                    alias: alias.into(),
192                }
193            }
194        }
195    }
196}
197
198impl From<SetExpression> for Select {
199    fn from(select: SetExpression) -> Self {
200        match select {
201            SetExpression::Query {
202                result_exprs,
203                from,
204                where_expr,
205                group_by,
206            } => Select {
207                distinct: None,
208                top: None,
209                projection: result_exprs.into_iter().map(SelectItem::from).collect(),
210                into: None,
211                from: from
212                    .into_iter()
213                    .map(|table_expression| TableWithJoins {
214                        relation: (*table_expression).into(),
215                        joins: vec![],
216                    })
217                    .collect(),
218                lateral_views: vec![],
219                selection: where_expr.map(|expr| (*expr).into()),
220                group_by: GroupByExpr::Expressions(group_by.into_iter().map(id).collect()),
221                cluster_by: vec![],
222                distribute_by: vec![],
223                sort_by: vec![],
224                having: None,
225                named_window: vec![],
226                qualify: None,
227                value_table_mode: None,
228            },
229        }
230    }
231}
232
233impl From<SelectStatement> for Query {
234    fn from(select: SelectStatement) -> Self {
235        Query {
236            with: None,
237            body: Box::new(SetExpr::Select(Box::new((*select.expr).into()))),
238            order_by: select.order_by.into_iter().map(OrderByExpr::from).collect(),
239            limit: select.slice.clone().map(|slice| number(slice.number_rows)),
240            limit_by: vec![],
241            offset: select.slice.map(|slice| Offset {
242                value: number(slice.offset_value),
243                rows: OffsetRows::None,
244            }),
245            fetch: None,
246            locks: vec![],
247            for_clause: None,
248        }
249    }
250}
251
252#[cfg(test)]
253mod test {
254    use super::*;
255    use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser};
256
257    // Check that the intermediate AST can be converted to the SQL parser AST which should functionally match
258    // the direct conversion from the SQL string.
259    // Note that the `PoSQL` parser has some quirks:
260    // - If LIMIT is specified, OFFSET must also be specified so we have to append `OFFSET 0`.
261    // - Explicit aliases are mandatory for all columns.
262    // In this case we will provide an equivalent query
263    // for sqlparser that clearly demonstrates the same functionality.
264    fn check_posql_intermediate_ast_to_sqlparser_equivalence(posql_sql: &str, sqlparser_sql: &str) {
265        let dialect = PostgreSqlDialect {};
266        let posql_ast = posql_sql.parse::<SelectStatement>().unwrap();
267        let converted_sqlparser_ast = &Statement::Query(Box::new(Query::from(posql_ast)));
268        let direct_sqlparser_ast = &Parser::parse_sql(&dialect, sqlparser_sql).unwrap()[0];
269        assert_eq!(converted_sqlparser_ast, direct_sqlparser_ast);
270    }
271
272    #[test]
273    fn we_can_convert_posql_intermediate_ast_to_sqlparser_with_slight_modification() {
274        check_posql_intermediate_ast_to_sqlparser_equivalence(
275            "select a, b from t limit 10;",
276            "select a as a, b as b from t limit 10 offset 0;",
277        );
278        check_posql_intermediate_ast_to_sqlparser_equivalence(
279            "select timestamp '2024-11-07T04:55:12+00:00' as time from t;",
280            "select timestamp(0) '2024-11-07 04:55:12 UTC' as time from t;",
281        );
282        check_posql_intermediate_ast_to_sqlparser_equivalence(
283            "select timestamp '2024-11-07T04:55:12.345+03:00' as time from t;",
284            "select timestamp(3) '2024-11-07 01:55:12.345 UTC' as time from t;",
285        );
286    }
287
288    // Check that PoSQL intermediate AST can be converted to SQL parser AST and that the two are equal.
289    // Note that this is a stricter test than the previous one so when quirks are present in the PoSQL AST
290    // We will have to use the `check_posql_intermediate_ast_to_sqlparser_equivalence` function.
291    fn check_posql_intermediate_ast_to_sqlparser_equality(sql: &str) {
292        check_posql_intermediate_ast_to_sqlparser_equivalence(sql, sql);
293    }
294
295    #[test]
296    fn we_can_convert_posql_intermediate_ast_to_sqlparser() {
297        check_posql_intermediate_ast_to_sqlparser_equality("SELECT * FROM t");
298        check_posql_intermediate_ast_to_sqlparser_equality(
299            "select a as a, 4.7 * b as b from namespace.table where c = 2.5;",
300        );
301        check_posql_intermediate_ast_to_sqlparser_equality(
302            "select a as a, b as b from namespace.table where c = 4;",
303        );
304        check_posql_intermediate_ast_to_sqlparser_equality(
305            "select a as a, b as b from namespace.table where c = 4 order by a desc;",
306        );
307        check_posql_intermediate_ast_to_sqlparser_equality("select 1 as a, 'Meow' as d, b as b from namespace.table where c = 4 order by a desc limit 10 offset 0;");
308        check_posql_intermediate_ast_to_sqlparser_equality(
309            "select true as cons, a and b or c > 4 as comp from tab where d = 'Space and Time';",
310        );
311        check_posql_intermediate_ast_to_sqlparser_equality(
312            "select cat as cat, true as cons, max(meow) as max_meow from tab where d = 'Space and Time' group by cat;",
313        );
314        check_posql_intermediate_ast_to_sqlparser_equality(
315            "select cat as cat, sum(a) as s, count(*) as rows from tab where d = 'Space and Time' group by cat;",
316        );
317    }
318}