lance_datafusion/
planner.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Exec plan planner
5
6use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::{SessionContext, SessionState};
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29    AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30    Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{
34    ContextProvider, NullOrdering, ParserOptions, PlannerContext, SqlToRel,
35};
36use datafusion::sql::sqlparser::ast::{
37    AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
38    Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
39    ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
40};
41use datafusion::{
42    common::Column,
43    logical_expr::{col, Between, BinaryExpr, Like, Operator},
44    physical_expr::execution_props::ExecutionProps,
45    physical_plan::PhysicalExpr,
46    prelude::Expr,
47    scalar::ScalarValue,
48};
49use datafusion_functions::core::getfield::GetFieldFunc;
50use lance_arrow::cast::cast_with_options;
51use lance_core::datatypes::Schema;
52use lance_core::error::LanceOptionExt;
53use snafu::location;
54
55use lance_core::{Error, Result};
56
57#[derive(Debug, Clone, Eq, PartialEq, Hash)]
58struct CastListF16Udf {
59    signature: Signature,
60}
61
62impl CastListF16Udf {
63    pub fn new() -> Self {
64        Self {
65            signature: Signature::any(1, Volatility::Immutable),
66        }
67    }
68}
69
70impl ScalarUDFImpl for CastListF16Udf {
71    fn as_any(&self) -> &dyn std::any::Any {
72        self
73    }
74
75    fn name(&self) -> &str {
76        "_cast_list_f16"
77    }
78
79    fn signature(&self) -> &Signature {
80        &self.signature
81    }
82
83    fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
84        let input = &arg_types[0];
85        match input {
86            ArrowDataType::FixedSizeList(field, size) => {
87                if field.data_type() != &ArrowDataType::Float32
88                    && field.data_type() != &ArrowDataType::Float16
89                {
90                    return Err(datafusion::error::DataFusionError::Execution(
91                        "cast_list_f16 only supports list of float32 or float16".to_string(),
92                    ));
93                }
94                Ok(ArrowDataType::FixedSizeList(
95                    Arc::new(Field::new(
96                        field.name(),
97                        ArrowDataType::Float16,
98                        field.is_nullable(),
99                    )),
100                    *size,
101                ))
102            }
103            ArrowDataType::List(field) => {
104                if field.data_type() != &ArrowDataType::Float32
105                    && field.data_type() != &ArrowDataType::Float16
106                {
107                    return Err(datafusion::error::DataFusionError::Execution(
108                        "cast_list_f16 only supports list of float32 or float16".to_string(),
109                    ));
110                }
111                Ok(ArrowDataType::List(Arc::new(Field::new(
112                    field.name(),
113                    ArrowDataType::Float16,
114                    field.is_nullable(),
115                ))))
116            }
117            _ => Err(datafusion::error::DataFusionError::Execution(
118                "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
119            )),
120        }
121    }
122
123    fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
124        let ColumnarValue::Array(arr) = &func_args.args[0] else {
125            return Err(datafusion::error::DataFusionError::Execution(
126                "cast_list_f16 only supports array arguments".to_string(),
127            ));
128        };
129
130        let to_type = match arr.data_type() {
131            ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
132                Arc::new(Field::new(
133                    field.name(),
134                    ArrowDataType::Float16,
135                    field.is_nullable(),
136                )),
137                *size,
138            ),
139            ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
140                field.name(),
141                ArrowDataType::Float16,
142                field.is_nullable(),
143            ))),
144            _ => {
145                return Err(datafusion::error::DataFusionError::Execution(
146                    "cast_list_f16 only supports array arguments".to_string(),
147                ));
148            }
149        };
150
151        let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
152        Ok(ColumnarValue::Array(res))
153    }
154}
155
156// Adapter that instructs datafusion how lance expects expressions to be interpreted
157struct LanceContextProvider {
158    options: datafusion::config::ConfigOptions,
159    state: SessionState,
160    expr_planners: Vec<Arc<dyn ExprPlanner>>,
161}
162
163impl Default for LanceContextProvider {
164    fn default() -> Self {
165        let config = SessionConfig::new();
166        let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
167
168        let ctx = SessionContext::new_with_config_rt(config.clone(), runtime.clone());
169        crate::udf::register_functions(&ctx);
170
171        let state = ctx.state();
172
173        // SessionState does not expose expr_planners, so we need to get them separately
174        let mut state_builder = SessionStateBuilder::new()
175            .with_config(config)
176            .with_runtime_env(runtime)
177            .with_default_features();
178
179        // unwrap safe because with_default_features sets expr_planners
180        let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
181
182        Self {
183            options: ConfigOptions::default(),
184            state,
185            expr_planners,
186        }
187    }
188}
189
190impl ContextProvider for LanceContextProvider {
191    fn get_table_source(
192        &self,
193        name: datafusion::sql::TableReference,
194    ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
195        Err(datafusion::error::DataFusionError::NotImplemented(format!(
196            "Attempt to reference inner table {} not supported",
197            name
198        )))
199    }
200
201    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
202        self.state.aggregate_functions().get(name).cloned()
203    }
204
205    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
206        self.state.window_functions().get(name).cloned()
207    }
208
209    fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
210        match f {
211            // TODO: cast should go thru CAST syntax instead of UDF
212            // Going thru UDF makes it hard for the optimizer to find no-ops
213            "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
214            _ => self.state.scalar_functions().get(f).cloned(),
215        }
216    }
217
218    fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
219        // Variables (things like @@LANGUAGE) not supported
220        None
221    }
222
223    fn options(&self) -> &datafusion::config::ConfigOptions {
224        &self.options
225    }
226
227    fn udf_names(&self) -> Vec<String> {
228        self.state.scalar_functions().keys().cloned().collect()
229    }
230
231    fn udaf_names(&self) -> Vec<String> {
232        self.state.aggregate_functions().keys().cloned().collect()
233    }
234
235    fn udwf_names(&self) -> Vec<String> {
236        self.state.window_functions().keys().cloned().collect()
237    }
238
239    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
240        &self.expr_planners
241    }
242}
243
244pub struct Planner {
245    schema: SchemaRef,
246    context_provider: LanceContextProvider,
247    enable_relations: bool,
248}
249
250impl Planner {
251    pub fn new(schema: SchemaRef) -> Self {
252        Self {
253            schema,
254            context_provider: LanceContextProvider::default(),
255            enable_relations: false,
256        }
257    }
258
259    /// If passed with `true`, then the first identifier in column reference
260    /// is parsed as the relation. For example, `table.field.inner` will be
261    /// read as the nested field `field.inner` (`inner` on struct field `field`)
262    /// on the `table` relation. If `false` (the default), then no relations
263    /// are used and all identifiers are assumed to be a nested column path.
264    pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
265        self.enable_relations = enable_relations;
266        self
267    }
268
269    fn column(&self, idents: &[Ident]) -> Expr {
270        fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
271            for ident in idents {
272                *expr = Expr::ScalarFunction(ScalarFunction {
273                    args: vec![
274                        std::mem::take(expr),
275                        Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
276                    ],
277                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
278                });
279            }
280        }
281
282        if self.enable_relations && idents.len() > 1 {
283            // Create qualified column reference (relation.column)
284            let relation = &idents[0].value;
285            let column_name = &idents[1].value;
286            let column = Expr::Column(Column::new(Some(relation.clone()), column_name.clone()));
287            let mut result = column;
288            handle_remaining_idents(&mut result, &idents[2..]);
289            result
290        } else {
291            // Default behavior - treat as struct field access
292            let mut column = col(&idents[0].value);
293            handle_remaining_idents(&mut column, &idents[1..]);
294            column
295        }
296    }
297
298    fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
299        Ok(match op {
300            BinaryOperator::Plus => Operator::Plus,
301            BinaryOperator::Minus => Operator::Minus,
302            BinaryOperator::Multiply => Operator::Multiply,
303            BinaryOperator::Divide => Operator::Divide,
304            BinaryOperator::Modulo => Operator::Modulo,
305            BinaryOperator::StringConcat => Operator::StringConcat,
306            BinaryOperator::Gt => Operator::Gt,
307            BinaryOperator::Lt => Operator::Lt,
308            BinaryOperator::GtEq => Operator::GtEq,
309            BinaryOperator::LtEq => Operator::LtEq,
310            BinaryOperator::Eq => Operator::Eq,
311            BinaryOperator::NotEq => Operator::NotEq,
312            BinaryOperator::And => Operator::And,
313            BinaryOperator::Or => Operator::Or,
314            _ => {
315                return Err(Error::invalid_input(
316                    format!("Operator {op} is not supported"),
317                    location!(),
318                ));
319            }
320        })
321    }
322
323    fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
324        Ok(Expr::BinaryExpr(BinaryExpr::new(
325            Box::new(self.parse_sql_expr(left)?),
326            self.binary_op(op)?,
327            Box::new(self.parse_sql_expr(right)?),
328        )))
329    }
330
331    fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
332        Ok(match op {
333            UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
334                Expr::Not(Box::new(self.parse_sql_expr(expr)?))
335            }
336
337            UnaryOperator::Minus => {
338                use datafusion::logical_expr::lit;
339                match expr {
340                    SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
341                        Ok(n) => lit(-n),
342                        Err(_) => lit(-n
343                            .parse::<f64>()
344                            .map_err(|_e| {
345                                Error::invalid_input(
346                                    format!("negative operator can be only applied to integer and float operands, got: {n}"),
347                                    location!(),
348                                )
349                            })?),
350                    },
351                    _ => {
352                        Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
353                    }
354                }
355            }
356
357            _ => {
358                return Err(Error::invalid_input(
359                    format!("Unary operator '{:?}' is not supported", op),
360                    location!(),
361                ));
362            }
363        })
364    }
365
366    // See datafusion `sqlToRel::parse_sql_number()`
367    fn number(&self, value: &str, negative: bool) -> Result<Expr> {
368        use datafusion::logical_expr::lit;
369        let value: Cow<str> = if negative {
370            Cow::Owned(format!("-{}", value))
371        } else {
372            Cow::Borrowed(value)
373        };
374        if let Ok(n) = value.parse::<i64>() {
375            Ok(lit(n))
376        } else {
377            value.parse::<f64>().map(lit).map_err(|_| {
378                Error::invalid_input(
379                    format!("'{value}' is not supported number value."),
380                    location!(),
381                )
382            })
383        }
384    }
385
386    fn value(&self, value: &Value) -> Result<Expr> {
387        Ok(match value {
388            Value::Number(v, _) => self.number(v.as_str(), false)?,
389            Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
390            Value::HexStringLiteral(hsl) => {
391                Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
392            }
393            Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
394            Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
395            Value::Null => Expr::Literal(ScalarValue::Null, None),
396            _ => todo!(),
397        })
398    }
399
400    fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
401        match func_args {
402            FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
403            _ => Err(Error::invalid_input(
404                format!("Unsupported function args: {:?}", func_args),
405                location!(),
406            )),
407        }
408    }
409
410    // We now use datafusion to parse functions.  This allows us to use datafusion's
411    // entire collection of functions (previously we had just hard-coded support for two functions).
412    //
413    // Unfortunately, one of those two functions was is_valid and the reason we needed it was because
414    // this is a function that comes from duckdb.  Datafusion does not consider is_valid to be a function
415    // but rather an AST node (Expr::IsNotNull) and so we need to handle this case specially.
416    fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
417        match &func.args {
418            FunctionArguments::List(args) => {
419                if func.name.0.len() != 1 {
420                    return Err(Error::invalid_input(
421                        format!("Function name must have 1 part, got: {:?}", func.name.0),
422                        location!(),
423                    ));
424                }
425                Ok(Expr::IsNotNull(Box::new(
426                    self.parse_function_args(&args.args[0])?,
427                )))
428            }
429            _ => Err(Error::invalid_input(
430                format!("Unsupported function args: {:?}", &func.args),
431                location!(),
432            )),
433        }
434    }
435
436    fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
437        if let SQLExpr::Function(function) = &function {
438            if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
439                if &name.value == "is_valid" {
440                    return self.legacy_parse_function(function);
441                }
442            }
443        }
444        let sql_to_rel = SqlToRel::new_with_options(
445            &self.context_provider,
446            ParserOptions {
447                parse_float_as_decimal: false,
448                enable_ident_normalization: false,
449                support_varchar_with_length: false,
450                enable_options_value_normalization: false,
451                collect_spans: false,
452                map_string_types_to_utf8view: false,
453                default_null_ordering: NullOrdering::NullsMax,
454            },
455        );
456
457        let mut planner_context = PlannerContext::default();
458        let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
459        sql_to_rel
460            .sql_to_expr(function, &schema, &mut planner_context)
461            .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
462    }
463
464    fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
465        const SUPPORTED_TYPES: [&str; 13] = [
466            "int [unsigned]",
467            "tinyint [unsigned]",
468            "smallint [unsigned]",
469            "bigint [unsigned]",
470            "float",
471            "double",
472            "string",
473            "binary",
474            "date",
475            "timestamp(precision)",
476            "datetime(precision)",
477            "decimal(precision,scale)",
478            "boolean",
479        ];
480        match data_type {
481            SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
482            SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
483            SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
484            SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
485            SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
486            SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
487            SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
488            SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
489            SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
490            SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
491            SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
492            SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
493                Ok(ArrowDataType::UInt32)
494            }
495            SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
496            SQLDataType::Date => Ok(ArrowDataType::Date32),
497            SQLDataType::Timestamp(resolution, tz) => {
498                match tz {
499                    TimezoneInfo::None => {}
500                    _ => {
501                        return Err(Error::invalid_input(
502                            "Timezone not supported in timestamp".to_string(),
503                            location!(),
504                        ));
505                    }
506                };
507                let time_unit = match resolution {
508                    // Default to microsecond to match PyArrow
509                    None => TimeUnit::Microsecond,
510                    Some(0) => TimeUnit::Second,
511                    Some(3) => TimeUnit::Millisecond,
512                    Some(6) => TimeUnit::Microsecond,
513                    Some(9) => TimeUnit::Nanosecond,
514                    _ => {
515                        return Err(Error::invalid_input(
516                            format!("Unsupported datetime resolution: {:?}", resolution),
517                            location!(),
518                        ));
519                    }
520                };
521                Ok(ArrowDataType::Timestamp(time_unit, None))
522            }
523            SQLDataType::Datetime(resolution) => {
524                let time_unit = match resolution {
525                    None => TimeUnit::Microsecond,
526                    Some(0) => TimeUnit::Second,
527                    Some(3) => TimeUnit::Millisecond,
528                    Some(6) => TimeUnit::Microsecond,
529                    Some(9) => TimeUnit::Nanosecond,
530                    _ => {
531                        return Err(Error::invalid_input(
532                            format!("Unsupported datetime resolution: {:?}", resolution),
533                            location!(),
534                        ));
535                    }
536                };
537                Ok(ArrowDataType::Timestamp(time_unit, None))
538            }
539            SQLDataType::Decimal(number_info) => match number_info {
540                ExactNumberInfo::PrecisionAndScale(precision, scale) => {
541                    Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
542                }
543                _ => Err(Error::invalid_input(
544                    format!(
545                        "Must provide precision and scale for decimal: {:?}",
546                        number_info
547                    ),
548                    location!(),
549                )),
550            },
551            _ => Err(Error::invalid_input(
552                format!(
553                    "Unsupported data type: {:?}. Supported types: {:?}",
554                    data_type, SUPPORTED_TYPES
555                ),
556                location!(),
557            )),
558        }
559    }
560
561    fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
562        let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
563        for planner in self.context_provider.get_expr_planners() {
564            match planner.plan_field_access(field_access_expr, &df_schema)? {
565                PlannerResult::Planned(expr) => return Ok(expr),
566                PlannerResult::Original(expr) => {
567                    field_access_expr = expr;
568                }
569            }
570        }
571        Err(Error::invalid_input(
572            "Field access could not be planned",
573            location!(),
574        ))
575    }
576
577    fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
578        match expr {
579            SQLExpr::Identifier(id) => {
580                // Users can pass string literals wrapped in `"`.
581                // (Normally SQL only allows single quotes.)
582                if id.quote_style == Some('"') {
583                    Ok(Expr::Literal(
584                        ScalarValue::Utf8(Some(id.value.clone())),
585                        None,
586                    ))
587                // Users can wrap identifiers with ` to reference non-standard
588                // names, such as uppercase or spaces.
589                } else if id.quote_style == Some('`') {
590                    Ok(Expr::Column(Column::from_name(id.value.clone())))
591                } else {
592                    Ok(self.column(vec![id.clone()].as_slice()))
593                }
594            }
595            SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
596            SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
597            SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
598            SQLExpr::Value(value) => self.value(&value.value),
599            SQLExpr::Array(SQLArray { elem, .. }) => {
600                let mut values = vec![];
601
602                let array_literal_error = |pos: usize, value: &_| {
603                    Err(Error::invalid_input(
604                        format!(
605                            "Expected a literal value in array, instead got {} at position {}",
606                            value, pos
607                        ),
608                        location!(),
609                    ))
610                };
611
612                for (pos, expr) in elem.iter().enumerate() {
613                    match expr {
614                        SQLExpr::Value(value) => {
615                            if let Expr::Literal(value, _) = self.value(&value.value)? {
616                                values.push(value);
617                            } else {
618                                return array_literal_error(pos, expr);
619                            }
620                        }
621                        SQLExpr::UnaryOp {
622                            op: UnaryOperator::Minus,
623                            expr,
624                        } => {
625                            if let SQLExpr::Value(ValueWithSpan {
626                                value: Value::Number(number, _),
627                                ..
628                            }) = expr.as_ref()
629                            {
630                                if let Expr::Literal(value, _) = self.number(number, true)? {
631                                    values.push(value);
632                                } else {
633                                    return array_literal_error(pos, expr);
634                                }
635                            } else {
636                                return array_literal_error(pos, expr);
637                            }
638                        }
639                        _ => {
640                            return array_literal_error(pos, expr);
641                        }
642                    }
643                }
644
645                let field = if !values.is_empty() {
646                    let data_type = values[0].data_type();
647
648                    for value in &mut values {
649                        if value.data_type() != data_type {
650                            *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
651                                format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
652                                location!()
653                            ))?;
654                        }
655                    }
656                    Field::new("item", data_type, true)
657                } else {
658                    Field::new("item", ArrowDataType::Null, true)
659                };
660
661                let values = values
662                    .into_iter()
663                    .map(|v| v.to_array().map_err(Error::from))
664                    .collect::<Result<Vec<_>>>()?;
665                let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
666                let values = concat(&array_refs)?;
667                let values = ListArray::try_new(
668                    field.into(),
669                    OffsetBuffer::from_lengths([values.len()]),
670                    values,
671                    None,
672                )?;
673
674                Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
675            }
676            // For example, DATE '2020-01-01'
677            SQLExpr::TypedString { data_type, value } => {
678                let value = value.clone().into_string().expect_ok()?;
679                Ok(Expr::Cast(datafusion::logical_expr::Cast {
680                    expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
681                    data_type: self.parse_type(data_type)?,
682                }))
683            }
684            SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
685            SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
686            SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
687            SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
688            SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
689            SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
690            SQLExpr::InList {
691                expr,
692                list,
693                negated,
694            } => {
695                let value_expr = self.parse_sql_expr(expr)?;
696                let list_exprs = list
697                    .iter()
698                    .map(|e| self.parse_sql_expr(e))
699                    .collect::<Result<Vec<_>>>()?;
700                Ok(value_expr.in_list(list_exprs, *negated))
701            }
702            SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
703            SQLExpr::Function(_) => self.parse_function(expr.clone()),
704            SQLExpr::ILike {
705                negated,
706                expr,
707                pattern,
708                escape_char,
709                any: _,
710            } => Ok(Expr::Like(Like::new(
711                *negated,
712                Box::new(self.parse_sql_expr(expr)?),
713                Box::new(self.parse_sql_expr(pattern)?),
714                match escape_char {
715                    Some(Value::SingleQuotedString(char)) => char.chars().next(),
716                    Some(value) => return Err(Error::invalid_input(
717                        format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
718                        location!()
719                    )),
720                    None => None,
721                },
722                true,
723            ))),
724            SQLExpr::Like {
725                negated,
726                expr,
727                pattern,
728                escape_char,
729                any: _,
730            } => Ok(Expr::Like(Like::new(
731                *negated,
732                Box::new(self.parse_sql_expr(expr)?),
733                Box::new(self.parse_sql_expr(pattern)?),
734                match escape_char {
735                    Some(Value::SingleQuotedString(char)) => char.chars().next(),
736                    Some(value) => return Err(Error::invalid_input(
737                        format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
738                        location!()
739                    )),
740                    None => None,
741                },
742                false,
743            ))),
744            SQLExpr::Cast {
745                expr,
746                data_type,
747                kind,
748                ..
749            } => match kind {
750                datafusion::sql::sqlparser::ast::CastKind::TryCast
751                | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
752                    Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
753                        expr: Box::new(self.parse_sql_expr(expr)?),
754                        data_type: self.parse_type(data_type)?,
755                    }))
756                }
757                _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
758                    expr: Box::new(self.parse_sql_expr(expr)?),
759                    data_type: self.parse_type(data_type)?,
760                })),
761            },
762            SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
763                "JSON access is not supported",
764                location!(),
765            )),
766            SQLExpr::CompoundFieldAccess { root, access_chain } => {
767                let mut expr = self.parse_sql_expr(root)?;
768
769                for access in access_chain {
770                    let field_access = match access {
771                        // x.y or x['y']
772                        AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
773                        | AccessExpr::Subscript(Subscript::Index {
774                            index:
775                                SQLExpr::Value(ValueWithSpan {
776                                    value:
777                                        Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
778                                    ..
779                                }),
780                        }) => GetFieldAccess::NamedStructField {
781                            name: ScalarValue::from(s.as_str()),
782                        },
783                        AccessExpr::Subscript(Subscript::Index { index }) => {
784                            let key = Box::new(self.parse_sql_expr(index)?);
785                            GetFieldAccess::ListIndex { key }
786                        }
787                        AccessExpr::Subscript(Subscript::Slice { .. }) => {
788                            return Err(Error::invalid_input(
789                                "Slice subscript is not supported",
790                                location!(),
791                            ));
792                        }
793                        _ => {
794                            // Handle other cases like JSON access
795                            // Note: JSON access is not supported in lance
796                            return Err(Error::invalid_input(
797                                "Only dot notation or index access is supported for field access",
798                                location!(),
799                            ));
800                        }
801                    };
802
803                    let field_access_expr = RawFieldAccessExpr { expr, field_access };
804                    expr = self.plan_field_access(field_access_expr)?;
805                }
806
807                Ok(expr)
808            }
809            SQLExpr::Between {
810                expr,
811                negated,
812                low,
813                high,
814            } => {
815                // Parse the main expression and bounds
816                let expr = self.parse_sql_expr(expr)?;
817                let low = self.parse_sql_expr(low)?;
818                let high = self.parse_sql_expr(high)?;
819
820                let between = Expr::Between(Between::new(
821                    Box::new(expr),
822                    *negated,
823                    Box::new(low),
824                    Box::new(high),
825                ));
826                Ok(between)
827            }
828            _ => Err(Error::invalid_input(
829                format!("Expression '{expr}' is not supported SQL in lance"),
830                location!(),
831            )),
832        }
833    }
834
835    /// Create Logical [Expr] from a SQL filter clause.
836    ///
837    /// Note: the returned expression must be passed through [optimize_expr()]
838    /// before being passed to [create_physical_expr()].
839    pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
840        // Allow sqlparser to parse filter as part of ONE SQL statement.
841        let ast_expr = parse_sql_filter(filter)?;
842        let expr = self.parse_sql_expr(&ast_expr)?;
843        let schema = Schema::try_from(self.schema.as_ref())?;
844        let resolved = resolve_expr(&expr, &schema).map_err(|e| {
845            Error::invalid_input(
846                format!("Error resolving filter expression {filter}: {e}"),
847                location!(),
848            )
849        })?;
850
851        Ok(coerce_filter_type_to_boolean(resolved))
852    }
853
854    /// Create Logical [Expr] from a SQL expression.
855    ///
856    /// Note: the returned expression must be passed through [optimize_filter()]
857    /// before being passed to [create_physical_expr()].
858    pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
859        if self.schema.field_with_name(expr).is_ok() {
860            return Ok(col(expr));
861        }
862
863        let ast_expr = parse_sql_expr(expr)?;
864        let expr = self.parse_sql_expr(&ast_expr)?;
865        let schema = Schema::try_from(self.schema.as_ref())?;
866        let resolved = resolve_expr(&expr, &schema)?;
867        Ok(resolved)
868    }
869
870    /// Try to decode bytes from hex literal string.
871    ///
872    /// Copied from datafusion because this is not public.
873    ///
874    /// TODO: use SqlToRel from Datafusion directly?
875    fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
876        let hex_bytes = s.as_bytes();
877        let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
878
879        let start_idx = hex_bytes.len() % 2;
880        if start_idx > 0 {
881            // The first byte is formed of only one char.
882            decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
883        }
884
885        for i in (start_idx..hex_bytes.len()).step_by(2) {
886            let high = Self::try_decode_hex_char(hex_bytes[i])?;
887            let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
888            decoded_bytes.push((high << 4) | low);
889        }
890
891        Some(decoded_bytes)
892    }
893
894    /// Try to decode a byte from a hex char.
895    ///
896    /// None will be returned if the input char is hex-invalid.
897    const fn try_decode_hex_char(c: u8) -> Option<u8> {
898        match c {
899            b'A'..=b'F' => Some(c - b'A' + 10),
900            b'a'..=b'f' => Some(c - b'a' + 10),
901            b'0'..=b'9' => Some(c - b'0'),
902            _ => None,
903        }
904    }
905
906    /// Optimize the filter expression and coerce data types.
907    pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
908        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
909
910        // DataFusion needs the simplify and coerce passes to be applied before
911        // expressions can be handled by the physical planner.
912        let props = ExecutionProps::default();
913        let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
914        let simplifier =
915            datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
916
917        let expr = simplifier.simplify(expr)?;
918        let expr = simplifier.coerce(expr, &df_schema)?;
919
920        Ok(expr)
921    }
922
923    /// Create the [`PhysicalExpr`] from a logical [`Expr`]
924    pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
925        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
926        Ok(datafusion::physical_expr::create_physical_expr(
927            expr,
928            df_schema.as_ref(),
929            &Default::default(),
930        )?)
931    }
932
933    /// Collect the columns in the expression.
934    ///
935    /// The columns are returned in sorted order.
936    ///
937    /// If the expr refers to nested columns these will be returned
938    /// as dotted paths (x.y.z)
939    pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
940        let mut visitor = ColumnCapturingVisitor {
941            current_path: VecDeque::new(),
942            columns: BTreeSet::new(),
943        };
944        expr.visit(&mut visitor).unwrap();
945        visitor.columns.into_iter().collect()
946    }
947}
948
949struct ColumnCapturingVisitor {
950    // Current column path. If this is empty, we are not in a column expression.
951    current_path: VecDeque<String>,
952    columns: BTreeSet<String>,
953}
954
955impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
956    type Node = Expr;
957
958    fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
959        match node {
960            Expr::Column(Column { name, .. }) => {
961                // Build the field path from the column name and any nested fields
962                // The nested field names from get_field already come as literal strings,
963                // so we just need to concatenate them properly
964                let mut path = name.clone();
965                for part in self.current_path.drain(..) {
966                    path.push('.');
967                    // Check if the part needs quoting (contains dots)
968                    if part.contains('.') || part.contains('`') {
969                        // Quote the field name with backticks and escape any existing backticks
970                        let escaped = part.replace('`', "``");
971                        path.push('`');
972                        path.push_str(&escaped);
973                        path.push('`');
974                    } else {
975                        path.push_str(&part);
976                    }
977                }
978                self.columns.insert(path);
979                self.current_path.clear();
980            }
981            Expr::ScalarFunction(udf) => {
982                if udf.name() == GetFieldFunc::default().name() {
983                    if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
984                        self.current_path.push_front(name.to_string())
985                    } else {
986                        self.current_path.clear();
987                    }
988                } else {
989                    self.current_path.clear();
990                }
991            }
992            _ => {
993                self.current_path.clear();
994            }
995        }
996
997        Ok(TreeNodeRecursion::Continue)
998    }
999}
1000
1001#[cfg(test)]
1002mod tests {
1003
1004    use crate::logical_expr::ExprExt;
1005
1006    use super::*;
1007
1008    use arrow::datatypes::Float64Type;
1009    use arrow_array::{
1010        ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
1011        StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1012        TimestampNanosecondArray, TimestampSecondArray,
1013    };
1014    use arrow_schema::{DataType, Fields, Schema};
1015    use datafusion::{
1016        logical_expr::{lit, Cast},
1017        prelude::{array_element, get_field},
1018    };
1019    use datafusion_functions::core::expr_ext::FieldAccessor;
1020
1021    #[test]
1022    fn test_parse_filter_simple() {
1023        let schema = Arc::new(Schema::new(vec![
1024            Field::new("i", DataType::Int32, false),
1025            Field::new("s", DataType::Utf8, true),
1026            Field::new(
1027                "st",
1028                DataType::Struct(Fields::from(vec![
1029                    Field::new("x", DataType::Float32, false),
1030                    Field::new("y", DataType::Float32, false),
1031                ])),
1032                true,
1033            ),
1034        ]));
1035
1036        let planner = Planner::new(schema.clone());
1037
1038        let expected = col("i")
1039            .gt(lit(3_i32))
1040            .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1041            .and(
1042                col("s")
1043                    .eq(lit("str-4"))
1044                    .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1045            );
1046
1047        // double quotes
1048        let expr = planner
1049            .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1050            .unwrap();
1051        assert_eq!(expr, expected);
1052
1053        // single quote
1054        let expr = planner
1055            .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1056            .unwrap();
1057
1058        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1059
1060        let batch = RecordBatch::try_new(
1061            schema,
1062            vec![
1063                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1064                Arc::new(StringArray::from_iter_values(
1065                    (0..10).map(|v| format!("str-{}", v)),
1066                )),
1067                Arc::new(StructArray::from(vec![
1068                    (
1069                        Arc::new(Field::new("x", DataType::Float32, false)),
1070                        Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1071                            as ArrayRef,
1072                    ),
1073                    (
1074                        Arc::new(Field::new("y", DataType::Float32, false)),
1075                        Arc::new(Float32Array::from_iter_values(
1076                            (0..10).map(|v| (v * 10) as f32),
1077                        )),
1078                    ),
1079                ])),
1080            ],
1081        )
1082        .unwrap();
1083        let predicates = physical_expr.evaluate(&batch).unwrap();
1084        assert_eq!(
1085            predicates.into_array(0).unwrap().as_ref(),
1086            &BooleanArray::from(vec![
1087                false, false, false, false, true, true, false, false, false, false
1088            ])
1089        );
1090    }
1091
1092    #[test]
1093    fn test_nested_col_refs() {
1094        let schema = Arc::new(Schema::new(vec![
1095            Field::new("s0", DataType::Utf8, true),
1096            Field::new(
1097                "st",
1098                DataType::Struct(Fields::from(vec![
1099                    Field::new("s1", DataType::Utf8, true),
1100                    Field::new(
1101                        "st",
1102                        DataType::Struct(Fields::from(vec![Field::new(
1103                            "s2",
1104                            DataType::Utf8,
1105                            true,
1106                        )])),
1107                        true,
1108                    ),
1109                ])),
1110                true,
1111            ),
1112        ]));
1113
1114        let planner = Planner::new(schema);
1115
1116        fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1117            let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1118            assert!(matches!(
1119                expr,
1120                Expr::BinaryExpr(BinaryExpr {
1121                    left: _,
1122                    op: Operator::Eq,
1123                    right: _
1124                })
1125            ));
1126            if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1127                assert_eq!(left.as_ref(), expected);
1128            }
1129        }
1130
1131        let expected = Expr::Column(Column::new_unqualified("s0"));
1132        assert_column_eq(&planner, "s0", &expected);
1133        assert_column_eq(&planner, "`s0`", &expected);
1134
1135        let expected = Expr::ScalarFunction(ScalarFunction {
1136            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1137            args: vec![
1138                Expr::Column(Column::new_unqualified("st")),
1139                Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1140            ],
1141        });
1142        assert_column_eq(&planner, "st.s1", &expected);
1143        assert_column_eq(&planner, "`st`.`s1`", &expected);
1144        assert_column_eq(&planner, "st.`s1`", &expected);
1145
1146        let expected = Expr::ScalarFunction(ScalarFunction {
1147            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1148            args: vec![
1149                Expr::ScalarFunction(ScalarFunction {
1150                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1151                    args: vec![
1152                        Expr::Column(Column::new_unqualified("st")),
1153                        Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1154                    ],
1155                }),
1156                Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1157            ],
1158        });
1159
1160        assert_column_eq(&planner, "st.st.s2", &expected);
1161        assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1162        assert_column_eq(&planner, "st.st.`s2`", &expected);
1163        assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1164    }
1165
1166    #[test]
1167    fn test_nested_list_refs() {
1168        let schema = Arc::new(Schema::new(vec![Field::new(
1169            "l",
1170            DataType::List(Arc::new(Field::new(
1171                "item",
1172                DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1173                true,
1174            ))),
1175            true,
1176        )]));
1177
1178        let planner = Planner::new(schema);
1179
1180        let expected = array_element(col("l"), lit(0_i64));
1181        let expr = planner.parse_expr("l[0]").unwrap();
1182        assert_eq!(expr, expected);
1183
1184        let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1185        let expr = planner.parse_expr("l[0]['f1']").unwrap();
1186        assert_eq!(expr, expected);
1187
1188        // FIXME: This should work, but sqlparser doesn't recognize anything
1189        // after the period for some reason.
1190        // let expr = planner.parse_expr("l[0].f1").unwrap();
1191        // assert_eq!(expr, expected);
1192    }
1193
1194    #[test]
1195    fn test_negative_expressions() {
1196        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1197
1198        let planner = Planner::new(schema.clone());
1199
1200        let expected = col("x")
1201            .gt(lit(-3_i64))
1202            .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1203
1204        let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1205
1206        assert_eq!(expr, expected);
1207
1208        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1209
1210        let batch = RecordBatch::try_new(
1211            schema,
1212            vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1213        )
1214        .unwrap();
1215        let predicates = physical_expr.evaluate(&batch).unwrap();
1216        assert_eq!(
1217            predicates.into_array(0).unwrap().as_ref(),
1218            &BooleanArray::from(vec![
1219                false, false, false, true, true, true, true, false, false, false
1220            ])
1221        );
1222    }
1223
1224    #[test]
1225    fn test_negative_array_expressions() {
1226        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1227
1228        let planner = Planner::new(schema);
1229
1230        let expected = Expr::Literal(
1231            ScalarValue::List(Arc::new(
1232                ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1233                    [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1234                )]),
1235            )),
1236            None,
1237        );
1238
1239        let expr = planner
1240            .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1241            .unwrap();
1242
1243        assert_eq!(expr, expected);
1244    }
1245
1246    #[test]
1247    fn test_sql_like() {
1248        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1249
1250        let planner = Planner::new(schema.clone());
1251
1252        let expected = col("s").like(lit("str-4"));
1253        // single quote
1254        let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1255        assert_eq!(expr, expected);
1256        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1257
1258        let batch = RecordBatch::try_new(
1259            schema,
1260            vec![Arc::new(StringArray::from_iter_values(
1261                (0..10).map(|v| format!("str-{}", v)),
1262            ))],
1263        )
1264        .unwrap();
1265        let predicates = physical_expr.evaluate(&batch).unwrap();
1266        assert_eq!(
1267            predicates.into_array(0).unwrap().as_ref(),
1268            &BooleanArray::from(vec![
1269                false, false, false, false, true, false, false, false, false, false
1270            ])
1271        );
1272    }
1273
1274    #[test]
1275    fn test_not_like() {
1276        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1277
1278        let planner = Planner::new(schema.clone());
1279
1280        let expected = col("s").not_like(lit("str-4"));
1281        // single quote
1282        let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1283        assert_eq!(expr, expected);
1284        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1285
1286        let batch = RecordBatch::try_new(
1287            schema,
1288            vec![Arc::new(StringArray::from_iter_values(
1289                (0..10).map(|v| format!("str-{}", v)),
1290            ))],
1291        )
1292        .unwrap();
1293        let predicates = physical_expr.evaluate(&batch).unwrap();
1294        assert_eq!(
1295            predicates.into_array(0).unwrap().as_ref(),
1296            &BooleanArray::from(vec![
1297                true, true, true, true, false, true, true, true, true, true
1298            ])
1299        );
1300    }
1301
1302    #[test]
1303    fn test_sql_is_in() {
1304        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1305
1306        let planner = Planner::new(schema.clone());
1307
1308        let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1309        // single quote
1310        let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1311        assert_eq!(expr, expected);
1312        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1313
1314        let batch = RecordBatch::try_new(
1315            schema,
1316            vec![Arc::new(StringArray::from_iter_values(
1317                (0..10).map(|v| format!("str-{}", v)),
1318            ))],
1319        )
1320        .unwrap();
1321        let predicates = physical_expr.evaluate(&batch).unwrap();
1322        assert_eq!(
1323            predicates.into_array(0).unwrap().as_ref(),
1324            &BooleanArray::from(vec![
1325                false, false, false, false, true, true, false, false, false, false
1326            ])
1327        );
1328    }
1329
1330    #[test]
1331    fn test_sql_is_null() {
1332        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1333
1334        let planner = Planner::new(schema.clone());
1335
1336        let expected = col("s").is_null();
1337        let expr = planner.parse_filter("s IS NULL").unwrap();
1338        assert_eq!(expr, expected);
1339        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1340
1341        let batch = RecordBatch::try_new(
1342            schema,
1343            vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1344                if v % 3 == 0 {
1345                    Some(format!("str-{}", v))
1346                } else {
1347                    None
1348                }
1349            })))],
1350        )
1351        .unwrap();
1352        let predicates = physical_expr.evaluate(&batch).unwrap();
1353        assert_eq!(
1354            predicates.into_array(0).unwrap().as_ref(),
1355            &BooleanArray::from(vec![
1356                false, true, true, false, true, true, false, true, true, false
1357            ])
1358        );
1359
1360        let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1361        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1362        let predicates = physical_expr.evaluate(&batch).unwrap();
1363        assert_eq!(
1364            predicates.into_array(0).unwrap().as_ref(),
1365            &BooleanArray::from(vec![
1366                true, false, false, true, false, false, true, false, false, true,
1367            ])
1368        );
1369    }
1370
1371    #[test]
1372    fn test_sql_invert() {
1373        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1374
1375        let planner = Planner::new(schema.clone());
1376
1377        let expr = planner.parse_filter("NOT s").unwrap();
1378        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1379
1380        let batch = RecordBatch::try_new(
1381            schema,
1382            vec![Arc::new(BooleanArray::from_iter(
1383                (0..10).map(|v| Some(v % 3 == 0)),
1384            ))],
1385        )
1386        .unwrap();
1387        let predicates = physical_expr.evaluate(&batch).unwrap();
1388        assert_eq!(
1389            predicates.into_array(0).unwrap().as_ref(),
1390            &BooleanArray::from(vec![
1391                false, true, true, false, true, true, false, true, true, false
1392            ])
1393        );
1394    }
1395
1396    #[test]
1397    fn test_sql_cast() {
1398        let cases = &[
1399            (
1400                "x = cast('2021-01-01 00:00:00' as timestamp)",
1401                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1402            ),
1403            (
1404                "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1405                ArrowDataType::Timestamp(TimeUnit::Second, None),
1406            ),
1407            (
1408                "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1409                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1410            ),
1411            (
1412                "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1413                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1414            ),
1415            ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1416            (
1417                "x = cast('1.238' as decimal(9,3))",
1418                ArrowDataType::Decimal128(9, 3),
1419            ),
1420            ("x = cast(1 as float)", ArrowDataType::Float32),
1421            ("x = cast(1 as double)", ArrowDataType::Float64),
1422            ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1423            ("x = cast(1 as smallint)", ArrowDataType::Int16),
1424            ("x = cast(1 as int)", ArrowDataType::Int32),
1425            ("x = cast(1 as integer)", ArrowDataType::Int32),
1426            ("x = cast(1 as bigint)", ArrowDataType::Int64),
1427            ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1428            ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1429            ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1430            ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1431            ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1432            ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1433            ("x = cast(1 as string)", ArrowDataType::Utf8),
1434        ];
1435
1436        for (sql, expected_data_type) in cases {
1437            let schema = Arc::new(Schema::new(vec![Field::new(
1438                "x",
1439                expected_data_type.clone(),
1440                true,
1441            )]));
1442            let planner = Planner::new(schema.clone());
1443            let expr = planner.parse_filter(sql).unwrap();
1444
1445            // Get the thing after 'cast(` but before ' as'.
1446            let expected_value_str = sql
1447                .split("cast(")
1448                .nth(1)
1449                .unwrap()
1450                .split(" as")
1451                .next()
1452                .unwrap();
1453            // Remove any quote marks
1454            let expected_value_str = expected_value_str.trim_matches('\'');
1455
1456            match expr {
1457                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1458                    Expr::Cast(Cast { expr, data_type }) => {
1459                        match expr.as_ref() {
1460                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1461                                assert_eq!(value_str, expected_value_str);
1462                            }
1463                            Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1464                                assert_eq!(*value, 1);
1465                            }
1466                            _ => panic!("Expected cast to be applied to literal"),
1467                        }
1468                        assert_eq!(data_type, expected_data_type);
1469                    }
1470                    _ => panic!("Expected right to be a cast"),
1471                },
1472                _ => panic!("Expected binary expression"),
1473            }
1474        }
1475    }
1476
1477    #[test]
1478    fn test_sql_literals() {
1479        let cases = &[
1480            (
1481                "x = timestamp '2021-01-01 00:00:00'",
1482                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1483            ),
1484            (
1485                "x = timestamp(0) '2021-01-01 00:00:00'",
1486                ArrowDataType::Timestamp(TimeUnit::Second, None),
1487            ),
1488            (
1489                "x = timestamp(9) '2021-01-01 00:00:00.123'",
1490                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1491            ),
1492            ("x = date '2021-01-01'", ArrowDataType::Date32),
1493            ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1494        ];
1495
1496        for (sql, expected_data_type) in cases {
1497            let schema = Arc::new(Schema::new(vec![Field::new(
1498                "x",
1499                expected_data_type.clone(),
1500                true,
1501            )]));
1502            let planner = Planner::new(schema.clone());
1503            let expr = planner.parse_filter(sql).unwrap();
1504
1505            let expected_value_str = sql.split('\'').nth(1).unwrap();
1506
1507            match expr {
1508                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1509                    Expr::Cast(Cast { expr, data_type }) => {
1510                        match expr.as_ref() {
1511                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1512                                assert_eq!(value_str, expected_value_str);
1513                            }
1514                            _ => panic!("Expected cast to be applied to literal"),
1515                        }
1516                        assert_eq!(data_type, expected_data_type);
1517                    }
1518                    _ => panic!("Expected right to be a cast"),
1519                },
1520                _ => panic!("Expected binary expression"),
1521            }
1522        }
1523    }
1524
1525    #[test]
1526    fn test_sql_array_literals() {
1527        let cases = [
1528            (
1529                "x = [1, 2, 3]",
1530                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1531            ),
1532            (
1533                "x = [1, 2, 3]",
1534                ArrowDataType::FixedSizeList(
1535                    Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1536                    3,
1537                ),
1538            ),
1539        ];
1540
1541        for (sql, expected_data_type) in cases {
1542            let schema = Arc::new(Schema::new(vec![Field::new(
1543                "x",
1544                expected_data_type.clone(),
1545                true,
1546            )]));
1547            let planner = Planner::new(schema.clone());
1548            let expr = planner.parse_filter(sql).unwrap();
1549            let expr = planner.optimize_expr(expr).unwrap();
1550
1551            match expr {
1552                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1553                    Expr::Literal(value, _) => {
1554                        assert_eq!(&value.data_type(), &expected_data_type);
1555                    }
1556                    _ => panic!("Expected right to be a literal"),
1557                },
1558                _ => panic!("Expected binary expression"),
1559            }
1560        }
1561    }
1562
1563    #[test]
1564    fn test_sql_between() {
1565        use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1566        use arrow_schema::{DataType, Field, Schema, TimeUnit};
1567        use std::sync::Arc;
1568
1569        let schema = Arc::new(Schema::new(vec![
1570            Field::new("x", DataType::Int32, false),
1571            Field::new("y", DataType::Float64, false),
1572            Field::new(
1573                "ts",
1574                DataType::Timestamp(TimeUnit::Microsecond, None),
1575                false,
1576            ),
1577        ]));
1578
1579        let planner = Planner::new(schema.clone());
1580
1581        // Test integer BETWEEN
1582        let expr = planner
1583            .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1584            .unwrap();
1585        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1586
1587        // Create timestamp array with values representing:
1588        // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
1589        let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
1590        let ts_array = TimestampMicrosecondArray::from_iter_values(
1591            (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
1592        );
1593
1594        let batch = RecordBatch::try_new(
1595            schema,
1596            vec![
1597                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1598                Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1599                Arc::new(ts_array),
1600            ],
1601        )
1602        .unwrap();
1603
1604        let predicates = physical_expr.evaluate(&batch).unwrap();
1605        assert_eq!(
1606            predicates.into_array(0).unwrap().as_ref(),
1607            &BooleanArray::from(vec![
1608                false, false, false, true, true, true, true, true, false, false
1609            ])
1610        );
1611
1612        // Test NOT BETWEEN
1613        let expr = planner
1614            .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1615            .unwrap();
1616        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1617
1618        let predicates = physical_expr.evaluate(&batch).unwrap();
1619        assert_eq!(
1620            predicates.into_array(0).unwrap().as_ref(),
1621            &BooleanArray::from(vec![
1622                true, true, true, false, false, false, false, false, true, true
1623            ])
1624        );
1625
1626        // Test floating point BETWEEN
1627        let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1628        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1629
1630        let predicates = physical_expr.evaluate(&batch).unwrap();
1631        assert_eq!(
1632            predicates.into_array(0).unwrap().as_ref(),
1633            &BooleanArray::from(vec![
1634                false, false, false, true, true, true, true, false, false, false
1635            ])
1636        );
1637
1638        // Test timestamp BETWEEN
1639        let expr = planner
1640            .parse_filter(
1641                "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1642            )
1643            .unwrap();
1644        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1645
1646        let predicates = physical_expr.evaluate(&batch).unwrap();
1647        assert_eq!(
1648            predicates.into_array(0).unwrap().as_ref(),
1649            &BooleanArray::from(vec![
1650                false, false, false, true, true, true, true, true, false, false
1651            ])
1652        );
1653    }
1654
1655    #[test]
1656    fn test_sql_comparison() {
1657        // Create a batch with all data types
1658        let batch: Vec<(&str, ArrayRef)> = vec![
1659            (
1660                "timestamp_s",
1661                Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1662            ),
1663            (
1664                "timestamp_ms",
1665                Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1666            ),
1667            (
1668                "timestamp_us",
1669                Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1670            ),
1671            (
1672                "timestamp_ns",
1673                Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1674            ),
1675        ];
1676        let batch = RecordBatch::try_from_iter(batch).unwrap();
1677
1678        let planner = Planner::new(batch.schema());
1679
1680        // Each expression is meant to select the final 5 rows
1681        let expressions = &[
1682            "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1683            "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1684            "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1685            "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1686        ];
1687
1688        let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1689            std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1690        ));
1691        for expression in expressions {
1692            // convert to physical expression
1693            let logical_expr = planner.parse_filter(expression).unwrap();
1694            let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1695            let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1696
1697            // Evaluate and assert they have correct results
1698            let result = physical_expr.evaluate(&batch).unwrap();
1699            let result = result.into_array(batch.num_rows()).unwrap();
1700            assert_eq!(&expected, &result, "unexpected result for {}", expression);
1701        }
1702    }
1703
1704    #[test]
1705    fn test_columns_in_expr() {
1706        let expr = col("s0").gt(lit("value")).and(
1707            col("st")
1708                .field("st")
1709                .field("s2")
1710                .eq(lit("value"))
1711                .or(col("st")
1712                    .field("s1")
1713                    .in_list(vec![lit("value 1"), lit("value 2")], false)),
1714        );
1715
1716        let columns = Planner::column_names_in_expr(&expr);
1717        assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1718    }
1719
1720    #[test]
1721    fn test_parse_binary_expr() {
1722        let bin_str = "x'616263'";
1723
1724        let schema = Arc::new(Schema::new(vec![Field::new(
1725            "binary",
1726            DataType::Binary,
1727            true,
1728        )]));
1729        let planner = Planner::new(schema);
1730        let expr = planner.parse_expr(bin_str).unwrap();
1731        assert_eq!(
1732            expr,
1733            Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1734        );
1735    }
1736
1737    #[test]
1738    fn test_lance_context_provider_expr_planners() {
1739        let ctx_provider = LanceContextProvider::default();
1740        assert!(!ctx_provider.get_expr_planners().is_empty());
1741    }
1742}