proof_of_sql_parser/
sqlparser.rs

1//! This module exists to adapt the current parser to `sqlparser`.
2use crate::{
3    intermediate_ast::{
4        AliasedResultExpr, BinaryOperator as PoSqlBinaryOperator, Expression, Literal,
5        OrderBy as PoSqlOrderBy, OrderByDirection, SelectResultExpr, SetExpression,
6        TableExpression, UnaryOperator as PoSqlUnaryOperator,
7    },
8    Identifier, ResourceId, SelectStatement,
9};
10use alloc::{
11    boxed::Box,
12    format,
13    string::{String, ToString},
14    vec,
15};
16use core::fmt::Display;
17use sqlparser::ast::{
18    BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, GroupByExpr, Ident,
19    ObjectName, Offset, OffsetRows, OrderByExpr, Query, Select, SelectItem, SetExpr, TableFactor,
20    TableWithJoins, TimezoneInfo, UnaryOperator, Value, WildcardAdditionalOptions,
21};
22
23/// Convert a number into a [`Expr`].
24fn number<T>(val: T) -> Expr
25where
26    T: Display,
27{
28    Expr::Value(Value::Number(val.to_string(), false))
29}
30
31/// Convert an [`Identifier`] into a [`Expr`].
32fn id(id: Identifier) -> Expr {
33    Expr::Identifier(id.into())
34}
35
36impl From<Identifier> for Ident {
37    fn from(id: Identifier) -> Self {
38        Ident::new(id.as_str())
39    }
40}
41
42impl From<ResourceId> for ObjectName {
43    fn from(id: ResourceId) -> Self {
44        ObjectName(vec![id.schema().into(), id.object_name().into()])
45    }
46}
47
48impl From<TableExpression> for TableFactor {
49    fn from(table: TableExpression) -> Self {
50        match table {
51            TableExpression::Named { table, schema } => {
52                let object_name = if let Some(schema) = schema {
53                    ObjectName(vec![schema.into(), table.into()])
54                } else {
55                    ObjectName(vec![table.into()])
56                };
57                TableFactor::Table {
58                    name: object_name,
59                    alias: None,
60                    args: None,
61                    with_hints: vec![],
62                    version: None,
63                    partitions: vec![],
64                }
65            }
66        }
67    }
68}
69
70impl From<Literal> for Expr {
71    fn from(literal: Literal) -> Self {
72        match literal {
73            Literal::VarChar(s) => Expr::Value(Value::SingleQuotedString(s)),
74            Literal::VarBinary(bytes) => {
75                // Convert binary data to hex string for SQL representation
76                let hex_string =
77                    bytes
78                        .iter()
79                        .fold(String::with_capacity(bytes.len() * 2), |mut acc, byte| {
80                            acc.push_str(&format!("{byte:02x}"));
81                            acc
82                        });
83                Expr::Value(Value::HexStringLiteral(hex_string))
84            }
85            Literal::BigInt(n) => Expr::Value(Value::Number(n.to_string(), false)),
86            Literal::Int128(n) => Expr::Value(Value::Number(n.to_string(), false)),
87            Literal::Decimal(n) => Expr::Value(Value::Number(n.to_string(), false)),
88            Literal::Boolean(b) => Expr::Value(Value::Boolean(b)),
89            Literal::Timestamp(timestamp) => {
90                // We currently exclusively store timestamps in UTC.
91                Expr::TypedString {
92                    data_type: DataType::Timestamp(
93                        Some(timestamp.timeunit().into()),
94                        TimezoneInfo::None,
95                    ),
96                    value: timestamp.timestamp().to_string(),
97                }
98            }
99        }
100    }
101}
102
103impl From<PoSqlBinaryOperator> for BinaryOperator {
104    fn from(op: PoSqlBinaryOperator) -> Self {
105        match op {
106            PoSqlBinaryOperator::And => BinaryOperator::And,
107            PoSqlBinaryOperator::Or => BinaryOperator::Or,
108            PoSqlBinaryOperator::Equal => BinaryOperator::Eq,
109            PoSqlBinaryOperator::LessThan => BinaryOperator::Lt,
110            PoSqlBinaryOperator::GreaterThan => BinaryOperator::Gt,
111            PoSqlBinaryOperator::Add => BinaryOperator::Plus,
112            PoSqlBinaryOperator::Subtract => BinaryOperator::Minus,
113            PoSqlBinaryOperator::Multiply => BinaryOperator::Multiply,
114            PoSqlBinaryOperator::Division => BinaryOperator::Divide,
115        }
116    }
117}
118
119impl From<PoSqlUnaryOperator> for UnaryOperator {
120    fn from(op: PoSqlUnaryOperator) -> Self {
121        match op {
122            PoSqlUnaryOperator::Not => UnaryOperator::Not,
123        }
124    }
125}
126
127impl From<PoSqlOrderBy> for OrderByExpr {
128    fn from(order_by: PoSqlOrderBy) -> Self {
129        let asc = match order_by.direction {
130            OrderByDirection::Asc => Some(true),
131            OrderByDirection::Desc => Some(false),
132        };
133        OrderByExpr {
134            expr: id(order_by.expr),
135            asc,
136            nulls_first: None,
137        }
138    }
139}
140
141impl From<Expression> for Expr {
142    fn from(expr: Expression) -> Self {
143        match expr {
144            Expression::Literal(literal) => literal.into(),
145            Expression::Column(identifier) => id(identifier),
146            Expression::Unary { op, expr } => Expr::UnaryOp {
147                op: op.into(),
148                expr: Box::new((*expr).into()),
149            },
150            Expression::Binary { op, left, right } => Expr::BinaryOp {
151                left: Box::new((*left).into()),
152                op: op.into(),
153                right: Box::new((*right).into()),
154            },
155            Expression::Wildcard => Expr::Wildcard,
156            Expression::Aggregation { op, expr } => Expr::Function(Function {
157                name: ObjectName(vec![Ident::new(op.to_string())]),
158                args: vec![FunctionArg::Unnamed((*expr).into())],
159                filter: None,
160                null_treatment: None,
161                over: None,
162                distinct: false,
163                special: false,
164                order_by: vec![],
165            }),
166        }
167    }
168}
169
170// Note that sqlparser singles out `Wildcard` as a separate case, so we have to handle it separately.
171impl From<Expression> for FunctionArgExpr {
172    fn from(expr: Expression) -> Self {
173        match expr {
174            Expression::Wildcard => FunctionArgExpr::Wildcard,
175            _ => FunctionArgExpr::Expr(expr.into()),
176        }
177    }
178}
179
180impl From<SelectResultExpr> for SelectItem {
181    fn from(select: SelectResultExpr) -> Self {
182        match select {
183            SelectResultExpr::ALL => SelectItem::Wildcard(WildcardAdditionalOptions {
184                opt_exclude: None,
185                opt_except: None,
186                opt_rename: None,
187                opt_replace: None,
188            }),
189            SelectResultExpr::AliasedResultExpr(AliasedResultExpr { expr, alias }) => {
190                SelectItem::ExprWithAlias {
191                    expr: (*expr).into(),
192                    alias: alias.into(),
193                }
194            }
195        }
196    }
197}
198
199impl From<SetExpression> for Select {
200    fn from(select: SetExpression) -> Self {
201        match select {
202            SetExpression::Query {
203                result_exprs,
204                from,
205                where_expr,
206                group_by,
207            } => Select {
208                distinct: None,
209                top: None,
210                projection: result_exprs.into_iter().map(SelectItem::from).collect(),
211                into: None,
212                from: from
213                    .into_iter()
214                    .map(|table_expression| TableWithJoins {
215                        relation: (*table_expression).into(),
216                        joins: vec![],
217                    })
218                    .collect(),
219                lateral_views: vec![],
220                selection: where_expr.map(|expr| (*expr).into()),
221                group_by: GroupByExpr::Expressions(group_by.into_iter().map(id).collect()),
222                cluster_by: vec![],
223                distribute_by: vec![],
224                sort_by: vec![],
225                having: None,
226                named_window: vec![],
227                qualify: None,
228                value_table_mode: None,
229            },
230        }
231    }
232}
233
234impl From<SelectStatement> for Query {
235    fn from(select: SelectStatement) -> Self {
236        Query {
237            with: None,
238            body: Box::new(SetExpr::Select(Box::new((*select.expr).into()))),
239            order_by: select.order_by.into_iter().map(OrderByExpr::from).collect(),
240            limit: select.slice.clone().map(|slice| number(slice.number_rows)),
241            limit_by: vec![],
242            offset: select.slice.map(|slice| Offset {
243                value: number(slice.offset_value),
244                rows: OffsetRows::None,
245            }),
246            fetch: None,
247            locks: vec![],
248            for_clause: None,
249        }
250    }
251}
252
253#[cfg(test)]
254mod test {
255    use super::*;
256    use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser};
257
258    // Check that the intermediate AST can be converted to the SQL parser AST which should functionally match
259    // the direct conversion from the SQL string.
260    // Note that the `PoSQL` parser has some quirks:
261    // - If LIMIT is specified, OFFSET must also be specified so we have to append `OFFSET 0`.
262    // - Explicit aliases are mandatory for all columns.
263    // In this case we will provide an equivalent query
264    // for sqlparser that clearly demonstrates the same functionality.
265    fn check_posql_intermediate_ast_to_sqlparser_equivalence(posql_sql: &str, sqlparser_sql: &str) {
266        let dialect = PostgreSqlDialect {};
267        let posql_ast = posql_sql.parse::<SelectStatement>().unwrap();
268        let converted_sqlparser_ast = &Statement::Query(Box::new(Query::from(posql_ast)));
269        let direct_sqlparser_ast = &Parser::parse_sql(&dialect, sqlparser_sql).unwrap()[0];
270        assert_eq!(converted_sqlparser_ast, direct_sqlparser_ast);
271    }
272
273    #[test]
274    fn we_can_convert_posql_intermediate_ast_to_sqlparser_with_slight_modification() {
275        check_posql_intermediate_ast_to_sqlparser_equivalence(
276            "select a, b from t limit 10;",
277            "select a as a, b as b from t limit 10 offset 0;",
278        );
279        check_posql_intermediate_ast_to_sqlparser_equivalence(
280            "select timestamp '2024-11-07T04:55:12+00:00' as time from t;",
281            "select timestamp(0) '2024-11-07 04:55:12 UTC' as time from t;",
282        );
283        check_posql_intermediate_ast_to_sqlparser_equivalence(
284            "select timestamp '2024-11-07T04:55:12.345+03:00' as time from t;",
285            "select timestamp(3) '2024-11-07 01:55:12.345 UTC' as time from t;",
286        );
287    }
288
289    // Check that PoSQL intermediate AST can be converted to SQL parser AST and that the two are equal.
290    // Note that this is a stricter test than the previous one so when quirks are present in the PoSQL AST
291    // We will have to use the `check_posql_intermediate_ast_to_sqlparser_equivalence` function.
292    fn check_posql_intermediate_ast_to_sqlparser_equality(sql: &str) {
293        check_posql_intermediate_ast_to_sqlparser_equivalence(sql, sql);
294    }
295
296    #[test]
297    fn we_can_convert_posql_intermediate_ast_to_sqlparser() {
298        check_posql_intermediate_ast_to_sqlparser_equality("SELECT * FROM t");
299        check_posql_intermediate_ast_to_sqlparser_equality(
300            "select a as a, 4.7 * b as b from namespace.table where c = 2.5;",
301        );
302        check_posql_intermediate_ast_to_sqlparser_equality(
303            "select a as a, b as b from namespace.table where c = 4;",
304        );
305        check_posql_intermediate_ast_to_sqlparser_equality(
306            "select a as a, b as b from namespace.table where c = 4 order by a desc;",
307        );
308        check_posql_intermediate_ast_to_sqlparser_equality("select 1 as a, 'Meow' as d, b as b from namespace.table where c = 4 order by a desc limit 10 offset 0;");
309        check_posql_intermediate_ast_to_sqlparser_equality(
310            "select true as cons, a and b or c > 4 as comp from tab where d = 'Space and Time';",
311        );
312        check_posql_intermediate_ast_to_sqlparser_equality(
313            "select cat as cat, true as cons, max(meow) as max_meow from tab where d = 'Space and Time' group by cat;",
314        );
315        check_posql_intermediate_ast_to_sqlparser_equality(
316            "select cat as cat, sum(a) as s, count(*) as rows from tab where d = 'Space and Time' group by cat;",
317        );
318    }
319}