proof_of_sql_parser/
sqlparser.rs

1//! This module exists to adapt the current parser to `sqlparser`.
2use crate::{
3    intermediate_ast::{
4        AliasedResultExpr, BinaryOperator as PoSqlBinaryOperator, Expression, Literal,
5        OrderBy as PoSqlOrderBy, OrderByDirection, SelectResultExpr, SetExpression,
6        TableExpression, UnaryOperator as PoSqlUnaryOperator,
7    },
8    Identifier, ResourceId, SelectStatement,
9};
10use alloc::{boxed::Box, string::ToString, vec};
11use core::fmt::Display;
12use sqlparser::ast::{
13    BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, GroupByExpr, Ident,
14    ObjectName, Offset, OffsetRows, OrderByExpr, Query, Select, SelectItem, SetExpr, TableFactor,
15    TableWithJoins, TimezoneInfo, UnaryOperator, Value, WildcardAdditionalOptions,
16};
17
18/// Convert a number into a [`Expr`].
19fn number<T>(val: T) -> Expr
20where
21    T: Display,
22{
23    Expr::Value(Value::Number(val.to_string(), false))
24}
25
26/// Convert an [`Identifier`] into a [`Expr`].
27fn id(id: Identifier) -> Expr {
28    Expr::Identifier(id.into())
29}
30
31impl From<Identifier> for Ident {
32    fn from(id: Identifier) -> Self {
33        Ident::new(id.as_str())
34    }
35}
36
37impl From<ResourceId> for ObjectName {
38    fn from(id: ResourceId) -> Self {
39        ObjectName(vec![id.schema().into(), id.object_name().into()])
40    }
41}
42
43impl From<TableExpression> for TableFactor {
44    fn from(table: TableExpression) -> Self {
45        match table {
46            TableExpression::Named { table, schema } => {
47                let object_name = if let Some(schema) = schema {
48                    ObjectName(vec![schema.into(), table.into()])
49                } else {
50                    ObjectName(vec![table.into()])
51                };
52                TableFactor::Table {
53                    name: object_name,
54                    alias: None,
55                    args: None,
56                    with_hints: vec![],
57                    version: None,
58                    partitions: vec![],
59                }
60            }
61        }
62    }
63}
64
65impl From<Literal> for Expr {
66    fn from(literal: Literal) -> Self {
67        match literal {
68            Literal::VarChar(s) => Expr::Value(Value::SingleQuotedString(s)),
69            Literal::BigInt(n) => Expr::Value(Value::Number(n.to_string(), false)),
70            Literal::Int128(n) => Expr::Value(Value::Number(n.to_string(), false)),
71            Literal::Decimal(n) => Expr::Value(Value::Number(n.to_string(), false)),
72            Literal::Boolean(b) => Expr::Value(Value::Boolean(b)),
73            Literal::Timestamp(timestamp) => {
74                // We currently exclusively store timestamps in UTC.
75                Expr::TypedString {
76                    data_type: DataType::Timestamp(
77                        Some(timestamp.timeunit().into()),
78                        TimezoneInfo::None,
79                    ),
80                    value: timestamp.timestamp().to_string(),
81                }
82            }
83        }
84    }
85}
86
87impl From<PoSqlBinaryOperator> for BinaryOperator {
88    fn from(op: PoSqlBinaryOperator) -> Self {
89        match op {
90            PoSqlBinaryOperator::And => BinaryOperator::And,
91            PoSqlBinaryOperator::Or => BinaryOperator::Or,
92            PoSqlBinaryOperator::Equal => BinaryOperator::Eq,
93            PoSqlBinaryOperator::LessThan => BinaryOperator::Lt,
94            PoSqlBinaryOperator::GreaterThan => BinaryOperator::Gt,
95            PoSqlBinaryOperator::Add => BinaryOperator::Plus,
96            PoSqlBinaryOperator::Subtract => BinaryOperator::Minus,
97            PoSqlBinaryOperator::Multiply => BinaryOperator::Multiply,
98            PoSqlBinaryOperator::Division => BinaryOperator::Divide,
99        }
100    }
101}
102
103impl From<PoSqlUnaryOperator> for UnaryOperator {
104    fn from(op: PoSqlUnaryOperator) -> Self {
105        match op {
106            PoSqlUnaryOperator::Not => UnaryOperator::Not,
107        }
108    }
109}
110
111impl From<PoSqlOrderBy> for OrderByExpr {
112    fn from(order_by: PoSqlOrderBy) -> Self {
113        let asc = match order_by.direction {
114            OrderByDirection::Asc => Some(true),
115            OrderByDirection::Desc => Some(false),
116        };
117        OrderByExpr {
118            expr: id(order_by.expr),
119            asc,
120            nulls_first: None,
121        }
122    }
123}
124
125impl From<Expression> for Expr {
126    fn from(expr: Expression) -> Self {
127        match expr {
128            Expression::Literal(literal) => literal.into(),
129            Expression::Column(identifier) => id(identifier),
130            Expression::Unary { op, expr } => Expr::UnaryOp {
131                op: op.into(),
132                expr: Box::new((*expr).into()),
133            },
134            Expression::Binary { op, left, right } => Expr::BinaryOp {
135                left: Box::new((*left).into()),
136                op: op.into(),
137                right: Box::new((*right).into()),
138            },
139            Expression::Wildcard => Expr::Wildcard,
140            Expression::Aggregation { op, expr } => Expr::Function(Function {
141                name: ObjectName(vec![Ident::new(op.to_string())]),
142                args: vec![FunctionArg::Unnamed((*expr).into())],
143                filter: None,
144                null_treatment: None,
145                over: None,
146                distinct: false,
147                special: false,
148                order_by: vec![],
149            }),
150        }
151    }
152}
153
154// Note that sqlparser singles out `Wildcard` as a separate case, so we have to handle it separately.
155impl From<Expression> for FunctionArgExpr {
156    fn from(expr: Expression) -> Self {
157        match expr {
158            Expression::Wildcard => FunctionArgExpr::Wildcard,
159            _ => FunctionArgExpr::Expr(expr.into()),
160        }
161    }
162}
163
164impl From<SelectResultExpr> for SelectItem {
165    fn from(select: SelectResultExpr) -> Self {
166        match select {
167            SelectResultExpr::ALL => SelectItem::Wildcard(WildcardAdditionalOptions {
168                opt_exclude: None,
169                opt_except: None,
170                opt_rename: None,
171                opt_replace: None,
172            }),
173            SelectResultExpr::AliasedResultExpr(AliasedResultExpr { expr, alias }) => {
174                SelectItem::ExprWithAlias {
175                    expr: (*expr).into(),
176                    alias: alias.into(),
177                }
178            }
179        }
180    }
181}
182
183impl From<SetExpression> for Select {
184    fn from(select: SetExpression) -> Self {
185        match select {
186            SetExpression::Query {
187                result_exprs,
188                from,
189                where_expr,
190                group_by,
191            } => Select {
192                distinct: None,
193                top: None,
194                projection: result_exprs.into_iter().map(SelectItem::from).collect(),
195                into: None,
196                from: from
197                    .into_iter()
198                    .map(|table_expression| TableWithJoins {
199                        relation: (*table_expression).into(),
200                        joins: vec![],
201                    })
202                    .collect(),
203                lateral_views: vec![],
204                selection: where_expr.map(|expr| (*expr).into()),
205                group_by: GroupByExpr::Expressions(group_by.into_iter().map(id).collect()),
206                cluster_by: vec![],
207                distribute_by: vec![],
208                sort_by: vec![],
209                having: None,
210                named_window: vec![],
211                qualify: None,
212                value_table_mode: None,
213            },
214        }
215    }
216}
217
218impl From<SelectStatement> for Query {
219    fn from(select: SelectStatement) -> Self {
220        Query {
221            with: None,
222            body: Box::new(SetExpr::Select(Box::new((*select.expr).into()))),
223            order_by: select.order_by.into_iter().map(OrderByExpr::from).collect(),
224            limit: select.slice.clone().map(|slice| number(slice.number_rows)),
225            limit_by: vec![],
226            offset: select.slice.map(|slice| Offset {
227                value: number(slice.offset_value),
228                rows: OffsetRows::None,
229            }),
230            fetch: None,
231            locks: vec![],
232            for_clause: None,
233        }
234    }
235}
236
237#[cfg(test)]
238mod test {
239    use super::*;
240    use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser};
241
242    // Check that the intermediate AST can be converted to the SQL parser AST which should functionally match
243    // the direct conversion from the SQL string.
244    // Note that the `PoSQL` parser has some quirks:
245    // - If LIMIT is specified, OFFSET must also be specified so we have to append `OFFSET 0`.
246    // - Explicit aliases are mandatory for all columns.
247    // In this case we will provide an equivalent query
248    // for sqlparser that clearly demonstrates the same functionality.
249    fn check_posql_intermediate_ast_to_sqlparser_equivalence(posql_sql: &str, sqlparser_sql: &str) {
250        let dialect = PostgreSqlDialect {};
251        let posql_ast = posql_sql.parse::<SelectStatement>().unwrap();
252        let converted_sqlparser_ast = &Statement::Query(Box::new(Query::from(posql_ast)));
253        let direct_sqlparser_ast = &Parser::parse_sql(&dialect, sqlparser_sql).unwrap()[0];
254        assert_eq!(converted_sqlparser_ast, direct_sqlparser_ast);
255    }
256
257    #[test]
258    fn we_can_convert_posql_intermediate_ast_to_sqlparser_with_slight_modification() {
259        check_posql_intermediate_ast_to_sqlparser_equivalence(
260            "select a, b from t limit 10;",
261            "select a as a, b as b from t limit 10 offset 0;",
262        );
263        check_posql_intermediate_ast_to_sqlparser_equivalence(
264            "select timestamp '2024-11-07T04:55:12+00:00' as time from t;",
265            "select timestamp(0) '2024-11-07 04:55:12 UTC' as time from t;",
266        );
267        check_posql_intermediate_ast_to_sqlparser_equivalence(
268            "select timestamp '2024-11-07T04:55:12.345+03:00' as time from t;",
269            "select timestamp(3) '2024-11-07 01:55:12.345 UTC' as time from t;",
270        );
271    }
272
273    // Check that PoSQL intermediate AST can be converted to SQL parser AST and that the two are equal.
274    // Note that this is a stricter test than the previous one so when quirks are present in the PoSQL AST
275    // We will have to use the `check_posql_intermediate_ast_to_sqlparser_equivalence` function.
276    fn check_posql_intermediate_ast_to_sqlparser_equality(sql: &str) {
277        check_posql_intermediate_ast_to_sqlparser_equivalence(sql, sql);
278    }
279
280    #[test]
281    fn we_can_convert_posql_intermediate_ast_to_sqlparser() {
282        check_posql_intermediate_ast_to_sqlparser_equality("SELECT * FROM t");
283        check_posql_intermediate_ast_to_sqlparser_equality(
284            "select a as a, 4.7 * b as b from namespace.table where c = 2.5;",
285        );
286        check_posql_intermediate_ast_to_sqlparser_equality(
287            "select a as a, b as b from namespace.table where c = 4;",
288        );
289        check_posql_intermediate_ast_to_sqlparser_equality(
290            "select a as a, b as b from namespace.table where c = 4 order by a desc;",
291        );
292        check_posql_intermediate_ast_to_sqlparser_equality("select 1 as a, 'Meow' as d, b as b from namespace.table where c = 4 order by a desc limit 10 offset 0;");
293        check_posql_intermediate_ast_to_sqlparser_equality(
294            "select true as cons, a and b or c > 4 as comp from tab where d = 'Space and Time';",
295        );
296        check_posql_intermediate_ast_to_sqlparser_equality(
297            "select cat as cat, true as cons, max(meow) as max_meow from tab where d = 'Space and Time' group by cat;",
298        );
299        check_posql_intermediate_ast_to_sqlparser_equality(
300            "select cat as cat, sum(a) as s, count(*) as rows from tab where d = 'Space and Time' group by cat;",
301        );
302    }
303}