1use std::any::TypeId;
7
8use datafusion::sql::sqlparser::{
9 ast::{Expr, SelectItem, SetExpr, Statement},
10 dialect::{Dialect, GenericDialect},
11 parser::Parser,
12 tokenizer::{Token, Tokenizer},
13};
14
15use lance_core::{Error, Result};
16use snafu::location;
17#[derive(Debug, Default)]
18struct LanceDialect(GenericDialect);
19
20impl LanceDialect {
21 fn new() -> Self {
22 Self(GenericDialect {})
23 }
24}
25
26impl Dialect for LanceDialect {
27 fn dialect(&self) -> TypeId {
28 self.0.dialect()
29 }
30
31 fn is_identifier_start(&self, ch: char) -> bool {
32 self.0.is_identifier_start(ch)
33 }
34
35 fn is_identifier_part(&self, ch: char) -> bool {
36 self.0.is_identifier_part(ch)
37 }
38
39 fn is_delimited_identifier_start(&self, ch: char) -> bool {
40 ch == '`'
41 }
42}
43
44pub(crate) fn parse_sql_filter(filter: &str) -> Result<Expr> {
46 let sql = format!("SELECT 1 FROM t WHERE {filter}");
47 let statement = parse_statement(&sql)?;
48
49 let selection = if let Statement::Query(query) = &statement {
50 if let SetExpr::Select(s) = query.body.as_ref() {
51 s.selection.as_ref()
52 } else {
53 None
54 }
55 } else {
56 None
57 };
58 let expr = selection.ok_or_else(|| {
59 Error::invalid_input(format!("Filter is not valid: {filter}"), location!())
60 })?;
61 Ok(expr.clone())
62}
63
64pub(crate) fn parse_sql_expr(expr: &str) -> Result<Expr> {
67 let sql = format!("SELECT {expr} FROM t");
68 let statement = parse_statement(&sql)?;
69
70 let selection = if let Statement::Query(query) = &statement {
71 if let SetExpr::Select(s) = query.body.as_ref() {
72 if let SelectItem::UnnamedExpr(expr) = &s.projection[0] {
73 Some(expr)
74 } else {
75 None
76 }
77 } else {
78 None
79 }
80 } else {
81 None
82 };
83 let expr = selection
84 .ok_or_else(|| Error::io(format!("Expression is not valid: {expr}"), location!()))?;
85 Ok(expr.clone())
86}
87
88fn parse_statement(statement: &str) -> Result<Statement> {
89 let dialect = LanceDialect::new();
90
91 let mut tokenizer = Tokenizer::new(&dialect, statement);
95 let mut tokens = Vec::new();
96 let mut token_iter = tokenizer
97 .tokenize()
98 .map_err(|e| {
99 Error::invalid_input(
100 format!("Error tokenizing statement: {statement} ({e})"),
101 location!(),
102 )
103 })?
104 .into_iter();
105 let mut prev_token = token_iter.next().unwrap();
106 for next_token in token_iter {
107 if let (Token::Eq, Token::Eq) = (&prev_token, &next_token) {
108 continue; }
110 let token = std::mem::replace(&mut prev_token, next_token);
111 tokens.push(token);
112 }
113 tokens.push(prev_token);
114
115 Parser::new(&dialect)
116 .with_tokens(tokens)
117 .parse_statement()
118 .map_err(|e| {
119 Error::invalid_input(
120 format!("Error parsing statement: {statement} ({e})"),
121 location!(),
122 )
123 })
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 use datafusion::sql::sqlparser::{
131 ast::{BinaryOperator, Ident, Value, ValueWithSpan},
132 tokenizer::Span,
133 };
134
135 #[test]
136 fn test_double_equal() {
137 let expr = parse_sql_filter("a == b").unwrap();
138 assert_eq!(
139 Expr::BinaryOp {
140 left: Box::new(Expr::Identifier(Ident::new("a"))),
141 op: BinaryOperator::Eq,
142 right: Box::new(Expr::Identifier(Ident::new("b")))
143 },
144 expr
145 );
146 }
147
148 #[test]
149 fn test_like() {
150 let expr = parse_sql_filter("a LIKE 'abc%'").unwrap();
151 assert_eq!(
152 Expr::Like {
153 negated: false,
154 expr: Box::new(Expr::Identifier(Ident::new("a"))),
155 pattern: Box::new(Expr::Value(ValueWithSpan {
156 value: Value::SingleQuotedString("abc%".to_string()),
157 span: Span::empty(),
158 })),
159 escape_char: None,
160 any: false,
161 },
162 expr
163 );
164 }
165
166 #[test]
167 fn test_quoted_ident() {
168 let expr = parse_sql_filter("`a:Test_Something` == `CUBE`").unwrap();
170 assert_eq!(
171 Expr::BinaryOp {
172 left: Box::new(Expr::Identifier(Ident::with_quote('`', "a:Test_Something"))),
173 op: BinaryOperator::Eq,
174 right: Box::new(Expr::Identifier(Ident::with_quote('`', "CUBE")))
175 },
176 expr
177 );
178
179 let expr = parse_sql_filter("`outer field`.`inner field` == 1").unwrap();
180 assert_eq!(
181 Expr::BinaryOp {
182 left: Box::new(Expr::CompoundIdentifier(vec![
183 Ident::with_quote('`', "outer field"),
184 Ident::with_quote('`', "inner field")
185 ])),
186 op: BinaryOperator::Eq,
187 right: Box::new(Expr::Value(ValueWithSpan {
188 value: Value::Number("1".to_string(), false),
189 span: Span::empty(),
190 })),
191 },
192 expr
193 );
194 }
195}