lance_datafusion/
planner.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Exec plan planner
5
6use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::{SessionContext, SessionState};
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29    AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30    Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{
34    ContextProvider, NullOrdering, ParserOptions, PlannerContext, SqlToRel,
35};
36use datafusion::sql::sqlparser::ast::{
37    AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
38    Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
39    ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
40};
41use datafusion::{
42    common::Column,
43    logical_expr::{Between, BinaryExpr, Like, Operator},
44    physical_expr::execution_props::ExecutionProps,
45    physical_plan::PhysicalExpr,
46    prelude::Expr,
47    scalar::ScalarValue,
48};
49use datafusion_functions::core::getfield::GetFieldFunc;
50use lance_arrow::cast::cast_with_options;
51use lance_core::datatypes::Schema;
52use lance_core::error::LanceOptionExt;
53use snafu::location;
54
55use chrono::Utc;
56use lance_core::{Error, Result};
57
58#[derive(Debug, Clone, Eq, PartialEq, Hash)]
59struct CastListF16Udf {
60    signature: Signature,
61}
62
63impl CastListF16Udf {
64    pub fn new() -> Self {
65        Self {
66            signature: Signature::any(1, Volatility::Immutable),
67        }
68    }
69}
70
71impl ScalarUDFImpl for CastListF16Udf {
72    fn as_any(&self) -> &dyn std::any::Any {
73        self
74    }
75
76    fn name(&self) -> &str {
77        "_cast_list_f16"
78    }
79
80    fn signature(&self) -> &Signature {
81        &self.signature
82    }
83
84    fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
85        let input = &arg_types[0];
86        match input {
87            ArrowDataType::FixedSizeList(field, size) => {
88                if field.data_type() != &ArrowDataType::Float32
89                    && field.data_type() != &ArrowDataType::Float16
90                {
91                    return Err(datafusion::error::DataFusionError::Execution(
92                        "cast_list_f16 only supports list of float32 or float16".to_string(),
93                    ));
94                }
95                Ok(ArrowDataType::FixedSizeList(
96                    Arc::new(Field::new(
97                        field.name(),
98                        ArrowDataType::Float16,
99                        field.is_nullable(),
100                    )),
101                    *size,
102                ))
103            }
104            ArrowDataType::List(field) => {
105                if field.data_type() != &ArrowDataType::Float32
106                    && field.data_type() != &ArrowDataType::Float16
107                {
108                    return Err(datafusion::error::DataFusionError::Execution(
109                        "cast_list_f16 only supports list of float32 or float16".to_string(),
110                    ));
111                }
112                Ok(ArrowDataType::List(Arc::new(Field::new(
113                    field.name(),
114                    ArrowDataType::Float16,
115                    field.is_nullable(),
116                ))))
117            }
118            _ => Err(datafusion::error::DataFusionError::Execution(
119                "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
120            )),
121        }
122    }
123
124    fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
125        let ColumnarValue::Array(arr) = &func_args.args[0] else {
126            return Err(datafusion::error::DataFusionError::Execution(
127                "cast_list_f16 only supports array arguments".to_string(),
128            ));
129        };
130
131        let to_type = match arr.data_type() {
132            ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
133                Arc::new(Field::new(
134                    field.name(),
135                    ArrowDataType::Float16,
136                    field.is_nullable(),
137                )),
138                *size,
139            ),
140            ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
141                field.name(),
142                ArrowDataType::Float16,
143                field.is_nullable(),
144            ))),
145            _ => {
146                return Err(datafusion::error::DataFusionError::Execution(
147                    "cast_list_f16 only supports array arguments".to_string(),
148                ));
149            }
150        };
151
152        let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
153        Ok(ColumnarValue::Array(res))
154    }
155}
156
157// Adapter that instructs datafusion how lance expects expressions to be interpreted
158struct LanceContextProvider {
159    options: datafusion::config::ConfigOptions,
160    state: SessionState,
161    expr_planners: Vec<Arc<dyn ExprPlanner>>,
162}
163
164impl Default for LanceContextProvider {
165    fn default() -> Self {
166        let config = SessionConfig::new();
167        let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
168
169        let ctx = SessionContext::new_with_config_rt(config.clone(), runtime.clone());
170        crate::udf::register_functions(&ctx);
171
172        let state = ctx.state();
173
174        // SessionState does not expose expr_planners, so we need to get them separately
175        let mut state_builder = SessionStateBuilder::new()
176            .with_config(config)
177            .with_runtime_env(runtime)
178            .with_default_features();
179
180        // unwrap safe because with_default_features sets expr_planners
181        let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
182
183        Self {
184            options: ConfigOptions::default(),
185            state,
186            expr_planners,
187        }
188    }
189}
190
191impl ContextProvider for LanceContextProvider {
192    fn get_table_source(
193        &self,
194        name: datafusion::sql::TableReference,
195    ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
196        Err(datafusion::error::DataFusionError::NotImplemented(format!(
197            "Attempt to reference inner table {} not supported",
198            name
199        )))
200    }
201
202    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
203        self.state.aggregate_functions().get(name).cloned()
204    }
205
206    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
207        self.state.window_functions().get(name).cloned()
208    }
209
210    fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
211        match f {
212            // TODO: cast should go thru CAST syntax instead of UDF
213            // Going thru UDF makes it hard for the optimizer to find no-ops
214            "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
215            _ => self.state.scalar_functions().get(f).cloned(),
216        }
217    }
218
219    fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
220        // Variables (things like @@LANGUAGE) not supported
221        None
222    }
223
224    fn options(&self) -> &datafusion::config::ConfigOptions {
225        &self.options
226    }
227
228    fn udf_names(&self) -> Vec<String> {
229        self.state.scalar_functions().keys().cloned().collect()
230    }
231
232    fn udaf_names(&self) -> Vec<String> {
233        self.state.aggregate_functions().keys().cloned().collect()
234    }
235
236    fn udwf_names(&self) -> Vec<String> {
237        self.state.window_functions().keys().cloned().collect()
238    }
239
240    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
241        &self.expr_planners
242    }
243}
244
245pub struct Planner {
246    schema: SchemaRef,
247    context_provider: LanceContextProvider,
248    enable_relations: bool,
249}
250
251impl Planner {
252    pub fn new(schema: SchemaRef) -> Self {
253        Self {
254            schema,
255            context_provider: LanceContextProvider::default(),
256            enable_relations: false,
257        }
258    }
259
260    /// If passed with `true`, then the first identifier in column reference
261    /// is parsed as the relation. For example, `table.field.inner` will be
262    /// read as the nested field `field.inner` (`inner` on struct field `field`)
263    /// on the `table` relation. If `false` (the default), then no relations
264    /// are used and all identifiers are assumed to be a nested column path.
265    pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
266        self.enable_relations = enable_relations;
267        self
268    }
269
270    /// Resolve a column name using case-insensitive matching against the schema.
271    /// Returns the actual field name if found, otherwise returns the original name.
272    fn resolve_column_name(&self, name: &str) -> String {
273        // Try exact match first
274        if self.schema.field_with_name(name).is_ok() {
275            return name.to_string();
276        }
277        // Fall back to case-insensitive match
278        for field in self.schema.fields() {
279            if field.name().eq_ignore_ascii_case(name) {
280                return field.name().clone();
281            }
282        }
283        // Not found in schema - return original (might be computed column, system column, etc.)
284        name.to_string()
285    }
286
287    fn column(&self, idents: &[Ident]) -> Expr {
288        fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
289            for ident in idents {
290                *expr = Expr::ScalarFunction(ScalarFunction {
291                    args: vec![
292                        std::mem::take(expr),
293                        Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
294                    ],
295                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
296                });
297            }
298        }
299
300        if self.enable_relations && idents.len() > 1 {
301            // Create qualified column reference (relation.column)
302            let relation = &idents[0].value;
303            let column_name = self.resolve_column_name(&idents[1].value);
304            let column = Expr::Column(Column::new(Some(relation.clone()), column_name));
305            let mut result = column;
306            handle_remaining_idents(&mut result, &idents[2..]);
307            result
308        } else {
309            // Default behavior - treat as struct field access
310            // Use resolved column name to handle case-insensitive matching
311            let resolved_name = self.resolve_column_name(&idents[0].value);
312            let mut column = Expr::Column(Column::from_name(resolved_name));
313            handle_remaining_idents(&mut column, &idents[1..]);
314            column
315        }
316    }
317
318    fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
319        Ok(match op {
320            BinaryOperator::Plus => Operator::Plus,
321            BinaryOperator::Minus => Operator::Minus,
322            BinaryOperator::Multiply => Operator::Multiply,
323            BinaryOperator::Divide => Operator::Divide,
324            BinaryOperator::Modulo => Operator::Modulo,
325            BinaryOperator::StringConcat => Operator::StringConcat,
326            BinaryOperator::Gt => Operator::Gt,
327            BinaryOperator::Lt => Operator::Lt,
328            BinaryOperator::GtEq => Operator::GtEq,
329            BinaryOperator::LtEq => Operator::LtEq,
330            BinaryOperator::Eq => Operator::Eq,
331            BinaryOperator::NotEq => Operator::NotEq,
332            BinaryOperator::And => Operator::And,
333            BinaryOperator::Or => Operator::Or,
334            _ => {
335                return Err(Error::invalid_input(
336                    format!("Operator {op} is not supported"),
337                    location!(),
338                ));
339            }
340        })
341    }
342
343    fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
344        Ok(Expr::BinaryExpr(BinaryExpr::new(
345            Box::new(self.parse_sql_expr(left)?),
346            self.binary_op(op)?,
347            Box::new(self.parse_sql_expr(right)?),
348        )))
349    }
350
351    fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
352        Ok(match op {
353            UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
354                Expr::Not(Box::new(self.parse_sql_expr(expr)?))
355            }
356
357            UnaryOperator::Minus => {
358                use datafusion::logical_expr::lit;
359                match expr {
360                    SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
361                        Ok(n) => lit(-n),
362                        Err(_) => lit(-n
363                            .parse::<f64>()
364                            .map_err(|_e| {
365                                Error::invalid_input(
366                                    format!("negative operator can be only applied to integer and float operands, got: {n}"),
367                                    location!(),
368                                )
369                            })?),
370                    },
371                    _ => {
372                        Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
373                    }
374                }
375            }
376
377            _ => {
378                return Err(Error::invalid_input(
379                    format!("Unary operator '{:?}' is not supported", op),
380                    location!(),
381                ));
382            }
383        })
384    }
385
386    // See datafusion `sqlToRel::parse_sql_number()`
387    fn number(&self, value: &str, negative: bool) -> Result<Expr> {
388        use datafusion::logical_expr::lit;
389        let value: Cow<str> = if negative {
390            Cow::Owned(format!("-{}", value))
391        } else {
392            Cow::Borrowed(value)
393        };
394        if let Ok(n) = value.parse::<i64>() {
395            Ok(lit(n))
396        } else {
397            value.parse::<f64>().map(lit).map_err(|_| {
398                Error::invalid_input(
399                    format!("'{value}' is not supported number value."),
400                    location!(),
401                )
402            })
403        }
404    }
405
406    fn value(&self, value: &Value) -> Result<Expr> {
407        Ok(match value {
408            Value::Number(v, _) => self.number(v.as_str(), false)?,
409            Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
410            Value::HexStringLiteral(hsl) => {
411                Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
412            }
413            Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
414            Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
415            Value::Null => Expr::Literal(ScalarValue::Null, None),
416            _ => todo!(),
417        })
418    }
419
420    fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
421        match func_args {
422            FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
423            _ => Err(Error::invalid_input(
424                format!("Unsupported function args: {:?}", func_args),
425                location!(),
426            )),
427        }
428    }
429
430    // We now use datafusion to parse functions.  This allows us to use datafusion's
431    // entire collection of functions (previously we had just hard-coded support for two functions).
432    //
433    // Unfortunately, one of those two functions was is_valid and the reason we needed it was because
434    // this is a function that comes from duckdb.  Datafusion does not consider is_valid to be a function
435    // but rather an AST node (Expr::IsNotNull) and so we need to handle this case specially.
436    fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
437        match &func.args {
438            FunctionArguments::List(args) => {
439                if func.name.0.len() != 1 {
440                    return Err(Error::invalid_input(
441                        format!("Function name must have 1 part, got: {:?}", func.name.0),
442                        location!(),
443                    ));
444                }
445                Ok(Expr::IsNotNull(Box::new(
446                    self.parse_function_args(&args.args[0])?,
447                )))
448            }
449            _ => Err(Error::invalid_input(
450                format!("Unsupported function args: {:?}", &func.args),
451                location!(),
452            )),
453        }
454    }
455
456    fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
457        if let SQLExpr::Function(function) = &function {
458            if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
459                if &name.value == "is_valid" {
460                    return self.legacy_parse_function(function);
461                }
462            }
463        }
464        let sql_to_rel = SqlToRel::new_with_options(
465            &self.context_provider,
466            ParserOptions {
467                parse_float_as_decimal: false,
468                enable_ident_normalization: false,
469                support_varchar_with_length: false,
470                enable_options_value_normalization: false,
471                collect_spans: false,
472                map_string_types_to_utf8view: false,
473                default_null_ordering: NullOrdering::NullsMax,
474            },
475        );
476
477        let mut planner_context = PlannerContext::default();
478        let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
479        sql_to_rel
480            .sql_to_expr(function, &schema, &mut planner_context)
481            .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
482    }
483
484    fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
485        const SUPPORTED_TYPES: [&str; 13] = [
486            "int [unsigned]",
487            "tinyint [unsigned]",
488            "smallint [unsigned]",
489            "bigint [unsigned]",
490            "float",
491            "double",
492            "string",
493            "binary",
494            "date",
495            "timestamp(precision)",
496            "datetime(precision)",
497            "decimal(precision,scale)",
498            "boolean",
499        ];
500        match data_type {
501            SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
502            SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
503            SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
504            SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
505            SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
506            SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
507            SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
508            SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
509            SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
510            SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
511            SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
512            SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
513                Ok(ArrowDataType::UInt32)
514            }
515            SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
516            SQLDataType::Date => Ok(ArrowDataType::Date32),
517            SQLDataType::Timestamp(resolution, tz) => {
518                match tz {
519                    TimezoneInfo::None => {}
520                    _ => {
521                        return Err(Error::invalid_input(
522                            "Timezone not supported in timestamp".to_string(),
523                            location!(),
524                        ));
525                    }
526                };
527                let time_unit = match resolution {
528                    // Default to microsecond to match PyArrow
529                    None => TimeUnit::Microsecond,
530                    Some(0) => TimeUnit::Second,
531                    Some(3) => TimeUnit::Millisecond,
532                    Some(6) => TimeUnit::Microsecond,
533                    Some(9) => TimeUnit::Nanosecond,
534                    _ => {
535                        return Err(Error::invalid_input(
536                            format!("Unsupported datetime resolution: {:?}", resolution),
537                            location!(),
538                        ));
539                    }
540                };
541                Ok(ArrowDataType::Timestamp(time_unit, None))
542            }
543            SQLDataType::Datetime(resolution) => {
544                let time_unit = match resolution {
545                    None => TimeUnit::Microsecond,
546                    Some(0) => TimeUnit::Second,
547                    Some(3) => TimeUnit::Millisecond,
548                    Some(6) => TimeUnit::Microsecond,
549                    Some(9) => TimeUnit::Nanosecond,
550                    _ => {
551                        return Err(Error::invalid_input(
552                            format!("Unsupported datetime resolution: {:?}", resolution),
553                            location!(),
554                        ));
555                    }
556                };
557                Ok(ArrowDataType::Timestamp(time_unit, None))
558            }
559            SQLDataType::Decimal(number_info) => match number_info {
560                ExactNumberInfo::PrecisionAndScale(precision, scale) => {
561                    Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
562                }
563                _ => Err(Error::invalid_input(
564                    format!(
565                        "Must provide precision and scale for decimal: {:?}",
566                        number_info
567                    ),
568                    location!(),
569                )),
570            },
571            _ => Err(Error::invalid_input(
572                format!(
573                    "Unsupported data type: {:?}. Supported types: {:?}",
574                    data_type, SUPPORTED_TYPES
575                ),
576                location!(),
577            )),
578        }
579    }
580
581    fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
582        let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
583        for planner in self.context_provider.get_expr_planners() {
584            match planner.plan_field_access(field_access_expr, &df_schema)? {
585                PlannerResult::Planned(expr) => return Ok(expr),
586                PlannerResult::Original(expr) => {
587                    field_access_expr = expr;
588                }
589            }
590        }
591        Err(Error::invalid_input(
592            "Field access could not be planned",
593            location!(),
594        ))
595    }
596
597    fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
598        match expr {
599            SQLExpr::Identifier(id) => {
600                // Users can pass string literals wrapped in `"`.
601                // (Normally SQL only allows single quotes.)
602                if id.quote_style == Some('"') {
603                    Ok(Expr::Literal(
604                        ScalarValue::Utf8(Some(id.value.clone())),
605                        None,
606                    ))
607                // Users can wrap identifiers with ` to reference non-standard
608                // names, such as uppercase or spaces.
609                } else if id.quote_style == Some('`') {
610                    Ok(Expr::Column(Column::from_name(id.value.clone())))
611                } else {
612                    Ok(self.column(vec![id.clone()].as_slice()))
613                }
614            }
615            SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
616            SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
617            SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
618            SQLExpr::Value(value) => self.value(&value.value),
619            SQLExpr::Array(SQLArray { elem, .. }) => {
620                let mut values = vec![];
621
622                let array_literal_error = |pos: usize, value: &_| {
623                    Err(Error::invalid_input(
624                        format!(
625                            "Expected a literal value in array, instead got {} at position {}",
626                            value, pos
627                        ),
628                        location!(),
629                    ))
630                };
631
632                for (pos, expr) in elem.iter().enumerate() {
633                    match expr {
634                        SQLExpr::Value(value) => {
635                            if let Expr::Literal(value, _) = self.value(&value.value)? {
636                                values.push(value);
637                            } else {
638                                return array_literal_error(pos, expr);
639                            }
640                        }
641                        SQLExpr::UnaryOp {
642                            op: UnaryOperator::Minus,
643                            expr,
644                        } => {
645                            if let SQLExpr::Value(ValueWithSpan {
646                                value: Value::Number(number, _),
647                                ..
648                            }) = expr.as_ref()
649                            {
650                                if let Expr::Literal(value, _) = self.number(number, true)? {
651                                    values.push(value);
652                                } else {
653                                    return array_literal_error(pos, expr);
654                                }
655                            } else {
656                                return array_literal_error(pos, expr);
657                            }
658                        }
659                        _ => {
660                            return array_literal_error(pos, expr);
661                        }
662                    }
663                }
664
665                let field = if !values.is_empty() {
666                    let data_type = values[0].data_type();
667
668                    for value in &mut values {
669                        if value.data_type() != data_type {
670                            *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
671                                format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
672                                location!()
673                            ))?;
674                        }
675                    }
676                    Field::new("item", data_type, true)
677                } else {
678                    Field::new("item", ArrowDataType::Null, true)
679                };
680
681                let values = values
682                    .into_iter()
683                    .map(|v| v.to_array().map_err(Error::from))
684                    .collect::<Result<Vec<_>>>()?;
685                let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
686                let values = concat(&array_refs)?;
687                let values = ListArray::try_new(
688                    field.into(),
689                    OffsetBuffer::from_lengths([values.len()]),
690                    values,
691                    None,
692                )?;
693
694                Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
695            }
696            // For example, DATE '2020-01-01'
697            SQLExpr::TypedString { data_type, value } => {
698                let value = value.clone().into_string().expect_ok()?;
699                Ok(Expr::Cast(datafusion::logical_expr::Cast {
700                    expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
701                    data_type: self.parse_type(data_type)?,
702                }))
703            }
704            SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
705            SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
706            SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
707            SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
708            SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
709            SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
710            SQLExpr::InList {
711                expr,
712                list,
713                negated,
714            } => {
715                let value_expr = self.parse_sql_expr(expr)?;
716                let list_exprs = list
717                    .iter()
718                    .map(|e| self.parse_sql_expr(e))
719                    .collect::<Result<Vec<_>>>()?;
720                Ok(value_expr.in_list(list_exprs, *negated))
721            }
722            SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
723            SQLExpr::Function(_) => self.parse_function(expr.clone()),
724            SQLExpr::ILike {
725                negated,
726                expr,
727                pattern,
728                escape_char,
729                any: _,
730            } => Ok(Expr::Like(Like::new(
731                *negated,
732                Box::new(self.parse_sql_expr(expr)?),
733                Box::new(self.parse_sql_expr(pattern)?),
734                match escape_char {
735                    Some(Value::SingleQuotedString(char)) => char.chars().next(),
736                    Some(value) => return Err(Error::invalid_input(
737                        format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
738                        location!()
739                    )),
740                    None => None,
741                },
742                true,
743            ))),
744            SQLExpr::Like {
745                negated,
746                expr,
747                pattern,
748                escape_char,
749                any: _,
750            } => Ok(Expr::Like(Like::new(
751                *negated,
752                Box::new(self.parse_sql_expr(expr)?),
753                Box::new(self.parse_sql_expr(pattern)?),
754                match escape_char {
755                    Some(Value::SingleQuotedString(char)) => char.chars().next(),
756                    Some(value) => return Err(Error::invalid_input(
757                        format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
758                        location!()
759                    )),
760                    None => None,
761                },
762                false,
763            ))),
764            SQLExpr::Cast {
765                expr,
766                data_type,
767                kind,
768                ..
769            } => match kind {
770                datafusion::sql::sqlparser::ast::CastKind::TryCast
771                | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
772                    Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
773                        expr: Box::new(self.parse_sql_expr(expr)?),
774                        data_type: self.parse_type(data_type)?,
775                    }))
776                }
777                _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
778                    expr: Box::new(self.parse_sql_expr(expr)?),
779                    data_type: self.parse_type(data_type)?,
780                })),
781            },
782            SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
783                "JSON access is not supported",
784                location!(),
785            )),
786            SQLExpr::CompoundFieldAccess { root, access_chain } => {
787                let mut expr = self.parse_sql_expr(root)?;
788
789                for access in access_chain {
790                    let field_access = match access {
791                        // x.y or x['y']
792                        AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
793                        | AccessExpr::Subscript(Subscript::Index {
794                            index:
795                                SQLExpr::Value(ValueWithSpan {
796                                    value:
797                                        Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
798                                    ..
799                                }),
800                        }) => GetFieldAccess::NamedStructField {
801                            name: ScalarValue::from(s.as_str()),
802                        },
803                        AccessExpr::Subscript(Subscript::Index { index }) => {
804                            let key = Box::new(self.parse_sql_expr(index)?);
805                            GetFieldAccess::ListIndex { key }
806                        }
807                        AccessExpr::Subscript(Subscript::Slice { .. }) => {
808                            return Err(Error::invalid_input(
809                                "Slice subscript is not supported",
810                                location!(),
811                            ));
812                        }
813                        _ => {
814                            // Handle other cases like JSON access
815                            // Note: JSON access is not supported in lance
816                            return Err(Error::invalid_input(
817                                "Only dot notation or index access is supported for field access",
818                                location!(),
819                            ));
820                        }
821                    };
822
823                    let field_access_expr = RawFieldAccessExpr { expr, field_access };
824                    expr = self.plan_field_access(field_access_expr)?;
825                }
826
827                Ok(expr)
828            }
829            SQLExpr::Between {
830                expr,
831                negated,
832                low,
833                high,
834            } => {
835                // Parse the main expression and bounds
836                let expr = self.parse_sql_expr(expr)?;
837                let low = self.parse_sql_expr(low)?;
838                let high = self.parse_sql_expr(high)?;
839
840                let between = Expr::Between(Between::new(
841                    Box::new(expr),
842                    *negated,
843                    Box::new(low),
844                    Box::new(high),
845                ));
846                Ok(between)
847            }
848            _ => Err(Error::invalid_input(
849                format!("Expression '{expr}' is not supported SQL in lance"),
850                location!(),
851            )),
852        }
853    }
854
855    /// Create Logical [Expr] from a SQL filter clause.
856    ///
857    /// Note: the returned expression must be passed through [optimize_expr()]
858    /// before being passed to [create_physical_expr()].
859    pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
860        // Allow sqlparser to parse filter as part of ONE SQL statement.
861        let ast_expr = parse_sql_filter(filter)?;
862        let expr = self.parse_sql_expr(&ast_expr)?;
863        let schema = Schema::try_from(self.schema.as_ref())?;
864        let resolved = resolve_expr(&expr, &schema).map_err(|e| {
865            Error::invalid_input(
866                format!("Error resolving filter expression {filter}: {e}"),
867                location!(),
868            )
869        })?;
870
871        Ok(coerce_filter_type_to_boolean(resolved))
872    }
873
874    /// Create Logical [Expr] from a SQL expression.
875    ///
876    /// Note: the returned expression must be passed through [optimize_filter()]
877    /// before being passed to [create_physical_expr()].
878    pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
879        // First check if it's a simple column reference (no operators, functions, etc.)
880        // resolve_column_name tries exact match first, then falls back to case-insensitive
881        let resolved_name = self.resolve_column_name(expr);
882        if self.schema.field_with_name(&resolved_name).is_ok() {
883            return Ok(Expr::Column(Column::from_name(resolved_name)));
884        }
885
886        // Parse as SQL expression
887        let ast_expr = parse_sql_expr(expr)?;
888        let expr = self.parse_sql_expr(&ast_expr)?;
889        let schema = Schema::try_from(self.schema.as_ref())?;
890        let resolved = resolve_expr(&expr, &schema)?;
891        Ok(resolved)
892    }
893
894    /// Try to decode bytes from hex literal string.
895    ///
896    /// Copied from datafusion because this is not public.
897    ///
898    /// TODO: use SqlToRel from Datafusion directly?
899    fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
900        let hex_bytes = s.as_bytes();
901        let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
902
903        let start_idx = hex_bytes.len() % 2;
904        if start_idx > 0 {
905            // The first byte is formed of only one char.
906            decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
907        }
908
909        for i in (start_idx..hex_bytes.len()).step_by(2) {
910            let high = Self::try_decode_hex_char(hex_bytes[i])?;
911            let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
912            decoded_bytes.push((high << 4) | low);
913        }
914
915        Some(decoded_bytes)
916    }
917
918    /// Try to decode a byte from a hex char.
919    ///
920    /// None will be returned if the input char is hex-invalid.
921    const fn try_decode_hex_char(c: u8) -> Option<u8> {
922        match c {
923            b'A'..=b'F' => Some(c - b'A' + 10),
924            b'a'..=b'f' => Some(c - b'a' + 10),
925            b'0'..=b'9' => Some(c - b'0'),
926            _ => None,
927        }
928    }
929
930    /// Optimize the filter expression and coerce data types.
931    pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
932        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
933
934        // DataFusion needs the simplify and coerce passes to be applied before
935        // expressions can be handled by the physical planner.
936        let props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
937        let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
938        let simplifier =
939            datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
940
941        let expr = simplifier.simplify(expr)?;
942        let expr = simplifier.coerce(expr, &df_schema)?;
943
944        Ok(expr)
945    }
946
947    /// Create the [`PhysicalExpr`] from a logical [`Expr`]
948    pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
949        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
950        Ok(datafusion::physical_expr::create_physical_expr(
951            expr,
952            df_schema.as_ref(),
953            &Default::default(),
954        )?)
955    }
956
957    /// Collect the columns in the expression.
958    ///
959    /// The columns are returned in sorted order.
960    ///
961    /// If the expr refers to nested columns these will be returned
962    /// as dotted paths (x.y.z)
963    pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
964        let mut visitor = ColumnCapturingVisitor {
965            current_path: VecDeque::new(),
966            columns: BTreeSet::new(),
967        };
968        expr.visit(&mut visitor).unwrap();
969        visitor.columns.into_iter().collect()
970    }
971}
972
973struct ColumnCapturingVisitor {
974    // Current column path. If this is empty, we are not in a column expression.
975    current_path: VecDeque<String>,
976    columns: BTreeSet<String>,
977}
978
979impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
980    type Node = Expr;
981
982    fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
983        match node {
984            Expr::Column(Column { name, .. }) => {
985                // Build the field path from the column name and any nested fields
986                // The nested field names from get_field already come as literal strings,
987                // so we just need to concatenate them properly
988                let mut path = name.clone();
989                for part in self.current_path.drain(..) {
990                    path.push('.');
991                    // Check if the part needs quoting (contains dots)
992                    if part.contains('.') || part.contains('`') {
993                        // Quote the field name with backticks and escape any existing backticks
994                        let escaped = part.replace('`', "``");
995                        path.push('`');
996                        path.push_str(&escaped);
997                        path.push('`');
998                    } else {
999                        path.push_str(&part);
1000                    }
1001                }
1002                self.columns.insert(path);
1003                self.current_path.clear();
1004            }
1005            Expr::ScalarFunction(udf) => {
1006                if udf.name() == GetFieldFunc::default().name() {
1007                    if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
1008                        self.current_path.push_front(name.to_string())
1009                    } else {
1010                        self.current_path.clear();
1011                    }
1012                } else {
1013                    self.current_path.clear();
1014                }
1015            }
1016            _ => {
1017                self.current_path.clear();
1018            }
1019        }
1020
1021        Ok(TreeNodeRecursion::Continue)
1022    }
1023}
1024
1025#[cfg(test)]
1026mod tests {
1027
1028    use crate::logical_expr::ExprExt;
1029
1030    use super::*;
1031
1032    use arrow::datatypes::Float64Type;
1033    use arrow_array::{
1034        ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
1035        StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1036        TimestampNanosecondArray, TimestampSecondArray,
1037    };
1038    use arrow_schema::{DataType, Fields, Schema};
1039    use datafusion::{
1040        logical_expr::{col, lit, Cast},
1041        prelude::{array_element, get_field},
1042    };
1043    use datafusion_functions::core::expr_ext::FieldAccessor;
1044
1045    #[test]
1046    fn test_parse_filter_simple() {
1047        let schema = Arc::new(Schema::new(vec![
1048            Field::new("i", DataType::Int32, false),
1049            Field::new("s", DataType::Utf8, true),
1050            Field::new(
1051                "st",
1052                DataType::Struct(Fields::from(vec![
1053                    Field::new("x", DataType::Float32, false),
1054                    Field::new("y", DataType::Float32, false),
1055                ])),
1056                true,
1057            ),
1058        ]));
1059
1060        let planner = Planner::new(schema.clone());
1061
1062        let expected = col("i")
1063            .gt(lit(3_i32))
1064            .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1065            .and(
1066                col("s")
1067                    .eq(lit("str-4"))
1068                    .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1069            );
1070
1071        // double quotes
1072        let expr = planner
1073            .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1074            .unwrap();
1075        assert_eq!(expr, expected);
1076
1077        // single quote
1078        let expr = planner
1079            .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1080            .unwrap();
1081
1082        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1083
1084        let batch = RecordBatch::try_new(
1085            schema,
1086            vec![
1087                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1088                Arc::new(StringArray::from_iter_values(
1089                    (0..10).map(|v| format!("str-{}", v)),
1090                )),
1091                Arc::new(StructArray::from(vec![
1092                    (
1093                        Arc::new(Field::new("x", DataType::Float32, false)),
1094                        Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1095                            as ArrayRef,
1096                    ),
1097                    (
1098                        Arc::new(Field::new("y", DataType::Float32, false)),
1099                        Arc::new(Float32Array::from_iter_values(
1100                            (0..10).map(|v| (v * 10) as f32),
1101                        )),
1102                    ),
1103                ])),
1104            ],
1105        )
1106        .unwrap();
1107        let predicates = physical_expr.evaluate(&batch).unwrap();
1108        assert_eq!(
1109            predicates.into_array(0).unwrap().as_ref(),
1110            &BooleanArray::from(vec![
1111                false, false, false, false, true, true, false, false, false, false
1112            ])
1113        );
1114    }
1115
1116    #[test]
1117    fn test_nested_col_refs() {
1118        let schema = Arc::new(Schema::new(vec![
1119            Field::new("s0", DataType::Utf8, true),
1120            Field::new(
1121                "st",
1122                DataType::Struct(Fields::from(vec![
1123                    Field::new("s1", DataType::Utf8, true),
1124                    Field::new(
1125                        "st",
1126                        DataType::Struct(Fields::from(vec![Field::new(
1127                            "s2",
1128                            DataType::Utf8,
1129                            true,
1130                        )])),
1131                        true,
1132                    ),
1133                ])),
1134                true,
1135            ),
1136        ]));
1137
1138        let planner = Planner::new(schema);
1139
1140        fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1141            let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1142            assert!(matches!(
1143                expr,
1144                Expr::BinaryExpr(BinaryExpr {
1145                    left: _,
1146                    op: Operator::Eq,
1147                    right: _
1148                })
1149            ));
1150            if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1151                assert_eq!(left.as_ref(), expected);
1152            }
1153        }
1154
1155        let expected = Expr::Column(Column::new_unqualified("s0"));
1156        assert_column_eq(&planner, "s0", &expected);
1157        assert_column_eq(&planner, "`s0`", &expected);
1158
1159        let expected = Expr::ScalarFunction(ScalarFunction {
1160            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1161            args: vec![
1162                Expr::Column(Column::new_unqualified("st")),
1163                Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1164            ],
1165        });
1166        assert_column_eq(&planner, "st.s1", &expected);
1167        assert_column_eq(&planner, "`st`.`s1`", &expected);
1168        assert_column_eq(&planner, "st.`s1`", &expected);
1169
1170        let expected = Expr::ScalarFunction(ScalarFunction {
1171            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1172            args: vec![
1173                Expr::ScalarFunction(ScalarFunction {
1174                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1175                    args: vec![
1176                        Expr::Column(Column::new_unqualified("st")),
1177                        Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1178                    ],
1179                }),
1180                Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1181            ],
1182        });
1183
1184        assert_column_eq(&planner, "st.st.s2", &expected);
1185        assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1186        assert_column_eq(&planner, "st.st.`s2`", &expected);
1187        assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1188    }
1189
1190    #[test]
1191    fn test_nested_list_refs() {
1192        let schema = Arc::new(Schema::new(vec![Field::new(
1193            "l",
1194            DataType::List(Arc::new(Field::new(
1195                "item",
1196                DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1197                true,
1198            ))),
1199            true,
1200        )]));
1201
1202        let planner = Planner::new(schema);
1203
1204        let expected = array_element(col("l"), lit(0_i64));
1205        let expr = planner.parse_expr("l[0]").unwrap();
1206        assert_eq!(expr, expected);
1207
1208        let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1209        let expr = planner.parse_expr("l[0]['f1']").unwrap();
1210        assert_eq!(expr, expected);
1211
1212        // FIXME: This should work, but sqlparser doesn't recognize anything
1213        // after the period for some reason.
1214        // let expr = planner.parse_expr("l[0].f1").unwrap();
1215        // assert_eq!(expr, expected);
1216    }
1217
1218    #[test]
1219    fn test_negative_expressions() {
1220        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1221
1222        let planner = Planner::new(schema.clone());
1223
1224        let expected = col("x")
1225            .gt(lit(-3_i64))
1226            .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1227
1228        let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1229
1230        assert_eq!(expr, expected);
1231
1232        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1233
1234        let batch = RecordBatch::try_new(
1235            schema,
1236            vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1237        )
1238        .unwrap();
1239        let predicates = physical_expr.evaluate(&batch).unwrap();
1240        assert_eq!(
1241            predicates.into_array(0).unwrap().as_ref(),
1242            &BooleanArray::from(vec![
1243                false, false, false, true, true, true, true, false, false, false
1244            ])
1245        );
1246    }
1247
1248    #[test]
1249    fn test_negative_array_expressions() {
1250        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1251
1252        let planner = Planner::new(schema);
1253
1254        let expected = Expr::Literal(
1255            ScalarValue::List(Arc::new(
1256                ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1257                    [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1258                )]),
1259            )),
1260            None,
1261        );
1262
1263        let expr = planner
1264            .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1265            .unwrap();
1266
1267        assert_eq!(expr, expected);
1268    }
1269
1270    #[test]
1271    fn test_sql_like() {
1272        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1273
1274        let planner = Planner::new(schema.clone());
1275
1276        let expected = col("s").like(lit("str-4"));
1277        // single quote
1278        let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1279        assert_eq!(expr, expected);
1280        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1281
1282        let batch = RecordBatch::try_new(
1283            schema,
1284            vec![Arc::new(StringArray::from_iter_values(
1285                (0..10).map(|v| format!("str-{}", v)),
1286            ))],
1287        )
1288        .unwrap();
1289        let predicates = physical_expr.evaluate(&batch).unwrap();
1290        assert_eq!(
1291            predicates.into_array(0).unwrap().as_ref(),
1292            &BooleanArray::from(vec![
1293                false, false, false, false, true, false, false, false, false, false
1294            ])
1295        );
1296    }
1297
1298    #[test]
1299    fn test_not_like() {
1300        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1301
1302        let planner = Planner::new(schema.clone());
1303
1304        let expected = col("s").not_like(lit("str-4"));
1305        // single quote
1306        let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1307        assert_eq!(expr, expected);
1308        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1309
1310        let batch = RecordBatch::try_new(
1311            schema,
1312            vec![Arc::new(StringArray::from_iter_values(
1313                (0..10).map(|v| format!("str-{}", v)),
1314            ))],
1315        )
1316        .unwrap();
1317        let predicates = physical_expr.evaluate(&batch).unwrap();
1318        assert_eq!(
1319            predicates.into_array(0).unwrap().as_ref(),
1320            &BooleanArray::from(vec![
1321                true, true, true, true, false, true, true, true, true, true
1322            ])
1323        );
1324    }
1325
1326    #[test]
1327    fn test_sql_is_in() {
1328        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1329
1330        let planner = Planner::new(schema.clone());
1331
1332        let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1333        // single quote
1334        let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1335        assert_eq!(expr, expected);
1336        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1337
1338        let batch = RecordBatch::try_new(
1339            schema,
1340            vec![Arc::new(StringArray::from_iter_values(
1341                (0..10).map(|v| format!("str-{}", v)),
1342            ))],
1343        )
1344        .unwrap();
1345        let predicates = physical_expr.evaluate(&batch).unwrap();
1346        assert_eq!(
1347            predicates.into_array(0).unwrap().as_ref(),
1348            &BooleanArray::from(vec![
1349                false, false, false, false, true, true, false, false, false, false
1350            ])
1351        );
1352    }
1353
1354    #[test]
1355    fn test_sql_is_null() {
1356        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1357
1358        let planner = Planner::new(schema.clone());
1359
1360        let expected = col("s").is_null();
1361        let expr = planner.parse_filter("s IS NULL").unwrap();
1362        assert_eq!(expr, expected);
1363        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1364
1365        let batch = RecordBatch::try_new(
1366            schema,
1367            vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1368                if v % 3 == 0 {
1369                    Some(format!("str-{}", v))
1370                } else {
1371                    None
1372                }
1373            })))],
1374        )
1375        .unwrap();
1376        let predicates = physical_expr.evaluate(&batch).unwrap();
1377        assert_eq!(
1378            predicates.into_array(0).unwrap().as_ref(),
1379            &BooleanArray::from(vec![
1380                false, true, true, false, true, true, false, true, true, false
1381            ])
1382        );
1383
1384        let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1385        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1386        let predicates = physical_expr.evaluate(&batch).unwrap();
1387        assert_eq!(
1388            predicates.into_array(0).unwrap().as_ref(),
1389            &BooleanArray::from(vec![
1390                true, false, false, true, false, false, true, false, false, true,
1391            ])
1392        );
1393    }
1394
1395    #[test]
1396    fn test_sql_invert() {
1397        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1398
1399        let planner = Planner::new(schema.clone());
1400
1401        let expr = planner.parse_filter("NOT s").unwrap();
1402        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1403
1404        let batch = RecordBatch::try_new(
1405            schema,
1406            vec![Arc::new(BooleanArray::from_iter(
1407                (0..10).map(|v| Some(v % 3 == 0)),
1408            ))],
1409        )
1410        .unwrap();
1411        let predicates = physical_expr.evaluate(&batch).unwrap();
1412        assert_eq!(
1413            predicates.into_array(0).unwrap().as_ref(),
1414            &BooleanArray::from(vec![
1415                false, true, true, false, true, true, false, true, true, false
1416            ])
1417        );
1418    }
1419
1420    #[test]
1421    fn test_sql_cast() {
1422        let cases = &[
1423            (
1424                "x = cast('2021-01-01 00:00:00' as timestamp)",
1425                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1426            ),
1427            (
1428                "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1429                ArrowDataType::Timestamp(TimeUnit::Second, None),
1430            ),
1431            (
1432                "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1433                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1434            ),
1435            (
1436                "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1437                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1438            ),
1439            ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1440            (
1441                "x = cast('1.238' as decimal(9,3))",
1442                ArrowDataType::Decimal128(9, 3),
1443            ),
1444            ("x = cast(1 as float)", ArrowDataType::Float32),
1445            ("x = cast(1 as double)", ArrowDataType::Float64),
1446            ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1447            ("x = cast(1 as smallint)", ArrowDataType::Int16),
1448            ("x = cast(1 as int)", ArrowDataType::Int32),
1449            ("x = cast(1 as integer)", ArrowDataType::Int32),
1450            ("x = cast(1 as bigint)", ArrowDataType::Int64),
1451            ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1452            ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1453            ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1454            ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1455            ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1456            ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1457            ("x = cast(1 as string)", ArrowDataType::Utf8),
1458        ];
1459
1460        for (sql, expected_data_type) in cases {
1461            let schema = Arc::new(Schema::new(vec![Field::new(
1462                "x",
1463                expected_data_type.clone(),
1464                true,
1465            )]));
1466            let planner = Planner::new(schema.clone());
1467            let expr = planner.parse_filter(sql).unwrap();
1468
1469            // Get the thing after 'cast(` but before ' as'.
1470            let expected_value_str = sql
1471                .split("cast(")
1472                .nth(1)
1473                .unwrap()
1474                .split(" as")
1475                .next()
1476                .unwrap();
1477            // Remove any quote marks
1478            let expected_value_str = expected_value_str.trim_matches('\'');
1479
1480            match expr {
1481                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1482                    Expr::Cast(Cast { expr, data_type }) => {
1483                        match expr.as_ref() {
1484                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1485                                assert_eq!(value_str, expected_value_str);
1486                            }
1487                            Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1488                                assert_eq!(*value, 1);
1489                            }
1490                            _ => panic!("Expected cast to be applied to literal"),
1491                        }
1492                        assert_eq!(data_type, expected_data_type);
1493                    }
1494                    _ => panic!("Expected right to be a cast"),
1495                },
1496                _ => panic!("Expected binary expression"),
1497            }
1498        }
1499    }
1500
1501    #[test]
1502    fn test_sql_literals() {
1503        let cases = &[
1504            (
1505                "x = timestamp '2021-01-01 00:00:00'",
1506                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1507            ),
1508            (
1509                "x = timestamp(0) '2021-01-01 00:00:00'",
1510                ArrowDataType::Timestamp(TimeUnit::Second, None),
1511            ),
1512            (
1513                "x = timestamp(9) '2021-01-01 00:00:00.123'",
1514                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1515            ),
1516            ("x = date '2021-01-01'", ArrowDataType::Date32),
1517            ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1518        ];
1519
1520        for (sql, expected_data_type) in cases {
1521            let schema = Arc::new(Schema::new(vec![Field::new(
1522                "x",
1523                expected_data_type.clone(),
1524                true,
1525            )]));
1526            let planner = Planner::new(schema.clone());
1527            let expr = planner.parse_filter(sql).unwrap();
1528
1529            let expected_value_str = sql.split('\'').nth(1).unwrap();
1530
1531            match expr {
1532                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1533                    Expr::Cast(Cast { expr, data_type }) => {
1534                        match expr.as_ref() {
1535                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1536                                assert_eq!(value_str, expected_value_str);
1537                            }
1538                            _ => panic!("Expected cast to be applied to literal"),
1539                        }
1540                        assert_eq!(data_type, expected_data_type);
1541                    }
1542                    _ => panic!("Expected right to be a cast"),
1543                },
1544                _ => panic!("Expected binary expression"),
1545            }
1546        }
1547    }
1548
1549    #[test]
1550    fn test_sql_array_literals() {
1551        let cases = [
1552            (
1553                "x = [1, 2, 3]",
1554                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1555            ),
1556            (
1557                "x = [1, 2, 3]",
1558                ArrowDataType::FixedSizeList(
1559                    Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1560                    3,
1561                ),
1562            ),
1563        ];
1564
1565        for (sql, expected_data_type) in cases {
1566            let schema = Arc::new(Schema::new(vec![Field::new(
1567                "x",
1568                expected_data_type.clone(),
1569                true,
1570            )]));
1571            let planner = Planner::new(schema.clone());
1572            let expr = planner.parse_filter(sql).unwrap();
1573            let expr = planner.optimize_expr(expr).unwrap();
1574
1575            match expr {
1576                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1577                    Expr::Literal(value, _) => {
1578                        assert_eq!(&value.data_type(), &expected_data_type);
1579                    }
1580                    _ => panic!("Expected right to be a literal"),
1581                },
1582                _ => panic!("Expected binary expression"),
1583            }
1584        }
1585    }
1586
1587    #[test]
1588    fn test_sql_between() {
1589        use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1590        use arrow_schema::{DataType, Field, Schema, TimeUnit};
1591        use std::sync::Arc;
1592
1593        let schema = Arc::new(Schema::new(vec![
1594            Field::new("x", DataType::Int32, false),
1595            Field::new("y", DataType::Float64, false),
1596            Field::new(
1597                "ts",
1598                DataType::Timestamp(TimeUnit::Microsecond, None),
1599                false,
1600            ),
1601        ]));
1602
1603        let planner = Planner::new(schema.clone());
1604
1605        // Test integer BETWEEN
1606        let expr = planner
1607            .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1608            .unwrap();
1609        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1610
1611        // Create timestamp array with values representing:
1612        // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
1613        let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
1614        let ts_array = TimestampMicrosecondArray::from_iter_values(
1615            (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
1616        );
1617
1618        let batch = RecordBatch::try_new(
1619            schema,
1620            vec![
1621                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1622                Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1623                Arc::new(ts_array),
1624            ],
1625        )
1626        .unwrap();
1627
1628        let predicates = physical_expr.evaluate(&batch).unwrap();
1629        assert_eq!(
1630            predicates.into_array(0).unwrap().as_ref(),
1631            &BooleanArray::from(vec![
1632                false, false, false, true, true, true, true, true, false, false
1633            ])
1634        );
1635
1636        // Test NOT BETWEEN
1637        let expr = planner
1638            .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1639            .unwrap();
1640        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1641
1642        let predicates = physical_expr.evaluate(&batch).unwrap();
1643        assert_eq!(
1644            predicates.into_array(0).unwrap().as_ref(),
1645            &BooleanArray::from(vec![
1646                true, true, true, false, false, false, false, false, true, true
1647            ])
1648        );
1649
1650        // Test floating point BETWEEN
1651        let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1652        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1653
1654        let predicates = physical_expr.evaluate(&batch).unwrap();
1655        assert_eq!(
1656            predicates.into_array(0).unwrap().as_ref(),
1657            &BooleanArray::from(vec![
1658                false, false, false, true, true, true, true, false, false, false
1659            ])
1660        );
1661
1662        // Test timestamp BETWEEN
1663        let expr = planner
1664            .parse_filter(
1665                "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1666            )
1667            .unwrap();
1668        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1669
1670        let predicates = physical_expr.evaluate(&batch).unwrap();
1671        assert_eq!(
1672            predicates.into_array(0).unwrap().as_ref(),
1673            &BooleanArray::from(vec![
1674                false, false, false, true, true, true, true, true, false, false
1675            ])
1676        );
1677    }
1678
1679    #[test]
1680    fn test_sql_comparison() {
1681        // Create a batch with all data types
1682        let batch: Vec<(&str, ArrayRef)> = vec![
1683            (
1684                "timestamp_s",
1685                Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1686            ),
1687            (
1688                "timestamp_ms",
1689                Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1690            ),
1691            (
1692                "timestamp_us",
1693                Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1694            ),
1695            (
1696                "timestamp_ns",
1697                Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1698            ),
1699        ];
1700        let batch = RecordBatch::try_from_iter(batch).unwrap();
1701
1702        let planner = Planner::new(batch.schema());
1703
1704        // Each expression is meant to select the final 5 rows
1705        let expressions = &[
1706            "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1707            "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1708            "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1709            "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1710        ];
1711
1712        let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1713            std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1714        ));
1715        for expression in expressions {
1716            // convert to physical expression
1717            let logical_expr = planner.parse_filter(expression).unwrap();
1718            let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1719            let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1720
1721            // Evaluate and assert they have correct results
1722            let result = physical_expr.evaluate(&batch).unwrap();
1723            let result = result.into_array(batch.num_rows()).unwrap();
1724            assert_eq!(&expected, &result, "unexpected result for {}", expression);
1725        }
1726    }
1727
1728    #[test]
1729    fn test_columns_in_expr() {
1730        let expr = col("s0").gt(lit("value")).and(
1731            col("st")
1732                .field("st")
1733                .field("s2")
1734                .eq(lit("value"))
1735                .or(col("st")
1736                    .field("s1")
1737                    .in_list(vec![lit("value 1"), lit("value 2")], false)),
1738        );
1739
1740        let columns = Planner::column_names_in_expr(&expr);
1741        assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1742    }
1743
1744    #[test]
1745    fn test_parse_binary_expr() {
1746        let bin_str = "x'616263'";
1747
1748        let schema = Arc::new(Schema::new(vec![Field::new(
1749            "binary",
1750            DataType::Binary,
1751            true,
1752        )]));
1753        let planner = Planner::new(schema);
1754        let expr = planner.parse_expr(bin_str).unwrap();
1755        assert_eq!(
1756            expr,
1757            Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1758        );
1759    }
1760
1761    #[test]
1762    fn test_lance_context_provider_expr_planners() {
1763        let ctx_provider = LanceContextProvider::default();
1764        assert!(!ctx_provider.get_expr_planners().is_empty());
1765    }
1766
1767    #[test]
1768    fn test_regexp_match_and_non_empty_captions() {
1769        // Repro for a bug where regexp_match inside an AND chain wasn't coerced to boolean,
1770        // causing planning/evaluation failures. This should evaluate successfully.
1771        let schema = Arc::new(Schema::new(vec![
1772            Field::new("keywords", DataType::Utf8, true),
1773            Field::new("natural_caption", DataType::Utf8, true),
1774            Field::new("poetic_caption", DataType::Utf8, true),
1775        ]));
1776
1777        let planner = Planner::new(schema.clone());
1778
1779        let expr = planner
1780            .parse_filter(
1781                "regexp_match(keywords, 'Liberty|revolution') AND \
1782                 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1783                  poetic_caption IS NOT NULL AND poetic_caption <> '')",
1784            )
1785            .unwrap();
1786
1787        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1788
1789        let batch = RecordBatch::try_new(
1790            schema,
1791            vec![
1792                Arc::new(StringArray::from(vec![
1793                    Some("Liberty for all"),
1794                    Some("peace"),
1795                    Some("revolution now"),
1796                    Some("Liberty"),
1797                    Some("revolutionary"),
1798                    Some("none"),
1799                ])) as ArrayRef,
1800                Arc::new(StringArray::from(vec![
1801                    Some("a"),
1802                    Some("b"),
1803                    None,
1804                    Some(""),
1805                    Some("c"),
1806                    Some("d"),
1807                ])) as ArrayRef,
1808                Arc::new(StringArray::from(vec![
1809                    Some("x"),
1810                    Some(""),
1811                    Some("y"),
1812                    Some("z"),
1813                    None,
1814                    Some("w"),
1815                ])) as ArrayRef,
1816            ],
1817        )
1818        .unwrap();
1819
1820        let result = physical_expr.evaluate(&batch).unwrap();
1821        assert_eq!(
1822            result.into_array(0).unwrap().as_ref(),
1823            &BooleanArray::from(vec![true, false, false, false, false, false])
1824        );
1825    }
1826
1827    #[test]
1828    fn test_regexp_match_infer_error_without_boolean_coercion() {
1829        // With the fix applied, using parse_filter should coerce regexp_match to boolean
1830        // even when nested in a larger AND expression, so this should plan successfully.
1831        let schema = Arc::new(Schema::new(vec![
1832            Field::new("keywords", DataType::Utf8, true),
1833            Field::new("natural_caption", DataType::Utf8, true),
1834            Field::new("poetic_caption", DataType::Utf8, true),
1835        ]));
1836
1837        let planner = Planner::new(schema);
1838
1839        let expr = planner
1840            .parse_filter(
1841                "regexp_match(keywords, 'Liberty|revolution') AND \
1842                 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1843                  poetic_caption IS NOT NULL AND poetic_caption <> '')",
1844            )
1845            .unwrap();
1846
1847        // Should not panic
1848        let _physical = planner.create_physical_expr(&expr).unwrap();
1849    }
1850}