1use std::any::TypeId;
7
8use datafusion::sql::sqlparser::{
9 ast::{Expr, SelectItem, SetExpr, Statement},
10 dialect::{Dialect, GenericDialect},
11 parser::Parser,
12 tokenizer::{Token, Tokenizer},
13};
14
15use lance_core::{Error, Result};
16#[derive(Debug, Default)]
17struct LanceDialect(GenericDialect);
18
19impl LanceDialect {
20 fn new() -> Self {
21 Self(GenericDialect {})
22 }
23}
24
25impl Dialect for LanceDialect {
26 fn dialect(&self) -> TypeId {
27 self.0.dialect()
28 }
29
30 fn is_identifier_start(&self, ch: char) -> bool {
31 self.0.is_identifier_start(ch)
32 }
33
34 fn is_identifier_part(&self, ch: char) -> bool {
35 self.0.is_identifier_part(ch)
36 }
37
38 fn is_delimited_identifier_start(&self, ch: char) -> bool {
39 ch == '`'
40 }
41}
42
43pub(crate) fn parse_sql_filter(filter: &str) -> Result<Expr> {
45 let sql = format!("SELECT 1 FROM t WHERE {filter}");
46 let statement = parse_statement(&sql)?;
47
48 let selection = if let Statement::Query(query) = &statement {
49 if let SetExpr::Select(s) = query.body.as_ref() {
50 s.selection.as_ref()
51 } else {
52 None
53 }
54 } else {
55 None
56 };
57 let expr =
58 selection.ok_or_else(|| Error::invalid_input(format!("Filter is not valid: {filter}")))?;
59 Ok(expr.clone())
60}
61
62pub(crate) fn parse_sql_expr(expr: &str) -> Result<Expr> {
65 let sql = format!("SELECT {expr} FROM t");
66 let statement = parse_statement(&sql)?;
67
68 let selection = if let Statement::Query(query) = &statement {
69 if let SetExpr::Select(s) = query.body.as_ref() {
70 if let SelectItem::UnnamedExpr(expr) = &s.projection[0] {
71 Some(expr)
72 } else {
73 None
74 }
75 } else {
76 None
77 }
78 } else {
79 None
80 };
81 let expr = selection
82 .ok_or_else(|| Error::invalid_input(format!("Expression is not valid: {expr}")))?;
83 Ok(expr.clone())
84}
85
86fn parse_statement(statement: &str) -> Result<Statement> {
87 let dialect = LanceDialect::new();
88
89 let mut tokenizer = Tokenizer::new(&dialect, statement);
93 let mut tokens = Vec::new();
94 let mut token_iter = tokenizer
95 .tokenize()
96 .map_err(|e| {
97 Error::invalid_input(format!("Error tokenizing statement: {statement} ({e})"))
98 })?
99 .into_iter();
100 let mut prev_token = token_iter.next().unwrap();
101 for next_token in token_iter {
102 if let (Token::Eq, Token::Eq) = (&prev_token, &next_token) {
103 continue; }
105 let token = std::mem::replace(&mut prev_token, next_token);
106 tokens.push(token);
107 }
108 tokens.push(prev_token);
109
110 Parser::new(&dialect)
111 .with_tokens(tokens)
112 .parse_statement()
113 .map_err(|e| Error::invalid_input(format!("Error parsing statement: {statement} ({e})")))
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 use datafusion::sql::sqlparser::{
121 ast::{BinaryOperator, Ident, Value, ValueWithSpan},
122 tokenizer::Span,
123 };
124
125 #[test]
126 fn test_double_equal() {
127 let expr = parse_sql_filter("a == b").unwrap();
128 assert_eq!(
129 Expr::BinaryOp {
130 left: Box::new(Expr::Identifier(Ident::new("a"))),
131 op: BinaryOperator::Eq,
132 right: Box::new(Expr::Identifier(Ident::new("b")))
133 },
134 expr
135 );
136 }
137
138 #[test]
139 fn test_like() {
140 let expr = parse_sql_filter("a LIKE 'abc%'").unwrap();
141 assert_eq!(
142 Expr::Like {
143 negated: false,
144 expr: Box::new(Expr::Identifier(Ident::new("a"))),
145 pattern: Box::new(Expr::Value(ValueWithSpan {
146 value: Value::SingleQuotedString("abc%".to_string()),
147 span: Span::empty(),
148 })),
149 escape_char: None,
150 any: false,
151 },
152 expr
153 );
154 }
155
156 #[test]
157 fn test_quoted_ident() {
158 let expr = parse_sql_filter("`a:Test_Something` == `CUBE`").unwrap();
160 assert_eq!(
161 Expr::BinaryOp {
162 left: Box::new(Expr::Identifier(Ident::with_quote('`', "a:Test_Something"))),
163 op: BinaryOperator::Eq,
164 right: Box::new(Expr::Identifier(Ident::with_quote('`', "CUBE")))
165 },
166 expr
167 );
168
169 let expr = parse_sql_filter("`outer field`.`inner field` == 1").unwrap();
170 assert_eq!(
171 Expr::BinaryOp {
172 left: Box::new(Expr::CompoundIdentifier(vec![
173 Ident::with_quote('`', "outer field"),
174 Ident::with_quote('`', "inner field")
175 ])),
176 op: BinaryOperator::Eq,
177 right: Box::new(Expr::Value(ValueWithSpan {
178 value: Value::Number("1".to_string(), false),
179 span: Span::empty(),
180 })),
181 },
182 expr
183 );
184 }
185}