Skip to main content

lance_datafusion/
sql.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! SQL Parser utility
5
6use std::any::TypeId;
7
8use datafusion::sql::sqlparser::{
9    ast::{Expr, SelectItem, SetExpr, Statement},
10    dialect::{Dialect, GenericDialect},
11    parser::Parser,
12    tokenizer::{Token, Tokenizer},
13};
14
15use lance_core::{Error, Result};
16#[derive(Debug, Default)]
17struct LanceDialect(GenericDialect);
18
19impl LanceDialect {
20    fn new() -> Self {
21        Self(GenericDialect {})
22    }
23}
24
25impl Dialect for LanceDialect {
26    fn dialect(&self) -> TypeId {
27        self.0.dialect()
28    }
29
30    fn is_identifier_start(&self, ch: char) -> bool {
31        self.0.is_identifier_start(ch)
32    }
33
34    fn is_identifier_part(&self, ch: char) -> bool {
35        self.0.is_identifier_part(ch)
36    }
37
38    fn is_delimited_identifier_start(&self, ch: char) -> bool {
39        ch == '`'
40    }
41}
42
43/// Parse sql filter to Expression.
44pub(crate) fn parse_sql_filter(filter: &str) -> Result<Expr> {
45    let sql = format!("SELECT 1 FROM t WHERE {filter}");
46    let statement = parse_statement(&sql)?;
47
48    let selection = if let Statement::Query(query) = &statement {
49        if let SetExpr::Select(s) = query.body.as_ref() {
50            s.selection.as_ref()
51        } else {
52            None
53        }
54    } else {
55        None
56    };
57    let expr =
58        selection.ok_or_else(|| Error::invalid_input(format!("Filter is not valid: {filter}")))?;
59    Ok(expr.clone())
60}
61
62/// Parse a SQL expression to Expression. This is more lenient than parse_sql_filter
63/// as it can be used for projection expressions as well.
64pub(crate) fn parse_sql_expr(expr: &str) -> Result<Expr> {
65    let sql = format!("SELECT {expr} FROM t");
66    let statement = parse_statement(&sql)?;
67
68    let selection = if let Statement::Query(query) = &statement {
69        if let SetExpr::Select(s) = query.body.as_ref() {
70            if let SelectItem::UnnamedExpr(expr) = &s.projection[0] {
71                Some(expr)
72            } else {
73                None
74            }
75        } else {
76            None
77        }
78    } else {
79        None
80    };
81    let expr = selection
82        .ok_or_else(|| Error::invalid_input(format!("Expression is not valid: {expr}")))?;
83    Ok(expr.clone())
84}
85
86fn parse_statement(statement: &str) -> Result<Statement> {
87    let dialect = LanceDialect::new();
88
89    // Hack to allow == as equals
90    // This is used to parse PyArrow expressions from strings.
91    // See: https://github.com/sqlparser-rs/sqlparser-rs/pull/815#issuecomment-1450714278
92    let mut tokenizer = Tokenizer::new(&dialect, statement);
93    let mut tokens = Vec::new();
94    let mut token_iter = tokenizer
95        .tokenize()
96        .map_err(|e| {
97            Error::invalid_input(format!("Error tokenizing statement: {statement} ({e})"))
98        })?
99        .into_iter();
100    let mut prev_token = token_iter.next().unwrap();
101    for next_token in token_iter {
102        if let (Token::Eq, Token::Eq) = (&prev_token, &next_token) {
103            continue; // skip second equals
104        }
105        let token = std::mem::replace(&mut prev_token, next_token);
106        tokens.push(token);
107    }
108    tokens.push(prev_token);
109
110    Parser::new(&dialect)
111        .with_tokens(tokens)
112        .parse_statement()
113        .map_err(|e| Error::invalid_input(format!("Error parsing statement: {statement} ({e})")))
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    use datafusion::sql::sqlparser::{
121        ast::{BinaryOperator, Ident, Value, ValueWithSpan},
122        tokenizer::Span,
123    };
124
125    #[test]
126    fn test_double_equal() {
127        let expr = parse_sql_filter("a == b").unwrap();
128        assert_eq!(
129            Expr::BinaryOp {
130                left: Box::new(Expr::Identifier(Ident::new("a"))),
131                op: BinaryOperator::Eq,
132                right: Box::new(Expr::Identifier(Ident::new("b")))
133            },
134            expr
135        );
136    }
137
138    #[test]
139    fn test_like() {
140        let expr = parse_sql_filter("a LIKE 'abc%'").unwrap();
141        assert_eq!(
142            Expr::Like {
143                negated: false,
144                expr: Box::new(Expr::Identifier(Ident::new("a"))),
145                pattern: Box::new(Expr::Value(ValueWithSpan {
146                    value: Value::SingleQuotedString("abc%".to_string()),
147                    span: Span::empty(),
148                })),
149                escape_char: None,
150                any: false,
151            },
152            expr
153        );
154    }
155
156    #[test]
157    fn test_quoted_ident() {
158        // CUBE is a SQL keyword, so it must be quoted.
159        let expr = parse_sql_filter("`a:Test_Something` == `CUBE`").unwrap();
160        assert_eq!(
161            Expr::BinaryOp {
162                left: Box::new(Expr::Identifier(Ident::with_quote('`', "a:Test_Something"))),
163                op: BinaryOperator::Eq,
164                right: Box::new(Expr::Identifier(Ident::with_quote('`', "CUBE")))
165            },
166            expr
167        );
168
169        let expr = parse_sql_filter("`outer field`.`inner field` == 1").unwrap();
170        assert_eq!(
171            Expr::BinaryOp {
172                left: Box::new(Expr::CompoundIdentifier(vec![
173                    Ident::with_quote('`', "outer field"),
174                    Ident::with_quote('`', "inner field")
175                ])),
176                op: BinaryOperator::Eq,
177                right: Box::new(Expr::Value(ValueWithSpan {
178                    value: Value::Number("1".to_string(), false),
179                    span: Span::empty(),
180                })),
181            },
182            expr
183        );
184    }
185}