1use std::any::TypeId;
7
8use datafusion::sql::sqlparser::{
9 ast::{Expr, SelectItem, SetExpr, Statement},
10 dialect::{Dialect, GenericDialect},
11 parser::Parser,
12 tokenizer::{Token, Tokenizer},
13};
14
15use lance_core::{Error, Result};
16use snafu::location;
17#[derive(Debug, Default)]
18struct LanceDialect(GenericDialect);
19
20impl LanceDialect {
21 fn new() -> Self {
22 Self(GenericDialect {})
23 }
24}
25
26impl Dialect for LanceDialect {
27 fn dialect(&self) -> TypeId {
28 self.0.dialect()
29 }
30
31 fn is_identifier_start(&self, ch: char) -> bool {
32 self.0.is_identifier_start(ch)
33 }
34
35 fn is_identifier_part(&self, ch: char) -> bool {
36 self.0.is_identifier_part(ch)
37 }
38
39 fn is_delimited_identifier_start(&self, ch: char) -> bool {
40 ch == '`'
41 }
42}
43
44pub(crate) fn parse_sql_filter(filter: &str) -> Result<Expr> {
46 let sql = format!("SELECT 1 FROM t WHERE {filter}");
47 let statement = parse_statement(&sql)?;
48
49 let selection = if let Statement::Query(query) = &statement {
50 if let SetExpr::Select(s) = query.body.as_ref() {
51 s.selection.as_ref()
52 } else {
53 None
54 }
55 } else {
56 None
57 };
58 let expr = selection.ok_or_else(|| {
59 Error::invalid_input(format!("Filter is not valid: {filter}"), location!())
60 })?;
61 Ok(expr.clone())
62}
63
64pub(crate) fn parse_sql_expr(expr: &str) -> Result<Expr> {
67 let sql = format!("SELECT {expr} FROM t");
68 let statement = parse_statement(&sql)?;
69
70 let selection = if let Statement::Query(query) = &statement {
71 if let SetExpr::Select(s) = query.body.as_ref() {
72 if let SelectItem::UnnamedExpr(expr) = &s.projection[0] {
73 Some(expr)
74 } else {
75 None
76 }
77 } else {
78 None
79 }
80 } else {
81 None
82 };
83 let expr = selection.ok_or_else(|| {
84 Error::invalid_input(format!("Expression is not valid: {expr}"), location!())
85 })?;
86 Ok(expr.clone())
87}
88
89fn parse_statement(statement: &str) -> Result<Statement> {
90 let dialect = LanceDialect::new();
91
92 let mut tokenizer = Tokenizer::new(&dialect, statement);
96 let mut tokens = Vec::new();
97 let mut token_iter = tokenizer
98 .tokenize()
99 .map_err(|e| {
100 Error::invalid_input(
101 format!("Error tokenizing statement: {statement} ({e})"),
102 location!(),
103 )
104 })?
105 .into_iter();
106 let mut prev_token = token_iter.next().unwrap();
107 for next_token in token_iter {
108 if let (Token::Eq, Token::Eq) = (&prev_token, &next_token) {
109 continue; }
111 let token = std::mem::replace(&mut prev_token, next_token);
112 tokens.push(token);
113 }
114 tokens.push(prev_token);
115
116 Parser::new(&dialect)
117 .with_tokens(tokens)
118 .parse_statement()
119 .map_err(|e| {
120 Error::invalid_input(
121 format!("Error parsing statement: {statement} ({e})"),
122 location!(),
123 )
124 })
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 use datafusion::sql::sqlparser::{
132 ast::{BinaryOperator, Ident, Value, ValueWithSpan},
133 tokenizer::Span,
134 };
135
136 #[test]
137 fn test_double_equal() {
138 let expr = parse_sql_filter("a == b").unwrap();
139 assert_eq!(
140 Expr::BinaryOp {
141 left: Box::new(Expr::Identifier(Ident::new("a"))),
142 op: BinaryOperator::Eq,
143 right: Box::new(Expr::Identifier(Ident::new("b")))
144 },
145 expr
146 );
147 }
148
149 #[test]
150 fn test_like() {
151 let expr = parse_sql_filter("a LIKE 'abc%'").unwrap();
152 assert_eq!(
153 Expr::Like {
154 negated: false,
155 expr: Box::new(Expr::Identifier(Ident::new("a"))),
156 pattern: Box::new(Expr::Value(ValueWithSpan {
157 value: Value::SingleQuotedString("abc%".to_string()),
158 span: Span::empty(),
159 })),
160 escape_char: None,
161 any: false,
162 },
163 expr
164 );
165 }
166
167 #[test]
168 fn test_quoted_ident() {
169 let expr = parse_sql_filter("`a:Test_Something` == `CUBE`").unwrap();
171 assert_eq!(
172 Expr::BinaryOp {
173 left: Box::new(Expr::Identifier(Ident::with_quote('`', "a:Test_Something"))),
174 op: BinaryOperator::Eq,
175 right: Box::new(Expr::Identifier(Ident::with_quote('`', "CUBE")))
176 },
177 expr
178 );
179
180 let expr = parse_sql_filter("`outer field`.`inner field` == 1").unwrap();
181 assert_eq!(
182 Expr::BinaryOp {
183 left: Box::new(Expr::CompoundIdentifier(vec![
184 Ident::with_quote('`', "outer field"),
185 Ident::with_quote('`', "inner field")
186 ])),
187 op: BinaryOperator::Eq,
188 right: Box::new(Expr::Value(ValueWithSpan {
189 value: Value::Number("1".to_string(), false),
190 span: Span::empty(),
191 })),
192 },
193 expr
194 );
195 }
196}