lance_datafusion/
planner.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3// SPDX-License-Identifier: Apache-2.0
4// SPDX-FileCopyrightText: Copyright The Lance Authors
5
6//! Exec plan planner
7
8use std::borrow::Cow;
9use std::collections::{BTreeSet, VecDeque};
10use std::sync::Arc;
11
12use crate::expr::safe_coerce_scalar;
13use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
14use crate::sql::{parse_sql_expr, parse_sql_filter};
15use arrow::compute::CastOptions;
16use arrow_array::ListArray;
17use arrow_buffer::OffsetBuffer;
18use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
19use arrow_select::concat::concat;
20use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
21use datafusion::common::DFSchema;
22use datafusion::config::ConfigOptions;
23use datafusion::error::Result as DFResult;
24use datafusion::execution::config::SessionConfig;
25use datafusion::execution::context::SessionState;
26use datafusion::execution::runtime_env::RuntimeEnvBuilder;
27use datafusion::execution::session_state::SessionStateBuilder;
28use datafusion::logical_expr::expr::ScalarFunction;
29use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
30use datafusion::logical_expr::{
31    AggregateUDF, ColumnarValue, GetFieldAccess, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
32    WindowUDF,
33};
34use datafusion::optimizer::simplify_expressions::SimplifyContext;
35use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
36use datafusion::sql::sqlparser::ast::{
37    AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
38    Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Subscript,
39    TimezoneInfo, UnaryOperator, Value,
40};
41use datafusion::{
42    common::Column,
43    logical_expr::{col, Between, BinaryExpr, Like, Operator},
44    physical_expr::execution_props::ExecutionProps,
45    physical_plan::PhysicalExpr,
46    prelude::Expr,
47    scalar::ScalarValue,
48};
49use datafusion_functions::core::getfield::GetFieldFunc;
50use lance_arrow::cast::cast_with_options;
51use lance_core::datatypes::Schema;
52use snafu::location;
53
54use lance_core::{Error, Result};
55
56#[derive(Debug, Clone)]
57struct CastListF16Udf {
58    signature: Signature,
59}
60
61impl CastListF16Udf {
62    pub fn new() -> Self {
63        Self {
64            signature: Signature::any(1, Volatility::Immutable),
65        }
66    }
67}
68
69impl ScalarUDFImpl for CastListF16Udf {
70    fn as_any(&self) -> &dyn std::any::Any {
71        self
72    }
73
74    fn name(&self) -> &str {
75        "_cast_list_f16"
76    }
77
78    fn signature(&self) -> &Signature {
79        &self.signature
80    }
81
82    fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
83        let input = &arg_types[0];
84        match input {
85            ArrowDataType::FixedSizeList(field, size) => {
86                if field.data_type() != &ArrowDataType::Float32
87                    && field.data_type() != &ArrowDataType::Float16
88                {
89                    return Err(datafusion::error::DataFusionError::Execution(
90                        "cast_list_f16 only supports list of float32 or float16".to_string(),
91                    ));
92                }
93                Ok(ArrowDataType::FixedSizeList(
94                    Arc::new(Field::new(
95                        field.name(),
96                        ArrowDataType::Float16,
97                        field.is_nullable(),
98                    )),
99                    *size,
100                ))
101            }
102            ArrowDataType::List(field) => {
103                if field.data_type() != &ArrowDataType::Float32
104                    && field.data_type() != &ArrowDataType::Float16
105                {
106                    return Err(datafusion::error::DataFusionError::Execution(
107                        "cast_list_f16 only supports list of float32 or float16".to_string(),
108                    ));
109                }
110                Ok(ArrowDataType::List(Arc::new(Field::new(
111                    field.name(),
112                    ArrowDataType::Float16,
113                    field.is_nullable(),
114                ))))
115            }
116            _ => Err(datafusion::error::DataFusionError::Execution(
117                "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
118            )),
119        }
120    }
121
122    fn invoke(&self, args: &[ColumnarValue]) -> DFResult<ColumnarValue> {
123        let ColumnarValue::Array(arr) = &args[0] else {
124            return Err(datafusion::error::DataFusionError::Execution(
125                "cast_list_f16 only supports array arguments".to_string(),
126            ));
127        };
128
129        let to_type = match arr.data_type() {
130            ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
131                Arc::new(Field::new(
132                    field.name(),
133                    ArrowDataType::Float16,
134                    field.is_nullable(),
135                )),
136                *size,
137            ),
138            ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
139                field.name(),
140                ArrowDataType::Float16,
141                field.is_nullable(),
142            ))),
143            _ => {
144                return Err(datafusion::error::DataFusionError::Execution(
145                    "cast_list_f16 only supports array arguments".to_string(),
146                ));
147            }
148        };
149
150        let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
151        Ok(ColumnarValue::Array(res))
152    }
153}
154
155// Adapter that instructs datafusion how lance expects expressions to be interpreted
156struct LanceContextProvider {
157    options: datafusion::config::ConfigOptions,
158    state: SessionState,
159    expr_planners: Vec<Arc<dyn ExprPlanner>>,
160}
161
162impl Default for LanceContextProvider {
163    fn default() -> Self {
164        let config = SessionConfig::new();
165        let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
166        let mut state_builder = SessionStateBuilder::new()
167            .with_config(config)
168            .with_runtime_env(runtime)
169            .with_default_features();
170
171        // SessionState does not expose expr_planners, so we need to get the default ones from
172        // the builder and store them to return from get_expr_planners
173
174        // unwrap safe because with_default_features sets expr_planners
175        let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
176
177        Self {
178            options: ConfigOptions::default(),
179            state: state_builder.build(),
180            expr_planners,
181        }
182    }
183}
184
185impl ContextProvider for LanceContextProvider {
186    fn get_table_source(
187        &self,
188        name: datafusion::sql::TableReference,
189    ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
190        Err(datafusion::error::DataFusionError::NotImplemented(format!(
191            "Attempt to reference inner table {} not supported",
192            name
193        )))
194    }
195
196    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
197        self.state.aggregate_functions().get(name).cloned()
198    }
199
200    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
201        self.state.window_functions().get(name).cloned()
202    }
203
204    fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
205        match f {
206            // TODO: cast should go thru CAST syntax instead of UDF
207            // Going thru UDF makes it hard for the optimizer to find no-ops
208            "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
209            _ => self.state.scalar_functions().get(f).cloned(),
210        }
211    }
212
213    fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
214        // Variables (things like @@LANGUAGE) not supported
215        None
216    }
217
218    fn options(&self) -> &datafusion::config::ConfigOptions {
219        &self.options
220    }
221
222    fn udf_names(&self) -> Vec<String> {
223        self.state.scalar_functions().keys().cloned().collect()
224    }
225
226    fn udaf_names(&self) -> Vec<String> {
227        self.state.aggregate_functions().keys().cloned().collect()
228    }
229
230    fn udwf_names(&self) -> Vec<String> {
231        self.state.window_functions().keys().cloned().collect()
232    }
233
234    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
235        &self.expr_planners
236    }
237}
238
239pub struct Planner {
240    schema: SchemaRef,
241    context_provider: LanceContextProvider,
242}
243
244impl Planner {
245    pub fn new(schema: SchemaRef) -> Self {
246        Self {
247            schema,
248            context_provider: LanceContextProvider::default(),
249        }
250    }
251
252    fn column(idents: &[Ident]) -> Expr {
253        let mut column = col(&idents[0].value);
254        for ident in &idents[1..] {
255            column = Expr::ScalarFunction(ScalarFunction {
256                args: vec![
257                    column,
258                    Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone()))),
259                ],
260                func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
261            });
262        }
263        column
264    }
265
266    fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
267        Ok(match op {
268            BinaryOperator::Plus => Operator::Plus,
269            BinaryOperator::Minus => Operator::Minus,
270            BinaryOperator::Multiply => Operator::Multiply,
271            BinaryOperator::Divide => Operator::Divide,
272            BinaryOperator::Modulo => Operator::Modulo,
273            BinaryOperator::StringConcat => Operator::StringConcat,
274            BinaryOperator::Gt => Operator::Gt,
275            BinaryOperator::Lt => Operator::Lt,
276            BinaryOperator::GtEq => Operator::GtEq,
277            BinaryOperator::LtEq => Operator::LtEq,
278            BinaryOperator::Eq => Operator::Eq,
279            BinaryOperator::NotEq => Operator::NotEq,
280            BinaryOperator::And => Operator::And,
281            BinaryOperator::Or => Operator::Or,
282            _ => {
283                return Err(Error::invalid_input(
284                    format!("Operator {op} is not supported"),
285                    location!(),
286                ));
287            }
288        })
289    }
290
291    fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
292        Ok(Expr::BinaryExpr(BinaryExpr::new(
293            Box::new(self.parse_sql_expr(left)?),
294            self.binary_op(op)?,
295            Box::new(self.parse_sql_expr(right)?),
296        )))
297    }
298
299    fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
300        Ok(match op {
301            UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
302                Expr::Not(Box::new(self.parse_sql_expr(expr)?))
303            }
304
305            UnaryOperator::Minus => {
306                use datafusion::logical_expr::lit;
307                match expr {
308                    SQLExpr::Value(Value::Number(n, _)) => match n.parse::<i64>() {
309                        Ok(n) => lit(-n),
310                        Err(_) => lit(-n
311                            .parse::<f64>()
312                            .map_err(|_e| {
313                                Error::invalid_input(
314                                    format!("negative operator can be only applied to integer and float operands, got: {n}"),
315                                    location!(),
316                                )
317                            })?),
318                    },
319                    _ => {
320                        Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
321                    }
322                }
323            }
324
325            _ => {
326                return Err(Error::invalid_input(
327                    format!("Unary operator '{:?}' is not supported", op),
328                    location!(),
329                ));
330            }
331        })
332    }
333
334    // See datafusion `sqlToRel::parse_sql_number()`
335    fn number(&self, value: &str, negative: bool) -> Result<Expr> {
336        use datafusion::logical_expr::lit;
337        let value: Cow<str> = if negative {
338            Cow::Owned(format!("-{}", value))
339        } else {
340            Cow::Borrowed(value)
341        };
342        if let Ok(n) = value.parse::<i64>() {
343            Ok(lit(n))
344        } else {
345            value.parse::<f64>().map(lit).map_err(|_| {
346                Error::invalid_input(
347                    format!("'{value}' is not supported number value."),
348                    location!(),
349                )
350            })
351        }
352    }
353
354    fn value(&self, value: &Value) -> Result<Expr> {
355        Ok(match value {
356            Value::Number(v, _) => self.number(v.as_str(), false)?,
357            Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
358            Value::HexStringLiteral(hsl) => {
359                Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)))
360            }
361            Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
362            Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v))),
363            Value::Null => Expr::Literal(ScalarValue::Null),
364            _ => todo!(),
365        })
366    }
367
368    fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
369        match func_args {
370            FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
371            _ => Err(Error::invalid_input(
372                format!("Unsupported function args: {:?}", func_args),
373                location!(),
374            )),
375        }
376    }
377
378    // We now use datafusion to parse functions.  This allows us to use datafusion's
379    // entire collection of functions (previously we had just hard-coded support for two functions).
380    //
381    // Unfortunately, one of those two functions was is_valid and the reason we needed it was because
382    // this is a function that comes from duckdb.  Datafusion does not consider is_valid to be a function
383    // but rather an AST node (Expr::IsNotNull) and so we need to handle this case specially.
384    fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
385        match &func.args {
386            FunctionArguments::List(args) => {
387                if func.name.0.len() != 1 {
388                    return Err(Error::invalid_input(
389                        format!("Function name must have 1 part, got: {:?}", func.name.0),
390                        location!(),
391                    ));
392                }
393                Ok(Expr::IsNotNull(Box::new(
394                    self.parse_function_args(&args.args[0])?,
395                )))
396            }
397            _ => Err(Error::invalid_input(
398                format!("Unsupported function args: {:?}", &func.args),
399                location!(),
400            )),
401        }
402    }
403
404    fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
405        if let SQLExpr::Function(function) = &function {
406            if !function.name.0.is_empty() && function.name.0[0].value == "is_valid" {
407                return self.legacy_parse_function(function);
408            }
409        }
410        let sql_to_rel = SqlToRel::new_with_options(
411            &self.context_provider,
412            ParserOptions {
413                parse_float_as_decimal: false,
414                enable_ident_normalization: false,
415                support_varchar_with_length: false,
416                enable_options_value_normalization: false,
417                collect_spans: false,
418            },
419        );
420
421        let mut planner_context = PlannerContext::default();
422        let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
423        Ok(sql_to_rel.sql_to_expr(function, &schema, &mut planner_context)?)
424    }
425
426    fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
427        const SUPPORTED_TYPES: [&str; 13] = [
428            "int [unsigned]",
429            "tinyint [unsigned]",
430            "smallint [unsigned]",
431            "bigint [unsigned]",
432            "float",
433            "double",
434            "string",
435            "binary",
436            "date",
437            "timestamp(precision)",
438            "datetime(precision)",
439            "decimal(precision,scale)",
440            "boolean",
441        ];
442        match data_type {
443            SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
444            SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
445            SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
446            SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
447            SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
448            SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
449            SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
450            SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
451            SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
452            SQLDataType::UnsignedTinyInt(_) => Ok(ArrowDataType::UInt8),
453            SQLDataType::UnsignedSmallInt(_) => Ok(ArrowDataType::UInt16),
454            SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => {
455                Ok(ArrowDataType::UInt32)
456            }
457            SQLDataType::UnsignedBigInt(_) => Ok(ArrowDataType::UInt64),
458            SQLDataType::Date => Ok(ArrowDataType::Date32),
459            SQLDataType::Timestamp(resolution, tz) => {
460                match tz {
461                    TimezoneInfo::None => {}
462                    _ => {
463                        return Err(Error::invalid_input(
464                            "Timezone not supported in timestamp".to_string(),
465                            location!(),
466                        ));
467                    }
468                };
469                let time_unit = match resolution {
470                    // Default to microsecond to match PyArrow
471                    None => TimeUnit::Microsecond,
472                    Some(0) => TimeUnit::Second,
473                    Some(3) => TimeUnit::Millisecond,
474                    Some(6) => TimeUnit::Microsecond,
475                    Some(9) => TimeUnit::Nanosecond,
476                    _ => {
477                        return Err(Error::invalid_input(
478                            format!("Unsupported datetime resolution: {:?}", resolution),
479                            location!(),
480                        ));
481                    }
482                };
483                Ok(ArrowDataType::Timestamp(time_unit, None))
484            }
485            SQLDataType::Datetime(resolution) => {
486                let time_unit = match resolution {
487                    None => TimeUnit::Microsecond,
488                    Some(0) => TimeUnit::Second,
489                    Some(3) => TimeUnit::Millisecond,
490                    Some(6) => TimeUnit::Microsecond,
491                    Some(9) => TimeUnit::Nanosecond,
492                    _ => {
493                        return Err(Error::invalid_input(
494                            format!("Unsupported datetime resolution: {:?}", resolution),
495                            location!(),
496                        ));
497                    }
498                };
499                Ok(ArrowDataType::Timestamp(time_unit, None))
500            }
501            SQLDataType::Decimal(number_info) => match number_info {
502                ExactNumberInfo::PrecisionAndScale(precision, scale) => {
503                    Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
504                }
505                _ => Err(Error::invalid_input(
506                    format!(
507                        "Must provide precision and scale for decimal: {:?}",
508                        number_info
509                    ),
510                    location!(),
511                )),
512            },
513            _ => Err(Error::invalid_input(
514                format!(
515                    "Unsupported data type: {:?}. Supported types: {:?}",
516                    data_type, SUPPORTED_TYPES
517                ),
518                location!(),
519            )),
520        }
521    }
522
523    fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
524        let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
525        for planner in self.context_provider.get_expr_planners() {
526            match planner.plan_field_access(field_access_expr, &df_schema)? {
527                PlannerResult::Planned(expr) => return Ok(expr),
528                PlannerResult::Original(expr) => {
529                    field_access_expr = expr;
530                }
531            }
532        }
533        Err(Error::invalid_input(
534            "Field access could not be planned",
535            location!(),
536        ))
537    }
538
539    fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
540        match expr {
541            SQLExpr::Identifier(id) => {
542                // Users can pass string literals wrapped in `"`.
543                // (Normally SQL only allows single quotes.)
544                if id.quote_style == Some('"') {
545                    Ok(Expr::Literal(ScalarValue::Utf8(Some(id.value.clone()))))
546                // Users can wrap identifiers with ` to reference non-standard
547                // names, such as uppercase or spaces.
548                } else if id.quote_style == Some('`') {
549                    Ok(Expr::Column(Column::from_name(id.value.clone())))
550                } else {
551                    Ok(Self::column(vec![id.clone()].as_slice()))
552                }
553            }
554            SQLExpr::CompoundIdentifier(ids) => Ok(Self::column(ids.as_slice())),
555            SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
556            SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
557            SQLExpr::Value(value) => self.value(value),
558            SQLExpr::Array(SQLArray { elem, .. }) => {
559                let mut values = vec![];
560
561                let array_literal_error = |pos: usize, value: &_| {
562                    Err(Error::invalid_input(
563                        format!(
564                            "Expected a literal value in array, instead got {} at position {}",
565                            value, pos
566                        ),
567                        location!(),
568                    ))
569                };
570
571                for (pos, expr) in elem.iter().enumerate() {
572                    match expr {
573                        SQLExpr::Value(value) => {
574                            if let Expr::Literal(value) = self.value(value)? {
575                                values.push(value);
576                            } else {
577                                return array_literal_error(pos, expr);
578                            }
579                        }
580                        SQLExpr::UnaryOp {
581                            op: UnaryOperator::Minus,
582                            expr,
583                        } => {
584                            if let SQLExpr::Value(Value::Number(number, _)) = expr.as_ref() {
585                                if let Expr::Literal(value) = self.number(number, true)? {
586                                    values.push(value);
587                                } else {
588                                    return array_literal_error(pos, expr);
589                                }
590                            } else {
591                                return array_literal_error(pos, expr);
592                            }
593                        }
594                        _ => {
595                            return array_literal_error(pos, expr);
596                        }
597                    }
598                }
599
600                let field = if !values.is_empty() {
601                    let data_type = values[0].data_type();
602
603                    for value in &mut values {
604                        if value.data_type() != data_type {
605                            *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
606                                format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
607                                location!()
608                            ))?;
609                        }
610                    }
611                    Field::new("item", data_type, true)
612                } else {
613                    Field::new("item", ArrowDataType::Null, true)
614                };
615
616                let values = values
617                    .into_iter()
618                    .map(|v| v.to_array().map_err(Error::from))
619                    .collect::<Result<Vec<_>>>()?;
620                let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
621                let values = concat(&array_refs)?;
622                let values = ListArray::try_new(
623                    field.into(),
624                    OffsetBuffer::from_lengths([values.len()]),
625                    values,
626                    None,
627                )?;
628
629                Ok(Expr::Literal(ScalarValue::List(Arc::new(values))))
630            }
631            // For example, DATE '2020-01-01'
632            SQLExpr::TypedString { data_type, value } => {
633                Ok(Expr::Cast(datafusion::logical_expr::Cast {
634                    expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value.clone())))),
635                    data_type: self.parse_type(data_type)?,
636                }))
637            }
638            SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
639            SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
640            SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
641            SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
642            SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
643            SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
644            SQLExpr::InList {
645                expr,
646                list,
647                negated,
648            } => {
649                let value_expr = self.parse_sql_expr(expr)?;
650                let list_exprs = list
651                    .iter()
652                    .map(|e| self.parse_sql_expr(e))
653                    .collect::<Result<Vec<_>>>()?;
654                Ok(value_expr.in_list(list_exprs, *negated))
655            }
656            SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
657            SQLExpr::Function(_) => self.parse_function(expr.clone()),
658            SQLExpr::ILike {
659                negated,
660                expr,
661                pattern,
662                escape_char,
663                any: _,
664            } => Ok(Expr::Like(Like::new(
665                *negated,
666                Box::new(self.parse_sql_expr(expr)?),
667                Box::new(self.parse_sql_expr(pattern)?),
668                escape_char.as_ref().and_then(|c| c.chars().next()),
669                true,
670            ))),
671            SQLExpr::Like {
672                negated,
673                expr,
674                pattern,
675                escape_char,
676                any: _,
677            } => Ok(Expr::Like(Like::new(
678                *negated,
679                Box::new(self.parse_sql_expr(expr)?),
680                Box::new(self.parse_sql_expr(pattern)?),
681                escape_char.as_ref().and_then(|c| c.chars().next()),
682                false,
683            ))),
684            SQLExpr::Cast {
685                expr, data_type, ..
686            } => Ok(Expr::Cast(datafusion::logical_expr::Cast {
687                expr: Box::new(self.parse_sql_expr(expr)?),
688                data_type: self.parse_type(data_type)?,
689            })),
690            SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
691                "JSON access is not supported",
692                location!(),
693            )),
694            SQLExpr::CompoundFieldAccess { root, access_chain } => {
695                let mut expr = self.parse_sql_expr(root)?;
696
697                for access in access_chain {
698                    let field_access = match access {
699                        // x.y or x['y']
700                        AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
701                        | AccessExpr::Subscript(Subscript::Index {
702                            index:
703                                SQLExpr::Value(
704                                    Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
705                                ),
706                        }) => GetFieldAccess::NamedStructField {
707                            name: ScalarValue::from(s.as_str()),
708                        },
709                        AccessExpr::Subscript(Subscript::Index { index }) => {
710                            let key = Box::new(self.parse_sql_expr(index)?);
711                            GetFieldAccess::ListIndex { key }
712                        }
713                        AccessExpr::Subscript(Subscript::Slice { .. }) => {
714                            return Err(Error::invalid_input(
715                                "Slice subscript is not supported",
716                                location!(),
717                            ));
718                        }
719                        _ => {
720                            // Handle other cases like JSON access
721                            // Note: JSON access is not supported in lance
722                            return Err(Error::invalid_input(
723                                "Only dot notation or index access is supported for field access",
724                                location!(),
725                            ));
726                        }
727                    };
728
729                    let field_access_expr = RawFieldAccessExpr { expr, field_access };
730                    expr = self.plan_field_access(field_access_expr)?;
731                }
732
733                Ok(expr)
734            }
735            SQLExpr::Between {
736                expr,
737                negated,
738                low,
739                high,
740            } => {
741                // Parse the main expression and bounds
742                let expr = self.parse_sql_expr(expr)?;
743                let low = self.parse_sql_expr(low)?;
744                let high = self.parse_sql_expr(high)?;
745
746                let between = Expr::Between(Between::new(
747                    Box::new(expr),
748                    *negated,
749                    Box::new(low),
750                    Box::new(high),
751                ));
752                Ok(between)
753            }
754            _ => Err(Error::invalid_input(
755                format!("Expression '{expr}' is not supported SQL in lance"),
756                location!(),
757            )),
758        }
759    }
760
761    /// Create Logical [Expr] from a SQL filter clause.
762    ///
763    /// Note: the returned expression must be passed through [optimize_expr()]
764    /// before being passed to [create_physical_expr()].
765    pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
766        // Allow sqlparser to parse filter as part of ONE SQL statement.
767        let ast_expr = parse_sql_filter(filter)?;
768        let expr = self.parse_sql_expr(&ast_expr)?;
769        let schema = Schema::try_from(self.schema.as_ref())?;
770        let resolved = resolve_expr(&expr, &schema)?;
771        coerce_filter_type_to_boolean(resolved)
772    }
773
774    /// Create Logical [Expr] from a SQL expression.
775    ///
776    /// Note: the returned expression must be passed through [optimize_filter()]
777    /// before being passed to [create_physical_expr()].
778    pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
779        let ast_expr = parse_sql_expr(expr)?;
780        let expr = self.parse_sql_expr(&ast_expr)?;
781        let schema = Schema::try_from(self.schema.as_ref())?;
782        let resolved = resolve_expr(&expr, &schema)?;
783        Ok(resolved)
784    }
785
786    /// Try to decode bytes from hex literal string.
787    ///
788    /// Copied from datafusion because this is not public.
789    ///
790    /// TODO: use SqlToRel from Datafusion directly?
791    fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
792        let hex_bytes = s.as_bytes();
793        let mut decoded_bytes = Vec::with_capacity((hex_bytes.len() + 1) / 2);
794
795        let start_idx = hex_bytes.len() % 2;
796        if start_idx > 0 {
797            // The first byte is formed of only one char.
798            decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
799        }
800
801        for i in (start_idx..hex_bytes.len()).step_by(2) {
802            let high = Self::try_decode_hex_char(hex_bytes[i])?;
803            let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
804            decoded_bytes.push((high << 4) | low);
805        }
806
807        Some(decoded_bytes)
808    }
809
810    /// Try to decode a byte from a hex char.
811    ///
812    /// None will be returned if the input char is hex-invalid.
813    const fn try_decode_hex_char(c: u8) -> Option<u8> {
814        match c {
815            b'A'..=b'F' => Some(c - b'A' + 10),
816            b'a'..=b'f' => Some(c - b'a' + 10),
817            b'0'..=b'9' => Some(c - b'0'),
818            _ => None,
819        }
820    }
821
822    /// Optimize the filter expression and coerce data types.
823    pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
824        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
825
826        // DataFusion needs the simplify and coerce passes to be applied before
827        // expressions can be handled by the physical planner.
828        let props = ExecutionProps::default();
829        let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
830        let simplifier =
831            datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
832
833        let expr = simplifier.simplify(expr)?;
834        let expr = simplifier.coerce(expr, &df_schema)?;
835
836        Ok(expr)
837    }
838
839    /// Create the [`PhysicalExpr`] from a logical [`Expr`]
840    pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
841        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
842
843        Ok(datafusion::physical_expr::create_physical_expr(
844            expr,
845            df_schema.as_ref(),
846            &Default::default(),
847        )?)
848    }
849
850    /// Collect the columns in the expression.
851    ///
852    /// The columns are returned in sorted order.
853    pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
854        let mut visitor = ColumnCapturingVisitor {
855            current_path: VecDeque::new(),
856            columns: BTreeSet::new(),
857        };
858        expr.visit(&mut visitor).unwrap();
859        visitor.columns.into_iter().collect()
860    }
861}
862
863struct ColumnCapturingVisitor {
864    // Current column path. If this is empty, we are not in a column expression.
865    current_path: VecDeque<String>,
866    columns: BTreeSet<String>,
867}
868
869impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
870    type Node = Expr;
871
872    fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
873        match node {
874            Expr::Column(Column { name, .. }) => {
875                let mut path = name.clone();
876                for part in self.current_path.drain(..) {
877                    path.push('.');
878                    path.push_str(&part);
879                }
880                self.columns.insert(path);
881                self.current_path.clear();
882            }
883            Expr::ScalarFunction(udf) => {
884                if udf.name() == GetFieldFunc::default().name() {
885                    if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
886                        self.current_path.push_front(name.to_string())
887                    } else {
888                        self.current_path.clear();
889                    }
890                } else {
891                    self.current_path.clear();
892                }
893            }
894            _ => {
895                self.current_path.clear();
896            }
897        }
898
899        Ok(TreeNodeRecursion::Continue)
900    }
901}
902
903#[cfg(test)]
904mod tests {
905
906    use crate::logical_expr::ExprExt;
907
908    use super::*;
909
910    use arrow::datatypes::Float64Type;
911    use arrow_array::{
912        ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
913        StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
914        TimestampNanosecondArray, TimestampSecondArray,
915    };
916    use arrow_schema::{DataType, Fields, Schema};
917    use datafusion::{
918        logical_expr::{lit, Cast},
919        prelude::{array_element, get_field},
920    };
921    use datafusion_functions::core::expr_ext::FieldAccessor;
922
923    #[test]
924    fn test_parse_filter_simple() {
925        let schema = Arc::new(Schema::new(vec![
926            Field::new("i", DataType::Int32, false),
927            Field::new("s", DataType::Utf8, true),
928            Field::new(
929                "st",
930                DataType::Struct(Fields::from(vec![
931                    Field::new("x", DataType::Float32, false),
932                    Field::new("y", DataType::Float32, false),
933                ])),
934                true,
935            ),
936        ]));
937
938        let planner = Planner::new(schema.clone());
939
940        let expected = col("i")
941            .gt(lit(3_i32))
942            .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
943            .and(
944                col("s")
945                    .eq(lit("str-4"))
946                    .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
947            );
948
949        // double quotes
950        let expr = planner
951            .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
952            .unwrap();
953        assert_eq!(expr, expected);
954
955        // single quote
956        let expr = planner
957            .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
958            .unwrap();
959
960        let physical_expr = planner.create_physical_expr(&expr).unwrap();
961
962        let batch = RecordBatch::try_new(
963            schema,
964            vec![
965                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
966                Arc::new(StringArray::from_iter_values(
967                    (0..10).map(|v| format!("str-{}", v)),
968                )),
969                Arc::new(StructArray::from(vec![
970                    (
971                        Arc::new(Field::new("x", DataType::Float32, false)),
972                        Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
973                            as ArrayRef,
974                    ),
975                    (
976                        Arc::new(Field::new("y", DataType::Float32, false)),
977                        Arc::new(Float32Array::from_iter_values(
978                            (0..10).map(|v| (v * 10) as f32),
979                        )),
980                    ),
981                ])),
982            ],
983        )
984        .unwrap();
985        let predicates = physical_expr.evaluate(&batch).unwrap();
986        assert_eq!(
987            predicates.into_array(0).unwrap().as_ref(),
988            &BooleanArray::from(vec![
989                false, false, false, false, true, true, false, false, false, false
990            ])
991        );
992    }
993
994    #[test]
995    fn test_nested_col_refs() {
996        let schema = Arc::new(Schema::new(vec![
997            Field::new("s0", DataType::Utf8, true),
998            Field::new(
999                "st",
1000                DataType::Struct(Fields::from(vec![
1001                    Field::new("s1", DataType::Utf8, true),
1002                    Field::new(
1003                        "st",
1004                        DataType::Struct(Fields::from(vec![Field::new(
1005                            "s2",
1006                            DataType::Utf8,
1007                            true,
1008                        )])),
1009                        true,
1010                    ),
1011                ])),
1012                true,
1013            ),
1014        ]));
1015
1016        let planner = Planner::new(schema);
1017
1018        fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1019            let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1020            assert!(matches!(
1021                expr,
1022                Expr::BinaryExpr(BinaryExpr {
1023                    left: _,
1024                    op: Operator::Eq,
1025                    right: _
1026                })
1027            ));
1028            if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1029                assert_eq!(left.as_ref(), expected);
1030            }
1031        }
1032
1033        let expected = Expr::Column(Column::new_unqualified("s0"));
1034        assert_column_eq(&planner, "s0", &expected);
1035        assert_column_eq(&planner, "`s0`", &expected);
1036
1037        let expected = Expr::ScalarFunction(ScalarFunction {
1038            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1039            args: vec![
1040                Expr::Column(Column::new_unqualified("st")),
1041                Expr::Literal(ScalarValue::Utf8(Some("s1".to_string()))),
1042            ],
1043        });
1044        assert_column_eq(&planner, "st.s1", &expected);
1045        assert_column_eq(&planner, "`st`.`s1`", &expected);
1046        assert_column_eq(&planner, "st.`s1`", &expected);
1047
1048        let expected = Expr::ScalarFunction(ScalarFunction {
1049            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1050            args: vec![
1051                Expr::ScalarFunction(ScalarFunction {
1052                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1053                    args: vec![
1054                        Expr::Column(Column::new_unqualified("st")),
1055                        Expr::Literal(ScalarValue::Utf8(Some("st".to_string()))),
1056                    ],
1057                }),
1058                Expr::Literal(ScalarValue::Utf8(Some("s2".to_string()))),
1059            ],
1060        });
1061
1062        assert_column_eq(&planner, "st.st.s2", &expected);
1063        assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1064        assert_column_eq(&planner, "st.st.`s2`", &expected);
1065        assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1066    }
1067
1068    #[test]
1069    fn test_nested_list_refs() {
1070        let schema = Arc::new(Schema::new(vec![Field::new(
1071            "l",
1072            DataType::List(Arc::new(Field::new(
1073                "item",
1074                DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1075                true,
1076            ))),
1077            true,
1078        )]));
1079
1080        let planner = Planner::new(schema);
1081
1082        let expected = array_element(col("l"), lit(0_i64));
1083        let expr = planner.parse_expr("l[0]").unwrap();
1084        assert_eq!(expr, expected);
1085
1086        let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1087        let expr = planner.parse_expr("l[0]['f1']").unwrap();
1088        assert_eq!(expr, expected);
1089
1090        // FIXME: This should work, but sqlparser doesn't recognize anything
1091        // after the period for some reason.
1092        // let expr = planner.parse_expr("l[0].f1").unwrap();
1093        // assert_eq!(expr, expected);
1094    }
1095
1096    #[test]
1097    fn test_negative_expressions() {
1098        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1099
1100        let planner = Planner::new(schema.clone());
1101
1102        let expected = col("x")
1103            .gt(lit(-3_i64))
1104            .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1105
1106        let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1107
1108        assert_eq!(expr, expected);
1109
1110        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1111
1112        let batch = RecordBatch::try_new(
1113            schema,
1114            vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1115        )
1116        .unwrap();
1117        let predicates = physical_expr.evaluate(&batch).unwrap();
1118        assert_eq!(
1119            predicates.into_array(0).unwrap().as_ref(),
1120            &BooleanArray::from(vec![
1121                false, false, false, true, true, true, true, false, false, false
1122            ])
1123        );
1124    }
1125
1126    #[test]
1127    fn test_negative_array_expressions() {
1128        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1129
1130        let planner = Planner::new(schema);
1131
1132        let expected = Expr::Literal(ScalarValue::List(Arc::new(
1133            ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1134                [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1135            )]),
1136        )));
1137
1138        let expr = planner
1139            .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1140            .unwrap();
1141
1142        assert_eq!(expr, expected);
1143    }
1144
1145    #[test]
1146    fn test_sql_like() {
1147        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1148
1149        let planner = Planner::new(schema.clone());
1150
1151        let expected = col("s").like(lit("str-4"));
1152        // single quote
1153        let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1154        assert_eq!(expr, expected);
1155        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1156
1157        let batch = RecordBatch::try_new(
1158            schema,
1159            vec![Arc::new(StringArray::from_iter_values(
1160                (0..10).map(|v| format!("str-{}", v)),
1161            ))],
1162        )
1163        .unwrap();
1164        let predicates = physical_expr.evaluate(&batch).unwrap();
1165        assert_eq!(
1166            predicates.into_array(0).unwrap().as_ref(),
1167            &BooleanArray::from(vec![
1168                false, false, false, false, true, false, false, false, false, false
1169            ])
1170        );
1171    }
1172
1173    #[test]
1174    fn test_not_like() {
1175        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1176
1177        let planner = Planner::new(schema.clone());
1178
1179        let expected = col("s").not_like(lit("str-4"));
1180        // single quote
1181        let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1182        assert_eq!(expr, expected);
1183        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1184
1185        let batch = RecordBatch::try_new(
1186            schema,
1187            vec![Arc::new(StringArray::from_iter_values(
1188                (0..10).map(|v| format!("str-{}", v)),
1189            ))],
1190        )
1191        .unwrap();
1192        let predicates = physical_expr.evaluate(&batch).unwrap();
1193        assert_eq!(
1194            predicates.into_array(0).unwrap().as_ref(),
1195            &BooleanArray::from(vec![
1196                true, true, true, true, false, true, true, true, true, true
1197            ])
1198        );
1199    }
1200
1201    #[test]
1202    fn test_sql_is_in() {
1203        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1204
1205        let planner = Planner::new(schema.clone());
1206
1207        let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1208        // single quote
1209        let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1210        assert_eq!(expr, expected);
1211        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1212
1213        let batch = RecordBatch::try_new(
1214            schema,
1215            vec![Arc::new(StringArray::from_iter_values(
1216                (0..10).map(|v| format!("str-{}", v)),
1217            ))],
1218        )
1219        .unwrap();
1220        let predicates = physical_expr.evaluate(&batch).unwrap();
1221        assert_eq!(
1222            predicates.into_array(0).unwrap().as_ref(),
1223            &BooleanArray::from(vec![
1224                false, false, false, false, true, true, false, false, false, false
1225            ])
1226        );
1227    }
1228
1229    #[test]
1230    fn test_sql_is_null() {
1231        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1232
1233        let planner = Planner::new(schema.clone());
1234
1235        let expected = col("s").is_null();
1236        let expr = planner.parse_filter("s IS NULL").unwrap();
1237        assert_eq!(expr, expected);
1238        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1239
1240        let batch = RecordBatch::try_new(
1241            schema,
1242            vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1243                if v % 3 == 0 {
1244                    Some(format!("str-{}", v))
1245                } else {
1246                    None
1247                }
1248            })))],
1249        )
1250        .unwrap();
1251        let predicates = physical_expr.evaluate(&batch).unwrap();
1252        assert_eq!(
1253            predicates.into_array(0).unwrap().as_ref(),
1254            &BooleanArray::from(vec![
1255                false, true, true, false, true, true, false, true, true, false
1256            ])
1257        );
1258
1259        let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1260        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1261        let predicates = physical_expr.evaluate(&batch).unwrap();
1262        assert_eq!(
1263            predicates.into_array(0).unwrap().as_ref(),
1264            &BooleanArray::from(vec![
1265                true, false, false, true, false, false, true, false, false, true,
1266            ])
1267        );
1268    }
1269
1270    #[test]
1271    fn test_sql_invert() {
1272        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1273
1274        let planner = Planner::new(schema.clone());
1275
1276        let expr = planner.parse_filter("NOT s").unwrap();
1277        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1278
1279        let batch = RecordBatch::try_new(
1280            schema,
1281            vec![Arc::new(BooleanArray::from_iter(
1282                (0..10).map(|v| Some(v % 3 == 0)),
1283            ))],
1284        )
1285        .unwrap();
1286        let predicates = physical_expr.evaluate(&batch).unwrap();
1287        assert_eq!(
1288            predicates.into_array(0).unwrap().as_ref(),
1289            &BooleanArray::from(vec![
1290                false, true, true, false, true, true, false, true, true, false
1291            ])
1292        );
1293    }
1294
1295    #[test]
1296    fn test_sql_cast() {
1297        let cases = &[
1298            (
1299                "x = cast('2021-01-01 00:00:00' as timestamp)",
1300                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1301            ),
1302            (
1303                "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1304                ArrowDataType::Timestamp(TimeUnit::Second, None),
1305            ),
1306            (
1307                "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1308                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1309            ),
1310            (
1311                "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1312                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1313            ),
1314            ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1315            (
1316                "x = cast('1.238' as decimal(9,3))",
1317                ArrowDataType::Decimal128(9, 3),
1318            ),
1319            ("x = cast(1 as float)", ArrowDataType::Float32),
1320            ("x = cast(1 as double)", ArrowDataType::Float64),
1321            ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1322            ("x = cast(1 as smallint)", ArrowDataType::Int16),
1323            ("x = cast(1 as int)", ArrowDataType::Int32),
1324            ("x = cast(1 as integer)", ArrowDataType::Int32),
1325            ("x = cast(1 as bigint)", ArrowDataType::Int64),
1326            ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1327            ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1328            ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1329            ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1330            ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1331            ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1332            ("x = cast(1 as string)", ArrowDataType::Utf8),
1333        ];
1334
1335        for (sql, expected_data_type) in cases {
1336            let schema = Arc::new(Schema::new(vec![Field::new(
1337                "x",
1338                expected_data_type.clone(),
1339                true,
1340            )]));
1341            let planner = Planner::new(schema.clone());
1342            let expr = planner.parse_filter(sql).unwrap();
1343
1344            // Get the thing after 'cast(` but before ' as'.
1345            let expected_value_str = sql
1346                .split("cast(")
1347                .nth(1)
1348                .unwrap()
1349                .split(" as")
1350                .next()
1351                .unwrap();
1352            // Remove any quote marks
1353            let expected_value_str = expected_value_str.trim_matches('\'');
1354
1355            match expr {
1356                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1357                    Expr::Cast(Cast { expr, data_type }) => {
1358                        match expr.as_ref() {
1359                            Expr::Literal(ScalarValue::Utf8(Some(value_str))) => {
1360                                assert_eq!(value_str, expected_value_str);
1361                            }
1362                            Expr::Literal(ScalarValue::Int64(Some(value))) => {
1363                                assert_eq!(*value, 1);
1364                            }
1365                            _ => panic!("Expected cast to be applied to literal"),
1366                        }
1367                        assert_eq!(data_type, expected_data_type);
1368                    }
1369                    _ => panic!("Expected right to be a cast"),
1370                },
1371                _ => panic!("Expected binary expression"),
1372            }
1373        }
1374    }
1375
1376    #[test]
1377    fn test_sql_literals() {
1378        let cases = &[
1379            (
1380                "x = timestamp '2021-01-01 00:00:00'",
1381                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1382            ),
1383            (
1384                "x = timestamp(0) '2021-01-01 00:00:00'",
1385                ArrowDataType::Timestamp(TimeUnit::Second, None),
1386            ),
1387            (
1388                "x = timestamp(9) '2021-01-01 00:00:00.123'",
1389                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1390            ),
1391            ("x = date '2021-01-01'", ArrowDataType::Date32),
1392            ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1393        ];
1394
1395        for (sql, expected_data_type) in cases {
1396            let schema = Arc::new(Schema::new(vec![Field::new(
1397                "x",
1398                expected_data_type.clone(),
1399                true,
1400            )]));
1401            let planner = Planner::new(schema.clone());
1402            let expr = planner.parse_filter(sql).unwrap();
1403
1404            let expected_value_str = sql.split('\'').nth(1).unwrap();
1405
1406            match expr {
1407                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1408                    Expr::Cast(Cast { expr, data_type }) => {
1409                        match expr.as_ref() {
1410                            Expr::Literal(ScalarValue::Utf8(Some(value_str))) => {
1411                                assert_eq!(value_str, expected_value_str);
1412                            }
1413                            _ => panic!("Expected cast to be applied to literal"),
1414                        }
1415                        assert_eq!(data_type, expected_data_type);
1416                    }
1417                    _ => panic!("Expected right to be a cast"),
1418                },
1419                _ => panic!("Expected binary expression"),
1420            }
1421        }
1422    }
1423
1424    #[test]
1425    fn test_sql_array_literals() {
1426        let cases = [
1427            (
1428                "x = [1, 2, 3]",
1429                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1430            ),
1431            (
1432                "x = [1, 2, 3]",
1433                ArrowDataType::FixedSizeList(
1434                    Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1435                    3,
1436                ),
1437            ),
1438        ];
1439
1440        for (sql, expected_data_type) in cases {
1441            let schema = Arc::new(Schema::new(vec![Field::new(
1442                "x",
1443                expected_data_type.clone(),
1444                true,
1445            )]));
1446            let planner = Planner::new(schema.clone());
1447            let expr = planner.parse_filter(sql).unwrap();
1448            let expr = planner.optimize_expr(expr).unwrap();
1449
1450            match expr {
1451                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1452                    Expr::Literal(value) => {
1453                        assert_eq!(&value.data_type(), &expected_data_type);
1454                    }
1455                    _ => panic!("Expected right to be a literal"),
1456                },
1457                _ => panic!("Expected binary expression"),
1458            }
1459        }
1460    }
1461
1462    #[test]
1463    fn test_sql_between() {
1464        use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1465        use arrow_schema::{DataType, Field, Schema, TimeUnit};
1466        use std::sync::Arc;
1467
1468        let schema = Arc::new(Schema::new(vec![
1469            Field::new("x", DataType::Int32, false),
1470            Field::new("y", DataType::Float64, false),
1471            Field::new(
1472                "ts",
1473                DataType::Timestamp(TimeUnit::Microsecond, None),
1474                false,
1475            ),
1476        ]));
1477
1478        let planner = Planner::new(schema.clone());
1479
1480        // Test integer BETWEEN
1481        let expr = planner
1482            .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1483            .unwrap();
1484        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1485
1486        // Create timestamp array with values representing:
1487        // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
1488        let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
1489        let ts_array = TimestampMicrosecondArray::from_iter_values(
1490            (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
1491        );
1492
1493        let batch = RecordBatch::try_new(
1494            schema,
1495            vec![
1496                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1497                Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1498                Arc::new(ts_array),
1499            ],
1500        )
1501        .unwrap();
1502
1503        let predicates = physical_expr.evaluate(&batch).unwrap();
1504        assert_eq!(
1505            predicates.into_array(0).unwrap().as_ref(),
1506            &BooleanArray::from(vec![
1507                false, false, false, true, true, true, true, true, false, false
1508            ])
1509        );
1510
1511        // Test NOT BETWEEN
1512        let expr = planner
1513            .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1514            .unwrap();
1515        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1516
1517        let predicates = physical_expr.evaluate(&batch).unwrap();
1518        assert_eq!(
1519            predicates.into_array(0).unwrap().as_ref(),
1520            &BooleanArray::from(vec![
1521                true, true, true, false, false, false, false, false, true, true
1522            ])
1523        );
1524
1525        // Test floating point BETWEEN
1526        let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1527        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1528
1529        let predicates = physical_expr.evaluate(&batch).unwrap();
1530        assert_eq!(
1531            predicates.into_array(0).unwrap().as_ref(),
1532            &BooleanArray::from(vec![
1533                false, false, false, true, true, true, true, false, false, false
1534            ])
1535        );
1536
1537        // Test timestamp BETWEEN
1538        let expr = planner
1539            .parse_filter(
1540                "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1541            )
1542            .unwrap();
1543        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1544
1545        let predicates = physical_expr.evaluate(&batch).unwrap();
1546        assert_eq!(
1547            predicates.into_array(0).unwrap().as_ref(),
1548            &BooleanArray::from(vec![
1549                false, false, false, true, true, true, true, true, false, false
1550            ])
1551        );
1552    }
1553
1554    #[test]
1555    fn test_sql_comparison() {
1556        // Create a batch with all data types
1557        let batch: Vec<(&str, ArrayRef)> = vec![
1558            (
1559                "timestamp_s",
1560                Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1561            ),
1562            (
1563                "timestamp_ms",
1564                Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1565            ),
1566            (
1567                "timestamp_us",
1568                Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1569            ),
1570            (
1571                "timestamp_ns",
1572                Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1573            ),
1574        ];
1575        let batch = RecordBatch::try_from_iter(batch).unwrap();
1576
1577        let planner = Planner::new(batch.schema());
1578
1579        // Each expression is meant to select the final 5 rows
1580        let expressions = &[
1581            "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1582            "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1583            "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1584            "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1585        ];
1586
1587        let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1588            std::iter::repeat(Some(false))
1589                .take(5)
1590                .chain(std::iter::repeat(Some(true)).take(5)),
1591        ));
1592        for expression in expressions {
1593            // convert to physical expression
1594            let logical_expr = planner.parse_filter(expression).unwrap();
1595            let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1596            let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1597
1598            // Evaluate and assert they have correct results
1599            let result = physical_expr.evaluate(&batch).unwrap();
1600            let result = result.into_array(batch.num_rows()).unwrap();
1601            assert_eq!(&expected, &result, "unexpected result for {}", expression);
1602        }
1603    }
1604
1605    #[test]
1606    fn test_columns_in_expr() {
1607        let expr = col("s0").gt(lit("value")).and(
1608            col("st")
1609                .field("st")
1610                .field("s2")
1611                .eq(lit("value"))
1612                .or(col("st")
1613                    .field("s1")
1614                    .in_list(vec![lit("value 1"), lit("value 2")], false)),
1615        );
1616
1617        let columns = Planner::column_names_in_expr(&expr);
1618        assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1619    }
1620
1621    #[test]
1622    fn test_parse_binary_expr() {
1623        let bin_str = "x'616263'";
1624
1625        let schema = Arc::new(Schema::new(vec![Field::new(
1626            "binary",
1627            DataType::Binary,
1628            true,
1629        )]));
1630        let planner = Planner::new(schema);
1631        let expr = planner.parse_expr(bin_str).unwrap();
1632        assert_eq!(
1633            expr,
1634            Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])))
1635        );
1636    }
1637
1638    #[test]
1639    fn test_lance_context_provider_expr_planners() {
1640        let ctx_provider = LanceContextProvider::default();
1641        assert!(!ctx_provider.get_expr_planners().is_empty());
1642    }
1643}