lance_datafusion/
planner.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Exec plan planner
5
6use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::{SessionContext, SessionState};
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29    AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30    Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{
34    ContextProvider, NullOrdering, ParserOptions, PlannerContext, SqlToRel,
35};
36use datafusion::sql::sqlparser::ast::{
37    AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
38    Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
39    ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
40};
41use datafusion::{
42    common::Column,
43    logical_expr::{col, Between, BinaryExpr, Like, Operator},
44    physical_expr::execution_props::ExecutionProps,
45    physical_plan::PhysicalExpr,
46    prelude::Expr,
47    scalar::ScalarValue,
48};
49use datafusion_functions::core::getfield::GetFieldFunc;
50use lance_arrow::cast::cast_with_options;
51use lance_core::datatypes::Schema;
52use lance_core::error::LanceOptionExt;
53use snafu::location;
54
55use chrono::Utc;
56use lance_core::{Error, Result};
57
58#[derive(Debug, Clone, Eq, PartialEq, Hash)]
59struct CastListF16Udf {
60    signature: Signature,
61}
62
63impl CastListF16Udf {
64    pub fn new() -> Self {
65        Self {
66            signature: Signature::any(1, Volatility::Immutable),
67        }
68    }
69}
70
71impl ScalarUDFImpl for CastListF16Udf {
72    fn as_any(&self) -> &dyn std::any::Any {
73        self
74    }
75
76    fn name(&self) -> &str {
77        "_cast_list_f16"
78    }
79
80    fn signature(&self) -> &Signature {
81        &self.signature
82    }
83
84    fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
85        let input = &arg_types[0];
86        match input {
87            ArrowDataType::FixedSizeList(field, size) => {
88                if field.data_type() != &ArrowDataType::Float32
89                    && field.data_type() != &ArrowDataType::Float16
90                {
91                    return Err(datafusion::error::DataFusionError::Execution(
92                        "cast_list_f16 only supports list of float32 or float16".to_string(),
93                    ));
94                }
95                Ok(ArrowDataType::FixedSizeList(
96                    Arc::new(Field::new(
97                        field.name(),
98                        ArrowDataType::Float16,
99                        field.is_nullable(),
100                    )),
101                    *size,
102                ))
103            }
104            ArrowDataType::List(field) => {
105                if field.data_type() != &ArrowDataType::Float32
106                    && field.data_type() != &ArrowDataType::Float16
107                {
108                    return Err(datafusion::error::DataFusionError::Execution(
109                        "cast_list_f16 only supports list of float32 or float16".to_string(),
110                    ));
111                }
112                Ok(ArrowDataType::List(Arc::new(Field::new(
113                    field.name(),
114                    ArrowDataType::Float16,
115                    field.is_nullable(),
116                ))))
117            }
118            _ => Err(datafusion::error::DataFusionError::Execution(
119                "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
120            )),
121        }
122    }
123
124    fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
125        let ColumnarValue::Array(arr) = &func_args.args[0] else {
126            return Err(datafusion::error::DataFusionError::Execution(
127                "cast_list_f16 only supports array arguments".to_string(),
128            ));
129        };
130
131        let to_type = match arr.data_type() {
132            ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
133                Arc::new(Field::new(
134                    field.name(),
135                    ArrowDataType::Float16,
136                    field.is_nullable(),
137                )),
138                *size,
139            ),
140            ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
141                field.name(),
142                ArrowDataType::Float16,
143                field.is_nullable(),
144            ))),
145            _ => {
146                return Err(datafusion::error::DataFusionError::Execution(
147                    "cast_list_f16 only supports array arguments".to_string(),
148                ));
149            }
150        };
151
152        let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
153        Ok(ColumnarValue::Array(res))
154    }
155}
156
157// Adapter that instructs datafusion how lance expects expressions to be interpreted
158struct LanceContextProvider {
159    options: datafusion::config::ConfigOptions,
160    state: SessionState,
161    expr_planners: Vec<Arc<dyn ExprPlanner>>,
162}
163
164impl Default for LanceContextProvider {
165    fn default() -> Self {
166        let config = SessionConfig::new();
167        let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
168
169        let ctx = SessionContext::new_with_config_rt(config.clone(), runtime.clone());
170        crate::udf::register_functions(&ctx);
171
172        let state = ctx.state();
173
174        // SessionState does not expose expr_planners, so we need to get them separately
175        let mut state_builder = SessionStateBuilder::new()
176            .with_config(config)
177            .with_runtime_env(runtime)
178            .with_default_features();
179
180        // unwrap safe because with_default_features sets expr_planners
181        let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
182
183        Self {
184            options: ConfigOptions::default(),
185            state,
186            expr_planners,
187        }
188    }
189}
190
191impl ContextProvider for LanceContextProvider {
192    fn get_table_source(
193        &self,
194        name: datafusion::sql::TableReference,
195    ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
196        Err(datafusion::error::DataFusionError::NotImplemented(format!(
197            "Attempt to reference inner table {} not supported",
198            name
199        )))
200    }
201
202    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
203        self.state.aggregate_functions().get(name).cloned()
204    }
205
206    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
207        self.state.window_functions().get(name).cloned()
208    }
209
210    fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
211        match f {
212            // TODO: cast should go thru CAST syntax instead of UDF
213            // Going thru UDF makes it hard for the optimizer to find no-ops
214            "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
215            _ => self.state.scalar_functions().get(f).cloned(),
216        }
217    }
218
219    fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
220        // Variables (things like @@LANGUAGE) not supported
221        None
222    }
223
224    fn options(&self) -> &datafusion::config::ConfigOptions {
225        &self.options
226    }
227
228    fn udf_names(&self) -> Vec<String> {
229        self.state.scalar_functions().keys().cloned().collect()
230    }
231
232    fn udaf_names(&self) -> Vec<String> {
233        self.state.aggregate_functions().keys().cloned().collect()
234    }
235
236    fn udwf_names(&self) -> Vec<String> {
237        self.state.window_functions().keys().cloned().collect()
238    }
239
240    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
241        &self.expr_planners
242    }
243}
244
245pub struct Planner {
246    schema: SchemaRef,
247    context_provider: LanceContextProvider,
248    enable_relations: bool,
249}
250
251impl Planner {
252    pub fn new(schema: SchemaRef) -> Self {
253        Self {
254            schema,
255            context_provider: LanceContextProvider::default(),
256            enable_relations: false,
257        }
258    }
259
260    /// If passed with `true`, then the first identifier in column reference
261    /// is parsed as the relation. For example, `table.field.inner` will be
262    /// read as the nested field `field.inner` (`inner` on struct field `field`)
263    /// on the `table` relation. If `false` (the default), then no relations
264    /// are used and all identifiers are assumed to be a nested column path.
265    pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
266        self.enable_relations = enable_relations;
267        self
268    }
269
270    fn column(&self, idents: &[Ident]) -> Expr {
271        fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
272            for ident in idents {
273                *expr = Expr::ScalarFunction(ScalarFunction {
274                    args: vec![
275                        std::mem::take(expr),
276                        Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
277                    ],
278                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
279                });
280            }
281        }
282
283        if self.enable_relations && idents.len() > 1 {
284            // Create qualified column reference (relation.column)
285            let relation = &idents[0].value;
286            let column_name = &idents[1].value;
287            let column = Expr::Column(Column::new(Some(relation.clone()), column_name.clone()));
288            let mut result = column;
289            handle_remaining_idents(&mut result, &idents[2..]);
290            result
291        } else {
292            // Default behavior - treat as struct field access
293            let mut column = col(&idents[0].value);
294            handle_remaining_idents(&mut column, &idents[1..]);
295            column
296        }
297    }
298
299    fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
300        Ok(match op {
301            BinaryOperator::Plus => Operator::Plus,
302            BinaryOperator::Minus => Operator::Minus,
303            BinaryOperator::Multiply => Operator::Multiply,
304            BinaryOperator::Divide => Operator::Divide,
305            BinaryOperator::Modulo => Operator::Modulo,
306            BinaryOperator::StringConcat => Operator::StringConcat,
307            BinaryOperator::Gt => Operator::Gt,
308            BinaryOperator::Lt => Operator::Lt,
309            BinaryOperator::GtEq => Operator::GtEq,
310            BinaryOperator::LtEq => Operator::LtEq,
311            BinaryOperator::Eq => Operator::Eq,
312            BinaryOperator::NotEq => Operator::NotEq,
313            BinaryOperator::And => Operator::And,
314            BinaryOperator::Or => Operator::Or,
315            _ => {
316                return Err(Error::invalid_input(
317                    format!("Operator {op} is not supported"),
318                    location!(),
319                ));
320            }
321        })
322    }
323
324    fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
325        Ok(Expr::BinaryExpr(BinaryExpr::new(
326            Box::new(self.parse_sql_expr(left)?),
327            self.binary_op(op)?,
328            Box::new(self.parse_sql_expr(right)?),
329        )))
330    }
331
332    fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
333        Ok(match op {
334            UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
335                Expr::Not(Box::new(self.parse_sql_expr(expr)?))
336            }
337
338            UnaryOperator::Minus => {
339                use datafusion::logical_expr::lit;
340                match expr {
341                    SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
342                        Ok(n) => lit(-n),
343                        Err(_) => lit(-n
344                            .parse::<f64>()
345                            .map_err(|_e| {
346                                Error::invalid_input(
347                                    format!("negative operator can be only applied to integer and float operands, got: {n}"),
348                                    location!(),
349                                )
350                            })?),
351                    },
352                    _ => {
353                        Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
354                    }
355                }
356            }
357
358            _ => {
359                return Err(Error::invalid_input(
360                    format!("Unary operator '{:?}' is not supported", op),
361                    location!(),
362                ));
363            }
364        })
365    }
366
367    // See datafusion `sqlToRel::parse_sql_number()`
368    fn number(&self, value: &str, negative: bool) -> Result<Expr> {
369        use datafusion::logical_expr::lit;
370        let value: Cow<str> = if negative {
371            Cow::Owned(format!("-{}", value))
372        } else {
373            Cow::Borrowed(value)
374        };
375        if let Ok(n) = value.parse::<i64>() {
376            Ok(lit(n))
377        } else {
378            value.parse::<f64>().map(lit).map_err(|_| {
379                Error::invalid_input(
380                    format!("'{value}' is not supported number value."),
381                    location!(),
382                )
383            })
384        }
385    }
386
387    fn value(&self, value: &Value) -> Result<Expr> {
388        Ok(match value {
389            Value::Number(v, _) => self.number(v.as_str(), false)?,
390            Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
391            Value::HexStringLiteral(hsl) => {
392                Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
393            }
394            Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
395            Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
396            Value::Null => Expr::Literal(ScalarValue::Null, None),
397            _ => todo!(),
398        })
399    }
400
401    fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
402        match func_args {
403            FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
404            _ => Err(Error::invalid_input(
405                format!("Unsupported function args: {:?}", func_args),
406                location!(),
407            )),
408        }
409    }
410
411    // We now use datafusion to parse functions.  This allows us to use datafusion's
412    // entire collection of functions (previously we had just hard-coded support for two functions).
413    //
414    // Unfortunately, one of those two functions was is_valid and the reason we needed it was because
415    // this is a function that comes from duckdb.  Datafusion does not consider is_valid to be a function
416    // but rather an AST node (Expr::IsNotNull) and so we need to handle this case specially.
417    fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
418        match &func.args {
419            FunctionArguments::List(args) => {
420                if func.name.0.len() != 1 {
421                    return Err(Error::invalid_input(
422                        format!("Function name must have 1 part, got: {:?}", func.name.0),
423                        location!(),
424                    ));
425                }
426                Ok(Expr::IsNotNull(Box::new(
427                    self.parse_function_args(&args.args[0])?,
428                )))
429            }
430            _ => Err(Error::invalid_input(
431                format!("Unsupported function args: {:?}", &func.args),
432                location!(),
433            )),
434        }
435    }
436
437    fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
438        if let SQLExpr::Function(function) = &function {
439            if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
440                if &name.value == "is_valid" {
441                    return self.legacy_parse_function(function);
442                }
443            }
444        }
445        let sql_to_rel = SqlToRel::new_with_options(
446            &self.context_provider,
447            ParserOptions {
448                parse_float_as_decimal: false,
449                enable_ident_normalization: false,
450                support_varchar_with_length: false,
451                enable_options_value_normalization: false,
452                collect_spans: false,
453                map_string_types_to_utf8view: false,
454                default_null_ordering: NullOrdering::NullsMax,
455            },
456        );
457
458        let mut planner_context = PlannerContext::default();
459        let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
460        sql_to_rel
461            .sql_to_expr(function, &schema, &mut planner_context)
462            .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
463    }
464
465    fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
466        const SUPPORTED_TYPES: [&str; 13] = [
467            "int [unsigned]",
468            "tinyint [unsigned]",
469            "smallint [unsigned]",
470            "bigint [unsigned]",
471            "float",
472            "double",
473            "string",
474            "binary",
475            "date",
476            "timestamp(precision)",
477            "datetime(precision)",
478            "decimal(precision,scale)",
479            "boolean",
480        ];
481        match data_type {
482            SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
483            SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
484            SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
485            SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
486            SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
487            SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
488            SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
489            SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
490            SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
491            SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
492            SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
493            SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
494                Ok(ArrowDataType::UInt32)
495            }
496            SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
497            SQLDataType::Date => Ok(ArrowDataType::Date32),
498            SQLDataType::Timestamp(resolution, tz) => {
499                match tz {
500                    TimezoneInfo::None => {}
501                    _ => {
502                        return Err(Error::invalid_input(
503                            "Timezone not supported in timestamp".to_string(),
504                            location!(),
505                        ));
506                    }
507                };
508                let time_unit = match resolution {
509                    // Default to microsecond to match PyArrow
510                    None => TimeUnit::Microsecond,
511                    Some(0) => TimeUnit::Second,
512                    Some(3) => TimeUnit::Millisecond,
513                    Some(6) => TimeUnit::Microsecond,
514                    Some(9) => TimeUnit::Nanosecond,
515                    _ => {
516                        return Err(Error::invalid_input(
517                            format!("Unsupported datetime resolution: {:?}", resolution),
518                            location!(),
519                        ));
520                    }
521                };
522                Ok(ArrowDataType::Timestamp(time_unit, None))
523            }
524            SQLDataType::Datetime(resolution) => {
525                let time_unit = match resolution {
526                    None => TimeUnit::Microsecond,
527                    Some(0) => TimeUnit::Second,
528                    Some(3) => TimeUnit::Millisecond,
529                    Some(6) => TimeUnit::Microsecond,
530                    Some(9) => TimeUnit::Nanosecond,
531                    _ => {
532                        return Err(Error::invalid_input(
533                            format!("Unsupported datetime resolution: {:?}", resolution),
534                            location!(),
535                        ));
536                    }
537                };
538                Ok(ArrowDataType::Timestamp(time_unit, None))
539            }
540            SQLDataType::Decimal(number_info) => match number_info {
541                ExactNumberInfo::PrecisionAndScale(precision, scale) => {
542                    Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
543                }
544                _ => Err(Error::invalid_input(
545                    format!(
546                        "Must provide precision and scale for decimal: {:?}",
547                        number_info
548                    ),
549                    location!(),
550                )),
551            },
552            _ => Err(Error::invalid_input(
553                format!(
554                    "Unsupported data type: {:?}. Supported types: {:?}",
555                    data_type, SUPPORTED_TYPES
556                ),
557                location!(),
558            )),
559        }
560    }
561
562    fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
563        let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
564        for planner in self.context_provider.get_expr_planners() {
565            match planner.plan_field_access(field_access_expr, &df_schema)? {
566                PlannerResult::Planned(expr) => return Ok(expr),
567                PlannerResult::Original(expr) => {
568                    field_access_expr = expr;
569                }
570            }
571        }
572        Err(Error::invalid_input(
573            "Field access could not be planned",
574            location!(),
575        ))
576    }
577
578    fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
579        match expr {
580            SQLExpr::Identifier(id) => {
581                // Users can pass string literals wrapped in `"`.
582                // (Normally SQL only allows single quotes.)
583                if id.quote_style == Some('"') {
584                    Ok(Expr::Literal(
585                        ScalarValue::Utf8(Some(id.value.clone())),
586                        None,
587                    ))
588                // Users can wrap identifiers with ` to reference non-standard
589                // names, such as uppercase or spaces.
590                } else if id.quote_style == Some('`') {
591                    Ok(Expr::Column(Column::from_name(id.value.clone())))
592                } else {
593                    Ok(self.column(vec![id.clone()].as_slice()))
594                }
595            }
596            SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
597            SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
598            SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
599            SQLExpr::Value(value) => self.value(&value.value),
600            SQLExpr::Array(SQLArray { elem, .. }) => {
601                let mut values = vec![];
602
603                let array_literal_error = |pos: usize, value: &_| {
604                    Err(Error::invalid_input(
605                        format!(
606                            "Expected a literal value in array, instead got {} at position {}",
607                            value, pos
608                        ),
609                        location!(),
610                    ))
611                };
612
613                for (pos, expr) in elem.iter().enumerate() {
614                    match expr {
615                        SQLExpr::Value(value) => {
616                            if let Expr::Literal(value, _) = self.value(&value.value)? {
617                                values.push(value);
618                            } else {
619                                return array_literal_error(pos, expr);
620                            }
621                        }
622                        SQLExpr::UnaryOp {
623                            op: UnaryOperator::Minus,
624                            expr,
625                        } => {
626                            if let SQLExpr::Value(ValueWithSpan {
627                                value: Value::Number(number, _),
628                                ..
629                            }) = expr.as_ref()
630                            {
631                                if let Expr::Literal(value, _) = self.number(number, true)? {
632                                    values.push(value);
633                                } else {
634                                    return array_literal_error(pos, expr);
635                                }
636                            } else {
637                                return array_literal_error(pos, expr);
638                            }
639                        }
640                        _ => {
641                            return array_literal_error(pos, expr);
642                        }
643                    }
644                }
645
646                let field = if !values.is_empty() {
647                    let data_type = values[0].data_type();
648
649                    for value in &mut values {
650                        if value.data_type() != data_type {
651                            *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
652                                format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
653                                location!()
654                            ))?;
655                        }
656                    }
657                    Field::new("item", data_type, true)
658                } else {
659                    Field::new("item", ArrowDataType::Null, true)
660                };
661
662                let values = values
663                    .into_iter()
664                    .map(|v| v.to_array().map_err(Error::from))
665                    .collect::<Result<Vec<_>>>()?;
666                let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
667                let values = concat(&array_refs)?;
668                let values = ListArray::try_new(
669                    field.into(),
670                    OffsetBuffer::from_lengths([values.len()]),
671                    values,
672                    None,
673                )?;
674
675                Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
676            }
677            // For example, DATE '2020-01-01'
678            SQLExpr::TypedString { data_type, value } => {
679                let value = value.clone().into_string().expect_ok()?;
680                Ok(Expr::Cast(datafusion::logical_expr::Cast {
681                    expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
682                    data_type: self.parse_type(data_type)?,
683                }))
684            }
685            SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
686            SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
687            SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
688            SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
689            SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
690            SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
691            SQLExpr::InList {
692                expr,
693                list,
694                negated,
695            } => {
696                let value_expr = self.parse_sql_expr(expr)?;
697                let list_exprs = list
698                    .iter()
699                    .map(|e| self.parse_sql_expr(e))
700                    .collect::<Result<Vec<_>>>()?;
701                Ok(value_expr.in_list(list_exprs, *negated))
702            }
703            SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
704            SQLExpr::Function(_) => self.parse_function(expr.clone()),
705            SQLExpr::ILike {
706                negated,
707                expr,
708                pattern,
709                escape_char,
710                any: _,
711            } => Ok(Expr::Like(Like::new(
712                *negated,
713                Box::new(self.parse_sql_expr(expr)?),
714                Box::new(self.parse_sql_expr(pattern)?),
715                match escape_char {
716                    Some(Value::SingleQuotedString(char)) => char.chars().next(),
717                    Some(value) => return Err(Error::invalid_input(
718                        format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
719                        location!()
720                    )),
721                    None => None,
722                },
723                true,
724            ))),
725            SQLExpr::Like {
726                negated,
727                expr,
728                pattern,
729                escape_char,
730                any: _,
731            } => Ok(Expr::Like(Like::new(
732                *negated,
733                Box::new(self.parse_sql_expr(expr)?),
734                Box::new(self.parse_sql_expr(pattern)?),
735                match escape_char {
736                    Some(Value::SingleQuotedString(char)) => char.chars().next(),
737                    Some(value) => return Err(Error::invalid_input(
738                        format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
739                        location!()
740                    )),
741                    None => None,
742                },
743                false,
744            ))),
745            SQLExpr::Cast {
746                expr,
747                data_type,
748                kind,
749                ..
750            } => match kind {
751                datafusion::sql::sqlparser::ast::CastKind::TryCast
752                | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
753                    Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
754                        expr: Box::new(self.parse_sql_expr(expr)?),
755                        data_type: self.parse_type(data_type)?,
756                    }))
757                }
758                _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
759                    expr: Box::new(self.parse_sql_expr(expr)?),
760                    data_type: self.parse_type(data_type)?,
761                })),
762            },
763            SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
764                "JSON access is not supported",
765                location!(),
766            )),
767            SQLExpr::CompoundFieldAccess { root, access_chain } => {
768                let mut expr = self.parse_sql_expr(root)?;
769
770                for access in access_chain {
771                    let field_access = match access {
772                        // x.y or x['y']
773                        AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
774                        | AccessExpr::Subscript(Subscript::Index {
775                            index:
776                                SQLExpr::Value(ValueWithSpan {
777                                    value:
778                                        Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
779                                    ..
780                                }),
781                        }) => GetFieldAccess::NamedStructField {
782                            name: ScalarValue::from(s.as_str()),
783                        },
784                        AccessExpr::Subscript(Subscript::Index { index }) => {
785                            let key = Box::new(self.parse_sql_expr(index)?);
786                            GetFieldAccess::ListIndex { key }
787                        }
788                        AccessExpr::Subscript(Subscript::Slice { .. }) => {
789                            return Err(Error::invalid_input(
790                                "Slice subscript is not supported",
791                                location!(),
792                            ));
793                        }
794                        _ => {
795                            // Handle other cases like JSON access
796                            // Note: JSON access is not supported in lance
797                            return Err(Error::invalid_input(
798                                "Only dot notation or index access is supported for field access",
799                                location!(),
800                            ));
801                        }
802                    };
803
804                    let field_access_expr = RawFieldAccessExpr { expr, field_access };
805                    expr = self.plan_field_access(field_access_expr)?;
806                }
807
808                Ok(expr)
809            }
810            SQLExpr::Between {
811                expr,
812                negated,
813                low,
814                high,
815            } => {
816                // Parse the main expression and bounds
817                let expr = self.parse_sql_expr(expr)?;
818                let low = self.parse_sql_expr(low)?;
819                let high = self.parse_sql_expr(high)?;
820
821                let between = Expr::Between(Between::new(
822                    Box::new(expr),
823                    *negated,
824                    Box::new(low),
825                    Box::new(high),
826                ));
827                Ok(between)
828            }
829            _ => Err(Error::invalid_input(
830                format!("Expression '{expr}' is not supported SQL in lance"),
831                location!(),
832            )),
833        }
834    }
835
836    /// Create Logical [Expr] from a SQL filter clause.
837    ///
838    /// Note: the returned expression must be passed through [optimize_expr()]
839    /// before being passed to [create_physical_expr()].
840    pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
841        // Allow sqlparser to parse filter as part of ONE SQL statement.
842        let ast_expr = parse_sql_filter(filter)?;
843        let expr = self.parse_sql_expr(&ast_expr)?;
844        let schema = Schema::try_from(self.schema.as_ref())?;
845        let resolved = resolve_expr(&expr, &schema).map_err(|e| {
846            Error::invalid_input(
847                format!("Error resolving filter expression {filter}: {e}"),
848                location!(),
849            )
850        })?;
851
852        Ok(coerce_filter_type_to_boolean(resolved))
853    }
854
855    /// Create Logical [Expr] from a SQL expression.
856    ///
857    /// Note: the returned expression must be passed through [optimize_filter()]
858    /// before being passed to [create_physical_expr()].
859    pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
860        if self.schema.field_with_name(expr).is_ok() {
861            return Ok(col(expr));
862        }
863
864        let ast_expr = parse_sql_expr(expr)?;
865        let expr = self.parse_sql_expr(&ast_expr)?;
866        let schema = Schema::try_from(self.schema.as_ref())?;
867        let resolved = resolve_expr(&expr, &schema)?;
868        Ok(resolved)
869    }
870
871    /// Try to decode bytes from hex literal string.
872    ///
873    /// Copied from datafusion because this is not public.
874    ///
875    /// TODO: use SqlToRel from Datafusion directly?
876    fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
877        let hex_bytes = s.as_bytes();
878        let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
879
880        let start_idx = hex_bytes.len() % 2;
881        if start_idx > 0 {
882            // The first byte is formed of only one char.
883            decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
884        }
885
886        for i in (start_idx..hex_bytes.len()).step_by(2) {
887            let high = Self::try_decode_hex_char(hex_bytes[i])?;
888            let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
889            decoded_bytes.push((high << 4) | low);
890        }
891
892        Some(decoded_bytes)
893    }
894
895    /// Try to decode a byte from a hex char.
896    ///
897    /// None will be returned if the input char is hex-invalid.
898    const fn try_decode_hex_char(c: u8) -> Option<u8> {
899        match c {
900            b'A'..=b'F' => Some(c - b'A' + 10),
901            b'a'..=b'f' => Some(c - b'a' + 10),
902            b'0'..=b'9' => Some(c - b'0'),
903            _ => None,
904        }
905    }
906
907    /// Optimize the filter expression and coerce data types.
908    pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
909        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
910
911        // DataFusion needs the simplify and coerce passes to be applied before
912        // expressions can be handled by the physical planner.
913        let props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
914        let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
915        let simplifier =
916            datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
917
918        let expr = simplifier.simplify(expr)?;
919        let expr = simplifier.coerce(expr, &df_schema)?;
920
921        Ok(expr)
922    }
923
924    /// Create the [`PhysicalExpr`] from a logical [`Expr`]
925    pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
926        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
927        Ok(datafusion::physical_expr::create_physical_expr(
928            expr,
929            df_schema.as_ref(),
930            &Default::default(),
931        )?)
932    }
933
934    /// Collect the columns in the expression.
935    ///
936    /// The columns are returned in sorted order.
937    ///
938    /// If the expr refers to nested columns these will be returned
939    /// as dotted paths (x.y.z)
940    pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
941        let mut visitor = ColumnCapturingVisitor {
942            current_path: VecDeque::new(),
943            columns: BTreeSet::new(),
944        };
945        expr.visit(&mut visitor).unwrap();
946        visitor.columns.into_iter().collect()
947    }
948}
949
950struct ColumnCapturingVisitor {
951    // Current column path. If this is empty, we are not in a column expression.
952    current_path: VecDeque<String>,
953    columns: BTreeSet<String>,
954}
955
956impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
957    type Node = Expr;
958
959    fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
960        match node {
961            Expr::Column(Column { name, .. }) => {
962                // Build the field path from the column name and any nested fields
963                // The nested field names from get_field already come as literal strings,
964                // so we just need to concatenate them properly
965                let mut path = name.clone();
966                for part in self.current_path.drain(..) {
967                    path.push('.');
968                    // Check if the part needs quoting (contains dots)
969                    if part.contains('.') || part.contains('`') {
970                        // Quote the field name with backticks and escape any existing backticks
971                        let escaped = part.replace('`', "``");
972                        path.push('`');
973                        path.push_str(&escaped);
974                        path.push('`');
975                    } else {
976                        path.push_str(&part);
977                    }
978                }
979                self.columns.insert(path);
980                self.current_path.clear();
981            }
982            Expr::ScalarFunction(udf) => {
983                if udf.name() == GetFieldFunc::default().name() {
984                    if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
985                        self.current_path.push_front(name.to_string())
986                    } else {
987                        self.current_path.clear();
988                    }
989                } else {
990                    self.current_path.clear();
991                }
992            }
993            _ => {
994                self.current_path.clear();
995            }
996        }
997
998        Ok(TreeNodeRecursion::Continue)
999    }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004
1005    use crate::logical_expr::ExprExt;
1006
1007    use super::*;
1008
1009    use arrow::datatypes::Float64Type;
1010    use arrow_array::{
1011        ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
1012        StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1013        TimestampNanosecondArray, TimestampSecondArray,
1014    };
1015    use arrow_schema::{DataType, Fields, Schema};
1016    use datafusion::{
1017        logical_expr::{lit, Cast},
1018        prelude::{array_element, get_field},
1019    };
1020    use datafusion_functions::core::expr_ext::FieldAccessor;
1021
1022    #[test]
1023    fn test_parse_filter_simple() {
1024        let schema = Arc::new(Schema::new(vec![
1025            Field::new("i", DataType::Int32, false),
1026            Field::new("s", DataType::Utf8, true),
1027            Field::new(
1028                "st",
1029                DataType::Struct(Fields::from(vec![
1030                    Field::new("x", DataType::Float32, false),
1031                    Field::new("y", DataType::Float32, false),
1032                ])),
1033                true,
1034            ),
1035        ]));
1036
1037        let planner = Planner::new(schema.clone());
1038
1039        let expected = col("i")
1040            .gt(lit(3_i32))
1041            .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1042            .and(
1043                col("s")
1044                    .eq(lit("str-4"))
1045                    .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1046            );
1047
1048        // double quotes
1049        let expr = planner
1050            .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1051            .unwrap();
1052        assert_eq!(expr, expected);
1053
1054        // single quote
1055        let expr = planner
1056            .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1057            .unwrap();
1058
1059        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1060
1061        let batch = RecordBatch::try_new(
1062            schema,
1063            vec![
1064                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1065                Arc::new(StringArray::from_iter_values(
1066                    (0..10).map(|v| format!("str-{}", v)),
1067                )),
1068                Arc::new(StructArray::from(vec![
1069                    (
1070                        Arc::new(Field::new("x", DataType::Float32, false)),
1071                        Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1072                            as ArrayRef,
1073                    ),
1074                    (
1075                        Arc::new(Field::new("y", DataType::Float32, false)),
1076                        Arc::new(Float32Array::from_iter_values(
1077                            (0..10).map(|v| (v * 10) as f32),
1078                        )),
1079                    ),
1080                ])),
1081            ],
1082        )
1083        .unwrap();
1084        let predicates = physical_expr.evaluate(&batch).unwrap();
1085        assert_eq!(
1086            predicates.into_array(0).unwrap().as_ref(),
1087            &BooleanArray::from(vec![
1088                false, false, false, false, true, true, false, false, false, false
1089            ])
1090        );
1091    }
1092
1093    #[test]
1094    fn test_nested_col_refs() {
1095        let schema = Arc::new(Schema::new(vec![
1096            Field::new("s0", DataType::Utf8, true),
1097            Field::new(
1098                "st",
1099                DataType::Struct(Fields::from(vec![
1100                    Field::new("s1", DataType::Utf8, true),
1101                    Field::new(
1102                        "st",
1103                        DataType::Struct(Fields::from(vec![Field::new(
1104                            "s2",
1105                            DataType::Utf8,
1106                            true,
1107                        )])),
1108                        true,
1109                    ),
1110                ])),
1111                true,
1112            ),
1113        ]));
1114
1115        let planner = Planner::new(schema);
1116
1117        fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1118            let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1119            assert!(matches!(
1120                expr,
1121                Expr::BinaryExpr(BinaryExpr {
1122                    left: _,
1123                    op: Operator::Eq,
1124                    right: _
1125                })
1126            ));
1127            if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1128                assert_eq!(left.as_ref(), expected);
1129            }
1130        }
1131
1132        let expected = Expr::Column(Column::new_unqualified("s0"));
1133        assert_column_eq(&planner, "s0", &expected);
1134        assert_column_eq(&planner, "`s0`", &expected);
1135
1136        let expected = Expr::ScalarFunction(ScalarFunction {
1137            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1138            args: vec![
1139                Expr::Column(Column::new_unqualified("st")),
1140                Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1141            ],
1142        });
1143        assert_column_eq(&planner, "st.s1", &expected);
1144        assert_column_eq(&planner, "`st`.`s1`", &expected);
1145        assert_column_eq(&planner, "st.`s1`", &expected);
1146
1147        let expected = Expr::ScalarFunction(ScalarFunction {
1148            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1149            args: vec![
1150                Expr::ScalarFunction(ScalarFunction {
1151                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1152                    args: vec![
1153                        Expr::Column(Column::new_unqualified("st")),
1154                        Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1155                    ],
1156                }),
1157                Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1158            ],
1159        });
1160
1161        assert_column_eq(&planner, "st.st.s2", &expected);
1162        assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1163        assert_column_eq(&planner, "st.st.`s2`", &expected);
1164        assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1165    }
1166
1167    #[test]
1168    fn test_nested_list_refs() {
1169        let schema = Arc::new(Schema::new(vec![Field::new(
1170            "l",
1171            DataType::List(Arc::new(Field::new(
1172                "item",
1173                DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1174                true,
1175            ))),
1176            true,
1177        )]));
1178
1179        let planner = Planner::new(schema);
1180
1181        let expected = array_element(col("l"), lit(0_i64));
1182        let expr = planner.parse_expr("l[0]").unwrap();
1183        assert_eq!(expr, expected);
1184
1185        let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1186        let expr = planner.parse_expr("l[0]['f1']").unwrap();
1187        assert_eq!(expr, expected);
1188
1189        // FIXME: This should work, but sqlparser doesn't recognize anything
1190        // after the period for some reason.
1191        // let expr = planner.parse_expr("l[0].f1").unwrap();
1192        // assert_eq!(expr, expected);
1193    }
1194
1195    #[test]
1196    fn test_negative_expressions() {
1197        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1198
1199        let planner = Planner::new(schema.clone());
1200
1201        let expected = col("x")
1202            .gt(lit(-3_i64))
1203            .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1204
1205        let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1206
1207        assert_eq!(expr, expected);
1208
1209        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1210
1211        let batch = RecordBatch::try_new(
1212            schema,
1213            vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1214        )
1215        .unwrap();
1216        let predicates = physical_expr.evaluate(&batch).unwrap();
1217        assert_eq!(
1218            predicates.into_array(0).unwrap().as_ref(),
1219            &BooleanArray::from(vec![
1220                false, false, false, true, true, true, true, false, false, false
1221            ])
1222        );
1223    }
1224
1225    #[test]
1226    fn test_negative_array_expressions() {
1227        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1228
1229        let planner = Planner::new(schema);
1230
1231        let expected = Expr::Literal(
1232            ScalarValue::List(Arc::new(
1233                ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1234                    [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1235                )]),
1236            )),
1237            None,
1238        );
1239
1240        let expr = planner
1241            .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1242            .unwrap();
1243
1244        assert_eq!(expr, expected);
1245    }
1246
1247    #[test]
1248    fn test_sql_like() {
1249        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1250
1251        let planner = Planner::new(schema.clone());
1252
1253        let expected = col("s").like(lit("str-4"));
1254        // single quote
1255        let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1256        assert_eq!(expr, expected);
1257        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1258
1259        let batch = RecordBatch::try_new(
1260            schema,
1261            vec![Arc::new(StringArray::from_iter_values(
1262                (0..10).map(|v| format!("str-{}", v)),
1263            ))],
1264        )
1265        .unwrap();
1266        let predicates = physical_expr.evaluate(&batch).unwrap();
1267        assert_eq!(
1268            predicates.into_array(0).unwrap().as_ref(),
1269            &BooleanArray::from(vec![
1270                false, false, false, false, true, false, false, false, false, false
1271            ])
1272        );
1273    }
1274
1275    #[test]
1276    fn test_not_like() {
1277        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1278
1279        let planner = Planner::new(schema.clone());
1280
1281        let expected = col("s").not_like(lit("str-4"));
1282        // single quote
1283        let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1284        assert_eq!(expr, expected);
1285        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1286
1287        let batch = RecordBatch::try_new(
1288            schema,
1289            vec![Arc::new(StringArray::from_iter_values(
1290                (0..10).map(|v| format!("str-{}", v)),
1291            ))],
1292        )
1293        .unwrap();
1294        let predicates = physical_expr.evaluate(&batch).unwrap();
1295        assert_eq!(
1296            predicates.into_array(0).unwrap().as_ref(),
1297            &BooleanArray::from(vec![
1298                true, true, true, true, false, true, true, true, true, true
1299            ])
1300        );
1301    }
1302
1303    #[test]
1304    fn test_sql_is_in() {
1305        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1306
1307        let planner = Planner::new(schema.clone());
1308
1309        let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1310        // single quote
1311        let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1312        assert_eq!(expr, expected);
1313        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1314
1315        let batch = RecordBatch::try_new(
1316            schema,
1317            vec![Arc::new(StringArray::from_iter_values(
1318                (0..10).map(|v| format!("str-{}", v)),
1319            ))],
1320        )
1321        .unwrap();
1322        let predicates = physical_expr.evaluate(&batch).unwrap();
1323        assert_eq!(
1324            predicates.into_array(0).unwrap().as_ref(),
1325            &BooleanArray::from(vec![
1326                false, false, false, false, true, true, false, false, false, false
1327            ])
1328        );
1329    }
1330
1331    #[test]
1332    fn test_sql_is_null() {
1333        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1334
1335        let planner = Planner::new(schema.clone());
1336
1337        let expected = col("s").is_null();
1338        let expr = planner.parse_filter("s IS NULL").unwrap();
1339        assert_eq!(expr, expected);
1340        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1341
1342        let batch = RecordBatch::try_new(
1343            schema,
1344            vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1345                if v % 3 == 0 {
1346                    Some(format!("str-{}", v))
1347                } else {
1348                    None
1349                }
1350            })))],
1351        )
1352        .unwrap();
1353        let predicates = physical_expr.evaluate(&batch).unwrap();
1354        assert_eq!(
1355            predicates.into_array(0).unwrap().as_ref(),
1356            &BooleanArray::from(vec![
1357                false, true, true, false, true, true, false, true, true, false
1358            ])
1359        );
1360
1361        let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1362        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1363        let predicates = physical_expr.evaluate(&batch).unwrap();
1364        assert_eq!(
1365            predicates.into_array(0).unwrap().as_ref(),
1366            &BooleanArray::from(vec![
1367                true, false, false, true, false, false, true, false, false, true,
1368            ])
1369        );
1370    }
1371
1372    #[test]
1373    fn test_sql_invert() {
1374        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1375
1376        let planner = Planner::new(schema.clone());
1377
1378        let expr = planner.parse_filter("NOT s").unwrap();
1379        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1380
1381        let batch = RecordBatch::try_new(
1382            schema,
1383            vec![Arc::new(BooleanArray::from_iter(
1384                (0..10).map(|v| Some(v % 3 == 0)),
1385            ))],
1386        )
1387        .unwrap();
1388        let predicates = physical_expr.evaluate(&batch).unwrap();
1389        assert_eq!(
1390            predicates.into_array(0).unwrap().as_ref(),
1391            &BooleanArray::from(vec![
1392                false, true, true, false, true, true, false, true, true, false
1393            ])
1394        );
1395    }
1396
1397    #[test]
1398    fn test_sql_cast() {
1399        let cases = &[
1400            (
1401                "x = cast('2021-01-01 00:00:00' as timestamp)",
1402                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1403            ),
1404            (
1405                "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1406                ArrowDataType::Timestamp(TimeUnit::Second, None),
1407            ),
1408            (
1409                "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1410                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1411            ),
1412            (
1413                "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1414                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1415            ),
1416            ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1417            (
1418                "x = cast('1.238' as decimal(9,3))",
1419                ArrowDataType::Decimal128(9, 3),
1420            ),
1421            ("x = cast(1 as float)", ArrowDataType::Float32),
1422            ("x = cast(1 as double)", ArrowDataType::Float64),
1423            ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1424            ("x = cast(1 as smallint)", ArrowDataType::Int16),
1425            ("x = cast(1 as int)", ArrowDataType::Int32),
1426            ("x = cast(1 as integer)", ArrowDataType::Int32),
1427            ("x = cast(1 as bigint)", ArrowDataType::Int64),
1428            ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1429            ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1430            ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1431            ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1432            ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1433            ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1434            ("x = cast(1 as string)", ArrowDataType::Utf8),
1435        ];
1436
1437        for (sql, expected_data_type) in cases {
1438            let schema = Arc::new(Schema::new(vec![Field::new(
1439                "x",
1440                expected_data_type.clone(),
1441                true,
1442            )]));
1443            let planner = Planner::new(schema.clone());
1444            let expr = planner.parse_filter(sql).unwrap();
1445
1446            // Get the thing after 'cast(` but before ' as'.
1447            let expected_value_str = sql
1448                .split("cast(")
1449                .nth(1)
1450                .unwrap()
1451                .split(" as")
1452                .next()
1453                .unwrap();
1454            // Remove any quote marks
1455            let expected_value_str = expected_value_str.trim_matches('\'');
1456
1457            match expr {
1458                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1459                    Expr::Cast(Cast { expr, data_type }) => {
1460                        match expr.as_ref() {
1461                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1462                                assert_eq!(value_str, expected_value_str);
1463                            }
1464                            Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1465                                assert_eq!(*value, 1);
1466                            }
1467                            _ => panic!("Expected cast to be applied to literal"),
1468                        }
1469                        assert_eq!(data_type, expected_data_type);
1470                    }
1471                    _ => panic!("Expected right to be a cast"),
1472                },
1473                _ => panic!("Expected binary expression"),
1474            }
1475        }
1476    }
1477
1478    #[test]
1479    fn test_sql_literals() {
1480        let cases = &[
1481            (
1482                "x = timestamp '2021-01-01 00:00:00'",
1483                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1484            ),
1485            (
1486                "x = timestamp(0) '2021-01-01 00:00:00'",
1487                ArrowDataType::Timestamp(TimeUnit::Second, None),
1488            ),
1489            (
1490                "x = timestamp(9) '2021-01-01 00:00:00.123'",
1491                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1492            ),
1493            ("x = date '2021-01-01'", ArrowDataType::Date32),
1494            ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1495        ];
1496
1497        for (sql, expected_data_type) in cases {
1498            let schema = Arc::new(Schema::new(vec![Field::new(
1499                "x",
1500                expected_data_type.clone(),
1501                true,
1502            )]));
1503            let planner = Planner::new(schema.clone());
1504            let expr = planner.parse_filter(sql).unwrap();
1505
1506            let expected_value_str = sql.split('\'').nth(1).unwrap();
1507
1508            match expr {
1509                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1510                    Expr::Cast(Cast { expr, data_type }) => {
1511                        match expr.as_ref() {
1512                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1513                                assert_eq!(value_str, expected_value_str);
1514                            }
1515                            _ => panic!("Expected cast to be applied to literal"),
1516                        }
1517                        assert_eq!(data_type, expected_data_type);
1518                    }
1519                    _ => panic!("Expected right to be a cast"),
1520                },
1521                _ => panic!("Expected binary expression"),
1522            }
1523        }
1524    }
1525
1526    #[test]
1527    fn test_sql_array_literals() {
1528        let cases = [
1529            (
1530                "x = [1, 2, 3]",
1531                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1532            ),
1533            (
1534                "x = [1, 2, 3]",
1535                ArrowDataType::FixedSizeList(
1536                    Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1537                    3,
1538                ),
1539            ),
1540        ];
1541
1542        for (sql, expected_data_type) in cases {
1543            let schema = Arc::new(Schema::new(vec![Field::new(
1544                "x",
1545                expected_data_type.clone(),
1546                true,
1547            )]));
1548            let planner = Planner::new(schema.clone());
1549            let expr = planner.parse_filter(sql).unwrap();
1550            let expr = planner.optimize_expr(expr).unwrap();
1551
1552            match expr {
1553                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1554                    Expr::Literal(value, _) => {
1555                        assert_eq!(&value.data_type(), &expected_data_type);
1556                    }
1557                    _ => panic!("Expected right to be a literal"),
1558                },
1559                _ => panic!("Expected binary expression"),
1560            }
1561        }
1562    }
1563
1564    #[test]
1565    fn test_sql_between() {
1566        use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1567        use arrow_schema::{DataType, Field, Schema, TimeUnit};
1568        use std::sync::Arc;
1569
1570        let schema = Arc::new(Schema::new(vec![
1571            Field::new("x", DataType::Int32, false),
1572            Field::new("y", DataType::Float64, false),
1573            Field::new(
1574                "ts",
1575                DataType::Timestamp(TimeUnit::Microsecond, None),
1576                false,
1577            ),
1578        ]));
1579
1580        let planner = Planner::new(schema.clone());
1581
1582        // Test integer BETWEEN
1583        let expr = planner
1584            .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1585            .unwrap();
1586        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1587
1588        // Create timestamp array with values representing:
1589        // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
1590        let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
1591        let ts_array = TimestampMicrosecondArray::from_iter_values(
1592            (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
1593        );
1594
1595        let batch = RecordBatch::try_new(
1596            schema,
1597            vec![
1598                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1599                Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1600                Arc::new(ts_array),
1601            ],
1602        )
1603        .unwrap();
1604
1605        let predicates = physical_expr.evaluate(&batch).unwrap();
1606        assert_eq!(
1607            predicates.into_array(0).unwrap().as_ref(),
1608            &BooleanArray::from(vec![
1609                false, false, false, true, true, true, true, true, false, false
1610            ])
1611        );
1612
1613        // Test NOT BETWEEN
1614        let expr = planner
1615            .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1616            .unwrap();
1617        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1618
1619        let predicates = physical_expr.evaluate(&batch).unwrap();
1620        assert_eq!(
1621            predicates.into_array(0).unwrap().as_ref(),
1622            &BooleanArray::from(vec![
1623                true, true, true, false, false, false, false, false, true, true
1624            ])
1625        );
1626
1627        // Test floating point BETWEEN
1628        let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1629        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1630
1631        let predicates = physical_expr.evaluate(&batch).unwrap();
1632        assert_eq!(
1633            predicates.into_array(0).unwrap().as_ref(),
1634            &BooleanArray::from(vec![
1635                false, false, false, true, true, true, true, false, false, false
1636            ])
1637        );
1638
1639        // Test timestamp BETWEEN
1640        let expr = planner
1641            .parse_filter(
1642                "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1643            )
1644            .unwrap();
1645        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1646
1647        let predicates = physical_expr.evaluate(&batch).unwrap();
1648        assert_eq!(
1649            predicates.into_array(0).unwrap().as_ref(),
1650            &BooleanArray::from(vec![
1651                false, false, false, true, true, true, true, true, false, false
1652            ])
1653        );
1654    }
1655
1656    #[test]
1657    fn test_sql_comparison() {
1658        // Create a batch with all data types
1659        let batch: Vec<(&str, ArrayRef)> = vec![
1660            (
1661                "timestamp_s",
1662                Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1663            ),
1664            (
1665                "timestamp_ms",
1666                Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1667            ),
1668            (
1669                "timestamp_us",
1670                Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1671            ),
1672            (
1673                "timestamp_ns",
1674                Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1675            ),
1676        ];
1677        let batch = RecordBatch::try_from_iter(batch).unwrap();
1678
1679        let planner = Planner::new(batch.schema());
1680
1681        // Each expression is meant to select the final 5 rows
1682        let expressions = &[
1683            "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1684            "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1685            "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1686            "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1687        ];
1688
1689        let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1690            std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1691        ));
1692        for expression in expressions {
1693            // convert to physical expression
1694            let logical_expr = planner.parse_filter(expression).unwrap();
1695            let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1696            let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1697
1698            // Evaluate and assert they have correct results
1699            let result = physical_expr.evaluate(&batch).unwrap();
1700            let result = result.into_array(batch.num_rows()).unwrap();
1701            assert_eq!(&expected, &result, "unexpected result for {}", expression);
1702        }
1703    }
1704
1705    #[test]
1706    fn test_columns_in_expr() {
1707        let expr = col("s0").gt(lit("value")).and(
1708            col("st")
1709                .field("st")
1710                .field("s2")
1711                .eq(lit("value"))
1712                .or(col("st")
1713                    .field("s1")
1714                    .in_list(vec![lit("value 1"), lit("value 2")], false)),
1715        );
1716
1717        let columns = Planner::column_names_in_expr(&expr);
1718        assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1719    }
1720
1721    #[test]
1722    fn test_parse_binary_expr() {
1723        let bin_str = "x'616263'";
1724
1725        let schema = Arc::new(Schema::new(vec![Field::new(
1726            "binary",
1727            DataType::Binary,
1728            true,
1729        )]));
1730        let planner = Planner::new(schema);
1731        let expr = planner.parse_expr(bin_str).unwrap();
1732        assert_eq!(
1733            expr,
1734            Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1735        );
1736    }
1737
1738    #[test]
1739    fn test_lance_context_provider_expr_planners() {
1740        let ctx_provider = LanceContextProvider::default();
1741        assert!(!ctx_provider.get_expr_planners().is_empty());
1742    }
1743
1744    #[test]
1745    fn test_regexp_match_and_non_empty_captions() {
1746        // Repro for a bug where regexp_match inside an AND chain wasn't coerced to boolean,
1747        // causing planning/evaluation failures. This should evaluate successfully.
1748        let schema = Arc::new(Schema::new(vec![
1749            Field::new("keywords", DataType::Utf8, true),
1750            Field::new("natural_caption", DataType::Utf8, true),
1751            Field::new("poetic_caption", DataType::Utf8, true),
1752        ]));
1753
1754        let planner = Planner::new(schema.clone());
1755
1756        let expr = planner
1757            .parse_filter(
1758                "regexp_match(keywords, 'Liberty|revolution') AND \
1759                 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1760                  poetic_caption IS NOT NULL AND poetic_caption <> '')",
1761            )
1762            .unwrap();
1763
1764        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1765
1766        let batch = RecordBatch::try_new(
1767            schema,
1768            vec![
1769                Arc::new(StringArray::from(vec![
1770                    Some("Liberty for all"),
1771                    Some("peace"),
1772                    Some("revolution now"),
1773                    Some("Liberty"),
1774                    Some("revolutionary"),
1775                    Some("none"),
1776                ])) as ArrayRef,
1777                Arc::new(StringArray::from(vec![
1778                    Some("a"),
1779                    Some("b"),
1780                    None,
1781                    Some(""),
1782                    Some("c"),
1783                    Some("d"),
1784                ])) as ArrayRef,
1785                Arc::new(StringArray::from(vec![
1786                    Some("x"),
1787                    Some(""),
1788                    Some("y"),
1789                    Some("z"),
1790                    None,
1791                    Some("w"),
1792                ])) as ArrayRef,
1793            ],
1794        )
1795        .unwrap();
1796
1797        let result = physical_expr.evaluate(&batch).unwrap();
1798        assert_eq!(
1799            result.into_array(0).unwrap().as_ref(),
1800            &BooleanArray::from(vec![true, false, false, false, false, false])
1801        );
1802    }
1803
1804    #[test]
1805    fn test_regexp_match_infer_error_without_boolean_coercion() {
1806        // With the fix applied, using parse_filter should coerce regexp_match to boolean
1807        // even when nested in a larger AND expression, so this should plan successfully.
1808        let schema = Arc::new(Schema::new(vec![
1809            Field::new("keywords", DataType::Utf8, true),
1810            Field::new("natural_caption", DataType::Utf8, true),
1811            Field::new("poetic_caption", DataType::Utf8, true),
1812        ]));
1813
1814        let planner = Planner::new(schema);
1815
1816        let expr = planner
1817            .parse_filter(
1818                "regexp_match(keywords, 'Liberty|revolution') AND \
1819                 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1820                  poetic_caption IS NOT NULL AND poetic_caption <> '')",
1821            )
1822            .unwrap();
1823
1824        // Should not panic
1825        let _physical = planner.create_physical_expr(&expr).unwrap();
1826    }
1827}