Skip to main content

lance_datafusion/
sql.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! SQL Parser utility
5
6use std::any::TypeId;
7
8use datafusion::sql::sqlparser::{
9    ast::{Expr, SelectItem, SetExpr, Statement},
10    dialect::{Dialect, GenericDialect},
11    parser::Parser,
12    tokenizer::{Token, Tokenizer},
13};
14
15use lance_core::{Error, Result};
16use snafu::location;
17#[derive(Debug, Default)]
18struct LanceDialect(GenericDialect);
19
20impl LanceDialect {
21    fn new() -> Self {
22        Self(GenericDialect {})
23    }
24}
25
26impl Dialect for LanceDialect {
27    fn dialect(&self) -> TypeId {
28        self.0.dialect()
29    }
30
31    fn is_identifier_start(&self, ch: char) -> bool {
32        self.0.is_identifier_start(ch)
33    }
34
35    fn is_identifier_part(&self, ch: char) -> bool {
36        self.0.is_identifier_part(ch)
37    }
38
39    fn is_delimited_identifier_start(&self, ch: char) -> bool {
40        ch == '`'
41    }
42}
43
44/// Parse sql filter to Expression.
45pub(crate) fn parse_sql_filter(filter: &str) -> Result<Expr> {
46    let sql = format!("SELECT 1 FROM t WHERE {filter}");
47    let statement = parse_statement(&sql)?;
48
49    let selection = if let Statement::Query(query) = &statement {
50        if let SetExpr::Select(s) = query.body.as_ref() {
51            s.selection.as_ref()
52        } else {
53            None
54        }
55    } else {
56        None
57    };
58    let expr = selection.ok_or_else(|| {
59        Error::invalid_input(format!("Filter is not valid: {filter}"), location!())
60    })?;
61    Ok(expr.clone())
62}
63
64/// Parse a SQL expression to Expression. This is more lenient than parse_sql_filter
65/// as it can be used for projection expressions as well.
66pub(crate) fn parse_sql_expr(expr: &str) -> Result<Expr> {
67    let sql = format!("SELECT {expr} FROM t");
68    let statement = parse_statement(&sql)?;
69
70    let selection = if let Statement::Query(query) = &statement {
71        if let SetExpr::Select(s) = query.body.as_ref() {
72            if let SelectItem::UnnamedExpr(expr) = &s.projection[0] {
73                Some(expr)
74            } else {
75                None
76            }
77        } else {
78            None
79        }
80    } else {
81        None
82    };
83    let expr = selection.ok_or_else(|| {
84        Error::invalid_input(format!("Expression is not valid: {expr}"), location!())
85    })?;
86    Ok(expr.clone())
87}
88
89fn parse_statement(statement: &str) -> Result<Statement> {
90    let dialect = LanceDialect::new();
91
92    // Hack to allow == as equals
93    // This is used to parse PyArrow expressions from strings.
94    // See: https://github.com/sqlparser-rs/sqlparser-rs/pull/815#issuecomment-1450714278
95    let mut tokenizer = Tokenizer::new(&dialect, statement);
96    let mut tokens = Vec::new();
97    let mut token_iter = tokenizer
98        .tokenize()
99        .map_err(|e| {
100            Error::invalid_input(
101                format!("Error tokenizing statement: {statement} ({e})"),
102                location!(),
103            )
104        })?
105        .into_iter();
106    let mut prev_token = token_iter.next().unwrap();
107    for next_token in token_iter {
108        if let (Token::Eq, Token::Eq) = (&prev_token, &next_token) {
109            continue; // skip second equals
110        }
111        let token = std::mem::replace(&mut prev_token, next_token);
112        tokens.push(token);
113    }
114    tokens.push(prev_token);
115
116    Parser::new(&dialect)
117        .with_tokens(tokens)
118        .parse_statement()
119        .map_err(|e| {
120            Error::invalid_input(
121                format!("Error parsing statement: {statement} ({e})"),
122                location!(),
123            )
124        })
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    use datafusion::sql::sqlparser::{
132        ast::{BinaryOperator, Ident, Value, ValueWithSpan},
133        tokenizer::Span,
134    };
135
136    #[test]
137    fn test_double_equal() {
138        let expr = parse_sql_filter("a == b").unwrap();
139        assert_eq!(
140            Expr::BinaryOp {
141                left: Box::new(Expr::Identifier(Ident::new("a"))),
142                op: BinaryOperator::Eq,
143                right: Box::new(Expr::Identifier(Ident::new("b")))
144            },
145            expr
146        );
147    }
148
149    #[test]
150    fn test_like() {
151        let expr = parse_sql_filter("a LIKE 'abc%'").unwrap();
152        assert_eq!(
153            Expr::Like {
154                negated: false,
155                expr: Box::new(Expr::Identifier(Ident::new("a"))),
156                pattern: Box::new(Expr::Value(ValueWithSpan {
157                    value: Value::SingleQuotedString("abc%".to_string()),
158                    span: Span::empty(),
159                })),
160                escape_char: None,
161                any: false,
162            },
163            expr
164        );
165    }
166
167    #[test]
168    fn test_quoted_ident() {
169        // CUBE is a SQL keyword, so it must be quoted.
170        let expr = parse_sql_filter("`a:Test_Something` == `CUBE`").unwrap();
171        assert_eq!(
172            Expr::BinaryOp {
173                left: Box::new(Expr::Identifier(Ident::with_quote('`', "a:Test_Something"))),
174                op: BinaryOperator::Eq,
175                right: Box::new(Expr::Identifier(Ident::with_quote('`', "CUBE")))
176            },
177            expr
178        );
179
180        let expr = parse_sql_filter("`outer field`.`inner field` == 1").unwrap();
181        assert_eq!(
182            Expr::BinaryOp {
183                left: Box::new(Expr::CompoundIdentifier(vec![
184                    Ident::with_quote('`', "outer field"),
185                    Ident::with_quote('`', "inner field")
186                ])),
187                op: BinaryOperator::Eq,
188                right: Box::new(Expr::Value(ValueWithSpan {
189                    value: Value::Number("1".to_string(), false),
190                    span: Span::empty(),
191                })),
192            },
193            expr
194        );
195    }
196}