Skip to main content

lance_datafusion/
planner.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Exec plan planner
5
6use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::exec::{LanceExecutionOptions, get_session_context};
11use crate::expr::safe_coerce_scalar;
12use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
13use crate::sql::{parse_sql_expr, parse_sql_filter};
14use arrow::compute::CastOptions;
15use arrow_array::ListArray;
16use arrow_buffer::OffsetBuffer;
17use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
18use arrow_select::concat::concat;
19use datafusion::common::DFSchema;
20use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
21use datafusion::config::ConfigOptions;
22use datafusion::error::Result as DFResult;
23use datafusion::execution::context::SessionState;
24use datafusion::logical_expr::expr::ScalarFunction;
25use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
26use datafusion::logical_expr::{
27    AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
28    Signature, Volatility, WindowUDF,
29};
30use datafusion::optimizer::simplify_expressions::SimplifyContext;
31use datafusion::sql::planner::{
32    ContextProvider, NullOrdering, ParserOptions, PlannerContext, SqlToRel,
33};
34use datafusion::sql::sqlparser::ast::{
35    AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36    Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
37    ObjectNamePart, Subscript, TimezoneInfo, TypedString, UnaryOperator, Value, ValueWithSpan,
38};
39use datafusion::{
40    common::Column,
41    logical_expr::{Between, BinaryExpr, Like, Operator},
42    physical_expr::execution_props::ExecutionProps,
43    physical_plan::PhysicalExpr,
44    prelude::Expr,
45    scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use lance_core::error::LanceOptionExt;
51
52use chrono::Utc;
53use lance_core::{Error, Result};
54
55/// Encode a JSON string into a JSONB `LargeBinary` literal expression.
56fn encode_jsonb(json_str: &str) -> Result<Expr> {
57    let bytes = lance_arrow::json::encode_json(json_str)
58        .map_err(|e| Error::invalid_input(format!("Failed to encode JSONB: {e}")))?;
59    Ok(Expr::Literal(ScalarValue::LargeBinary(Some(bytes)), None))
60}
61
62#[derive(Debug, Clone, Eq, PartialEq, Hash)]
63struct CastListF16Udf {
64    signature: Signature,
65}
66
67impl CastListF16Udf {
68    pub fn new() -> Self {
69        Self {
70            signature: Signature::any(1, Volatility::Immutable),
71        }
72    }
73}
74
75impl ScalarUDFImpl for CastListF16Udf {
76    fn as_any(&self) -> &dyn std::any::Any {
77        self
78    }
79
80    fn name(&self) -> &str {
81        "_cast_list_f16"
82    }
83
84    fn signature(&self) -> &Signature {
85        &self.signature
86    }
87
88    fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
89        let input = &arg_types[0];
90        match input {
91            ArrowDataType::FixedSizeList(field, size) => {
92                if field.data_type() != &ArrowDataType::Float32
93                    && field.data_type() != &ArrowDataType::Float16
94                {
95                    return Err(datafusion::error::DataFusionError::Execution(
96                        "cast_list_f16 only supports list of float32 or float16".to_string(),
97                    ));
98                }
99                Ok(ArrowDataType::FixedSizeList(
100                    Arc::new(Field::new(
101                        field.name(),
102                        ArrowDataType::Float16,
103                        field.is_nullable(),
104                    )),
105                    *size,
106                ))
107            }
108            ArrowDataType::List(field) => {
109                if field.data_type() != &ArrowDataType::Float32
110                    && field.data_type() != &ArrowDataType::Float16
111                {
112                    return Err(datafusion::error::DataFusionError::Execution(
113                        "cast_list_f16 only supports list of float32 or float16".to_string(),
114                    ));
115                }
116                Ok(ArrowDataType::List(Arc::new(Field::new(
117                    field.name(),
118                    ArrowDataType::Float16,
119                    field.is_nullable(),
120                ))))
121            }
122            _ => Err(datafusion::error::DataFusionError::Execution(
123                "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
124            )),
125        }
126    }
127
128    fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
129        let ColumnarValue::Array(arr) = &func_args.args[0] else {
130            return Err(datafusion::error::DataFusionError::Execution(
131                "cast_list_f16 only supports array arguments".to_string(),
132            ));
133        };
134
135        let to_type = match arr.data_type() {
136            ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
137                Arc::new(Field::new(
138                    field.name(),
139                    ArrowDataType::Float16,
140                    field.is_nullable(),
141                )),
142                *size,
143            ),
144            ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
145                field.name(),
146                ArrowDataType::Float16,
147                field.is_nullable(),
148            ))),
149            _ => {
150                return Err(datafusion::error::DataFusionError::Execution(
151                    "cast_list_f16 only supports array arguments".to_string(),
152                ));
153            }
154        };
155
156        let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
157        Ok(ColumnarValue::Array(res))
158    }
159}
160
161// Adapter that instructs datafusion how lance expects expressions to be interpreted
162struct LanceContextProvider {
163    options: datafusion::config::ConfigOptions,
164    state: SessionState,
165    expr_planners: Vec<Arc<dyn ExprPlanner>>,
166}
167
168impl Default for LanceContextProvider {
169    fn default() -> Self {
170        let ctx = get_session_context(&LanceExecutionOptions::default());
171        let state = ctx.state();
172        let expr_planners = state.expr_planners().to_vec();
173
174        Self {
175            options: ConfigOptions::default(),
176            state,
177            expr_planners,
178        }
179    }
180}
181
182impl ContextProvider for LanceContextProvider {
183    fn get_table_source(
184        &self,
185        name: datafusion::sql::TableReference,
186    ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
187        Err(datafusion::error::DataFusionError::NotImplemented(format!(
188            "Attempt to reference inner table {} not supported",
189            name
190        )))
191    }
192
193    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
194        self.state.aggregate_functions().get(name).cloned()
195    }
196
197    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
198        self.state.window_functions().get(name).cloned()
199    }
200
201    fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
202        match f {
203            // TODO: cast should go thru CAST syntax instead of UDF
204            // Going thru UDF makes it hard for the optimizer to find no-ops
205            "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
206            _ => self.state.scalar_functions().get(f).cloned(),
207        }
208    }
209
210    fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
211        // Variables (things like @@LANGUAGE) not supported
212        None
213    }
214
215    fn options(&self) -> &datafusion::config::ConfigOptions {
216        &self.options
217    }
218
219    fn udf_names(&self) -> Vec<String> {
220        self.state.scalar_functions().keys().cloned().collect()
221    }
222
223    fn udaf_names(&self) -> Vec<String> {
224        self.state.aggregate_functions().keys().cloned().collect()
225    }
226
227    fn udwf_names(&self) -> Vec<String> {
228        self.state.window_functions().keys().cloned().collect()
229    }
230
231    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
232        &self.expr_planners
233    }
234}
235
236pub struct Planner {
237    schema: SchemaRef,
238    context_provider: LanceContextProvider,
239    enable_relations: bool,
240}
241
242impl Planner {
243    pub fn new(schema: SchemaRef) -> Self {
244        Self {
245            schema,
246            context_provider: LanceContextProvider::default(),
247            enable_relations: false,
248        }
249    }
250
251    /// If passed with `true`, then the first identifier in column reference
252    /// is parsed as the relation. For example, `table.field.inner` will be
253    /// read as the nested field `field.inner` (`inner` on struct field `field`)
254    /// on the `table` relation. If `false` (the default), then no relations
255    /// are used and all identifiers are assumed to be a nested column path.
256    pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
257        self.enable_relations = enable_relations;
258        self
259    }
260
261    /// Resolve a column name using case-insensitive matching against the schema.
262    /// Returns the actual field name if found, otherwise returns the original name.
263    fn resolve_column_name(&self, name: &str) -> String {
264        // Try exact match first
265        if self.schema.field_with_name(name).is_ok() {
266            return name.to_string();
267        }
268        // Fall back to case-insensitive match
269        for field in self.schema.fields() {
270            if field.name().eq_ignore_ascii_case(name) {
271                return field.name().clone();
272            }
273        }
274        // Not found in schema - return original (might be computed column, system column, etc.)
275        name.to_string()
276    }
277
278    fn column(&self, idents: &[Ident]) -> Expr {
279        fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
280            for ident in idents {
281                *expr = Expr::ScalarFunction(ScalarFunction {
282                    args: vec![
283                        std::mem::take(expr),
284                        Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
285                    ],
286                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
287                });
288            }
289        }
290
291        if self.enable_relations && idents.len() > 1 {
292            // Create qualified column reference (relation.column)
293            let relation = &idents[0].value;
294            let column_name = self.resolve_column_name(&idents[1].value);
295            let column = Expr::Column(Column::new(Some(relation.clone()), column_name));
296            let mut result = column;
297            handle_remaining_idents(&mut result, &idents[2..]);
298            result
299        } else {
300            // Default behavior - treat as struct field access
301            // Use resolved column name to handle case-insensitive matching
302            let resolved_name = self.resolve_column_name(&idents[0].value);
303            let mut column = Expr::Column(Column::from_name(resolved_name));
304            handle_remaining_idents(&mut column, &idents[1..]);
305            column
306        }
307    }
308
309    fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
310        Ok(match op {
311            BinaryOperator::Plus => Operator::Plus,
312            BinaryOperator::Minus => Operator::Minus,
313            BinaryOperator::Multiply => Operator::Multiply,
314            BinaryOperator::Divide => Operator::Divide,
315            BinaryOperator::Modulo => Operator::Modulo,
316            BinaryOperator::StringConcat => Operator::StringConcat,
317            BinaryOperator::Gt => Operator::Gt,
318            BinaryOperator::Lt => Operator::Lt,
319            BinaryOperator::GtEq => Operator::GtEq,
320            BinaryOperator::LtEq => Operator::LtEq,
321            BinaryOperator::Eq => Operator::Eq,
322            BinaryOperator::NotEq => Operator::NotEq,
323            BinaryOperator::And => Operator::And,
324            BinaryOperator::Or => Operator::Or,
325            _ => {
326                return Err(Error::invalid_input(format!(
327                    "Operator {op} is not supported"
328                )));
329            }
330        })
331    }
332
333    fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
334        Ok(Expr::BinaryExpr(BinaryExpr::new(
335            Box::new(self.parse_sql_expr(left)?),
336            self.binary_op(op)?,
337            Box::new(self.parse_sql_expr(right)?),
338        )))
339    }
340
341    fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
342        Ok(match op {
343            UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
344                Expr::Not(Box::new(self.parse_sql_expr(expr)?))
345            }
346
347            UnaryOperator::Minus => {
348                use datafusion::logical_expr::lit;
349                match expr {
350                    SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
351                        Ok(n) => lit(-n),
352                        Err(_) => lit(-n
353                            .parse::<f64>()
354                            .map_err(|_e| {
355                                Error::invalid_input(format!("negative operator can be only applied to integer and float operands, got: {n}"))
356                            })?),
357                    },
358                    _ => {
359                        Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
360                    }
361                }
362            }
363
364            _ => {
365                return Err(Error::invalid_input(format!(
366                    "Unary operator '{:?}' is not supported",
367                    op
368                )));
369            }
370        })
371    }
372
373    // See datafusion `sqlToRel::parse_sql_number()`
374    fn number(&self, value: &str, negative: bool) -> Result<Expr> {
375        use datafusion::logical_expr::lit;
376        let value: Cow<str> = if negative {
377            Cow::Owned(format!("-{}", value))
378        } else {
379            Cow::Borrowed(value)
380        };
381        if let Ok(n) = value.parse::<i64>() {
382            Ok(lit(n))
383        } else {
384            value.parse::<f64>().map(lit).map_err(|_| {
385                Error::invalid_input(format!("'{value}' is not supported number value."))
386            })
387        }
388    }
389
390    fn value(&self, value: &Value) -> Result<Expr> {
391        Ok(match value {
392            Value::Number(v, _) => self.number(v.as_str(), false)?,
393            Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
394            Value::HexStringLiteral(hsl) => {
395                Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
396            }
397            Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
398            Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
399            Value::Null => Expr::Literal(ScalarValue::Null, None),
400            _ => todo!(),
401        })
402    }
403
404    fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
405        match func_args {
406            FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
407            _ => Err(Error::invalid_input(format!(
408                "Unsupported function args: {:?}",
409                func_args
410            ))),
411        }
412    }
413
414    // We now use datafusion to parse functions.  This allows us to use datafusion's
415    // entire collection of functions (previously we had just hard-coded support for two functions).
416    //
417    // Unfortunately, one of those two functions was is_valid and the reason we needed it was because
418    // this is a function that comes from duckdb.  Datafusion does not consider is_valid to be a function
419    // but rather an AST node (Expr::IsNotNull) and so we need to handle this case specially.
420    fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
421        match &func.args {
422            FunctionArguments::List(args) => {
423                if func.name.0.len() != 1 {
424                    return Err(Error::invalid_input(format!(
425                        "Function name must have 1 part, got: {:?}",
426                        func.name.0
427                    )));
428                }
429                Ok(Expr::IsNotNull(Box::new(
430                    self.parse_function_args(&args.args[0])?,
431                )))
432            }
433            _ => Err(Error::invalid_input(format!(
434                "Unsupported function args: {:?}",
435                &func.args
436            ))),
437        }
438    }
439
440    fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
441        if let SQLExpr::Function(function) = &function
442            && let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first()
443            && &name.value == "is_valid"
444        {
445            return self.legacy_parse_function(function);
446        }
447        let sql_to_rel = SqlToRel::new_with_options(
448            &self.context_provider,
449            ParserOptions {
450                parse_float_as_decimal: false,
451                enable_ident_normalization: false,
452                support_varchar_with_length: false,
453                enable_options_value_normalization: false,
454                collect_spans: false,
455                map_string_types_to_utf8view: false,
456                default_null_ordering: NullOrdering::NullsMax,
457            },
458        );
459
460        let mut planner_context = PlannerContext::default();
461        let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
462        sql_to_rel
463            .sql_to_expr(function, &schema, &mut planner_context)
464            .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}")))
465    }
466
467    fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
468        const SUPPORTED_TYPES: [&str; 13] = [
469            "int [unsigned]",
470            "tinyint [unsigned]",
471            "smallint [unsigned]",
472            "bigint [unsigned]",
473            "float",
474            "double",
475            "string",
476            "binary",
477            "date",
478            "timestamp(precision)",
479            "datetime(precision)",
480            "decimal(precision,scale)",
481            "boolean",
482        ];
483        match data_type {
484            SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
485            SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
486            SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
487            SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
488            SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
489            SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
490            SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
491            SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
492            SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
493            SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
494            SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
495            SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
496                Ok(ArrowDataType::UInt32)
497            }
498            SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
499            SQLDataType::Date => Ok(ArrowDataType::Date32),
500            SQLDataType::Timestamp(resolution, tz) => {
501                match tz {
502                    TimezoneInfo::None => {}
503                    _ => {
504                        return Err(Error::invalid_input(
505                            "Timezone not supported in timestamp".to_string(),
506                        ));
507                    }
508                };
509                let time_unit = match resolution {
510                    // Default to microsecond to match PyArrow
511                    None => TimeUnit::Microsecond,
512                    Some(0) => TimeUnit::Second,
513                    Some(3) => TimeUnit::Millisecond,
514                    Some(6) => TimeUnit::Microsecond,
515                    Some(9) => TimeUnit::Nanosecond,
516                    _ => {
517                        return Err(Error::invalid_input(format!(
518                            "Unsupported datetime resolution: {:?}",
519                            resolution
520                        )));
521                    }
522                };
523                Ok(ArrowDataType::Timestamp(time_unit, None))
524            }
525            SQLDataType::Datetime(resolution) => {
526                let time_unit = match resolution {
527                    None => TimeUnit::Microsecond,
528                    Some(0) => TimeUnit::Second,
529                    Some(3) => TimeUnit::Millisecond,
530                    Some(6) => TimeUnit::Microsecond,
531                    Some(9) => TimeUnit::Nanosecond,
532                    _ => {
533                        return Err(Error::invalid_input(format!(
534                            "Unsupported datetime resolution: {:?}",
535                            resolution
536                        )));
537                    }
538                };
539                Ok(ArrowDataType::Timestamp(time_unit, None))
540            }
541            SQLDataType::Decimal(number_info) => match number_info {
542                ExactNumberInfo::PrecisionAndScale(precision, scale) => {
543                    Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
544                }
545                _ => Err(Error::invalid_input(format!(
546                    "Must provide precision and scale for decimal: {:?}",
547                    number_info
548                ))),
549            },
550            _ => Err(Error::invalid_input(format!(
551                "Unsupported data type: {:?}. Supported types: {:?}",
552                data_type, SUPPORTED_TYPES
553            ))),
554        }
555    }
556
557    fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
558        let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
559        for planner in self.context_provider.get_expr_planners() {
560            match planner.plan_field_access(field_access_expr, &df_schema)? {
561                PlannerResult::Planned(expr) => return Ok(expr),
562                PlannerResult::Original(expr) => {
563                    field_access_expr = expr;
564                }
565            }
566        }
567        Err(Error::invalid_input("Field access could not be planned"))
568    }
569
570    fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
571        match expr {
572            SQLExpr::Identifier(id) => {
573                // Users can pass string literals wrapped in `"`.
574                // (Normally SQL only allows single quotes.)
575                if id.quote_style == Some('"') {
576                    Ok(Expr::Literal(
577                        ScalarValue::Utf8(Some(id.value.clone())),
578                        None,
579                    ))
580                // Users can wrap identifiers with ` to reference non-standard
581                // names, such as uppercase or spaces.
582                } else if id.quote_style == Some('`') {
583                    Ok(Expr::Column(Column::from_name(id.value.clone())))
584                } else {
585                    Ok(self.column(vec![id.clone()].as_slice()))
586                }
587            }
588            SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
589            SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
590            SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
591            SQLExpr::Value(value) => self.value(&value.value),
592            SQLExpr::Array(SQLArray { elem, .. }) => {
593                let mut values = vec![];
594
595                let array_literal_error = |pos: usize, value: &_| {
596                    Err(Error::invalid_input(format!(
597                        "Expected a literal value in array, instead got {} at position {}",
598                        value, pos
599                    )))
600                };
601
602                for (pos, expr) in elem.iter().enumerate() {
603                    match expr {
604                        SQLExpr::Value(value) => {
605                            if let Expr::Literal(value, _) = self.value(&value.value)? {
606                                values.push(value);
607                            } else {
608                                return array_literal_error(pos, expr);
609                            }
610                        }
611                        SQLExpr::UnaryOp {
612                            op: UnaryOperator::Minus,
613                            expr,
614                        } => {
615                            if let SQLExpr::Value(ValueWithSpan {
616                                value: Value::Number(number, _),
617                                ..
618                            }) = expr.as_ref()
619                            {
620                                if let Expr::Literal(value, _) = self.number(number, true)? {
621                                    values.push(value);
622                                } else {
623                                    return array_literal_error(pos, expr);
624                                }
625                            } else {
626                                return array_literal_error(pos, expr);
627                            }
628                        }
629                        _ => {
630                            return array_literal_error(pos, expr);
631                        }
632                    }
633                }
634
635                let field = if !values.is_empty() {
636                    let data_type = values[0].data_type();
637
638                    for value in &mut values {
639                        if value.data_type() != data_type {
640                            *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type())))?;
641                        }
642                    }
643                    Field::new("item", data_type, true)
644                } else {
645                    Field::new("item", ArrowDataType::Null, true)
646                };
647
648                let values = values
649                    .into_iter()
650                    .map(|v| v.to_array().map_err(Error::from))
651                    .collect::<Result<Vec<_>>>()?;
652                let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
653                let values = concat(&array_refs)?;
654                let values = ListArray::try_new(
655                    field.into(),
656                    OffsetBuffer::from_lengths([values.len()]),
657                    values,
658                    None,
659                )?;
660
661                Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
662            }
663            // JSONB literal: jsonb '{"key": "value"}'
664            SQLExpr::TypedString(TypedString {
665                data_type: SQLDataType::JSONB,
666                value,
667                ..
668            }) => match &value.value {
669                Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => encode_jsonb(s),
670                _ => Err(Error::invalid_input(
671                    "Expected a string value for JSONB literal",
672                )),
673            },
674            // For example, DATE '2020-01-01'
675            SQLExpr::TypedString(TypedString {
676                data_type, value, ..
677            }) => {
678                let value = value.clone().into_string().expect_ok()?;
679                Ok(Expr::Cast(datafusion::logical_expr::Cast {
680                    expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
681                    data_type: self.parse_type(data_type)?,
682                }))
683            }
684            SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
685            SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
686            SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
687            SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
688            SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
689            SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
690            SQLExpr::InList {
691                expr,
692                list,
693                negated,
694            } => {
695                let value_expr = self.parse_sql_expr(expr)?;
696                let list_exprs = list
697                    .iter()
698                    .map(|e| self.parse_sql_expr(e))
699                    .collect::<Result<Vec<_>>>()?;
700                Ok(value_expr.in_list(list_exprs, *negated))
701            }
702            SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
703            SQLExpr::Function(_) => self.parse_function(expr.clone()),
704            SQLExpr::ILike {
705                negated,
706                expr,
707                pattern,
708                escape_char,
709                any: _,
710            } => Ok(Expr::Like(Like::new(
711                *negated,
712                Box::new(self.parse_sql_expr(expr)?),
713                Box::new(self.parse_sql_expr(pattern)?),
714                match escape_char {
715                    Some(Value::SingleQuotedString(char)) => char.chars().next(),
716                    Some(value) => {
717                        return Err(Error::invalid_input(format!(
718                            "Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}",
719                            value
720                        )));
721                    }
722                    None => None,
723                },
724                true,
725            ))),
726            SQLExpr::Like {
727                negated,
728                expr,
729                pattern,
730                escape_char,
731                any: _,
732            } => Ok(Expr::Like(Like::new(
733                *negated,
734                Box::new(self.parse_sql_expr(expr)?),
735                Box::new(self.parse_sql_expr(pattern)?),
736                match escape_char {
737                    Some(Value::SingleQuotedString(char)) => char.chars().next(),
738                    Some(value) => {
739                        return Err(Error::invalid_input(format!(
740                            "Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}",
741                            value
742                        )));
743                    }
744                    None => None,
745                },
746                false,
747            ))),
748            // JSONB cast: CAST('...' AS JSONB) or '...'::jsonb
749            SQLExpr::Cast {
750                data_type: SQLDataType::JSONB,
751                expr: inner,
752                ..
753            } => match inner.as_ref() {
754                SQLExpr::Value(ValueWithSpan {
755                    value: Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
756                    ..
757                }) => encode_jsonb(s),
758                _ => Err(Error::invalid_input(
759                    "CAST to JSONB only supports string literals",
760                )),
761            },
762            SQLExpr::Cast {
763                expr,
764                data_type,
765                kind,
766                ..
767            } => match kind {
768                datafusion::sql::sqlparser::ast::CastKind::TryCast
769                | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
770                    Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
771                        expr: Box::new(self.parse_sql_expr(expr)?),
772                        data_type: self.parse_type(data_type)?,
773                    }))
774                }
775                _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
776                    expr: Box::new(self.parse_sql_expr(expr)?),
777                    data_type: self.parse_type(data_type)?,
778                })),
779            },
780            SQLExpr::JsonAccess { .. } => Err(Error::invalid_input("JSON access is not supported")),
781            SQLExpr::CompoundFieldAccess { root, access_chain } => {
782                let mut expr = self.parse_sql_expr(root)?;
783
784                for access in access_chain {
785                    let field_access = match access {
786                        // x.y or x['y']
787                        AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
788                        | AccessExpr::Subscript(Subscript::Index {
789                            index:
790                                SQLExpr::Value(ValueWithSpan {
791                                    value:
792                                        Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
793                                    ..
794                                }),
795                        }) => GetFieldAccess::NamedStructField {
796                            name: ScalarValue::from(s.as_str()),
797                        },
798                        AccessExpr::Subscript(Subscript::Index { index }) => {
799                            let key = Box::new(self.parse_sql_expr(index)?);
800                            GetFieldAccess::ListIndex { key }
801                        }
802                        AccessExpr::Subscript(Subscript::Slice { .. }) => {
803                            return Err(Error::invalid_input("Slice subscript is not supported"));
804                        }
805                        _ => {
806                            // Handle other cases like JSON access
807                            // Note: JSON access is not supported in lance
808                            return Err(Error::invalid_input(
809                                "Only dot notation or index access is supported for field access",
810                            ));
811                        }
812                    };
813
814                    let field_access_expr = RawFieldAccessExpr { expr, field_access };
815                    expr = self.plan_field_access(field_access_expr)?;
816                }
817
818                Ok(expr)
819            }
820            SQLExpr::Between {
821                expr,
822                negated,
823                low,
824                high,
825            } => {
826                // Parse the main expression and bounds
827                let expr = self.parse_sql_expr(expr)?;
828                let low = self.parse_sql_expr(low)?;
829                let high = self.parse_sql_expr(high)?;
830
831                let between = Expr::Between(Between::new(
832                    Box::new(expr),
833                    *negated,
834                    Box::new(low),
835                    Box::new(high),
836                ));
837                Ok(between)
838            }
839            _ => Err(Error::invalid_input(format!(
840                "Expression '{expr}' is not supported SQL in lance"
841            ))),
842        }
843    }
844
845    /// Create Logical [Expr] from a SQL filter clause.
846    ///
847    /// Note: the returned expression must be passed through `optimize_expr()`
848    /// before being passed to `create_physical_expr()`.
849    pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
850        // Allow sqlparser to parse filter as part of ONE SQL statement.
851        let ast_expr = parse_sql_filter(filter)?;
852        let expr = self.parse_sql_expr(&ast_expr)?;
853        let schema = Schema::try_from(self.schema.as_ref())?;
854        let resolved = resolve_expr(&expr, &schema).map_err(|e| {
855            Error::invalid_input(format!("Error resolving filter expression {filter}: {e}"))
856        })?;
857
858        Ok(coerce_filter_type_to_boolean(resolved))
859    }
860
861    /// Create Logical [Expr] from a SQL expression.
862    ///
863    /// Note: the returned expression must be passed through `optimize_filter()`
864    /// before being passed to `create_physical_expr()`.
865    pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
866        // First check if it's a simple column reference (no operators, functions, etc.)
867        // resolve_column_name tries exact match first, then falls back to case-insensitive
868        let resolved_name = self.resolve_column_name(expr);
869        if self.schema.field_with_name(&resolved_name).is_ok() {
870            return Ok(Expr::Column(Column::from_name(resolved_name)));
871        }
872
873        // Parse as SQL expression
874        let ast_expr = parse_sql_expr(expr)?;
875        let expr = self.parse_sql_expr(&ast_expr)?;
876        let schema = Schema::try_from(self.schema.as_ref())?;
877        let resolved = resolve_expr(&expr, &schema)?;
878        Ok(resolved)
879    }
880
881    /// Try to decode bytes from hex literal string.
882    ///
883    /// Copied from datafusion because this is not public.
884    ///
885    /// TODO: use SqlToRel from Datafusion directly?
886    fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
887        let hex_bytes = s.as_bytes();
888        let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
889
890        let start_idx = hex_bytes.len() % 2;
891        if start_idx > 0 {
892            // The first byte is formed of only one char.
893            decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
894        }
895
896        for i in (start_idx..hex_bytes.len()).step_by(2) {
897            let high = Self::try_decode_hex_char(hex_bytes[i])?;
898            let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
899            decoded_bytes.push((high << 4) | low);
900        }
901
902        Some(decoded_bytes)
903    }
904
905    /// Try to decode a byte from a hex char.
906    ///
907    /// None will be returned if the input char is hex-invalid.
908    const fn try_decode_hex_char(c: u8) -> Option<u8> {
909        match c {
910            b'A'..=b'F' => Some(c - b'A' + 10),
911            b'a'..=b'f' => Some(c - b'a' + 10),
912            b'0'..=b'9' => Some(c - b'0'),
913            _ => None,
914        }
915    }
916
917    /// Optimize the filter expression and coerce data types.
918    pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
919        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
920
921        // DataFusion needs the simplify and coerce passes to be applied before
922        // expressions can be handled by the physical planner.
923        let props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
924        let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
925        let simplifier =
926            datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
927
928        let expr = simplifier.simplify(expr)?;
929        let expr = simplifier.coerce(expr, &df_schema)?;
930
931        Ok(expr)
932    }
933
934    /// Create the [`PhysicalExpr`] from a logical [`Expr`]
935    pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
936        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
937        Ok(datafusion::physical_expr::create_physical_expr(
938            expr,
939            df_schema.as_ref(),
940            &Default::default(),
941        )?)
942    }
943
944    /// Collect the columns in the expression.
945    ///
946    /// The columns are returned in sorted order.
947    ///
948    /// If the expr refers to nested columns these will be returned
949    /// as dotted paths (x.y.z)
950    pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
951        let mut visitor = ColumnCapturingVisitor {
952            current_path: VecDeque::new(),
953            columns: BTreeSet::new(),
954        };
955        expr.visit(&mut visitor).unwrap();
956        visitor.columns.into_iter().collect()
957    }
958}
959
960struct ColumnCapturingVisitor {
961    // Current column path. If this is empty, we are not in a column expression.
962    current_path: VecDeque<String>,
963    columns: BTreeSet<String>,
964}
965
966impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
967    type Node = Expr;
968
969    fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
970        match node {
971            Expr::Column(Column { name, .. }) => {
972                // Build the field path from the column name and any nested fields
973                // The nested field names from get_field already come as literal strings,
974                // so we just need to concatenate them properly
975                let mut path = name.clone();
976                for part in self.current_path.drain(..) {
977                    path.push('.');
978                    // Check if the part needs quoting (contains dots)
979                    if part.contains('.') || part.contains('`') {
980                        // Quote the field name with backticks and escape any existing backticks
981                        let escaped = part.replace('`', "``");
982                        path.push('`');
983                        path.push_str(&escaped);
984                        path.push('`');
985                    } else {
986                        path.push_str(&part);
987                    }
988                }
989                self.columns.insert(path);
990                self.current_path.clear();
991            }
992            Expr::ScalarFunction(udf) => {
993                if udf.name() == GetFieldFunc::default().name() {
994                    if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
995                        self.current_path.push_front(name.to_string())
996                    } else {
997                        self.current_path.clear();
998                    }
999                } else {
1000                    self.current_path.clear();
1001                }
1002            }
1003            _ => {
1004                self.current_path.clear();
1005            }
1006        }
1007
1008        Ok(TreeNodeRecursion::Continue)
1009    }
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014
1015    use crate::logical_expr::ExprExt;
1016
1017    use super::*;
1018
1019    use arrow::datatypes::Float64Type;
1020    use arrow_array::{
1021        ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
1022        StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1023        TimestampNanosecondArray, TimestampSecondArray,
1024    };
1025    use arrow_schema::{DataType, Fields, Schema};
1026    use datafusion::{
1027        logical_expr::{Cast, col, lit},
1028        prelude::{array_element, get_field},
1029    };
1030    use datafusion_functions::core::expr_ext::FieldAccessor;
1031
1032    #[test]
1033    fn test_parse_filter_simple() {
1034        let schema = Arc::new(Schema::new(vec![
1035            Field::new("i", DataType::Int32, false),
1036            Field::new("s", DataType::Utf8, true),
1037            Field::new(
1038                "st",
1039                DataType::Struct(Fields::from(vec![
1040                    Field::new("x", DataType::Float32, false),
1041                    Field::new("y", DataType::Float32, false),
1042                ])),
1043                true,
1044            ),
1045        ]));
1046
1047        let planner = Planner::new(schema.clone());
1048
1049        let expected = col("i")
1050            .gt(lit(3_i32))
1051            .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1052            .and(
1053                col("s")
1054                    .eq(lit("str-4"))
1055                    .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1056            );
1057
1058        // double quotes
1059        let expr = planner
1060            .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1061            .unwrap();
1062        assert_eq!(expr, expected);
1063
1064        // single quote
1065        let expr = planner
1066            .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1067            .unwrap();
1068
1069        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1070
1071        let batch = RecordBatch::try_new(
1072            schema,
1073            vec![
1074                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1075                Arc::new(StringArray::from_iter_values(
1076                    (0..10).map(|v| format!("str-{}", v)),
1077                )),
1078                Arc::new(StructArray::from(vec![
1079                    (
1080                        Arc::new(Field::new("x", DataType::Float32, false)),
1081                        Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1082                            as ArrayRef,
1083                    ),
1084                    (
1085                        Arc::new(Field::new("y", DataType::Float32, false)),
1086                        Arc::new(Float32Array::from_iter_values(
1087                            (0..10).map(|v| (v * 10) as f32),
1088                        )),
1089                    ),
1090                ])),
1091            ],
1092        )
1093        .unwrap();
1094        let predicates = physical_expr.evaluate(&batch).unwrap();
1095        assert_eq!(
1096            predicates.into_array(0).unwrap().as_ref(),
1097            &BooleanArray::from(vec![
1098                false, false, false, false, true, true, false, false, false, false
1099            ])
1100        );
1101    }
1102
1103    #[test]
1104    fn test_nested_col_refs() {
1105        let schema = Arc::new(Schema::new(vec![
1106            Field::new("s0", DataType::Utf8, true),
1107            Field::new(
1108                "st",
1109                DataType::Struct(Fields::from(vec![
1110                    Field::new("s1", DataType::Utf8, true),
1111                    Field::new(
1112                        "st",
1113                        DataType::Struct(Fields::from(vec![Field::new(
1114                            "s2",
1115                            DataType::Utf8,
1116                            true,
1117                        )])),
1118                        true,
1119                    ),
1120                ])),
1121                true,
1122            ),
1123        ]));
1124
1125        let planner = Planner::new(schema);
1126
1127        fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1128            let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1129            assert!(matches!(
1130                expr,
1131                Expr::BinaryExpr(BinaryExpr {
1132                    left: _,
1133                    op: Operator::Eq,
1134                    right: _
1135                })
1136            ));
1137            if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1138                assert_eq!(left.as_ref(), expected);
1139            }
1140        }
1141
1142        let expected = Expr::Column(Column::new_unqualified("s0"));
1143        assert_column_eq(&planner, "s0", &expected);
1144        assert_column_eq(&planner, "`s0`", &expected);
1145
1146        let expected = Expr::ScalarFunction(ScalarFunction {
1147            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1148            args: vec![
1149                Expr::Column(Column::new_unqualified("st")),
1150                Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1151            ],
1152        });
1153        assert_column_eq(&planner, "st.s1", &expected);
1154        assert_column_eq(&planner, "`st`.`s1`", &expected);
1155        assert_column_eq(&planner, "st.`s1`", &expected);
1156
1157        let expected = Expr::ScalarFunction(ScalarFunction {
1158            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1159            args: vec![
1160                Expr::ScalarFunction(ScalarFunction {
1161                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1162                    args: vec![
1163                        Expr::Column(Column::new_unqualified("st")),
1164                        Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1165                    ],
1166                }),
1167                Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1168            ],
1169        });
1170
1171        assert_column_eq(&planner, "st.st.s2", &expected);
1172        assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1173        assert_column_eq(&planner, "st.st.`s2`", &expected);
1174        assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1175    }
1176
1177    #[test]
1178    fn test_nested_list_refs() {
1179        let schema = Arc::new(Schema::new(vec![Field::new(
1180            "l",
1181            DataType::List(Arc::new(Field::new(
1182                "item",
1183                DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1184                true,
1185            ))),
1186            true,
1187        )]));
1188
1189        let planner = Planner::new(schema);
1190
1191        let expected = array_element(col("l"), lit(0_i64));
1192        let expr = planner.parse_expr("l[0]").unwrap();
1193        assert_eq!(expr, expected);
1194
1195        let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1196        let expr = planner.parse_expr("l[0]['f1']").unwrap();
1197        assert_eq!(expr, expected);
1198
1199        // FIXME: This should work, but sqlparser doesn't recognize anything
1200        // after the period for some reason.
1201        // let expr = planner.parse_expr("l[0].f1").unwrap();
1202        // assert_eq!(expr, expected);
1203    }
1204
1205    #[test]
1206    fn test_negative_expressions() {
1207        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1208
1209        let planner = Planner::new(schema.clone());
1210
1211        let expected = col("x")
1212            .gt(lit(-3_i64))
1213            .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1214
1215        let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1216
1217        assert_eq!(expr, expected);
1218
1219        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1220
1221        let batch = RecordBatch::try_new(
1222            schema,
1223            vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1224        )
1225        .unwrap();
1226        let predicates = physical_expr.evaluate(&batch).unwrap();
1227        assert_eq!(
1228            predicates.into_array(0).unwrap().as_ref(),
1229            &BooleanArray::from(vec![
1230                false, false, false, true, true, true, true, false, false, false
1231            ])
1232        );
1233    }
1234
1235    #[test]
1236    fn test_negative_array_expressions() {
1237        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1238
1239        let planner = Planner::new(schema);
1240
1241        let expected = Expr::Literal(
1242            ScalarValue::List(Arc::new(
1243                ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1244                    [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1245                )]),
1246            )),
1247            None,
1248        );
1249
1250        let expr = planner
1251            .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1252            .unwrap();
1253
1254        assert_eq!(expr, expected);
1255    }
1256
1257    #[test]
1258    fn test_sql_like() {
1259        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1260
1261        let planner = Planner::new(schema.clone());
1262
1263        let expected = col("s").like(lit("str-4"));
1264        // single quote
1265        let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1266        assert_eq!(expr, expected);
1267        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1268
1269        let batch = RecordBatch::try_new(
1270            schema,
1271            vec![Arc::new(StringArray::from_iter_values(
1272                (0..10).map(|v| format!("str-{}", v)),
1273            ))],
1274        )
1275        .unwrap();
1276        let predicates = physical_expr.evaluate(&batch).unwrap();
1277        assert_eq!(
1278            predicates.into_array(0).unwrap().as_ref(),
1279            &BooleanArray::from(vec![
1280                false, false, false, false, true, false, false, false, false, false
1281            ])
1282        );
1283    }
1284
1285    #[test]
1286    fn test_not_like() {
1287        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1288
1289        let planner = Planner::new(schema.clone());
1290
1291        let expected = col("s").not_like(lit("str-4"));
1292        // single quote
1293        let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1294        assert_eq!(expr, expected);
1295        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1296
1297        let batch = RecordBatch::try_new(
1298            schema,
1299            vec![Arc::new(StringArray::from_iter_values(
1300                (0..10).map(|v| format!("str-{}", v)),
1301            ))],
1302        )
1303        .unwrap();
1304        let predicates = physical_expr.evaluate(&batch).unwrap();
1305        assert_eq!(
1306            predicates.into_array(0).unwrap().as_ref(),
1307            &BooleanArray::from(vec![
1308                true, true, true, true, false, true, true, true, true, true
1309            ])
1310        );
1311    }
1312
1313    #[test]
1314    fn test_sql_is_in() {
1315        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1316
1317        let planner = Planner::new(schema.clone());
1318
1319        let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1320        // single quote
1321        let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1322        assert_eq!(expr, expected);
1323        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1324
1325        let batch = RecordBatch::try_new(
1326            schema,
1327            vec![Arc::new(StringArray::from_iter_values(
1328                (0..10).map(|v| format!("str-{}", v)),
1329            ))],
1330        )
1331        .unwrap();
1332        let predicates = physical_expr.evaluate(&batch).unwrap();
1333        assert_eq!(
1334            predicates.into_array(0).unwrap().as_ref(),
1335            &BooleanArray::from(vec![
1336                false, false, false, false, true, true, false, false, false, false
1337            ])
1338        );
1339    }
1340
1341    #[test]
1342    fn test_sql_is_null() {
1343        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1344
1345        let planner = Planner::new(schema.clone());
1346
1347        let expected = col("s").is_null();
1348        let expr = planner.parse_filter("s IS NULL").unwrap();
1349        assert_eq!(expr, expected);
1350        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1351
1352        let batch = RecordBatch::try_new(
1353            schema,
1354            vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1355                if v % 3 == 0 {
1356                    Some(format!("str-{}", v))
1357                } else {
1358                    None
1359                }
1360            })))],
1361        )
1362        .unwrap();
1363        let predicates = physical_expr.evaluate(&batch).unwrap();
1364        assert_eq!(
1365            predicates.into_array(0).unwrap().as_ref(),
1366            &BooleanArray::from(vec![
1367                false, true, true, false, true, true, false, true, true, false
1368            ])
1369        );
1370
1371        let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1372        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1373        let predicates = physical_expr.evaluate(&batch).unwrap();
1374        assert_eq!(
1375            predicates.into_array(0).unwrap().as_ref(),
1376            &BooleanArray::from(vec![
1377                true, false, false, true, false, false, true, false, false, true,
1378            ])
1379        );
1380    }
1381
1382    #[test]
1383    fn test_sql_invert() {
1384        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1385
1386        let planner = Planner::new(schema.clone());
1387
1388        let expr = planner.parse_filter("NOT s").unwrap();
1389        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1390
1391        let batch = RecordBatch::try_new(
1392            schema,
1393            vec![Arc::new(BooleanArray::from_iter(
1394                (0..10).map(|v| Some(v % 3 == 0)),
1395            ))],
1396        )
1397        .unwrap();
1398        let predicates = physical_expr.evaluate(&batch).unwrap();
1399        assert_eq!(
1400            predicates.into_array(0).unwrap().as_ref(),
1401            &BooleanArray::from(vec![
1402                false, true, true, false, true, true, false, true, true, false
1403            ])
1404        );
1405    }
1406
1407    #[test]
1408    fn test_sql_cast() {
1409        let cases = &[
1410            (
1411                "x = cast('2021-01-01 00:00:00' as timestamp)",
1412                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1413            ),
1414            (
1415                "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1416                ArrowDataType::Timestamp(TimeUnit::Second, None),
1417            ),
1418            (
1419                "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1420                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1421            ),
1422            (
1423                "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1424                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1425            ),
1426            ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1427            (
1428                "x = cast('1.238' as decimal(9,3))",
1429                ArrowDataType::Decimal128(9, 3),
1430            ),
1431            ("x = cast(1 as float)", ArrowDataType::Float32),
1432            ("x = cast(1 as double)", ArrowDataType::Float64),
1433            ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1434            ("x = cast(1 as smallint)", ArrowDataType::Int16),
1435            ("x = cast(1 as int)", ArrowDataType::Int32),
1436            ("x = cast(1 as integer)", ArrowDataType::Int32),
1437            ("x = cast(1 as bigint)", ArrowDataType::Int64),
1438            ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1439            ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1440            ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1441            ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1442            ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1443            ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1444            ("x = cast(1 as string)", ArrowDataType::Utf8),
1445        ];
1446
1447        for (sql, expected_data_type) in cases {
1448            let schema = Arc::new(Schema::new(vec![Field::new(
1449                "x",
1450                expected_data_type.clone(),
1451                true,
1452            )]));
1453            let planner = Planner::new(schema.clone());
1454            let expr = planner.parse_filter(sql).unwrap();
1455
1456            // Get the thing after 'cast(` but before ' as'.
1457            let expected_value_str = sql
1458                .split("cast(")
1459                .nth(1)
1460                .unwrap()
1461                .split(" as")
1462                .next()
1463                .unwrap();
1464            // Remove any quote marks
1465            let expected_value_str = expected_value_str.trim_matches('\'');
1466
1467            match expr {
1468                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1469                    Expr::Cast(Cast { expr, data_type }) => {
1470                        match expr.as_ref() {
1471                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1472                                assert_eq!(value_str, expected_value_str);
1473                            }
1474                            Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1475                                assert_eq!(*value, 1);
1476                            }
1477                            _ => panic!("Expected cast to be applied to literal"),
1478                        }
1479                        assert_eq!(data_type, expected_data_type);
1480                    }
1481                    _ => panic!("Expected right to be a cast"),
1482                },
1483                _ => panic!("Expected binary expression"),
1484            }
1485        }
1486    }
1487
1488    #[test]
1489    fn test_sql_literals() {
1490        let cases = &[
1491            (
1492                "x = timestamp '2021-01-01 00:00:00'",
1493                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1494            ),
1495            (
1496                "x = timestamp(0) '2021-01-01 00:00:00'",
1497                ArrowDataType::Timestamp(TimeUnit::Second, None),
1498            ),
1499            (
1500                "x = timestamp(9) '2021-01-01 00:00:00.123'",
1501                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1502            ),
1503            ("x = date '2021-01-01'", ArrowDataType::Date32),
1504            ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1505        ];
1506
1507        for (sql, expected_data_type) in cases {
1508            let schema = Arc::new(Schema::new(vec![Field::new(
1509                "x",
1510                expected_data_type.clone(),
1511                true,
1512            )]));
1513            let planner = Planner::new(schema.clone());
1514            let expr = planner.parse_filter(sql).unwrap();
1515
1516            let expected_value_str = sql.split('\'').nth(1).unwrap();
1517
1518            match expr {
1519                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1520                    Expr::Cast(Cast { expr, data_type }) => {
1521                        match expr.as_ref() {
1522                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1523                                assert_eq!(value_str, expected_value_str);
1524                            }
1525                            _ => panic!("Expected cast to be applied to literal"),
1526                        }
1527                        assert_eq!(data_type, expected_data_type);
1528                    }
1529                    _ => panic!("Expected right to be a cast"),
1530                },
1531                _ => panic!("Expected binary expression"),
1532            }
1533        }
1534    }
1535
1536    #[test]
1537    fn test_sql_array_literals() {
1538        let cases = [
1539            (
1540                "x = [1, 2, 3]",
1541                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1542            ),
1543            (
1544                "x = [1, 2, 3]",
1545                ArrowDataType::FixedSizeList(
1546                    Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1547                    3,
1548                ),
1549            ),
1550        ];
1551
1552        for (sql, expected_data_type) in cases {
1553            let schema = Arc::new(Schema::new(vec![Field::new(
1554                "x",
1555                expected_data_type.clone(),
1556                true,
1557            )]));
1558            let planner = Planner::new(schema.clone());
1559            let expr = planner.parse_filter(sql).unwrap();
1560            let expr = planner.optimize_expr(expr).unwrap();
1561
1562            match expr {
1563                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1564                    Expr::Literal(value, _) => {
1565                        assert_eq!(&value.data_type(), &expected_data_type);
1566                    }
1567                    _ => panic!("Expected right to be a literal"),
1568                },
1569                _ => panic!("Expected binary expression"),
1570            }
1571        }
1572    }
1573
1574    #[test]
1575    fn test_sql_between() {
1576        use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1577        use arrow_schema::{DataType, Field, Schema, TimeUnit};
1578        use std::sync::Arc;
1579
1580        let schema = Arc::new(Schema::new(vec![
1581            Field::new("x", DataType::Int32, false),
1582            Field::new("y", DataType::Float64, false),
1583            Field::new(
1584                "ts",
1585                DataType::Timestamp(TimeUnit::Microsecond, None),
1586                false,
1587            ),
1588        ]));
1589
1590        let planner = Planner::new(schema.clone());
1591
1592        // Test integer BETWEEN
1593        let expr = planner
1594            .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1595            .unwrap();
1596        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1597
1598        // Create timestamp array with values representing:
1599        // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
1600        let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
1601        let ts_array = TimestampMicrosecondArray::from_iter_values(
1602            (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
1603        );
1604
1605        let batch = RecordBatch::try_new(
1606            schema,
1607            vec![
1608                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1609                Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1610                Arc::new(ts_array),
1611            ],
1612        )
1613        .unwrap();
1614
1615        let predicates = physical_expr.evaluate(&batch).unwrap();
1616        assert_eq!(
1617            predicates.into_array(0).unwrap().as_ref(),
1618            &BooleanArray::from(vec![
1619                false, false, false, true, true, true, true, true, false, false
1620            ])
1621        );
1622
1623        // Test NOT BETWEEN
1624        let expr = planner
1625            .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1626            .unwrap();
1627        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1628
1629        let predicates = physical_expr.evaluate(&batch).unwrap();
1630        assert_eq!(
1631            predicates.into_array(0).unwrap().as_ref(),
1632            &BooleanArray::from(vec![
1633                true, true, true, false, false, false, false, false, true, true
1634            ])
1635        );
1636
1637        // Test floating point BETWEEN
1638        let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1639        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1640
1641        let predicates = physical_expr.evaluate(&batch).unwrap();
1642        assert_eq!(
1643            predicates.into_array(0).unwrap().as_ref(),
1644            &BooleanArray::from(vec![
1645                false, false, false, true, true, true, true, false, false, false
1646            ])
1647        );
1648
1649        // Test timestamp BETWEEN
1650        let expr = planner
1651            .parse_filter(
1652                "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1653            )
1654            .unwrap();
1655        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1656
1657        let predicates = physical_expr.evaluate(&batch).unwrap();
1658        assert_eq!(
1659            predicates.into_array(0).unwrap().as_ref(),
1660            &BooleanArray::from(vec![
1661                false, false, false, true, true, true, true, true, false, false
1662            ])
1663        );
1664    }
1665
1666    #[test]
1667    fn test_sql_comparison() {
1668        // Create a batch with all data types
1669        let batch: Vec<(&str, ArrayRef)> = vec![
1670            (
1671                "timestamp_s",
1672                Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1673            ),
1674            (
1675                "timestamp_ms",
1676                Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1677            ),
1678            (
1679                "timestamp_us",
1680                Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1681            ),
1682            (
1683                "timestamp_ns",
1684                Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1685            ),
1686        ];
1687        let batch = RecordBatch::try_from_iter(batch).unwrap();
1688
1689        let planner = Planner::new(batch.schema());
1690
1691        // Each expression is meant to select the final 5 rows
1692        let expressions = &[
1693            "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1694            "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1695            "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1696            "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1697        ];
1698
1699        let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1700            std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1701        ));
1702        for expression in expressions {
1703            // convert to physical expression
1704            let logical_expr = planner.parse_filter(expression).unwrap();
1705            let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1706            let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1707
1708            // Evaluate and assert they have correct results
1709            let result = physical_expr.evaluate(&batch).unwrap();
1710            let result = result.into_array(batch.num_rows()).unwrap();
1711            assert_eq!(&expected, &result, "unexpected result for {}", expression);
1712        }
1713    }
1714
1715    #[test]
1716    fn test_columns_in_expr() {
1717        let expr = col("s0").gt(lit("value")).and(
1718            col("st")
1719                .field("st")
1720                .field("s2")
1721                .eq(lit("value"))
1722                .or(col("st")
1723                    .field("s1")
1724                    .in_list(vec![lit("value 1"), lit("value 2")], false)),
1725        );
1726
1727        let columns = Planner::column_names_in_expr(&expr);
1728        assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1729    }
1730
1731    #[test]
1732    fn test_parse_binary_expr() {
1733        let bin_str = "x'616263'";
1734
1735        let schema = Arc::new(Schema::new(vec![Field::new(
1736            "binary",
1737            DataType::Binary,
1738            true,
1739        )]));
1740        let planner = Planner::new(schema);
1741        let expr = planner.parse_expr(bin_str).unwrap();
1742        assert_eq!(
1743            expr,
1744            Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1745        );
1746    }
1747
1748    #[test]
1749    fn test_lance_context_provider_expr_planners() {
1750        let ctx_provider = LanceContextProvider::default();
1751        assert!(!ctx_provider.get_expr_planners().is_empty());
1752    }
1753
1754    #[test]
1755    fn test_regexp_match_and_non_empty_captions() {
1756        // Repro for a bug where regexp_match inside an AND chain wasn't coerced to boolean,
1757        // causing planning/evaluation failures. This should evaluate successfully.
1758        let schema = Arc::new(Schema::new(vec![
1759            Field::new("keywords", DataType::Utf8, true),
1760            Field::new("natural_caption", DataType::Utf8, true),
1761            Field::new("poetic_caption", DataType::Utf8, true),
1762        ]));
1763
1764        let planner = Planner::new(schema.clone());
1765
1766        let expr = planner
1767            .parse_filter(
1768                "regexp_match(keywords, 'Liberty|revolution') AND \
1769                 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1770                  poetic_caption IS NOT NULL AND poetic_caption <> '')",
1771            )
1772            .unwrap();
1773
1774        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1775
1776        let batch = RecordBatch::try_new(
1777            schema,
1778            vec![
1779                Arc::new(StringArray::from(vec![
1780                    Some("Liberty for all"),
1781                    Some("peace"),
1782                    Some("revolution now"),
1783                    Some("Liberty"),
1784                    Some("revolutionary"),
1785                    Some("none"),
1786                ])) as ArrayRef,
1787                Arc::new(StringArray::from(vec![
1788                    Some("a"),
1789                    Some("b"),
1790                    None,
1791                    Some(""),
1792                    Some("c"),
1793                    Some("d"),
1794                ])) as ArrayRef,
1795                Arc::new(StringArray::from(vec![
1796                    Some("x"),
1797                    Some(""),
1798                    Some("y"),
1799                    Some("z"),
1800                    None,
1801                    Some("w"),
1802                ])) as ArrayRef,
1803            ],
1804        )
1805        .unwrap();
1806
1807        let result = physical_expr.evaluate(&batch).unwrap();
1808        assert_eq!(
1809            result.into_array(0).unwrap().as_ref(),
1810            &BooleanArray::from(vec![true, false, false, false, false, false])
1811        );
1812    }
1813
1814    #[test]
1815    fn test_regexp_match_infer_error_without_boolean_coercion() {
1816        // With the fix applied, using parse_filter should coerce regexp_match to boolean
1817        // even when nested in a larger AND expression, so this should plan successfully.
1818        let schema = Arc::new(Schema::new(vec![
1819            Field::new("keywords", DataType::Utf8, true),
1820            Field::new("natural_caption", DataType::Utf8, true),
1821            Field::new("poetic_caption", DataType::Utf8, true),
1822        ]));
1823
1824        let planner = Planner::new(schema);
1825
1826        let expr = planner
1827            .parse_filter(
1828                "regexp_match(keywords, 'Liberty|revolution') AND \
1829                 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1830                  poetic_caption IS NOT NULL AND poetic_caption <> '')",
1831            )
1832            .unwrap();
1833
1834        // Should not panic
1835        let _physical = planner.create_physical_expr(&expr).unwrap();
1836    }
1837
1838    #[test]
1839    fn test_jsonb_literals() {
1840        let schema = Arc::new(Schema::new(vec![Field::new(
1841            "j",
1842            DataType::LargeBinary,
1843            true,
1844        )]));
1845        let planner = Planner::new(schema);
1846
1847        let cases = [
1848            ("jsonb '{\"key\": \"value\"}'", r#"{"key":"value"}"#),
1849            ("cast('{\"a\": 1}' as jsonb)", r#"{"a":1}"#),
1850            ("'{\"a\": 1}'::jsonb", r#"{"a":1}"#),
1851        ];
1852        for (sql, expected) in cases {
1853            let ast = parse_sql_expr(sql).unwrap();
1854            let expr = planner.parse_sql_expr(&ast).unwrap();
1855            match expr {
1856                Expr::Literal(ScalarValue::LargeBinary(Some(bytes)), _) => {
1857                    assert_eq!(
1858                        lance_arrow::json::decode_json(&bytes),
1859                        expected,
1860                        "failed for: {sql}"
1861                    );
1862                }
1863                other => panic!("Expected LargeBinary literal for '{sql}', got: {other:?}"),
1864            }
1865        }
1866    }
1867
1868    #[test]
1869    fn test_jsonb_literal_errors() {
1870        let schema = Arc::new(Schema::new(vec![Field::new(
1871            "j",
1872            DataType::LargeBinary,
1873            true,
1874        )]));
1875        let planner = Planner::new(schema);
1876
1877        // Invalid JSON content
1878        let ast = parse_sql_expr("jsonb 'not valid json'").unwrap();
1879        let err = planner.parse_sql_expr(&ast).unwrap_err();
1880        assert!(
1881            err.to_string().contains("Failed to encode JSONB"),
1882            "expected JSONB encoding error, got: {err}"
1883        );
1884
1885        // CAST with non-literal expression
1886        let ast = parse_sql_expr("cast(j as jsonb)").unwrap();
1887        let err = planner.parse_sql_expr(&ast).unwrap_err();
1888        assert!(
1889            err.to_string()
1890                .contains("CAST to JSONB only supports string literals"),
1891            "got: {err}"
1892        );
1893    }
1894}