lance_datafusion/
planner.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Exec plan planner
5
6use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::SessionState;
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29    AggregateUDF, ColumnarValue, GetFieldAccess, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
30    WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
34use datafusion::sql::sqlparser::ast::{
35    AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36    Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Subscript,
37    TimezoneInfo, UnaryOperator, Value,
38};
39use datafusion::{
40    common::Column,
41    logical_expr::{col, Between, BinaryExpr, Like, Operator},
42    physical_expr::execution_props::ExecutionProps,
43    physical_plan::PhysicalExpr,
44    prelude::Expr,
45    scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use snafu::location;
51
52use lance_core::{Error, Result};
53
54#[derive(Debug, Clone)]
55struct CastListF16Udf {
56    signature: Signature,
57}
58
59impl CastListF16Udf {
60    pub fn new() -> Self {
61        Self {
62            signature: Signature::any(1, Volatility::Immutable),
63        }
64    }
65}
66
67impl ScalarUDFImpl for CastListF16Udf {
68    fn as_any(&self) -> &dyn std::any::Any {
69        self
70    }
71
72    fn name(&self) -> &str {
73        "_cast_list_f16"
74    }
75
76    fn signature(&self) -> &Signature {
77        &self.signature
78    }
79
80    fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
81        let input = &arg_types[0];
82        match input {
83            ArrowDataType::FixedSizeList(field, size) => {
84                if field.data_type() != &ArrowDataType::Float32
85                    && field.data_type() != &ArrowDataType::Float16
86                {
87                    return Err(datafusion::error::DataFusionError::Execution(
88                        "cast_list_f16 only supports list of float32 or float16".to_string(),
89                    ));
90                }
91                Ok(ArrowDataType::FixedSizeList(
92                    Arc::new(Field::new(
93                        field.name(),
94                        ArrowDataType::Float16,
95                        field.is_nullable(),
96                    )),
97                    *size,
98                ))
99            }
100            ArrowDataType::List(field) => {
101                if field.data_type() != &ArrowDataType::Float32
102                    && field.data_type() != &ArrowDataType::Float16
103                {
104                    return Err(datafusion::error::DataFusionError::Execution(
105                        "cast_list_f16 only supports list of float32 or float16".to_string(),
106                    ));
107                }
108                Ok(ArrowDataType::List(Arc::new(Field::new(
109                    field.name(),
110                    ArrowDataType::Float16,
111                    field.is_nullable(),
112                ))))
113            }
114            _ => Err(datafusion::error::DataFusionError::Execution(
115                "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
116            )),
117        }
118    }
119
120    fn invoke(&self, args: &[ColumnarValue]) -> DFResult<ColumnarValue> {
121        let ColumnarValue::Array(arr) = &args[0] else {
122            return Err(datafusion::error::DataFusionError::Execution(
123                "cast_list_f16 only supports array arguments".to_string(),
124            ));
125        };
126
127        let to_type = match arr.data_type() {
128            ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
129                Arc::new(Field::new(
130                    field.name(),
131                    ArrowDataType::Float16,
132                    field.is_nullable(),
133                )),
134                *size,
135            ),
136            ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
137                field.name(),
138                ArrowDataType::Float16,
139                field.is_nullable(),
140            ))),
141            _ => {
142                return Err(datafusion::error::DataFusionError::Execution(
143                    "cast_list_f16 only supports array arguments".to_string(),
144                ));
145            }
146        };
147
148        let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
149        Ok(ColumnarValue::Array(res))
150    }
151}
152
153// Adapter that instructs datafusion how lance expects expressions to be interpreted
154struct LanceContextProvider {
155    options: datafusion::config::ConfigOptions,
156    state: SessionState,
157    expr_planners: Vec<Arc<dyn ExprPlanner>>,
158}
159
160impl Default for LanceContextProvider {
161    fn default() -> Self {
162        let config = SessionConfig::new();
163        let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
164        let mut state_builder = SessionStateBuilder::new()
165            .with_config(config)
166            .with_runtime_env(runtime)
167            .with_default_features();
168
169        // SessionState does not expose expr_planners, so we need to get the default ones from
170        // the builder and store them to return from get_expr_planners
171
172        // unwrap safe because with_default_features sets expr_planners
173        let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
174
175        Self {
176            options: ConfigOptions::default(),
177            state: state_builder.build(),
178            expr_planners,
179        }
180    }
181}
182
183impl ContextProvider for LanceContextProvider {
184    fn get_table_source(
185        &self,
186        name: datafusion::sql::TableReference,
187    ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
188        Err(datafusion::error::DataFusionError::NotImplemented(format!(
189            "Attempt to reference inner table {} not supported",
190            name
191        )))
192    }
193
194    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
195        self.state.aggregate_functions().get(name).cloned()
196    }
197
198    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
199        self.state.window_functions().get(name).cloned()
200    }
201
202    fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
203        match f {
204            // TODO: cast should go thru CAST syntax instead of UDF
205            // Going thru UDF makes it hard for the optimizer to find no-ops
206            "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
207            _ => self.state.scalar_functions().get(f).cloned(),
208        }
209    }
210
211    fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
212        // Variables (things like @@LANGUAGE) not supported
213        None
214    }
215
216    fn options(&self) -> &datafusion::config::ConfigOptions {
217        &self.options
218    }
219
220    fn udf_names(&self) -> Vec<String> {
221        self.state.scalar_functions().keys().cloned().collect()
222    }
223
224    fn udaf_names(&self) -> Vec<String> {
225        self.state.aggregate_functions().keys().cloned().collect()
226    }
227
228    fn udwf_names(&self) -> Vec<String> {
229        self.state.window_functions().keys().cloned().collect()
230    }
231
232    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
233        &self.expr_planners
234    }
235}
236
237pub struct Planner {
238    schema: SchemaRef,
239    context_provider: LanceContextProvider,
240}
241
242impl Planner {
243    pub fn new(schema: SchemaRef) -> Self {
244        Self {
245            schema,
246            context_provider: LanceContextProvider::default(),
247        }
248    }
249
250    fn column(idents: &[Ident]) -> Expr {
251        let mut column = col(&idents[0].value);
252        for ident in &idents[1..] {
253            column = Expr::ScalarFunction(ScalarFunction {
254                args: vec![
255                    column,
256                    Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone()))),
257                ],
258                func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
259            });
260        }
261        column
262    }
263
264    fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
265        Ok(match op {
266            BinaryOperator::Plus => Operator::Plus,
267            BinaryOperator::Minus => Operator::Minus,
268            BinaryOperator::Multiply => Operator::Multiply,
269            BinaryOperator::Divide => Operator::Divide,
270            BinaryOperator::Modulo => Operator::Modulo,
271            BinaryOperator::StringConcat => Operator::StringConcat,
272            BinaryOperator::Gt => Operator::Gt,
273            BinaryOperator::Lt => Operator::Lt,
274            BinaryOperator::GtEq => Operator::GtEq,
275            BinaryOperator::LtEq => Operator::LtEq,
276            BinaryOperator::Eq => Operator::Eq,
277            BinaryOperator::NotEq => Operator::NotEq,
278            BinaryOperator::And => Operator::And,
279            BinaryOperator::Or => Operator::Or,
280            _ => {
281                return Err(Error::invalid_input(
282                    format!("Operator {op} is not supported"),
283                    location!(),
284                ));
285            }
286        })
287    }
288
289    fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
290        Ok(Expr::BinaryExpr(BinaryExpr::new(
291            Box::new(self.parse_sql_expr(left)?),
292            self.binary_op(op)?,
293            Box::new(self.parse_sql_expr(right)?),
294        )))
295    }
296
297    fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
298        Ok(match op {
299            UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
300                Expr::Not(Box::new(self.parse_sql_expr(expr)?))
301            }
302
303            UnaryOperator::Minus => {
304                use datafusion::logical_expr::lit;
305                match expr {
306                    SQLExpr::Value(Value::Number(n, _)) => match n.parse::<i64>() {
307                        Ok(n) => lit(-n),
308                        Err(_) => lit(-n
309                            .parse::<f64>()
310                            .map_err(|_e| {
311                                Error::invalid_input(
312                                    format!("negative operator can be only applied to integer and float operands, got: {n}"),
313                                    location!(),
314                                )
315                            })?),
316                    },
317                    _ => {
318                        Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
319                    }
320                }
321            }
322
323            _ => {
324                return Err(Error::invalid_input(
325                    format!("Unary operator '{:?}' is not supported", op),
326                    location!(),
327                ));
328            }
329        })
330    }
331
332    // See datafusion `sqlToRel::parse_sql_number()`
333    fn number(&self, value: &str, negative: bool) -> Result<Expr> {
334        use datafusion::logical_expr::lit;
335        let value: Cow<str> = if negative {
336            Cow::Owned(format!("-{}", value))
337        } else {
338            Cow::Borrowed(value)
339        };
340        if let Ok(n) = value.parse::<i64>() {
341            Ok(lit(n))
342        } else {
343            value.parse::<f64>().map(lit).map_err(|_| {
344                Error::invalid_input(
345                    format!("'{value}' is not supported number value."),
346                    location!(),
347                )
348            })
349        }
350    }
351
352    fn value(&self, value: &Value) -> Result<Expr> {
353        Ok(match value {
354            Value::Number(v, _) => self.number(v.as_str(), false)?,
355            Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
356            Value::HexStringLiteral(hsl) => {
357                Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)))
358            }
359            Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
360            Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v))),
361            Value::Null => Expr::Literal(ScalarValue::Null),
362            _ => todo!(),
363        })
364    }
365
366    fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
367        match func_args {
368            FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
369            _ => Err(Error::invalid_input(
370                format!("Unsupported function args: {:?}", func_args),
371                location!(),
372            )),
373        }
374    }
375
376    // We now use datafusion to parse functions.  This allows us to use datafusion's
377    // entire collection of functions (previously we had just hard-coded support for two functions).
378    //
379    // Unfortunately, one of those two functions was is_valid and the reason we needed it was because
380    // this is a function that comes from duckdb.  Datafusion does not consider is_valid to be a function
381    // but rather an AST node (Expr::IsNotNull) and so we need to handle this case specially.
382    fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
383        match &func.args {
384            FunctionArguments::List(args) => {
385                if func.name.0.len() != 1 {
386                    return Err(Error::invalid_input(
387                        format!("Function name must have 1 part, got: {:?}", func.name.0),
388                        location!(),
389                    ));
390                }
391                Ok(Expr::IsNotNull(Box::new(
392                    self.parse_function_args(&args.args[0])?,
393                )))
394            }
395            _ => Err(Error::invalid_input(
396                format!("Unsupported function args: {:?}", &func.args),
397                location!(),
398            )),
399        }
400    }
401
402    fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
403        if let SQLExpr::Function(function) = &function {
404            if !function.name.0.is_empty() && function.name.0[0].value == "is_valid" {
405                return self.legacy_parse_function(function);
406            }
407        }
408        let sql_to_rel = SqlToRel::new_with_options(
409            &self.context_provider,
410            ParserOptions {
411                parse_float_as_decimal: false,
412                enable_ident_normalization: false,
413                support_varchar_with_length: false,
414                enable_options_value_normalization: false,
415                collect_spans: false,
416            },
417        );
418
419        let mut planner_context = PlannerContext::default();
420        let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
421        Ok(sql_to_rel.sql_to_expr(function, &schema, &mut planner_context)?)
422    }
423
424    fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
425        const SUPPORTED_TYPES: [&str; 13] = [
426            "int [unsigned]",
427            "tinyint [unsigned]",
428            "smallint [unsigned]",
429            "bigint [unsigned]",
430            "float",
431            "double",
432            "string",
433            "binary",
434            "date",
435            "timestamp(precision)",
436            "datetime(precision)",
437            "decimal(precision,scale)",
438            "boolean",
439        ];
440        match data_type {
441            SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
442            SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
443            SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
444            SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
445            SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
446            SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
447            SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
448            SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
449            SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
450            SQLDataType::UnsignedTinyInt(_) => Ok(ArrowDataType::UInt8),
451            SQLDataType::UnsignedSmallInt(_) => Ok(ArrowDataType::UInt16),
452            SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => {
453                Ok(ArrowDataType::UInt32)
454            }
455            SQLDataType::UnsignedBigInt(_) => Ok(ArrowDataType::UInt64),
456            SQLDataType::Date => Ok(ArrowDataType::Date32),
457            SQLDataType::Timestamp(resolution, tz) => {
458                match tz {
459                    TimezoneInfo::None => {}
460                    _ => {
461                        return Err(Error::invalid_input(
462                            "Timezone not supported in timestamp".to_string(),
463                            location!(),
464                        ));
465                    }
466                };
467                let time_unit = match resolution {
468                    // Default to microsecond to match PyArrow
469                    None => TimeUnit::Microsecond,
470                    Some(0) => TimeUnit::Second,
471                    Some(3) => TimeUnit::Millisecond,
472                    Some(6) => TimeUnit::Microsecond,
473                    Some(9) => TimeUnit::Nanosecond,
474                    _ => {
475                        return Err(Error::invalid_input(
476                            format!("Unsupported datetime resolution: {:?}", resolution),
477                            location!(),
478                        ));
479                    }
480                };
481                Ok(ArrowDataType::Timestamp(time_unit, None))
482            }
483            SQLDataType::Datetime(resolution) => {
484                let time_unit = match resolution {
485                    None => TimeUnit::Microsecond,
486                    Some(0) => TimeUnit::Second,
487                    Some(3) => TimeUnit::Millisecond,
488                    Some(6) => TimeUnit::Microsecond,
489                    Some(9) => TimeUnit::Nanosecond,
490                    _ => {
491                        return Err(Error::invalid_input(
492                            format!("Unsupported datetime resolution: {:?}", resolution),
493                            location!(),
494                        ));
495                    }
496                };
497                Ok(ArrowDataType::Timestamp(time_unit, None))
498            }
499            SQLDataType::Decimal(number_info) => match number_info {
500                ExactNumberInfo::PrecisionAndScale(precision, scale) => {
501                    Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
502                }
503                _ => Err(Error::invalid_input(
504                    format!(
505                        "Must provide precision and scale for decimal: {:?}",
506                        number_info
507                    ),
508                    location!(),
509                )),
510            },
511            _ => Err(Error::invalid_input(
512                format!(
513                    "Unsupported data type: {:?}. Supported types: {:?}",
514                    data_type, SUPPORTED_TYPES
515                ),
516                location!(),
517            )),
518        }
519    }
520
521    fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
522        let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
523        for planner in self.context_provider.get_expr_planners() {
524            match planner.plan_field_access(field_access_expr, &df_schema)? {
525                PlannerResult::Planned(expr) => return Ok(expr),
526                PlannerResult::Original(expr) => {
527                    field_access_expr = expr;
528                }
529            }
530        }
531        Err(Error::invalid_input(
532            "Field access could not be planned",
533            location!(),
534        ))
535    }
536
537    fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
538        match expr {
539            SQLExpr::Identifier(id) => {
540                // Users can pass string literals wrapped in `"`.
541                // (Normally SQL only allows single quotes.)
542                if id.quote_style == Some('"') {
543                    Ok(Expr::Literal(ScalarValue::Utf8(Some(id.value.clone()))))
544                // Users can wrap identifiers with ` to reference non-standard
545                // names, such as uppercase or spaces.
546                } else if id.quote_style == Some('`') {
547                    Ok(Expr::Column(Column::from_name(id.value.clone())))
548                } else {
549                    Ok(Self::column(vec![id.clone()].as_slice()))
550                }
551            }
552            SQLExpr::CompoundIdentifier(ids) => Ok(Self::column(ids.as_slice())),
553            SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
554            SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
555            SQLExpr::Value(value) => self.value(value),
556            SQLExpr::Array(SQLArray { elem, .. }) => {
557                let mut values = vec![];
558
559                let array_literal_error = |pos: usize, value: &_| {
560                    Err(Error::invalid_input(
561                        format!(
562                            "Expected a literal value in array, instead got {} at position {}",
563                            value, pos
564                        ),
565                        location!(),
566                    ))
567                };
568
569                for (pos, expr) in elem.iter().enumerate() {
570                    match expr {
571                        SQLExpr::Value(value) => {
572                            if let Expr::Literal(value) = self.value(value)? {
573                                values.push(value);
574                            } else {
575                                return array_literal_error(pos, expr);
576                            }
577                        }
578                        SQLExpr::UnaryOp {
579                            op: UnaryOperator::Minus,
580                            expr,
581                        } => {
582                            if let SQLExpr::Value(Value::Number(number, _)) = expr.as_ref() {
583                                if let Expr::Literal(value) = self.number(number, true)? {
584                                    values.push(value);
585                                } else {
586                                    return array_literal_error(pos, expr);
587                                }
588                            } else {
589                                return array_literal_error(pos, expr);
590                            }
591                        }
592                        _ => {
593                            return array_literal_error(pos, expr);
594                        }
595                    }
596                }
597
598                let field = if !values.is_empty() {
599                    let data_type = values[0].data_type();
600
601                    for value in &mut values {
602                        if value.data_type() != data_type {
603                            *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
604                                format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
605                                location!()
606                            ))?;
607                        }
608                    }
609                    Field::new("item", data_type, true)
610                } else {
611                    Field::new("item", ArrowDataType::Null, true)
612                };
613
614                let values = values
615                    .into_iter()
616                    .map(|v| v.to_array().map_err(Error::from))
617                    .collect::<Result<Vec<_>>>()?;
618                let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
619                let values = concat(&array_refs)?;
620                let values = ListArray::try_new(
621                    field.into(),
622                    OffsetBuffer::from_lengths([values.len()]),
623                    values,
624                    None,
625                )?;
626
627                Ok(Expr::Literal(ScalarValue::List(Arc::new(values))))
628            }
629            // For example, DATE '2020-01-01'
630            SQLExpr::TypedString { data_type, value } => {
631                Ok(Expr::Cast(datafusion::logical_expr::Cast {
632                    expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value.clone())))),
633                    data_type: self.parse_type(data_type)?,
634                }))
635            }
636            SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
637            SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
638            SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
639            SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
640            SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
641            SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
642            SQLExpr::InList {
643                expr,
644                list,
645                negated,
646            } => {
647                let value_expr = self.parse_sql_expr(expr)?;
648                let list_exprs = list
649                    .iter()
650                    .map(|e| self.parse_sql_expr(e))
651                    .collect::<Result<Vec<_>>>()?;
652                Ok(value_expr.in_list(list_exprs, *negated))
653            }
654            SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
655            SQLExpr::Function(_) => self.parse_function(expr.clone()),
656            SQLExpr::ILike {
657                negated,
658                expr,
659                pattern,
660                escape_char,
661                any: _,
662            } => Ok(Expr::Like(Like::new(
663                *negated,
664                Box::new(self.parse_sql_expr(expr)?),
665                Box::new(self.parse_sql_expr(pattern)?),
666                escape_char.as_ref().and_then(|c| c.chars().next()),
667                true,
668            ))),
669            SQLExpr::Like {
670                negated,
671                expr,
672                pattern,
673                escape_char,
674                any: _,
675            } => Ok(Expr::Like(Like::new(
676                *negated,
677                Box::new(self.parse_sql_expr(expr)?),
678                Box::new(self.parse_sql_expr(pattern)?),
679                escape_char.as_ref().and_then(|c| c.chars().next()),
680                false,
681            ))),
682            SQLExpr::Cast {
683                expr, data_type, ..
684            } => Ok(Expr::Cast(datafusion::logical_expr::Cast {
685                expr: Box::new(self.parse_sql_expr(expr)?),
686                data_type: self.parse_type(data_type)?,
687            })),
688            SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
689                "JSON access is not supported",
690                location!(),
691            )),
692            SQLExpr::CompoundFieldAccess { root, access_chain } => {
693                let mut expr = self.parse_sql_expr(root)?;
694
695                for access in access_chain {
696                    let field_access = match access {
697                        // x.y or x['y']
698                        AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
699                        | AccessExpr::Subscript(Subscript::Index {
700                            index:
701                                SQLExpr::Value(
702                                    Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
703                                ),
704                        }) => GetFieldAccess::NamedStructField {
705                            name: ScalarValue::from(s.as_str()),
706                        },
707                        AccessExpr::Subscript(Subscript::Index { index }) => {
708                            let key = Box::new(self.parse_sql_expr(index)?);
709                            GetFieldAccess::ListIndex { key }
710                        }
711                        AccessExpr::Subscript(Subscript::Slice { .. }) => {
712                            return Err(Error::invalid_input(
713                                "Slice subscript is not supported",
714                                location!(),
715                            ));
716                        }
717                        _ => {
718                            // Handle other cases like JSON access
719                            // Note: JSON access is not supported in lance
720                            return Err(Error::invalid_input(
721                                "Only dot notation or index access is supported for field access",
722                                location!(),
723                            ));
724                        }
725                    };
726
727                    let field_access_expr = RawFieldAccessExpr { expr, field_access };
728                    expr = self.plan_field_access(field_access_expr)?;
729                }
730
731                Ok(expr)
732            }
733            SQLExpr::Between {
734                expr,
735                negated,
736                low,
737                high,
738            } => {
739                // Parse the main expression and bounds
740                let expr = self.parse_sql_expr(expr)?;
741                let low = self.parse_sql_expr(low)?;
742                let high = self.parse_sql_expr(high)?;
743
744                let between = Expr::Between(Between::new(
745                    Box::new(expr),
746                    *negated,
747                    Box::new(low),
748                    Box::new(high),
749                ));
750                Ok(between)
751            }
752            _ => Err(Error::invalid_input(
753                format!("Expression '{expr}' is not supported SQL in lance"),
754                location!(),
755            )),
756        }
757    }
758
759    /// Create Logical [Expr] from a SQL filter clause.
760    ///
761    /// Note: the returned expression must be passed through [optimize_expr()]
762    /// before being passed to [create_physical_expr()].
763    pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
764        // Allow sqlparser to parse filter as part of ONE SQL statement.
765        let ast_expr = parse_sql_filter(filter)?;
766        let expr = self.parse_sql_expr(&ast_expr)?;
767        let schema = Schema::try_from(self.schema.as_ref())?;
768        let resolved = resolve_expr(&expr, &schema)?;
769        coerce_filter_type_to_boolean(resolved)
770    }
771
772    /// Create Logical [Expr] from a SQL expression.
773    ///
774    /// Note: the returned expression must be passed through [optimize_filter()]
775    /// before being passed to [create_physical_expr()].
776    pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
777        let ast_expr = parse_sql_expr(expr)?;
778        let expr = self.parse_sql_expr(&ast_expr)?;
779        let schema = Schema::try_from(self.schema.as_ref())?;
780        let resolved = resolve_expr(&expr, &schema)?;
781        Ok(resolved)
782    }
783
784    /// Try to decode bytes from hex literal string.
785    ///
786    /// Copied from datafusion because this is not public.
787    ///
788    /// TODO: use SqlToRel from Datafusion directly?
789    fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
790        let hex_bytes = s.as_bytes();
791        let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
792
793        let start_idx = hex_bytes.len() % 2;
794        if start_idx > 0 {
795            // The first byte is formed of only one char.
796            decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
797        }
798
799        for i in (start_idx..hex_bytes.len()).step_by(2) {
800            let high = Self::try_decode_hex_char(hex_bytes[i])?;
801            let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
802            decoded_bytes.push((high << 4) | low);
803        }
804
805        Some(decoded_bytes)
806    }
807
808    /// Try to decode a byte from a hex char.
809    ///
810    /// None will be returned if the input char is hex-invalid.
811    const fn try_decode_hex_char(c: u8) -> Option<u8> {
812        match c {
813            b'A'..=b'F' => Some(c - b'A' + 10),
814            b'a'..=b'f' => Some(c - b'a' + 10),
815            b'0'..=b'9' => Some(c - b'0'),
816            _ => None,
817        }
818    }
819
820    /// Optimize the filter expression and coerce data types.
821    pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
822        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
823
824        // DataFusion needs the simplify and coerce passes to be applied before
825        // expressions can be handled by the physical planner.
826        let props = ExecutionProps::default();
827        let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
828        let simplifier =
829            datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
830
831        let expr = simplifier.simplify(expr)?;
832        let expr = simplifier.coerce(expr, &df_schema)?;
833
834        Ok(expr)
835    }
836
837    /// Create the [`PhysicalExpr`] from a logical [`Expr`]
838    pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
839        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
840
841        Ok(datafusion::physical_expr::create_physical_expr(
842            expr,
843            df_schema.as_ref(),
844            &Default::default(),
845        )?)
846    }
847
848    /// Collect the columns in the expression.
849    ///
850    /// The columns are returned in sorted order.
851    pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
852        let mut visitor = ColumnCapturingVisitor {
853            current_path: VecDeque::new(),
854            columns: BTreeSet::new(),
855        };
856        expr.visit(&mut visitor).unwrap();
857        visitor.columns.into_iter().collect()
858    }
859}
860
861struct ColumnCapturingVisitor {
862    // Current column path. If this is empty, we are not in a column expression.
863    current_path: VecDeque<String>,
864    columns: BTreeSet<String>,
865}
866
867impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
868    type Node = Expr;
869
870    fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
871        match node {
872            Expr::Column(Column { name, .. }) => {
873                let mut path = name.clone();
874                for part in self.current_path.drain(..) {
875                    path.push('.');
876                    path.push_str(&part);
877                }
878                self.columns.insert(path);
879                self.current_path.clear();
880            }
881            Expr::ScalarFunction(udf) => {
882                if udf.name() == GetFieldFunc::default().name() {
883                    if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
884                        self.current_path.push_front(name.to_string())
885                    } else {
886                        self.current_path.clear();
887                    }
888                } else {
889                    self.current_path.clear();
890                }
891            }
892            _ => {
893                self.current_path.clear();
894            }
895        }
896
897        Ok(TreeNodeRecursion::Continue)
898    }
899}
900
901#[cfg(test)]
902mod tests {
903
904    use crate::logical_expr::ExprExt;
905
906    use super::*;
907
908    use arrow::datatypes::Float64Type;
909    use arrow_array::{
910        ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
911        StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
912        TimestampNanosecondArray, TimestampSecondArray,
913    };
914    use arrow_schema::{DataType, Fields, Schema};
915    use datafusion::{
916        logical_expr::{lit, Cast},
917        prelude::{array_element, get_field},
918    };
919    use datafusion_functions::core::expr_ext::FieldAccessor;
920
921    #[test]
922    fn test_parse_filter_simple() {
923        let schema = Arc::new(Schema::new(vec![
924            Field::new("i", DataType::Int32, false),
925            Field::new("s", DataType::Utf8, true),
926            Field::new(
927                "st",
928                DataType::Struct(Fields::from(vec![
929                    Field::new("x", DataType::Float32, false),
930                    Field::new("y", DataType::Float32, false),
931                ])),
932                true,
933            ),
934        ]));
935
936        let planner = Planner::new(schema.clone());
937
938        let expected = col("i")
939            .gt(lit(3_i32))
940            .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
941            .and(
942                col("s")
943                    .eq(lit("str-4"))
944                    .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
945            );
946
947        // double quotes
948        let expr = planner
949            .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
950            .unwrap();
951        assert_eq!(expr, expected);
952
953        // single quote
954        let expr = planner
955            .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
956            .unwrap();
957
958        let physical_expr = planner.create_physical_expr(&expr).unwrap();
959
960        let batch = RecordBatch::try_new(
961            schema,
962            vec![
963                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
964                Arc::new(StringArray::from_iter_values(
965                    (0..10).map(|v| format!("str-{}", v)),
966                )),
967                Arc::new(StructArray::from(vec![
968                    (
969                        Arc::new(Field::new("x", DataType::Float32, false)),
970                        Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
971                            as ArrayRef,
972                    ),
973                    (
974                        Arc::new(Field::new("y", DataType::Float32, false)),
975                        Arc::new(Float32Array::from_iter_values(
976                            (0..10).map(|v| (v * 10) as f32),
977                        )),
978                    ),
979                ])),
980            ],
981        )
982        .unwrap();
983        let predicates = physical_expr.evaluate(&batch).unwrap();
984        assert_eq!(
985            predicates.into_array(0).unwrap().as_ref(),
986            &BooleanArray::from(vec![
987                false, false, false, false, true, true, false, false, false, false
988            ])
989        );
990    }
991
992    #[test]
993    fn test_nested_col_refs() {
994        let schema = Arc::new(Schema::new(vec![
995            Field::new("s0", DataType::Utf8, true),
996            Field::new(
997                "st",
998                DataType::Struct(Fields::from(vec![
999                    Field::new("s1", DataType::Utf8, true),
1000                    Field::new(
1001                        "st",
1002                        DataType::Struct(Fields::from(vec![Field::new(
1003                            "s2",
1004                            DataType::Utf8,
1005                            true,
1006                        )])),
1007                        true,
1008                    ),
1009                ])),
1010                true,
1011            ),
1012        ]));
1013
1014        let planner = Planner::new(schema);
1015
1016        fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1017            let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1018            assert!(matches!(
1019                expr,
1020                Expr::BinaryExpr(BinaryExpr {
1021                    left: _,
1022                    op: Operator::Eq,
1023                    right: _
1024                })
1025            ));
1026            if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1027                assert_eq!(left.as_ref(), expected);
1028            }
1029        }
1030
1031        let expected = Expr::Column(Column::new_unqualified("s0"));
1032        assert_column_eq(&planner, "s0", &expected);
1033        assert_column_eq(&planner, "`s0`", &expected);
1034
1035        let expected = Expr::ScalarFunction(ScalarFunction {
1036            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1037            args: vec![
1038                Expr::Column(Column::new_unqualified("st")),
1039                Expr::Literal(ScalarValue::Utf8(Some("s1".to_string()))),
1040            ],
1041        });
1042        assert_column_eq(&planner, "st.s1", &expected);
1043        assert_column_eq(&planner, "`st`.`s1`", &expected);
1044        assert_column_eq(&planner, "st.`s1`", &expected);
1045
1046        let expected = Expr::ScalarFunction(ScalarFunction {
1047            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1048            args: vec![
1049                Expr::ScalarFunction(ScalarFunction {
1050                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1051                    args: vec![
1052                        Expr::Column(Column::new_unqualified("st")),
1053                        Expr::Literal(ScalarValue::Utf8(Some("st".to_string()))),
1054                    ],
1055                }),
1056                Expr::Literal(ScalarValue::Utf8(Some("s2".to_string()))),
1057            ],
1058        });
1059
1060        assert_column_eq(&planner, "st.st.s2", &expected);
1061        assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1062        assert_column_eq(&planner, "st.st.`s2`", &expected);
1063        assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1064    }
1065
1066    #[test]
1067    fn test_nested_list_refs() {
1068        let schema = Arc::new(Schema::new(vec![Field::new(
1069            "l",
1070            DataType::List(Arc::new(Field::new(
1071                "item",
1072                DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1073                true,
1074            ))),
1075            true,
1076        )]));
1077
1078        let planner = Planner::new(schema);
1079
1080        let expected = array_element(col("l"), lit(0_i64));
1081        let expr = planner.parse_expr("l[0]").unwrap();
1082        assert_eq!(expr, expected);
1083
1084        let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1085        let expr = planner.parse_expr("l[0]['f1']").unwrap();
1086        assert_eq!(expr, expected);
1087
1088        // FIXME: This should work, but sqlparser doesn't recognize anything
1089        // after the period for some reason.
1090        // let expr = planner.parse_expr("l[0].f1").unwrap();
1091        // assert_eq!(expr, expected);
1092    }
1093
1094    #[test]
1095    fn test_negative_expressions() {
1096        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1097
1098        let planner = Planner::new(schema.clone());
1099
1100        let expected = col("x")
1101            .gt(lit(-3_i64))
1102            .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1103
1104        let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1105
1106        assert_eq!(expr, expected);
1107
1108        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1109
1110        let batch = RecordBatch::try_new(
1111            schema,
1112            vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1113        )
1114        .unwrap();
1115        let predicates = physical_expr.evaluate(&batch).unwrap();
1116        assert_eq!(
1117            predicates.into_array(0).unwrap().as_ref(),
1118            &BooleanArray::from(vec![
1119                false, false, false, true, true, true, true, false, false, false
1120            ])
1121        );
1122    }
1123
1124    #[test]
1125    fn test_negative_array_expressions() {
1126        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1127
1128        let planner = Planner::new(schema);
1129
1130        let expected = Expr::Literal(ScalarValue::List(Arc::new(
1131            ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1132                [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1133            )]),
1134        )));
1135
1136        let expr = planner
1137            .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1138            .unwrap();
1139
1140        assert_eq!(expr, expected);
1141    }
1142
1143    #[test]
1144    fn test_sql_like() {
1145        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1146
1147        let planner = Planner::new(schema.clone());
1148
1149        let expected = col("s").like(lit("str-4"));
1150        // single quote
1151        let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1152        assert_eq!(expr, expected);
1153        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1154
1155        let batch = RecordBatch::try_new(
1156            schema,
1157            vec![Arc::new(StringArray::from_iter_values(
1158                (0..10).map(|v| format!("str-{}", v)),
1159            ))],
1160        )
1161        .unwrap();
1162        let predicates = physical_expr.evaluate(&batch).unwrap();
1163        assert_eq!(
1164            predicates.into_array(0).unwrap().as_ref(),
1165            &BooleanArray::from(vec![
1166                false, false, false, false, true, false, false, false, false, false
1167            ])
1168        );
1169    }
1170
1171    #[test]
1172    fn test_not_like() {
1173        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1174
1175        let planner = Planner::new(schema.clone());
1176
1177        let expected = col("s").not_like(lit("str-4"));
1178        // single quote
1179        let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1180        assert_eq!(expr, expected);
1181        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1182
1183        let batch = RecordBatch::try_new(
1184            schema,
1185            vec![Arc::new(StringArray::from_iter_values(
1186                (0..10).map(|v| format!("str-{}", v)),
1187            ))],
1188        )
1189        .unwrap();
1190        let predicates = physical_expr.evaluate(&batch).unwrap();
1191        assert_eq!(
1192            predicates.into_array(0).unwrap().as_ref(),
1193            &BooleanArray::from(vec![
1194                true, true, true, true, false, true, true, true, true, true
1195            ])
1196        );
1197    }
1198
1199    #[test]
1200    fn test_sql_is_in() {
1201        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1202
1203        let planner = Planner::new(schema.clone());
1204
1205        let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1206        // single quote
1207        let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1208        assert_eq!(expr, expected);
1209        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1210
1211        let batch = RecordBatch::try_new(
1212            schema,
1213            vec![Arc::new(StringArray::from_iter_values(
1214                (0..10).map(|v| format!("str-{}", v)),
1215            ))],
1216        )
1217        .unwrap();
1218        let predicates = physical_expr.evaluate(&batch).unwrap();
1219        assert_eq!(
1220            predicates.into_array(0).unwrap().as_ref(),
1221            &BooleanArray::from(vec![
1222                false, false, false, false, true, true, false, false, false, false
1223            ])
1224        );
1225    }
1226
1227    #[test]
1228    fn test_sql_is_null() {
1229        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1230
1231        let planner = Planner::new(schema.clone());
1232
1233        let expected = col("s").is_null();
1234        let expr = planner.parse_filter("s IS NULL").unwrap();
1235        assert_eq!(expr, expected);
1236        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1237
1238        let batch = RecordBatch::try_new(
1239            schema,
1240            vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1241                if v % 3 == 0 {
1242                    Some(format!("str-{}", v))
1243                } else {
1244                    None
1245                }
1246            })))],
1247        )
1248        .unwrap();
1249        let predicates = physical_expr.evaluate(&batch).unwrap();
1250        assert_eq!(
1251            predicates.into_array(0).unwrap().as_ref(),
1252            &BooleanArray::from(vec![
1253                false, true, true, false, true, true, false, true, true, false
1254            ])
1255        );
1256
1257        let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1258        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1259        let predicates = physical_expr.evaluate(&batch).unwrap();
1260        assert_eq!(
1261            predicates.into_array(0).unwrap().as_ref(),
1262            &BooleanArray::from(vec![
1263                true, false, false, true, false, false, true, false, false, true,
1264            ])
1265        );
1266    }
1267
1268    #[test]
1269    fn test_sql_invert() {
1270        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1271
1272        let planner = Planner::new(schema.clone());
1273
1274        let expr = planner.parse_filter("NOT s").unwrap();
1275        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1276
1277        let batch = RecordBatch::try_new(
1278            schema,
1279            vec![Arc::new(BooleanArray::from_iter(
1280                (0..10).map(|v| Some(v % 3 == 0)),
1281            ))],
1282        )
1283        .unwrap();
1284        let predicates = physical_expr.evaluate(&batch).unwrap();
1285        assert_eq!(
1286            predicates.into_array(0).unwrap().as_ref(),
1287            &BooleanArray::from(vec![
1288                false, true, true, false, true, true, false, true, true, false
1289            ])
1290        );
1291    }
1292
1293    #[test]
1294    fn test_sql_cast() {
1295        let cases = &[
1296            (
1297                "x = cast('2021-01-01 00:00:00' as timestamp)",
1298                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1299            ),
1300            (
1301                "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1302                ArrowDataType::Timestamp(TimeUnit::Second, None),
1303            ),
1304            (
1305                "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1306                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1307            ),
1308            (
1309                "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1310                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1311            ),
1312            ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1313            (
1314                "x = cast('1.238' as decimal(9,3))",
1315                ArrowDataType::Decimal128(9, 3),
1316            ),
1317            ("x = cast(1 as float)", ArrowDataType::Float32),
1318            ("x = cast(1 as double)", ArrowDataType::Float64),
1319            ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1320            ("x = cast(1 as smallint)", ArrowDataType::Int16),
1321            ("x = cast(1 as int)", ArrowDataType::Int32),
1322            ("x = cast(1 as integer)", ArrowDataType::Int32),
1323            ("x = cast(1 as bigint)", ArrowDataType::Int64),
1324            ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1325            ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1326            ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1327            ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1328            ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1329            ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1330            ("x = cast(1 as string)", ArrowDataType::Utf8),
1331        ];
1332
1333        for (sql, expected_data_type) in cases {
1334            let schema = Arc::new(Schema::new(vec![Field::new(
1335                "x",
1336                expected_data_type.clone(),
1337                true,
1338            )]));
1339            let planner = Planner::new(schema.clone());
1340            let expr = planner.parse_filter(sql).unwrap();
1341
1342            // Get the thing after 'cast(` but before ' as'.
1343            let expected_value_str = sql
1344                .split("cast(")
1345                .nth(1)
1346                .unwrap()
1347                .split(" as")
1348                .next()
1349                .unwrap();
1350            // Remove any quote marks
1351            let expected_value_str = expected_value_str.trim_matches('\'');
1352
1353            match expr {
1354                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1355                    Expr::Cast(Cast { expr, data_type }) => {
1356                        match expr.as_ref() {
1357                            Expr::Literal(ScalarValue::Utf8(Some(value_str))) => {
1358                                assert_eq!(value_str, expected_value_str);
1359                            }
1360                            Expr::Literal(ScalarValue::Int64(Some(value))) => {
1361                                assert_eq!(*value, 1);
1362                            }
1363                            _ => panic!("Expected cast to be applied to literal"),
1364                        }
1365                        assert_eq!(data_type, expected_data_type);
1366                    }
1367                    _ => panic!("Expected right to be a cast"),
1368                },
1369                _ => panic!("Expected binary expression"),
1370            }
1371        }
1372    }
1373
1374    #[test]
1375    fn test_sql_literals() {
1376        let cases = &[
1377            (
1378                "x = timestamp '2021-01-01 00:00:00'",
1379                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1380            ),
1381            (
1382                "x = timestamp(0) '2021-01-01 00:00:00'",
1383                ArrowDataType::Timestamp(TimeUnit::Second, None),
1384            ),
1385            (
1386                "x = timestamp(9) '2021-01-01 00:00:00.123'",
1387                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1388            ),
1389            ("x = date '2021-01-01'", ArrowDataType::Date32),
1390            ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1391        ];
1392
1393        for (sql, expected_data_type) in cases {
1394            let schema = Arc::new(Schema::new(vec![Field::new(
1395                "x",
1396                expected_data_type.clone(),
1397                true,
1398            )]));
1399            let planner = Planner::new(schema.clone());
1400            let expr = planner.parse_filter(sql).unwrap();
1401
1402            let expected_value_str = sql.split('\'').nth(1).unwrap();
1403
1404            match expr {
1405                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1406                    Expr::Cast(Cast { expr, data_type }) => {
1407                        match expr.as_ref() {
1408                            Expr::Literal(ScalarValue::Utf8(Some(value_str))) => {
1409                                assert_eq!(value_str, expected_value_str);
1410                            }
1411                            _ => panic!("Expected cast to be applied to literal"),
1412                        }
1413                        assert_eq!(data_type, expected_data_type);
1414                    }
1415                    _ => panic!("Expected right to be a cast"),
1416                },
1417                _ => panic!("Expected binary expression"),
1418            }
1419        }
1420    }
1421
1422    #[test]
1423    fn test_sql_array_literals() {
1424        let cases = [
1425            (
1426                "x = [1, 2, 3]",
1427                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1428            ),
1429            (
1430                "x = [1, 2, 3]",
1431                ArrowDataType::FixedSizeList(
1432                    Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1433                    3,
1434                ),
1435            ),
1436        ];
1437
1438        for (sql, expected_data_type) in cases {
1439            let schema = Arc::new(Schema::new(vec![Field::new(
1440                "x",
1441                expected_data_type.clone(),
1442                true,
1443            )]));
1444            let planner = Planner::new(schema.clone());
1445            let expr = planner.parse_filter(sql).unwrap();
1446            let expr = planner.optimize_expr(expr).unwrap();
1447
1448            match expr {
1449                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1450                    Expr::Literal(value) => {
1451                        assert_eq!(&value.data_type(), &expected_data_type);
1452                    }
1453                    _ => panic!("Expected right to be a literal"),
1454                },
1455                _ => panic!("Expected binary expression"),
1456            }
1457        }
1458    }
1459
1460    #[test]
1461    fn test_sql_between() {
1462        use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1463        use arrow_schema::{DataType, Field, Schema, TimeUnit};
1464        use std::sync::Arc;
1465
1466        let schema = Arc::new(Schema::new(vec![
1467            Field::new("x", DataType::Int32, false),
1468            Field::new("y", DataType::Float64, false),
1469            Field::new(
1470                "ts",
1471                DataType::Timestamp(TimeUnit::Microsecond, None),
1472                false,
1473            ),
1474        ]));
1475
1476        let planner = Planner::new(schema.clone());
1477
1478        // Test integer BETWEEN
1479        let expr = planner
1480            .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1481            .unwrap();
1482        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1483
1484        // Create timestamp array with values representing:
1485        // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
1486        let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
1487        let ts_array = TimestampMicrosecondArray::from_iter_values(
1488            (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
1489        );
1490
1491        let batch = RecordBatch::try_new(
1492            schema,
1493            vec![
1494                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1495                Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1496                Arc::new(ts_array),
1497            ],
1498        )
1499        .unwrap();
1500
1501        let predicates = physical_expr.evaluate(&batch).unwrap();
1502        assert_eq!(
1503            predicates.into_array(0).unwrap().as_ref(),
1504            &BooleanArray::from(vec![
1505                false, false, false, true, true, true, true, true, false, false
1506            ])
1507        );
1508
1509        // Test NOT BETWEEN
1510        let expr = planner
1511            .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1512            .unwrap();
1513        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1514
1515        let predicates = physical_expr.evaluate(&batch).unwrap();
1516        assert_eq!(
1517            predicates.into_array(0).unwrap().as_ref(),
1518            &BooleanArray::from(vec![
1519                true, true, true, false, false, false, false, false, true, true
1520            ])
1521        );
1522
1523        // Test floating point BETWEEN
1524        let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1525        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1526
1527        let predicates = physical_expr.evaluate(&batch).unwrap();
1528        assert_eq!(
1529            predicates.into_array(0).unwrap().as_ref(),
1530            &BooleanArray::from(vec![
1531                false, false, false, true, true, true, true, false, false, false
1532            ])
1533        );
1534
1535        // Test timestamp BETWEEN
1536        let expr = planner
1537            .parse_filter(
1538                "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1539            )
1540            .unwrap();
1541        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1542
1543        let predicates = physical_expr.evaluate(&batch).unwrap();
1544        assert_eq!(
1545            predicates.into_array(0).unwrap().as_ref(),
1546            &BooleanArray::from(vec![
1547                false, false, false, true, true, true, true, true, false, false
1548            ])
1549        );
1550    }
1551
1552    #[test]
1553    fn test_sql_comparison() {
1554        // Create a batch with all data types
1555        let batch: Vec<(&str, ArrayRef)> = vec![
1556            (
1557                "timestamp_s",
1558                Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1559            ),
1560            (
1561                "timestamp_ms",
1562                Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1563            ),
1564            (
1565                "timestamp_us",
1566                Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1567            ),
1568            (
1569                "timestamp_ns",
1570                Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1571            ),
1572        ];
1573        let batch = RecordBatch::try_from_iter(batch).unwrap();
1574
1575        let planner = Planner::new(batch.schema());
1576
1577        // Each expression is meant to select the final 5 rows
1578        let expressions = &[
1579            "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1580            "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1581            "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1582            "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1583        ];
1584
1585        let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1586            std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1587        ));
1588        for expression in expressions {
1589            // convert to physical expression
1590            let logical_expr = planner.parse_filter(expression).unwrap();
1591            let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1592            let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1593
1594            // Evaluate and assert they have correct results
1595            let result = physical_expr.evaluate(&batch).unwrap();
1596            let result = result.into_array(batch.num_rows()).unwrap();
1597            assert_eq!(&expected, &result, "unexpected result for {}", expression);
1598        }
1599    }
1600
1601    #[test]
1602    fn test_columns_in_expr() {
1603        let expr = col("s0").gt(lit("value")).and(
1604            col("st")
1605                .field("st")
1606                .field("s2")
1607                .eq(lit("value"))
1608                .or(col("st")
1609                    .field("s1")
1610                    .in_list(vec![lit("value 1"), lit("value 2")], false)),
1611        );
1612
1613        let columns = Planner::column_names_in_expr(&expr);
1614        assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1615    }
1616
1617    #[test]
1618    fn test_parse_binary_expr() {
1619        let bin_str = "x'616263'";
1620
1621        let schema = Arc::new(Schema::new(vec![Field::new(
1622            "binary",
1623            DataType::Binary,
1624            true,
1625        )]));
1626        let planner = Planner::new(schema);
1627        let expr = planner.parse_expr(bin_str).unwrap();
1628        assert_eq!(
1629            expr,
1630            Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])))
1631        );
1632    }
1633
1634    #[test]
1635    fn test_lance_context_provider_expr_planners() {
1636        let ctx_provider = LanceContextProvider::default();
1637        assert!(!ctx_provider.get_expr_planners().is_empty());
1638    }
1639}