Skip to main content

lance_datafusion/
planner.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Exec plan planner
5
6use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::exec::{get_session_context, LanceExecutionOptions};
11use crate::expr::safe_coerce_scalar;
12use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
13use crate::sql::{parse_sql_expr, parse_sql_filter};
14use arrow::compute::CastOptions;
15use arrow_array::ListArray;
16use arrow_buffer::OffsetBuffer;
17use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
18use arrow_select::concat::concat;
19use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
20use datafusion::common::DFSchema;
21use datafusion::config::ConfigOptions;
22use datafusion::error::Result as DFResult;
23use datafusion::execution::context::SessionState;
24use datafusion::logical_expr::expr::ScalarFunction;
25use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
26use datafusion::logical_expr::{
27    AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
28    Signature, Volatility, WindowUDF,
29};
30use datafusion::optimizer::simplify_expressions::SimplifyContext;
31use datafusion::sql::planner::{
32    ContextProvider, NullOrdering, ParserOptions, PlannerContext, SqlToRel,
33};
34use datafusion::sql::sqlparser::ast::{
35    AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36    Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
37    ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
38};
39use datafusion::{
40    common::Column,
41    logical_expr::{Between, BinaryExpr, Like, Operator},
42    physical_expr::execution_props::ExecutionProps,
43    physical_plan::PhysicalExpr,
44    prelude::Expr,
45    scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use lance_core::error::LanceOptionExt;
51use snafu::location;
52
53use chrono::Utc;
54use lance_core::{Error, Result};
55
56#[derive(Debug, Clone, Eq, PartialEq, Hash)]
57struct CastListF16Udf {
58    signature: Signature,
59}
60
61impl CastListF16Udf {
62    pub fn new() -> Self {
63        Self {
64            signature: Signature::any(1, Volatility::Immutable),
65        }
66    }
67}
68
69impl ScalarUDFImpl for CastListF16Udf {
70    fn as_any(&self) -> &dyn std::any::Any {
71        self
72    }
73
74    fn name(&self) -> &str {
75        "_cast_list_f16"
76    }
77
78    fn signature(&self) -> &Signature {
79        &self.signature
80    }
81
82    fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
83        let input = &arg_types[0];
84        match input {
85            ArrowDataType::FixedSizeList(field, size) => {
86                if field.data_type() != &ArrowDataType::Float32
87                    && field.data_type() != &ArrowDataType::Float16
88                {
89                    return Err(datafusion::error::DataFusionError::Execution(
90                        "cast_list_f16 only supports list of float32 or float16".to_string(),
91                    ));
92                }
93                Ok(ArrowDataType::FixedSizeList(
94                    Arc::new(Field::new(
95                        field.name(),
96                        ArrowDataType::Float16,
97                        field.is_nullable(),
98                    )),
99                    *size,
100                ))
101            }
102            ArrowDataType::List(field) => {
103                if field.data_type() != &ArrowDataType::Float32
104                    && field.data_type() != &ArrowDataType::Float16
105                {
106                    return Err(datafusion::error::DataFusionError::Execution(
107                        "cast_list_f16 only supports list of float32 or float16".to_string(),
108                    ));
109                }
110                Ok(ArrowDataType::List(Arc::new(Field::new(
111                    field.name(),
112                    ArrowDataType::Float16,
113                    field.is_nullable(),
114                ))))
115            }
116            _ => Err(datafusion::error::DataFusionError::Execution(
117                "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
118            )),
119        }
120    }
121
122    fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
123        let ColumnarValue::Array(arr) = &func_args.args[0] else {
124            return Err(datafusion::error::DataFusionError::Execution(
125                "cast_list_f16 only supports array arguments".to_string(),
126            ));
127        };
128
129        let to_type = match arr.data_type() {
130            ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
131                Arc::new(Field::new(
132                    field.name(),
133                    ArrowDataType::Float16,
134                    field.is_nullable(),
135                )),
136                *size,
137            ),
138            ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
139                field.name(),
140                ArrowDataType::Float16,
141                field.is_nullable(),
142            ))),
143            _ => {
144                return Err(datafusion::error::DataFusionError::Execution(
145                    "cast_list_f16 only supports array arguments".to_string(),
146                ));
147            }
148        };
149
150        let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
151        Ok(ColumnarValue::Array(res))
152    }
153}
154
155// Adapter that instructs datafusion how lance expects expressions to be interpreted
156struct LanceContextProvider {
157    options: datafusion::config::ConfigOptions,
158    state: SessionState,
159    expr_planners: Vec<Arc<dyn ExprPlanner>>,
160}
161
162impl Default for LanceContextProvider {
163    fn default() -> Self {
164        let ctx = get_session_context(&LanceExecutionOptions::default());
165        let state = ctx.state();
166        let expr_planners = state.expr_planners().to_vec();
167
168        Self {
169            options: ConfigOptions::default(),
170            state,
171            expr_planners,
172        }
173    }
174}
175
176impl ContextProvider for LanceContextProvider {
177    fn get_table_source(
178        &self,
179        name: datafusion::sql::TableReference,
180    ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
181        Err(datafusion::error::DataFusionError::NotImplemented(format!(
182            "Attempt to reference inner table {} not supported",
183            name
184        )))
185    }
186
187    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
188        self.state.aggregate_functions().get(name).cloned()
189    }
190
191    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
192        self.state.window_functions().get(name).cloned()
193    }
194
195    fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
196        match f {
197            // TODO: cast should go thru CAST syntax instead of UDF
198            // Going thru UDF makes it hard for the optimizer to find no-ops
199            "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
200            _ => self.state.scalar_functions().get(f).cloned(),
201        }
202    }
203
204    fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
205        // Variables (things like @@LANGUAGE) not supported
206        None
207    }
208
209    fn options(&self) -> &datafusion::config::ConfigOptions {
210        &self.options
211    }
212
213    fn udf_names(&self) -> Vec<String> {
214        self.state.scalar_functions().keys().cloned().collect()
215    }
216
217    fn udaf_names(&self) -> Vec<String> {
218        self.state.aggregate_functions().keys().cloned().collect()
219    }
220
221    fn udwf_names(&self) -> Vec<String> {
222        self.state.window_functions().keys().cloned().collect()
223    }
224
225    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
226        &self.expr_planners
227    }
228}
229
230pub struct Planner {
231    schema: SchemaRef,
232    context_provider: LanceContextProvider,
233    enable_relations: bool,
234}
235
236impl Planner {
237    pub fn new(schema: SchemaRef) -> Self {
238        Self {
239            schema,
240            context_provider: LanceContextProvider::default(),
241            enable_relations: false,
242        }
243    }
244
245    /// If passed with `true`, then the first identifier in column reference
246    /// is parsed as the relation. For example, `table.field.inner` will be
247    /// read as the nested field `field.inner` (`inner` on struct field `field`)
248    /// on the `table` relation. If `false` (the default), then no relations
249    /// are used and all identifiers are assumed to be a nested column path.
250    pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
251        self.enable_relations = enable_relations;
252        self
253    }
254
255    /// Resolve a column name using case-insensitive matching against the schema.
256    /// Returns the actual field name if found, otherwise returns the original name.
257    fn resolve_column_name(&self, name: &str) -> String {
258        // Try exact match first
259        if self.schema.field_with_name(name).is_ok() {
260            return name.to_string();
261        }
262        // Fall back to case-insensitive match
263        for field in self.schema.fields() {
264            if field.name().eq_ignore_ascii_case(name) {
265                return field.name().clone();
266            }
267        }
268        // Not found in schema - return original (might be computed column, system column, etc.)
269        name.to_string()
270    }
271
272    fn column(&self, idents: &[Ident]) -> Expr {
273        fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
274            for ident in idents {
275                *expr = Expr::ScalarFunction(ScalarFunction {
276                    args: vec![
277                        std::mem::take(expr),
278                        Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
279                    ],
280                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
281                });
282            }
283        }
284
285        if self.enable_relations && idents.len() > 1 {
286            // Create qualified column reference (relation.column)
287            let relation = &idents[0].value;
288            let column_name = self.resolve_column_name(&idents[1].value);
289            let column = Expr::Column(Column::new(Some(relation.clone()), column_name));
290            let mut result = column;
291            handle_remaining_idents(&mut result, &idents[2..]);
292            result
293        } else {
294            // Default behavior - treat as struct field access
295            // Use resolved column name to handle case-insensitive matching
296            let resolved_name = self.resolve_column_name(&idents[0].value);
297            let mut column = Expr::Column(Column::from_name(resolved_name));
298            handle_remaining_idents(&mut column, &idents[1..]);
299            column
300        }
301    }
302
303    fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
304        Ok(match op {
305            BinaryOperator::Plus => Operator::Plus,
306            BinaryOperator::Minus => Operator::Minus,
307            BinaryOperator::Multiply => Operator::Multiply,
308            BinaryOperator::Divide => Operator::Divide,
309            BinaryOperator::Modulo => Operator::Modulo,
310            BinaryOperator::StringConcat => Operator::StringConcat,
311            BinaryOperator::Gt => Operator::Gt,
312            BinaryOperator::Lt => Operator::Lt,
313            BinaryOperator::GtEq => Operator::GtEq,
314            BinaryOperator::LtEq => Operator::LtEq,
315            BinaryOperator::Eq => Operator::Eq,
316            BinaryOperator::NotEq => Operator::NotEq,
317            BinaryOperator::And => Operator::And,
318            BinaryOperator::Or => Operator::Or,
319            _ => {
320                return Err(Error::invalid_input(
321                    format!("Operator {op} is not supported"),
322                    location!(),
323                ));
324            }
325        })
326    }
327
328    fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
329        Ok(Expr::BinaryExpr(BinaryExpr::new(
330            Box::new(self.parse_sql_expr(left)?),
331            self.binary_op(op)?,
332            Box::new(self.parse_sql_expr(right)?),
333        )))
334    }
335
336    fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
337        Ok(match op {
338            UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
339                Expr::Not(Box::new(self.parse_sql_expr(expr)?))
340            }
341
342            UnaryOperator::Minus => {
343                use datafusion::logical_expr::lit;
344                match expr {
345                    SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
346                        Ok(n) => lit(-n),
347                        Err(_) => lit(-n
348                            .parse::<f64>()
349                            .map_err(|_e| {
350                                Error::invalid_input(
351                                    format!("negative operator can be only applied to integer and float operands, got: {n}"),
352                                    location!(),
353                                )
354                            })?),
355                    },
356                    _ => {
357                        Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
358                    }
359                }
360            }
361
362            _ => {
363                return Err(Error::invalid_input(
364                    format!("Unary operator '{:?}' is not supported", op),
365                    location!(),
366                ));
367            }
368        })
369    }
370
371    // See datafusion `sqlToRel::parse_sql_number()`
372    fn number(&self, value: &str, negative: bool) -> Result<Expr> {
373        use datafusion::logical_expr::lit;
374        let value: Cow<str> = if negative {
375            Cow::Owned(format!("-{}", value))
376        } else {
377            Cow::Borrowed(value)
378        };
379        if let Ok(n) = value.parse::<i64>() {
380            Ok(lit(n))
381        } else {
382            value.parse::<f64>().map(lit).map_err(|_| {
383                Error::invalid_input(
384                    format!("'{value}' is not supported number value."),
385                    location!(),
386                )
387            })
388        }
389    }
390
391    fn value(&self, value: &Value) -> Result<Expr> {
392        Ok(match value {
393            Value::Number(v, _) => self.number(v.as_str(), false)?,
394            Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
395            Value::HexStringLiteral(hsl) => {
396                Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
397            }
398            Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
399            Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
400            Value::Null => Expr::Literal(ScalarValue::Null, None),
401            _ => todo!(),
402        })
403    }
404
405    fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
406        match func_args {
407            FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
408            _ => Err(Error::invalid_input(
409                format!("Unsupported function args: {:?}", func_args),
410                location!(),
411            )),
412        }
413    }
414
415    // We now use datafusion to parse functions.  This allows us to use datafusion's
416    // entire collection of functions (previously we had just hard-coded support for two functions).
417    //
418    // Unfortunately, one of those two functions was is_valid and the reason we needed it was because
419    // this is a function that comes from duckdb.  Datafusion does not consider is_valid to be a function
420    // but rather an AST node (Expr::IsNotNull) and so we need to handle this case specially.
421    fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
422        match &func.args {
423            FunctionArguments::List(args) => {
424                if func.name.0.len() != 1 {
425                    return Err(Error::invalid_input(
426                        format!("Function name must have 1 part, got: {:?}", func.name.0),
427                        location!(),
428                    ));
429                }
430                Ok(Expr::IsNotNull(Box::new(
431                    self.parse_function_args(&args.args[0])?,
432                )))
433            }
434            _ => Err(Error::invalid_input(
435                format!("Unsupported function args: {:?}", &func.args),
436                location!(),
437            )),
438        }
439    }
440
441    fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
442        if let SQLExpr::Function(function) = &function {
443            if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
444                if &name.value == "is_valid" {
445                    return self.legacy_parse_function(function);
446                }
447            }
448        }
449        let sql_to_rel = SqlToRel::new_with_options(
450            &self.context_provider,
451            ParserOptions {
452                parse_float_as_decimal: false,
453                enable_ident_normalization: false,
454                support_varchar_with_length: false,
455                enable_options_value_normalization: false,
456                collect_spans: false,
457                map_string_types_to_utf8view: false,
458                default_null_ordering: NullOrdering::NullsMax,
459            },
460        );
461
462        let mut planner_context = PlannerContext::default();
463        let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
464        sql_to_rel
465            .sql_to_expr(function, &schema, &mut planner_context)
466            .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
467    }
468
469    fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
470        const SUPPORTED_TYPES: [&str; 13] = [
471            "int [unsigned]",
472            "tinyint [unsigned]",
473            "smallint [unsigned]",
474            "bigint [unsigned]",
475            "float",
476            "double",
477            "string",
478            "binary",
479            "date",
480            "timestamp(precision)",
481            "datetime(precision)",
482            "decimal(precision,scale)",
483            "boolean",
484        ];
485        match data_type {
486            SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
487            SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
488            SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
489            SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
490            SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
491            SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
492            SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
493            SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
494            SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
495            SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
496            SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
497            SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
498                Ok(ArrowDataType::UInt32)
499            }
500            SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
501            SQLDataType::Date => Ok(ArrowDataType::Date32),
502            SQLDataType::Timestamp(resolution, tz) => {
503                match tz {
504                    TimezoneInfo::None => {}
505                    _ => {
506                        return Err(Error::invalid_input(
507                            "Timezone not supported in timestamp".to_string(),
508                            location!(),
509                        ));
510                    }
511                };
512                let time_unit = match resolution {
513                    // Default to microsecond to match PyArrow
514                    None => TimeUnit::Microsecond,
515                    Some(0) => TimeUnit::Second,
516                    Some(3) => TimeUnit::Millisecond,
517                    Some(6) => TimeUnit::Microsecond,
518                    Some(9) => TimeUnit::Nanosecond,
519                    _ => {
520                        return Err(Error::invalid_input(
521                            format!("Unsupported datetime resolution: {:?}", resolution),
522                            location!(),
523                        ));
524                    }
525                };
526                Ok(ArrowDataType::Timestamp(time_unit, None))
527            }
528            SQLDataType::Datetime(resolution) => {
529                let time_unit = match resolution {
530                    None => TimeUnit::Microsecond,
531                    Some(0) => TimeUnit::Second,
532                    Some(3) => TimeUnit::Millisecond,
533                    Some(6) => TimeUnit::Microsecond,
534                    Some(9) => TimeUnit::Nanosecond,
535                    _ => {
536                        return Err(Error::invalid_input(
537                            format!("Unsupported datetime resolution: {:?}", resolution),
538                            location!(),
539                        ));
540                    }
541                };
542                Ok(ArrowDataType::Timestamp(time_unit, None))
543            }
544            SQLDataType::Decimal(number_info) => match number_info {
545                ExactNumberInfo::PrecisionAndScale(precision, scale) => {
546                    Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
547                }
548                _ => Err(Error::invalid_input(
549                    format!(
550                        "Must provide precision and scale for decimal: {:?}",
551                        number_info
552                    ),
553                    location!(),
554                )),
555            },
556            _ => Err(Error::invalid_input(
557                format!(
558                    "Unsupported data type: {:?}. Supported types: {:?}",
559                    data_type, SUPPORTED_TYPES
560                ),
561                location!(),
562            )),
563        }
564    }
565
566    fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
567        let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
568        for planner in self.context_provider.get_expr_planners() {
569            match planner.plan_field_access(field_access_expr, &df_schema)? {
570                PlannerResult::Planned(expr) => return Ok(expr),
571                PlannerResult::Original(expr) => {
572                    field_access_expr = expr;
573                }
574            }
575        }
576        Err(Error::invalid_input(
577            "Field access could not be planned",
578            location!(),
579        ))
580    }
581
582    fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
583        match expr {
584            SQLExpr::Identifier(id) => {
585                // Users can pass string literals wrapped in `"`.
586                // (Normally SQL only allows single quotes.)
587                if id.quote_style == Some('"') {
588                    Ok(Expr::Literal(
589                        ScalarValue::Utf8(Some(id.value.clone())),
590                        None,
591                    ))
592                // Users can wrap identifiers with ` to reference non-standard
593                // names, such as uppercase or spaces.
594                } else if id.quote_style == Some('`') {
595                    Ok(Expr::Column(Column::from_name(id.value.clone())))
596                } else {
597                    Ok(self.column(vec![id.clone()].as_slice()))
598                }
599            }
600            SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
601            SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
602            SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
603            SQLExpr::Value(value) => self.value(&value.value),
604            SQLExpr::Array(SQLArray { elem, .. }) => {
605                let mut values = vec![];
606
607                let array_literal_error = |pos: usize, value: &_| {
608                    Err(Error::invalid_input(
609                        format!(
610                            "Expected a literal value in array, instead got {} at position {}",
611                            value, pos
612                        ),
613                        location!(),
614                    ))
615                };
616
617                for (pos, expr) in elem.iter().enumerate() {
618                    match expr {
619                        SQLExpr::Value(value) => {
620                            if let Expr::Literal(value, _) = self.value(&value.value)? {
621                                values.push(value);
622                            } else {
623                                return array_literal_error(pos, expr);
624                            }
625                        }
626                        SQLExpr::UnaryOp {
627                            op: UnaryOperator::Minus,
628                            expr,
629                        } => {
630                            if let SQLExpr::Value(ValueWithSpan {
631                                value: Value::Number(number, _),
632                                ..
633                            }) = expr.as_ref()
634                            {
635                                if let Expr::Literal(value, _) = self.number(number, true)? {
636                                    values.push(value);
637                                } else {
638                                    return array_literal_error(pos, expr);
639                                }
640                            } else {
641                                return array_literal_error(pos, expr);
642                            }
643                        }
644                        _ => {
645                            return array_literal_error(pos, expr);
646                        }
647                    }
648                }
649
650                let field = if !values.is_empty() {
651                    let data_type = values[0].data_type();
652
653                    for value in &mut values {
654                        if value.data_type() != data_type {
655                            *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
656                                format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
657                                location!()
658                            ))?;
659                        }
660                    }
661                    Field::new("item", data_type, true)
662                } else {
663                    Field::new("item", ArrowDataType::Null, true)
664                };
665
666                let values = values
667                    .into_iter()
668                    .map(|v| v.to_array().map_err(Error::from))
669                    .collect::<Result<Vec<_>>>()?;
670                let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
671                let values = concat(&array_refs)?;
672                let values = ListArray::try_new(
673                    field.into(),
674                    OffsetBuffer::from_lengths([values.len()]),
675                    values,
676                    None,
677                )?;
678
679                Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
680            }
681            // For example, DATE '2020-01-01'
682            SQLExpr::TypedString { data_type, value } => {
683                let value = value.clone().into_string().expect_ok()?;
684                Ok(Expr::Cast(datafusion::logical_expr::Cast {
685                    expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
686                    data_type: self.parse_type(data_type)?,
687                }))
688            }
689            SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
690            SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
691            SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
692            SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
693            SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
694            SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
695            SQLExpr::InList {
696                expr,
697                list,
698                negated,
699            } => {
700                let value_expr = self.parse_sql_expr(expr)?;
701                let list_exprs = list
702                    .iter()
703                    .map(|e| self.parse_sql_expr(e))
704                    .collect::<Result<Vec<_>>>()?;
705                Ok(value_expr.in_list(list_exprs, *negated))
706            }
707            SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
708            SQLExpr::Function(_) => self.parse_function(expr.clone()),
709            SQLExpr::ILike {
710                negated,
711                expr,
712                pattern,
713                escape_char,
714                any: _,
715            } => Ok(Expr::Like(Like::new(
716                *negated,
717                Box::new(self.parse_sql_expr(expr)?),
718                Box::new(self.parse_sql_expr(pattern)?),
719                match escape_char {
720                    Some(Value::SingleQuotedString(char)) => char.chars().next(),
721                    Some(value) => return Err(Error::invalid_input(
722                        format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
723                        location!()
724                    )),
725                    None => None,
726                },
727                true,
728            ))),
729            SQLExpr::Like {
730                negated,
731                expr,
732                pattern,
733                escape_char,
734                any: _,
735            } => Ok(Expr::Like(Like::new(
736                *negated,
737                Box::new(self.parse_sql_expr(expr)?),
738                Box::new(self.parse_sql_expr(pattern)?),
739                match escape_char {
740                    Some(Value::SingleQuotedString(char)) => char.chars().next(),
741                    Some(value) => return Err(Error::invalid_input(
742                        format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
743                        location!()
744                    )),
745                    None => None,
746                },
747                false,
748            ))),
749            SQLExpr::Cast {
750                expr,
751                data_type,
752                kind,
753                ..
754            } => match kind {
755                datafusion::sql::sqlparser::ast::CastKind::TryCast
756                | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
757                    Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
758                        expr: Box::new(self.parse_sql_expr(expr)?),
759                        data_type: self.parse_type(data_type)?,
760                    }))
761                }
762                _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
763                    expr: Box::new(self.parse_sql_expr(expr)?),
764                    data_type: self.parse_type(data_type)?,
765                })),
766            },
767            SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
768                "JSON access is not supported",
769                location!(),
770            )),
771            SQLExpr::CompoundFieldAccess { root, access_chain } => {
772                let mut expr = self.parse_sql_expr(root)?;
773
774                for access in access_chain {
775                    let field_access = match access {
776                        // x.y or x['y']
777                        AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
778                        | AccessExpr::Subscript(Subscript::Index {
779                            index:
780                                SQLExpr::Value(ValueWithSpan {
781                                    value:
782                                        Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
783                                    ..
784                                }),
785                        }) => GetFieldAccess::NamedStructField {
786                            name: ScalarValue::from(s.as_str()),
787                        },
788                        AccessExpr::Subscript(Subscript::Index { index }) => {
789                            let key = Box::new(self.parse_sql_expr(index)?);
790                            GetFieldAccess::ListIndex { key }
791                        }
792                        AccessExpr::Subscript(Subscript::Slice { .. }) => {
793                            return Err(Error::invalid_input(
794                                "Slice subscript is not supported",
795                                location!(),
796                            ));
797                        }
798                        _ => {
799                            // Handle other cases like JSON access
800                            // Note: JSON access is not supported in lance
801                            return Err(Error::invalid_input(
802                                "Only dot notation or index access is supported for field access",
803                                location!(),
804                            ));
805                        }
806                    };
807
808                    let field_access_expr = RawFieldAccessExpr { expr, field_access };
809                    expr = self.plan_field_access(field_access_expr)?;
810                }
811
812                Ok(expr)
813            }
814            SQLExpr::Between {
815                expr,
816                negated,
817                low,
818                high,
819            } => {
820                // Parse the main expression and bounds
821                let expr = self.parse_sql_expr(expr)?;
822                let low = self.parse_sql_expr(low)?;
823                let high = self.parse_sql_expr(high)?;
824
825                let between = Expr::Between(Between::new(
826                    Box::new(expr),
827                    *negated,
828                    Box::new(low),
829                    Box::new(high),
830                ));
831                Ok(between)
832            }
833            _ => Err(Error::invalid_input(
834                format!("Expression '{expr}' is not supported SQL in lance"),
835                location!(),
836            )),
837        }
838    }
839
840    /// Create Logical [Expr] from a SQL filter clause.
841    ///
842    /// Note: the returned expression must be passed through [optimize_expr()]
843    /// before being passed to [create_physical_expr()].
844    pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
845        // Allow sqlparser to parse filter as part of ONE SQL statement.
846        let ast_expr = parse_sql_filter(filter)?;
847        let expr = self.parse_sql_expr(&ast_expr)?;
848        let schema = Schema::try_from(self.schema.as_ref())?;
849        let resolved = resolve_expr(&expr, &schema).map_err(|e| {
850            Error::invalid_input(
851                format!("Error resolving filter expression {filter}: {e}"),
852                location!(),
853            )
854        })?;
855
856        Ok(coerce_filter_type_to_boolean(resolved))
857    }
858
859    /// Create Logical [Expr] from a SQL expression.
860    ///
861    /// Note: the returned expression must be passed through [optimize_filter()]
862    /// before being passed to [create_physical_expr()].
863    pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
864        // First check if it's a simple column reference (no operators, functions, etc.)
865        // resolve_column_name tries exact match first, then falls back to case-insensitive
866        let resolved_name = self.resolve_column_name(expr);
867        if self.schema.field_with_name(&resolved_name).is_ok() {
868            return Ok(Expr::Column(Column::from_name(resolved_name)));
869        }
870
871        // Parse as SQL expression
872        let ast_expr = parse_sql_expr(expr)?;
873        let expr = self.parse_sql_expr(&ast_expr)?;
874        let schema = Schema::try_from(self.schema.as_ref())?;
875        let resolved = resolve_expr(&expr, &schema)?;
876        Ok(resolved)
877    }
878
879    /// Try to decode bytes from hex literal string.
880    ///
881    /// Copied from datafusion because this is not public.
882    ///
883    /// TODO: use SqlToRel from Datafusion directly?
884    fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
885        let hex_bytes = s.as_bytes();
886        let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
887
888        let start_idx = hex_bytes.len() % 2;
889        if start_idx > 0 {
890            // The first byte is formed of only one char.
891            decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
892        }
893
894        for i in (start_idx..hex_bytes.len()).step_by(2) {
895            let high = Self::try_decode_hex_char(hex_bytes[i])?;
896            let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
897            decoded_bytes.push((high << 4) | low);
898        }
899
900        Some(decoded_bytes)
901    }
902
903    /// Try to decode a byte from a hex char.
904    ///
905    /// None will be returned if the input char is hex-invalid.
906    const fn try_decode_hex_char(c: u8) -> Option<u8> {
907        match c {
908            b'A'..=b'F' => Some(c - b'A' + 10),
909            b'a'..=b'f' => Some(c - b'a' + 10),
910            b'0'..=b'9' => Some(c - b'0'),
911            _ => None,
912        }
913    }
914
915    /// Optimize the filter expression and coerce data types.
916    pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
917        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
918
919        // DataFusion needs the simplify and coerce passes to be applied before
920        // expressions can be handled by the physical planner.
921        let props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
922        let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
923        let simplifier =
924            datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
925
926        let expr = simplifier.simplify(expr)?;
927        let expr = simplifier.coerce(expr, &df_schema)?;
928
929        Ok(expr)
930    }
931
932    /// Create the [`PhysicalExpr`] from a logical [`Expr`]
933    pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
934        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
935        Ok(datafusion::physical_expr::create_physical_expr(
936            expr,
937            df_schema.as_ref(),
938            &Default::default(),
939        )?)
940    }
941
942    /// Collect the columns in the expression.
943    ///
944    /// The columns are returned in sorted order.
945    ///
946    /// If the expr refers to nested columns these will be returned
947    /// as dotted paths (x.y.z)
948    pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
949        let mut visitor = ColumnCapturingVisitor {
950            current_path: VecDeque::new(),
951            columns: BTreeSet::new(),
952        };
953        expr.visit(&mut visitor).unwrap();
954        visitor.columns.into_iter().collect()
955    }
956}
957
958struct ColumnCapturingVisitor {
959    // Current column path. If this is empty, we are not in a column expression.
960    current_path: VecDeque<String>,
961    columns: BTreeSet<String>,
962}
963
964impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
965    type Node = Expr;
966
967    fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
968        match node {
969            Expr::Column(Column { name, .. }) => {
970                // Build the field path from the column name and any nested fields
971                // The nested field names from get_field already come as literal strings,
972                // so we just need to concatenate them properly
973                let mut path = name.clone();
974                for part in self.current_path.drain(..) {
975                    path.push('.');
976                    // Check if the part needs quoting (contains dots)
977                    if part.contains('.') || part.contains('`') {
978                        // Quote the field name with backticks and escape any existing backticks
979                        let escaped = part.replace('`', "``");
980                        path.push('`');
981                        path.push_str(&escaped);
982                        path.push('`');
983                    } else {
984                        path.push_str(&part);
985                    }
986                }
987                self.columns.insert(path);
988                self.current_path.clear();
989            }
990            Expr::ScalarFunction(udf) => {
991                if udf.name() == GetFieldFunc::default().name() {
992                    if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
993                        self.current_path.push_front(name.to_string())
994                    } else {
995                        self.current_path.clear();
996                    }
997                } else {
998                    self.current_path.clear();
999                }
1000            }
1001            _ => {
1002                self.current_path.clear();
1003            }
1004        }
1005
1006        Ok(TreeNodeRecursion::Continue)
1007    }
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012
1013    use crate::logical_expr::ExprExt;
1014
1015    use super::*;
1016
1017    use arrow::datatypes::Float64Type;
1018    use arrow_array::{
1019        ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
1020        StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1021        TimestampNanosecondArray, TimestampSecondArray,
1022    };
1023    use arrow_schema::{DataType, Fields, Schema};
1024    use datafusion::{
1025        logical_expr::{col, lit, Cast},
1026        prelude::{array_element, get_field},
1027    };
1028    use datafusion_functions::core::expr_ext::FieldAccessor;
1029
1030    #[test]
1031    fn test_parse_filter_simple() {
1032        let schema = Arc::new(Schema::new(vec![
1033            Field::new("i", DataType::Int32, false),
1034            Field::new("s", DataType::Utf8, true),
1035            Field::new(
1036                "st",
1037                DataType::Struct(Fields::from(vec![
1038                    Field::new("x", DataType::Float32, false),
1039                    Field::new("y", DataType::Float32, false),
1040                ])),
1041                true,
1042            ),
1043        ]));
1044
1045        let planner = Planner::new(schema.clone());
1046
1047        let expected = col("i")
1048            .gt(lit(3_i32))
1049            .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1050            .and(
1051                col("s")
1052                    .eq(lit("str-4"))
1053                    .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1054            );
1055
1056        // double quotes
1057        let expr = planner
1058            .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1059            .unwrap();
1060        assert_eq!(expr, expected);
1061
1062        // single quote
1063        let expr = planner
1064            .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1065            .unwrap();
1066
1067        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1068
1069        let batch = RecordBatch::try_new(
1070            schema,
1071            vec![
1072                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1073                Arc::new(StringArray::from_iter_values(
1074                    (0..10).map(|v| format!("str-{}", v)),
1075                )),
1076                Arc::new(StructArray::from(vec![
1077                    (
1078                        Arc::new(Field::new("x", DataType::Float32, false)),
1079                        Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1080                            as ArrayRef,
1081                    ),
1082                    (
1083                        Arc::new(Field::new("y", DataType::Float32, false)),
1084                        Arc::new(Float32Array::from_iter_values(
1085                            (0..10).map(|v| (v * 10) as f32),
1086                        )),
1087                    ),
1088                ])),
1089            ],
1090        )
1091        .unwrap();
1092        let predicates = physical_expr.evaluate(&batch).unwrap();
1093        assert_eq!(
1094            predicates.into_array(0).unwrap().as_ref(),
1095            &BooleanArray::from(vec![
1096                false, false, false, false, true, true, false, false, false, false
1097            ])
1098        );
1099    }
1100
1101    #[test]
1102    fn test_nested_col_refs() {
1103        let schema = Arc::new(Schema::new(vec![
1104            Field::new("s0", DataType::Utf8, true),
1105            Field::new(
1106                "st",
1107                DataType::Struct(Fields::from(vec![
1108                    Field::new("s1", DataType::Utf8, true),
1109                    Field::new(
1110                        "st",
1111                        DataType::Struct(Fields::from(vec![Field::new(
1112                            "s2",
1113                            DataType::Utf8,
1114                            true,
1115                        )])),
1116                        true,
1117                    ),
1118                ])),
1119                true,
1120            ),
1121        ]));
1122
1123        let planner = Planner::new(schema);
1124
1125        fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1126            let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1127            assert!(matches!(
1128                expr,
1129                Expr::BinaryExpr(BinaryExpr {
1130                    left: _,
1131                    op: Operator::Eq,
1132                    right: _
1133                })
1134            ));
1135            if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1136                assert_eq!(left.as_ref(), expected);
1137            }
1138        }
1139
1140        let expected = Expr::Column(Column::new_unqualified("s0"));
1141        assert_column_eq(&planner, "s0", &expected);
1142        assert_column_eq(&planner, "`s0`", &expected);
1143
1144        let expected = Expr::ScalarFunction(ScalarFunction {
1145            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1146            args: vec![
1147                Expr::Column(Column::new_unqualified("st")),
1148                Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1149            ],
1150        });
1151        assert_column_eq(&planner, "st.s1", &expected);
1152        assert_column_eq(&planner, "`st`.`s1`", &expected);
1153        assert_column_eq(&planner, "st.`s1`", &expected);
1154
1155        let expected = Expr::ScalarFunction(ScalarFunction {
1156            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1157            args: vec![
1158                Expr::ScalarFunction(ScalarFunction {
1159                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1160                    args: vec![
1161                        Expr::Column(Column::new_unqualified("st")),
1162                        Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1163                    ],
1164                }),
1165                Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1166            ],
1167        });
1168
1169        assert_column_eq(&planner, "st.st.s2", &expected);
1170        assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1171        assert_column_eq(&planner, "st.st.`s2`", &expected);
1172        assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1173    }
1174
1175    #[test]
1176    fn test_nested_list_refs() {
1177        let schema = Arc::new(Schema::new(vec![Field::new(
1178            "l",
1179            DataType::List(Arc::new(Field::new(
1180                "item",
1181                DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1182                true,
1183            ))),
1184            true,
1185        )]));
1186
1187        let planner = Planner::new(schema);
1188
1189        let expected = array_element(col("l"), lit(0_i64));
1190        let expr = planner.parse_expr("l[0]").unwrap();
1191        assert_eq!(expr, expected);
1192
1193        let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1194        let expr = planner.parse_expr("l[0]['f1']").unwrap();
1195        assert_eq!(expr, expected);
1196
1197        // FIXME: This should work, but sqlparser doesn't recognize anything
1198        // after the period for some reason.
1199        // let expr = planner.parse_expr("l[0].f1").unwrap();
1200        // assert_eq!(expr, expected);
1201    }
1202
1203    #[test]
1204    fn test_negative_expressions() {
1205        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1206
1207        let planner = Planner::new(schema.clone());
1208
1209        let expected = col("x")
1210            .gt(lit(-3_i64))
1211            .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1212
1213        let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1214
1215        assert_eq!(expr, expected);
1216
1217        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1218
1219        let batch = RecordBatch::try_new(
1220            schema,
1221            vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1222        )
1223        .unwrap();
1224        let predicates = physical_expr.evaluate(&batch).unwrap();
1225        assert_eq!(
1226            predicates.into_array(0).unwrap().as_ref(),
1227            &BooleanArray::from(vec![
1228                false, false, false, true, true, true, true, false, false, false
1229            ])
1230        );
1231    }
1232
1233    #[test]
1234    fn test_negative_array_expressions() {
1235        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1236
1237        let planner = Planner::new(schema);
1238
1239        let expected = Expr::Literal(
1240            ScalarValue::List(Arc::new(
1241                ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1242                    [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1243                )]),
1244            )),
1245            None,
1246        );
1247
1248        let expr = planner
1249            .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1250            .unwrap();
1251
1252        assert_eq!(expr, expected);
1253    }
1254
1255    #[test]
1256    fn test_sql_like() {
1257        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1258
1259        let planner = Planner::new(schema.clone());
1260
1261        let expected = col("s").like(lit("str-4"));
1262        // single quote
1263        let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1264        assert_eq!(expr, expected);
1265        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1266
1267        let batch = RecordBatch::try_new(
1268            schema,
1269            vec![Arc::new(StringArray::from_iter_values(
1270                (0..10).map(|v| format!("str-{}", v)),
1271            ))],
1272        )
1273        .unwrap();
1274        let predicates = physical_expr.evaluate(&batch).unwrap();
1275        assert_eq!(
1276            predicates.into_array(0).unwrap().as_ref(),
1277            &BooleanArray::from(vec![
1278                false, false, false, false, true, false, false, false, false, false
1279            ])
1280        );
1281    }
1282
1283    #[test]
1284    fn test_not_like() {
1285        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1286
1287        let planner = Planner::new(schema.clone());
1288
1289        let expected = col("s").not_like(lit("str-4"));
1290        // single quote
1291        let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1292        assert_eq!(expr, expected);
1293        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1294
1295        let batch = RecordBatch::try_new(
1296            schema,
1297            vec![Arc::new(StringArray::from_iter_values(
1298                (0..10).map(|v| format!("str-{}", v)),
1299            ))],
1300        )
1301        .unwrap();
1302        let predicates = physical_expr.evaluate(&batch).unwrap();
1303        assert_eq!(
1304            predicates.into_array(0).unwrap().as_ref(),
1305            &BooleanArray::from(vec![
1306                true, true, true, true, false, true, true, true, true, true
1307            ])
1308        );
1309    }
1310
1311    #[test]
1312    fn test_sql_is_in() {
1313        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1314
1315        let planner = Planner::new(schema.clone());
1316
1317        let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1318        // single quote
1319        let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1320        assert_eq!(expr, expected);
1321        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1322
1323        let batch = RecordBatch::try_new(
1324            schema,
1325            vec![Arc::new(StringArray::from_iter_values(
1326                (0..10).map(|v| format!("str-{}", v)),
1327            ))],
1328        )
1329        .unwrap();
1330        let predicates = physical_expr.evaluate(&batch).unwrap();
1331        assert_eq!(
1332            predicates.into_array(0).unwrap().as_ref(),
1333            &BooleanArray::from(vec![
1334                false, false, false, false, true, true, false, false, false, false
1335            ])
1336        );
1337    }
1338
1339    #[test]
1340    fn test_sql_is_null() {
1341        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1342
1343        let planner = Planner::new(schema.clone());
1344
1345        let expected = col("s").is_null();
1346        let expr = planner.parse_filter("s IS NULL").unwrap();
1347        assert_eq!(expr, expected);
1348        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1349
1350        let batch = RecordBatch::try_new(
1351            schema,
1352            vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1353                if v % 3 == 0 {
1354                    Some(format!("str-{}", v))
1355                } else {
1356                    None
1357                }
1358            })))],
1359        )
1360        .unwrap();
1361        let predicates = physical_expr.evaluate(&batch).unwrap();
1362        assert_eq!(
1363            predicates.into_array(0).unwrap().as_ref(),
1364            &BooleanArray::from(vec![
1365                false, true, true, false, true, true, false, true, true, false
1366            ])
1367        );
1368
1369        let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1370        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1371        let predicates = physical_expr.evaluate(&batch).unwrap();
1372        assert_eq!(
1373            predicates.into_array(0).unwrap().as_ref(),
1374            &BooleanArray::from(vec![
1375                true, false, false, true, false, false, true, false, false, true,
1376            ])
1377        );
1378    }
1379
1380    #[test]
1381    fn test_sql_invert() {
1382        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1383
1384        let planner = Planner::new(schema.clone());
1385
1386        let expr = planner.parse_filter("NOT s").unwrap();
1387        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1388
1389        let batch = RecordBatch::try_new(
1390            schema,
1391            vec![Arc::new(BooleanArray::from_iter(
1392                (0..10).map(|v| Some(v % 3 == 0)),
1393            ))],
1394        )
1395        .unwrap();
1396        let predicates = physical_expr.evaluate(&batch).unwrap();
1397        assert_eq!(
1398            predicates.into_array(0).unwrap().as_ref(),
1399            &BooleanArray::from(vec![
1400                false, true, true, false, true, true, false, true, true, false
1401            ])
1402        );
1403    }
1404
1405    #[test]
1406    fn test_sql_cast() {
1407        let cases = &[
1408            (
1409                "x = cast('2021-01-01 00:00:00' as timestamp)",
1410                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1411            ),
1412            (
1413                "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1414                ArrowDataType::Timestamp(TimeUnit::Second, None),
1415            ),
1416            (
1417                "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1418                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1419            ),
1420            (
1421                "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1422                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1423            ),
1424            ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1425            (
1426                "x = cast('1.238' as decimal(9,3))",
1427                ArrowDataType::Decimal128(9, 3),
1428            ),
1429            ("x = cast(1 as float)", ArrowDataType::Float32),
1430            ("x = cast(1 as double)", ArrowDataType::Float64),
1431            ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1432            ("x = cast(1 as smallint)", ArrowDataType::Int16),
1433            ("x = cast(1 as int)", ArrowDataType::Int32),
1434            ("x = cast(1 as integer)", ArrowDataType::Int32),
1435            ("x = cast(1 as bigint)", ArrowDataType::Int64),
1436            ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1437            ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1438            ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1439            ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1440            ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1441            ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1442            ("x = cast(1 as string)", ArrowDataType::Utf8),
1443        ];
1444
1445        for (sql, expected_data_type) in cases {
1446            let schema = Arc::new(Schema::new(vec![Field::new(
1447                "x",
1448                expected_data_type.clone(),
1449                true,
1450            )]));
1451            let planner = Planner::new(schema.clone());
1452            let expr = planner.parse_filter(sql).unwrap();
1453
1454            // Get the thing after 'cast(` but before ' as'.
1455            let expected_value_str = sql
1456                .split("cast(")
1457                .nth(1)
1458                .unwrap()
1459                .split(" as")
1460                .next()
1461                .unwrap();
1462            // Remove any quote marks
1463            let expected_value_str = expected_value_str.trim_matches('\'');
1464
1465            match expr {
1466                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1467                    Expr::Cast(Cast { expr, data_type }) => {
1468                        match expr.as_ref() {
1469                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1470                                assert_eq!(value_str, expected_value_str);
1471                            }
1472                            Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1473                                assert_eq!(*value, 1);
1474                            }
1475                            _ => panic!("Expected cast to be applied to literal"),
1476                        }
1477                        assert_eq!(data_type, expected_data_type);
1478                    }
1479                    _ => panic!("Expected right to be a cast"),
1480                },
1481                _ => panic!("Expected binary expression"),
1482            }
1483        }
1484    }
1485
1486    #[test]
1487    fn test_sql_literals() {
1488        let cases = &[
1489            (
1490                "x = timestamp '2021-01-01 00:00:00'",
1491                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1492            ),
1493            (
1494                "x = timestamp(0) '2021-01-01 00:00:00'",
1495                ArrowDataType::Timestamp(TimeUnit::Second, None),
1496            ),
1497            (
1498                "x = timestamp(9) '2021-01-01 00:00:00.123'",
1499                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1500            ),
1501            ("x = date '2021-01-01'", ArrowDataType::Date32),
1502            ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1503        ];
1504
1505        for (sql, expected_data_type) in cases {
1506            let schema = Arc::new(Schema::new(vec![Field::new(
1507                "x",
1508                expected_data_type.clone(),
1509                true,
1510            )]));
1511            let planner = Planner::new(schema.clone());
1512            let expr = planner.parse_filter(sql).unwrap();
1513
1514            let expected_value_str = sql.split('\'').nth(1).unwrap();
1515
1516            match expr {
1517                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1518                    Expr::Cast(Cast { expr, data_type }) => {
1519                        match expr.as_ref() {
1520                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1521                                assert_eq!(value_str, expected_value_str);
1522                            }
1523                            _ => panic!("Expected cast to be applied to literal"),
1524                        }
1525                        assert_eq!(data_type, expected_data_type);
1526                    }
1527                    _ => panic!("Expected right to be a cast"),
1528                },
1529                _ => panic!("Expected binary expression"),
1530            }
1531        }
1532    }
1533
1534    #[test]
1535    fn test_sql_array_literals() {
1536        let cases = [
1537            (
1538                "x = [1, 2, 3]",
1539                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1540            ),
1541            (
1542                "x = [1, 2, 3]",
1543                ArrowDataType::FixedSizeList(
1544                    Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1545                    3,
1546                ),
1547            ),
1548        ];
1549
1550        for (sql, expected_data_type) in cases {
1551            let schema = Arc::new(Schema::new(vec![Field::new(
1552                "x",
1553                expected_data_type.clone(),
1554                true,
1555            )]));
1556            let planner = Planner::new(schema.clone());
1557            let expr = planner.parse_filter(sql).unwrap();
1558            let expr = planner.optimize_expr(expr).unwrap();
1559
1560            match expr {
1561                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1562                    Expr::Literal(value, _) => {
1563                        assert_eq!(&value.data_type(), &expected_data_type);
1564                    }
1565                    _ => panic!("Expected right to be a literal"),
1566                },
1567                _ => panic!("Expected binary expression"),
1568            }
1569        }
1570    }
1571
1572    #[test]
1573    fn test_sql_between() {
1574        use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1575        use arrow_schema::{DataType, Field, Schema, TimeUnit};
1576        use std::sync::Arc;
1577
1578        let schema = Arc::new(Schema::new(vec![
1579            Field::new("x", DataType::Int32, false),
1580            Field::new("y", DataType::Float64, false),
1581            Field::new(
1582                "ts",
1583                DataType::Timestamp(TimeUnit::Microsecond, None),
1584                false,
1585            ),
1586        ]));
1587
1588        let planner = Planner::new(schema.clone());
1589
1590        // Test integer BETWEEN
1591        let expr = planner
1592            .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1593            .unwrap();
1594        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1595
1596        // Create timestamp array with values representing:
1597        // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
1598        let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
1599        let ts_array = TimestampMicrosecondArray::from_iter_values(
1600            (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
1601        );
1602
1603        let batch = RecordBatch::try_new(
1604            schema,
1605            vec![
1606                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1607                Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1608                Arc::new(ts_array),
1609            ],
1610        )
1611        .unwrap();
1612
1613        let predicates = physical_expr.evaluate(&batch).unwrap();
1614        assert_eq!(
1615            predicates.into_array(0).unwrap().as_ref(),
1616            &BooleanArray::from(vec![
1617                false, false, false, true, true, true, true, true, false, false
1618            ])
1619        );
1620
1621        // Test NOT BETWEEN
1622        let expr = planner
1623            .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1624            .unwrap();
1625        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1626
1627        let predicates = physical_expr.evaluate(&batch).unwrap();
1628        assert_eq!(
1629            predicates.into_array(0).unwrap().as_ref(),
1630            &BooleanArray::from(vec![
1631                true, true, true, false, false, false, false, false, true, true
1632            ])
1633        );
1634
1635        // Test floating point BETWEEN
1636        let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1637        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1638
1639        let predicates = physical_expr.evaluate(&batch).unwrap();
1640        assert_eq!(
1641            predicates.into_array(0).unwrap().as_ref(),
1642            &BooleanArray::from(vec![
1643                false, false, false, true, true, true, true, false, false, false
1644            ])
1645        );
1646
1647        // Test timestamp BETWEEN
1648        let expr = planner
1649            .parse_filter(
1650                "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1651            )
1652            .unwrap();
1653        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1654
1655        let predicates = physical_expr.evaluate(&batch).unwrap();
1656        assert_eq!(
1657            predicates.into_array(0).unwrap().as_ref(),
1658            &BooleanArray::from(vec![
1659                false, false, false, true, true, true, true, true, false, false
1660            ])
1661        );
1662    }
1663
1664    #[test]
1665    fn test_sql_comparison() {
1666        // Create a batch with all data types
1667        let batch: Vec<(&str, ArrayRef)> = vec![
1668            (
1669                "timestamp_s",
1670                Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1671            ),
1672            (
1673                "timestamp_ms",
1674                Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1675            ),
1676            (
1677                "timestamp_us",
1678                Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1679            ),
1680            (
1681                "timestamp_ns",
1682                Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1683            ),
1684        ];
1685        let batch = RecordBatch::try_from_iter(batch).unwrap();
1686
1687        let planner = Planner::new(batch.schema());
1688
1689        // Each expression is meant to select the final 5 rows
1690        let expressions = &[
1691            "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1692            "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1693            "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1694            "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1695        ];
1696
1697        let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1698            std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1699        ));
1700        for expression in expressions {
1701            // convert to physical expression
1702            let logical_expr = planner.parse_filter(expression).unwrap();
1703            let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1704            let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1705
1706            // Evaluate and assert they have correct results
1707            let result = physical_expr.evaluate(&batch).unwrap();
1708            let result = result.into_array(batch.num_rows()).unwrap();
1709            assert_eq!(&expected, &result, "unexpected result for {}", expression);
1710        }
1711    }
1712
1713    #[test]
1714    fn test_columns_in_expr() {
1715        let expr = col("s0").gt(lit("value")).and(
1716            col("st")
1717                .field("st")
1718                .field("s2")
1719                .eq(lit("value"))
1720                .or(col("st")
1721                    .field("s1")
1722                    .in_list(vec![lit("value 1"), lit("value 2")], false)),
1723        );
1724
1725        let columns = Planner::column_names_in_expr(&expr);
1726        assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1727    }
1728
1729    #[test]
1730    fn test_parse_binary_expr() {
1731        let bin_str = "x'616263'";
1732
1733        let schema = Arc::new(Schema::new(vec![Field::new(
1734            "binary",
1735            DataType::Binary,
1736            true,
1737        )]));
1738        let planner = Planner::new(schema);
1739        let expr = planner.parse_expr(bin_str).unwrap();
1740        assert_eq!(
1741            expr,
1742            Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1743        );
1744    }
1745
1746    #[test]
1747    fn test_lance_context_provider_expr_planners() {
1748        let ctx_provider = LanceContextProvider::default();
1749        assert!(!ctx_provider.get_expr_planners().is_empty());
1750    }
1751
1752    #[test]
1753    fn test_regexp_match_and_non_empty_captions() {
1754        // Repro for a bug where regexp_match inside an AND chain wasn't coerced to boolean,
1755        // causing planning/evaluation failures. This should evaluate successfully.
1756        let schema = Arc::new(Schema::new(vec![
1757            Field::new("keywords", DataType::Utf8, true),
1758            Field::new("natural_caption", DataType::Utf8, true),
1759            Field::new("poetic_caption", DataType::Utf8, true),
1760        ]));
1761
1762        let planner = Planner::new(schema.clone());
1763
1764        let expr = planner
1765            .parse_filter(
1766                "regexp_match(keywords, 'Liberty|revolution') AND \
1767                 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1768                  poetic_caption IS NOT NULL AND poetic_caption <> '')",
1769            )
1770            .unwrap();
1771
1772        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1773
1774        let batch = RecordBatch::try_new(
1775            schema,
1776            vec![
1777                Arc::new(StringArray::from(vec![
1778                    Some("Liberty for all"),
1779                    Some("peace"),
1780                    Some("revolution now"),
1781                    Some("Liberty"),
1782                    Some("revolutionary"),
1783                    Some("none"),
1784                ])) as ArrayRef,
1785                Arc::new(StringArray::from(vec![
1786                    Some("a"),
1787                    Some("b"),
1788                    None,
1789                    Some(""),
1790                    Some("c"),
1791                    Some("d"),
1792                ])) as ArrayRef,
1793                Arc::new(StringArray::from(vec![
1794                    Some("x"),
1795                    Some(""),
1796                    Some("y"),
1797                    Some("z"),
1798                    None,
1799                    Some("w"),
1800                ])) as ArrayRef,
1801            ],
1802        )
1803        .unwrap();
1804
1805        let result = physical_expr.evaluate(&batch).unwrap();
1806        assert_eq!(
1807            result.into_array(0).unwrap().as_ref(),
1808            &BooleanArray::from(vec![true, false, false, false, false, false])
1809        );
1810    }
1811
1812    #[test]
1813    fn test_regexp_match_infer_error_without_boolean_coercion() {
1814        // With the fix applied, using parse_filter should coerce regexp_match to boolean
1815        // even when nested in a larger AND expression, so this should plan successfully.
1816        let schema = Arc::new(Schema::new(vec![
1817            Field::new("keywords", DataType::Utf8, true),
1818            Field::new("natural_caption", DataType::Utf8, true),
1819            Field::new("poetic_caption", DataType::Utf8, true),
1820        ]));
1821
1822        let planner = Planner::new(schema);
1823
1824        let expr = planner
1825            .parse_filter(
1826                "regexp_match(keywords, 'Liberty|revolution') AND \
1827                 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1828                  poetic_caption IS NOT NULL AND poetic_caption <> '')",
1829            )
1830            .unwrap();
1831
1832        // Should not panic
1833        let _physical = planner.create_physical_expr(&expr).unwrap();
1834    }
1835}