lance_datafusion/
planner.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Exec plan planner
5
6use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::SessionState;
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29    AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30    Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
34use datafusion::sql::sqlparser::ast::{
35    AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36    Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
37    ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
38};
39use datafusion::{
40    common::Column,
41    logical_expr::{col, Between, BinaryExpr, Like, Operator},
42    physical_expr::execution_props::ExecutionProps,
43    physical_plan::PhysicalExpr,
44    prelude::Expr,
45    scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use lance_core::error::LanceOptionExt;
51use snafu::location;
52
53use lance_core::{Error, Result};
54
55#[derive(Debug, Clone)]
56struct CastListF16Udf {
57    signature: Signature,
58}
59
60impl CastListF16Udf {
61    pub fn new() -> Self {
62        Self {
63            signature: Signature::any(1, Volatility::Immutable),
64        }
65    }
66}
67
68impl ScalarUDFImpl for CastListF16Udf {
69    fn as_any(&self) -> &dyn std::any::Any {
70        self
71    }
72
73    fn name(&self) -> &str {
74        "_cast_list_f16"
75    }
76
77    fn signature(&self) -> &Signature {
78        &self.signature
79    }
80
81    fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
82        let input = &arg_types[0];
83        match input {
84            ArrowDataType::FixedSizeList(field, size) => {
85                if field.data_type() != &ArrowDataType::Float32
86                    && field.data_type() != &ArrowDataType::Float16
87                {
88                    return Err(datafusion::error::DataFusionError::Execution(
89                        "cast_list_f16 only supports list of float32 or float16".to_string(),
90                    ));
91                }
92                Ok(ArrowDataType::FixedSizeList(
93                    Arc::new(Field::new(
94                        field.name(),
95                        ArrowDataType::Float16,
96                        field.is_nullable(),
97                    )),
98                    *size,
99                ))
100            }
101            ArrowDataType::List(field) => {
102                if field.data_type() != &ArrowDataType::Float32
103                    && field.data_type() != &ArrowDataType::Float16
104                {
105                    return Err(datafusion::error::DataFusionError::Execution(
106                        "cast_list_f16 only supports list of float32 or float16".to_string(),
107                    ));
108                }
109                Ok(ArrowDataType::List(Arc::new(Field::new(
110                    field.name(),
111                    ArrowDataType::Float16,
112                    field.is_nullable(),
113                ))))
114            }
115            _ => Err(datafusion::error::DataFusionError::Execution(
116                "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
117            )),
118        }
119    }
120
121    fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
122        let ColumnarValue::Array(arr) = &func_args.args[0] else {
123            return Err(datafusion::error::DataFusionError::Execution(
124                "cast_list_f16 only supports array arguments".to_string(),
125            ));
126        };
127
128        let to_type = match arr.data_type() {
129            ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
130                Arc::new(Field::new(
131                    field.name(),
132                    ArrowDataType::Float16,
133                    field.is_nullable(),
134                )),
135                *size,
136            ),
137            ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
138                field.name(),
139                ArrowDataType::Float16,
140                field.is_nullable(),
141            ))),
142            _ => {
143                return Err(datafusion::error::DataFusionError::Execution(
144                    "cast_list_f16 only supports array arguments".to_string(),
145                ));
146            }
147        };
148
149        let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
150        Ok(ColumnarValue::Array(res))
151    }
152}
153
154// Adapter that instructs datafusion how lance expects expressions to be interpreted
155struct LanceContextProvider {
156    options: datafusion::config::ConfigOptions,
157    state: SessionState,
158    expr_planners: Vec<Arc<dyn ExprPlanner>>,
159}
160
161impl Default for LanceContextProvider {
162    fn default() -> Self {
163        let config = SessionConfig::new();
164        let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
165        let mut state_builder = SessionStateBuilder::new()
166            .with_config(config)
167            .with_runtime_env(runtime)
168            .with_default_features();
169
170        // SessionState does not expose expr_planners, so we need to get the default ones from
171        // the builder and store them to return from get_expr_planners
172
173        // unwrap safe because with_default_features sets expr_planners
174        let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
175
176        Self {
177            options: ConfigOptions::default(),
178            state: state_builder.build(),
179            expr_planners,
180        }
181    }
182}
183
184impl ContextProvider for LanceContextProvider {
185    fn get_table_source(
186        &self,
187        name: datafusion::sql::TableReference,
188    ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
189        Err(datafusion::error::DataFusionError::NotImplemented(format!(
190            "Attempt to reference inner table {} not supported",
191            name
192        )))
193    }
194
195    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
196        self.state.aggregate_functions().get(name).cloned()
197    }
198
199    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
200        self.state.window_functions().get(name).cloned()
201    }
202
203    fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
204        match f {
205            // TODO: cast should go thru CAST syntax instead of UDF
206            // Going thru UDF makes it hard for the optimizer to find no-ops
207            "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
208            _ => self.state.scalar_functions().get(f).cloned(),
209        }
210    }
211
212    fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
213        // Variables (things like @@LANGUAGE) not supported
214        None
215    }
216
217    fn options(&self) -> &datafusion::config::ConfigOptions {
218        &self.options
219    }
220
221    fn udf_names(&self) -> Vec<String> {
222        self.state.scalar_functions().keys().cloned().collect()
223    }
224
225    fn udaf_names(&self) -> Vec<String> {
226        self.state.aggregate_functions().keys().cloned().collect()
227    }
228
229    fn udwf_names(&self) -> Vec<String> {
230        self.state.window_functions().keys().cloned().collect()
231    }
232
233    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
234        &self.expr_planners
235    }
236}
237
238pub struct Planner {
239    schema: SchemaRef,
240    context_provider: LanceContextProvider,
241    enable_relations: bool,
242}
243
244impl Planner {
245    pub fn new(schema: SchemaRef) -> Self {
246        Self {
247            schema,
248            context_provider: LanceContextProvider::default(),
249            enable_relations: false,
250        }
251    }
252
253    /// If passed with `true`, then the first identifier in column reference
254    /// is parsed as the relation. For example, `table.field.inner` will be
255    /// read as the nested field `field.inner` (`inner` on struct field `field`)
256    /// on the `table` relation. If `false` (the default), then no relations
257    /// are used and all identifiers are assumed to be a nested column path.
258    pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
259        self.enable_relations = enable_relations;
260        self
261    }
262
263    fn column(&self, idents: &[Ident]) -> Expr {
264        fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
265            for ident in idents {
266                *expr = Expr::ScalarFunction(ScalarFunction {
267                    args: vec![
268                        std::mem::take(expr),
269                        Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
270                    ],
271                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
272                });
273            }
274        }
275
276        if self.enable_relations && idents.len() > 1 {
277            // Create qualified column reference (relation.column)
278            let relation = &idents[0].value;
279            let column_name = &idents[1].value;
280            let column = Expr::Column(Column::new(Some(relation.clone()), column_name.clone()));
281            let mut result = column;
282            handle_remaining_idents(&mut result, &idents[2..]);
283            result
284        } else {
285            // Default behavior - treat as struct field access
286            let mut column = col(&idents[0].value);
287            handle_remaining_idents(&mut column, &idents[1..]);
288            column
289        }
290    }
291
292    fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
293        Ok(match op {
294            BinaryOperator::Plus => Operator::Plus,
295            BinaryOperator::Minus => Operator::Minus,
296            BinaryOperator::Multiply => Operator::Multiply,
297            BinaryOperator::Divide => Operator::Divide,
298            BinaryOperator::Modulo => Operator::Modulo,
299            BinaryOperator::StringConcat => Operator::StringConcat,
300            BinaryOperator::Gt => Operator::Gt,
301            BinaryOperator::Lt => Operator::Lt,
302            BinaryOperator::GtEq => Operator::GtEq,
303            BinaryOperator::LtEq => Operator::LtEq,
304            BinaryOperator::Eq => Operator::Eq,
305            BinaryOperator::NotEq => Operator::NotEq,
306            BinaryOperator::And => Operator::And,
307            BinaryOperator::Or => Operator::Or,
308            _ => {
309                return Err(Error::invalid_input(
310                    format!("Operator {op} is not supported"),
311                    location!(),
312                ));
313            }
314        })
315    }
316
317    fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
318        Ok(Expr::BinaryExpr(BinaryExpr::new(
319            Box::new(self.parse_sql_expr(left)?),
320            self.binary_op(op)?,
321            Box::new(self.parse_sql_expr(right)?),
322        )))
323    }
324
325    fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
326        Ok(match op {
327            UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
328                Expr::Not(Box::new(self.parse_sql_expr(expr)?))
329            }
330
331            UnaryOperator::Minus => {
332                use datafusion::logical_expr::lit;
333                match expr {
334                    SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
335                        Ok(n) => lit(-n),
336                        Err(_) => lit(-n
337                            .parse::<f64>()
338                            .map_err(|_e| {
339                                Error::invalid_input(
340                                    format!("negative operator can be only applied to integer and float operands, got: {n}"),
341                                    location!(),
342                                )
343                            })?),
344                    },
345                    _ => {
346                        Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
347                    }
348                }
349            }
350
351            _ => {
352                return Err(Error::invalid_input(
353                    format!("Unary operator '{:?}' is not supported", op),
354                    location!(),
355                ));
356            }
357        })
358    }
359
360    // See datafusion `sqlToRel::parse_sql_number()`
361    fn number(&self, value: &str, negative: bool) -> Result<Expr> {
362        use datafusion::logical_expr::lit;
363        let value: Cow<str> = if negative {
364            Cow::Owned(format!("-{}", value))
365        } else {
366            Cow::Borrowed(value)
367        };
368        if let Ok(n) = value.parse::<i64>() {
369            Ok(lit(n))
370        } else {
371            value.parse::<f64>().map(lit).map_err(|_| {
372                Error::invalid_input(
373                    format!("'{value}' is not supported number value."),
374                    location!(),
375                )
376            })
377        }
378    }
379
380    fn value(&self, value: &Value) -> Result<Expr> {
381        Ok(match value {
382            Value::Number(v, _) => self.number(v.as_str(), false)?,
383            Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
384            Value::HexStringLiteral(hsl) => {
385                Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
386            }
387            Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
388            Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
389            Value::Null => Expr::Literal(ScalarValue::Null, None),
390            _ => todo!(),
391        })
392    }
393
394    fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
395        match func_args {
396            FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
397            _ => Err(Error::invalid_input(
398                format!("Unsupported function args: {:?}", func_args),
399                location!(),
400            )),
401        }
402    }
403
404    // We now use datafusion to parse functions.  This allows us to use datafusion's
405    // entire collection of functions (previously we had just hard-coded support for two functions).
406    //
407    // Unfortunately, one of those two functions was is_valid and the reason we needed it was because
408    // this is a function that comes from duckdb.  Datafusion does not consider is_valid to be a function
409    // but rather an AST node (Expr::IsNotNull) and so we need to handle this case specially.
410    fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
411        match &func.args {
412            FunctionArguments::List(args) => {
413                if func.name.0.len() != 1 {
414                    return Err(Error::invalid_input(
415                        format!("Function name must have 1 part, got: {:?}", func.name.0),
416                        location!(),
417                    ));
418                }
419                Ok(Expr::IsNotNull(Box::new(
420                    self.parse_function_args(&args.args[0])?,
421                )))
422            }
423            _ => Err(Error::invalid_input(
424                format!("Unsupported function args: {:?}", &func.args),
425                location!(),
426            )),
427        }
428    }
429
430    fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
431        if let SQLExpr::Function(function) = &function {
432            if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
433                if &name.value == "is_valid" {
434                    return self.legacy_parse_function(function);
435                }
436            }
437        }
438        let sql_to_rel = SqlToRel::new_with_options(
439            &self.context_provider,
440            ParserOptions {
441                parse_float_as_decimal: false,
442                enable_ident_normalization: false,
443                support_varchar_with_length: false,
444                enable_options_value_normalization: false,
445                collect_spans: false,
446                map_varchar_to_utf8view: false,
447            },
448        );
449
450        let mut planner_context = PlannerContext::default();
451        let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
452        sql_to_rel
453            .sql_to_expr(function, &schema, &mut planner_context)
454            .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
455    }
456
457    fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
458        const SUPPORTED_TYPES: [&str; 13] = [
459            "int [unsigned]",
460            "tinyint [unsigned]",
461            "smallint [unsigned]",
462            "bigint [unsigned]",
463            "float",
464            "double",
465            "string",
466            "binary",
467            "date",
468            "timestamp(precision)",
469            "datetime(precision)",
470            "decimal(precision,scale)",
471            "boolean",
472        ];
473        match data_type {
474            SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
475            SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
476            SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
477            SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
478            SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
479            SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
480            SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
481            SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
482            SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
483            SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
484            SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
485            SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
486                Ok(ArrowDataType::UInt32)
487            }
488            SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
489            SQLDataType::Date => Ok(ArrowDataType::Date32),
490            SQLDataType::Timestamp(resolution, tz) => {
491                match tz {
492                    TimezoneInfo::None => {}
493                    _ => {
494                        return Err(Error::invalid_input(
495                            "Timezone not supported in timestamp".to_string(),
496                            location!(),
497                        ));
498                    }
499                };
500                let time_unit = match resolution {
501                    // Default to microsecond to match PyArrow
502                    None => TimeUnit::Microsecond,
503                    Some(0) => TimeUnit::Second,
504                    Some(3) => TimeUnit::Millisecond,
505                    Some(6) => TimeUnit::Microsecond,
506                    Some(9) => TimeUnit::Nanosecond,
507                    _ => {
508                        return Err(Error::invalid_input(
509                            format!("Unsupported datetime resolution: {:?}", resolution),
510                            location!(),
511                        ));
512                    }
513                };
514                Ok(ArrowDataType::Timestamp(time_unit, None))
515            }
516            SQLDataType::Datetime(resolution) => {
517                let time_unit = match resolution {
518                    None => TimeUnit::Microsecond,
519                    Some(0) => TimeUnit::Second,
520                    Some(3) => TimeUnit::Millisecond,
521                    Some(6) => TimeUnit::Microsecond,
522                    Some(9) => TimeUnit::Nanosecond,
523                    _ => {
524                        return Err(Error::invalid_input(
525                            format!("Unsupported datetime resolution: {:?}", resolution),
526                            location!(),
527                        ));
528                    }
529                };
530                Ok(ArrowDataType::Timestamp(time_unit, None))
531            }
532            SQLDataType::Decimal(number_info) => match number_info {
533                ExactNumberInfo::PrecisionAndScale(precision, scale) => {
534                    Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
535                }
536                _ => Err(Error::invalid_input(
537                    format!(
538                        "Must provide precision and scale for decimal: {:?}",
539                        number_info
540                    ),
541                    location!(),
542                )),
543            },
544            _ => Err(Error::invalid_input(
545                format!(
546                    "Unsupported data type: {:?}. Supported types: {:?}",
547                    data_type, SUPPORTED_TYPES
548                ),
549                location!(),
550            )),
551        }
552    }
553
554    fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
555        let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
556        for planner in self.context_provider.get_expr_planners() {
557            match planner.plan_field_access(field_access_expr, &df_schema)? {
558                PlannerResult::Planned(expr) => return Ok(expr),
559                PlannerResult::Original(expr) => {
560                    field_access_expr = expr;
561                }
562            }
563        }
564        Err(Error::invalid_input(
565            "Field access could not be planned",
566            location!(),
567        ))
568    }
569
570    fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
571        match expr {
572            SQLExpr::Identifier(id) => {
573                // Users can pass string literals wrapped in `"`.
574                // (Normally SQL only allows single quotes.)
575                if id.quote_style == Some('"') {
576                    Ok(Expr::Literal(
577                        ScalarValue::Utf8(Some(id.value.clone())),
578                        None,
579                    ))
580                // Users can wrap identifiers with ` to reference non-standard
581                // names, such as uppercase or spaces.
582                } else if id.quote_style == Some('`') {
583                    Ok(Expr::Column(Column::from_name(id.value.clone())))
584                } else {
585                    Ok(self.column(vec![id.clone()].as_slice()))
586                }
587            }
588            SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
589            SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
590            SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
591            SQLExpr::Value(value) => self.value(&value.value),
592            SQLExpr::Array(SQLArray { elem, .. }) => {
593                let mut values = vec![];
594
595                let array_literal_error = |pos: usize, value: &_| {
596                    Err(Error::invalid_input(
597                        format!(
598                            "Expected a literal value in array, instead got {} at position {}",
599                            value, pos
600                        ),
601                        location!(),
602                    ))
603                };
604
605                for (pos, expr) in elem.iter().enumerate() {
606                    match expr {
607                        SQLExpr::Value(value) => {
608                            if let Expr::Literal(value, _) = self.value(&value.value)? {
609                                values.push(value);
610                            } else {
611                                return array_literal_error(pos, expr);
612                            }
613                        }
614                        SQLExpr::UnaryOp {
615                            op: UnaryOperator::Minus,
616                            expr,
617                        } => {
618                            if let SQLExpr::Value(ValueWithSpan {
619                                value: Value::Number(number, _),
620                                ..
621                            }) = expr.as_ref()
622                            {
623                                if let Expr::Literal(value, _) = self.number(number, true)? {
624                                    values.push(value);
625                                } else {
626                                    return array_literal_error(pos, expr);
627                                }
628                            } else {
629                                return array_literal_error(pos, expr);
630                            }
631                        }
632                        _ => {
633                            return array_literal_error(pos, expr);
634                        }
635                    }
636                }
637
638                let field = if !values.is_empty() {
639                    let data_type = values[0].data_type();
640
641                    for value in &mut values {
642                        if value.data_type() != data_type {
643                            *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
644                                format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
645                                location!()
646                            ))?;
647                        }
648                    }
649                    Field::new("item", data_type, true)
650                } else {
651                    Field::new("item", ArrowDataType::Null, true)
652                };
653
654                let values = values
655                    .into_iter()
656                    .map(|v| v.to_array().map_err(Error::from))
657                    .collect::<Result<Vec<_>>>()?;
658                let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
659                let values = concat(&array_refs)?;
660                let values = ListArray::try_new(
661                    field.into(),
662                    OffsetBuffer::from_lengths([values.len()]),
663                    values,
664                    None,
665                )?;
666
667                Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
668            }
669            // For example, DATE '2020-01-01'
670            SQLExpr::TypedString { data_type, value } => {
671                let value = value.clone().into_string().expect_ok()?;
672                Ok(Expr::Cast(datafusion::logical_expr::Cast {
673                    expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
674                    data_type: self.parse_type(data_type)?,
675                }))
676            }
677            SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
678            SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
679            SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
680            SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
681            SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
682            SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
683            SQLExpr::InList {
684                expr,
685                list,
686                negated,
687            } => {
688                let value_expr = self.parse_sql_expr(expr)?;
689                let list_exprs = list
690                    .iter()
691                    .map(|e| self.parse_sql_expr(e))
692                    .collect::<Result<Vec<_>>>()?;
693                Ok(value_expr.in_list(list_exprs, *negated))
694            }
695            SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
696            SQLExpr::Function(_) => self.parse_function(expr.clone()),
697            SQLExpr::ILike {
698                negated,
699                expr,
700                pattern,
701                escape_char,
702                any: _,
703            } => Ok(Expr::Like(Like::new(
704                *negated,
705                Box::new(self.parse_sql_expr(expr)?),
706                Box::new(self.parse_sql_expr(pattern)?),
707                escape_char.as_ref().and_then(|c| c.chars().next()),
708                true,
709            ))),
710            SQLExpr::Like {
711                negated,
712                expr,
713                pattern,
714                escape_char,
715                any: _,
716            } => Ok(Expr::Like(Like::new(
717                *negated,
718                Box::new(self.parse_sql_expr(expr)?),
719                Box::new(self.parse_sql_expr(pattern)?),
720                escape_char.as_ref().and_then(|c| c.chars().next()),
721                false,
722            ))),
723            SQLExpr::Cast {
724                expr,
725                data_type,
726                kind,
727                ..
728            } => match kind {
729                datafusion::sql::sqlparser::ast::CastKind::TryCast
730                | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
731                    Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
732                        expr: Box::new(self.parse_sql_expr(expr)?),
733                        data_type: self.parse_type(data_type)?,
734                    }))
735                }
736                _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
737                    expr: Box::new(self.parse_sql_expr(expr)?),
738                    data_type: self.parse_type(data_type)?,
739                })),
740            },
741            SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
742                "JSON access is not supported",
743                location!(),
744            )),
745            SQLExpr::CompoundFieldAccess { root, access_chain } => {
746                let mut expr = self.parse_sql_expr(root)?;
747
748                for access in access_chain {
749                    let field_access = match access {
750                        // x.y or x['y']
751                        AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
752                        | AccessExpr::Subscript(Subscript::Index {
753                            index:
754                                SQLExpr::Value(ValueWithSpan {
755                                    value:
756                                        Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
757                                    ..
758                                }),
759                        }) => GetFieldAccess::NamedStructField {
760                            name: ScalarValue::from(s.as_str()),
761                        },
762                        AccessExpr::Subscript(Subscript::Index { index }) => {
763                            let key = Box::new(self.parse_sql_expr(index)?);
764                            GetFieldAccess::ListIndex { key }
765                        }
766                        AccessExpr::Subscript(Subscript::Slice { .. }) => {
767                            return Err(Error::invalid_input(
768                                "Slice subscript is not supported",
769                                location!(),
770                            ));
771                        }
772                        _ => {
773                            // Handle other cases like JSON access
774                            // Note: JSON access is not supported in lance
775                            return Err(Error::invalid_input(
776                                "Only dot notation or index access is supported for field access",
777                                location!(),
778                            ));
779                        }
780                    };
781
782                    let field_access_expr = RawFieldAccessExpr { expr, field_access };
783                    expr = self.plan_field_access(field_access_expr)?;
784                }
785
786                Ok(expr)
787            }
788            SQLExpr::Between {
789                expr,
790                negated,
791                low,
792                high,
793            } => {
794                // Parse the main expression and bounds
795                let expr = self.parse_sql_expr(expr)?;
796                let low = self.parse_sql_expr(low)?;
797                let high = self.parse_sql_expr(high)?;
798
799                let between = Expr::Between(Between::new(
800                    Box::new(expr),
801                    *negated,
802                    Box::new(low),
803                    Box::new(high),
804                ));
805                Ok(between)
806            }
807            _ => Err(Error::invalid_input(
808                format!("Expression '{expr}' is not supported SQL in lance"),
809                location!(),
810            )),
811        }
812    }
813
814    /// Create Logical [Expr] from a SQL filter clause.
815    ///
816    /// Note: the returned expression must be passed through [optimize_expr()]
817    /// before being passed to [create_physical_expr()].
818    pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
819        // Allow sqlparser to parse filter as part of ONE SQL statement.
820        let ast_expr = parse_sql_filter(filter)?;
821        let expr = self.parse_sql_expr(&ast_expr)?;
822        let schema = Schema::try_from(self.schema.as_ref())?;
823        let resolved = resolve_expr(&expr, &schema).map_err(|e| {
824            Error::invalid_input(
825                format!("Error resolving filter expression {filter}: {e}"),
826                location!(),
827            )
828        })?;
829
830        Ok(coerce_filter_type_to_boolean(resolved))
831    }
832
833    /// Create Logical [Expr] from a SQL expression.
834    ///
835    /// Note: the returned expression must be passed through [optimize_filter()]
836    /// before being passed to [create_physical_expr()].
837    pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
838        let ast_expr = parse_sql_expr(expr)?;
839        let expr = self.parse_sql_expr(&ast_expr)?;
840        let schema = Schema::try_from(self.schema.as_ref())?;
841        let resolved = resolve_expr(&expr, &schema)?;
842        Ok(resolved)
843    }
844
845    /// Try to decode bytes from hex literal string.
846    ///
847    /// Copied from datafusion because this is not public.
848    ///
849    /// TODO: use SqlToRel from Datafusion directly?
850    fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
851        let hex_bytes = s.as_bytes();
852        let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
853
854        let start_idx = hex_bytes.len() % 2;
855        if start_idx > 0 {
856            // The first byte is formed of only one char.
857            decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
858        }
859
860        for i in (start_idx..hex_bytes.len()).step_by(2) {
861            let high = Self::try_decode_hex_char(hex_bytes[i])?;
862            let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
863            decoded_bytes.push((high << 4) | low);
864        }
865
866        Some(decoded_bytes)
867    }
868
869    /// Try to decode a byte from a hex char.
870    ///
871    /// None will be returned if the input char is hex-invalid.
872    const fn try_decode_hex_char(c: u8) -> Option<u8> {
873        match c {
874            b'A'..=b'F' => Some(c - b'A' + 10),
875            b'a'..=b'f' => Some(c - b'a' + 10),
876            b'0'..=b'9' => Some(c - b'0'),
877            _ => None,
878        }
879    }
880
881    /// Optimize the filter expression and coerce data types.
882    pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
883        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
884
885        // DataFusion needs the simplify and coerce passes to be applied before
886        // expressions can be handled by the physical planner.
887        let props = ExecutionProps::default();
888        let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
889        let simplifier =
890            datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
891
892        let expr = simplifier.simplify(expr)?;
893        let expr = simplifier.coerce(expr, &df_schema)?;
894
895        Ok(expr)
896    }
897
898    /// Create the [`PhysicalExpr`] from a logical [`Expr`]
899    pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
900        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
901        Ok(datafusion::physical_expr::create_physical_expr(
902            expr,
903            df_schema.as_ref(),
904            &Default::default(),
905        )?)
906    }
907
908    /// Collect the columns in the expression.
909    ///
910    /// The columns are returned in sorted order.
911    ///
912    /// If the expr refers to nested columns these will be returned
913    /// as dotted paths (x.y.z)
914    pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
915        let mut visitor = ColumnCapturingVisitor {
916            current_path: VecDeque::new(),
917            columns: BTreeSet::new(),
918        };
919        expr.visit(&mut visitor).unwrap();
920        visitor.columns.into_iter().collect()
921    }
922}
923
924struct ColumnCapturingVisitor {
925    // Current column path. If this is empty, we are not in a column expression.
926    current_path: VecDeque<String>,
927    columns: BTreeSet<String>,
928}
929
930impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
931    type Node = Expr;
932
933    fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
934        match node {
935            Expr::Column(Column { name, .. }) => {
936                let mut path = name.clone();
937                for part in self.current_path.drain(..) {
938                    path.push('.');
939                    path.push_str(&part);
940                }
941                self.columns.insert(path);
942                self.current_path.clear();
943            }
944            Expr::ScalarFunction(udf) => {
945                if udf.name() == GetFieldFunc::default().name() {
946                    if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
947                        self.current_path.push_front(name.to_string())
948                    } else {
949                        self.current_path.clear();
950                    }
951                } else {
952                    self.current_path.clear();
953                }
954            }
955            _ => {
956                self.current_path.clear();
957            }
958        }
959
960        Ok(TreeNodeRecursion::Continue)
961    }
962}
963
964#[cfg(test)]
965mod tests {
966
967    use crate::logical_expr::ExprExt;
968
969    use super::*;
970
971    use arrow::datatypes::Float64Type;
972    use arrow_array::{
973        ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
974        StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
975        TimestampNanosecondArray, TimestampSecondArray,
976    };
977    use arrow_schema::{DataType, Fields, Schema};
978    use datafusion::{
979        logical_expr::{lit, Cast},
980        prelude::{array_element, get_field},
981    };
982    use datafusion_functions::core::expr_ext::FieldAccessor;
983
984    #[test]
985    fn test_parse_filter_simple() {
986        let schema = Arc::new(Schema::new(vec![
987            Field::new("i", DataType::Int32, false),
988            Field::new("s", DataType::Utf8, true),
989            Field::new(
990                "st",
991                DataType::Struct(Fields::from(vec![
992                    Field::new("x", DataType::Float32, false),
993                    Field::new("y", DataType::Float32, false),
994                ])),
995                true,
996            ),
997        ]));
998
999        let planner = Planner::new(schema.clone());
1000
1001        let expected = col("i")
1002            .gt(lit(3_i32))
1003            .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1004            .and(
1005                col("s")
1006                    .eq(lit("str-4"))
1007                    .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1008            );
1009
1010        // double quotes
1011        let expr = planner
1012            .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1013            .unwrap();
1014        assert_eq!(expr, expected);
1015
1016        // single quote
1017        let expr = planner
1018            .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1019            .unwrap();
1020
1021        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1022
1023        let batch = RecordBatch::try_new(
1024            schema,
1025            vec![
1026                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1027                Arc::new(StringArray::from_iter_values(
1028                    (0..10).map(|v| format!("str-{}", v)),
1029                )),
1030                Arc::new(StructArray::from(vec![
1031                    (
1032                        Arc::new(Field::new("x", DataType::Float32, false)),
1033                        Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1034                            as ArrayRef,
1035                    ),
1036                    (
1037                        Arc::new(Field::new("y", DataType::Float32, false)),
1038                        Arc::new(Float32Array::from_iter_values(
1039                            (0..10).map(|v| (v * 10) as f32),
1040                        )),
1041                    ),
1042                ])),
1043            ],
1044        )
1045        .unwrap();
1046        let predicates = physical_expr.evaluate(&batch).unwrap();
1047        assert_eq!(
1048            predicates.into_array(0).unwrap().as_ref(),
1049            &BooleanArray::from(vec![
1050                false, false, false, false, true, true, false, false, false, false
1051            ])
1052        );
1053    }
1054
1055    #[test]
1056    fn test_nested_col_refs() {
1057        let schema = Arc::new(Schema::new(vec![
1058            Field::new("s0", DataType::Utf8, true),
1059            Field::new(
1060                "st",
1061                DataType::Struct(Fields::from(vec![
1062                    Field::new("s1", DataType::Utf8, true),
1063                    Field::new(
1064                        "st",
1065                        DataType::Struct(Fields::from(vec![Field::new(
1066                            "s2",
1067                            DataType::Utf8,
1068                            true,
1069                        )])),
1070                        true,
1071                    ),
1072                ])),
1073                true,
1074            ),
1075        ]));
1076
1077        let planner = Planner::new(schema);
1078
1079        fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1080            let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1081            assert!(matches!(
1082                expr,
1083                Expr::BinaryExpr(BinaryExpr {
1084                    left: _,
1085                    op: Operator::Eq,
1086                    right: _
1087                })
1088            ));
1089            if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1090                assert_eq!(left.as_ref(), expected);
1091            }
1092        }
1093
1094        let expected = Expr::Column(Column::new_unqualified("s0"));
1095        assert_column_eq(&planner, "s0", &expected);
1096        assert_column_eq(&planner, "`s0`", &expected);
1097
1098        let expected = Expr::ScalarFunction(ScalarFunction {
1099            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1100            args: vec![
1101                Expr::Column(Column::new_unqualified("st")),
1102                Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1103            ],
1104        });
1105        assert_column_eq(&planner, "st.s1", &expected);
1106        assert_column_eq(&planner, "`st`.`s1`", &expected);
1107        assert_column_eq(&planner, "st.`s1`", &expected);
1108
1109        let expected = Expr::ScalarFunction(ScalarFunction {
1110            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1111            args: vec![
1112                Expr::ScalarFunction(ScalarFunction {
1113                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1114                    args: vec![
1115                        Expr::Column(Column::new_unqualified("st")),
1116                        Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1117                    ],
1118                }),
1119                Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1120            ],
1121        });
1122
1123        assert_column_eq(&planner, "st.st.s2", &expected);
1124        assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1125        assert_column_eq(&planner, "st.st.`s2`", &expected);
1126        assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1127    }
1128
1129    #[test]
1130    fn test_nested_list_refs() {
1131        let schema = Arc::new(Schema::new(vec![Field::new(
1132            "l",
1133            DataType::List(Arc::new(Field::new(
1134                "item",
1135                DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1136                true,
1137            ))),
1138            true,
1139        )]));
1140
1141        let planner = Planner::new(schema);
1142
1143        let expected = array_element(col("l"), lit(0_i64));
1144        let expr = planner.parse_expr("l[0]").unwrap();
1145        assert_eq!(expr, expected);
1146
1147        let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1148        let expr = planner.parse_expr("l[0]['f1']").unwrap();
1149        assert_eq!(expr, expected);
1150
1151        // FIXME: This should work, but sqlparser doesn't recognize anything
1152        // after the period for some reason.
1153        // let expr = planner.parse_expr("l[0].f1").unwrap();
1154        // assert_eq!(expr, expected);
1155    }
1156
1157    #[test]
1158    fn test_negative_expressions() {
1159        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1160
1161        let planner = Planner::new(schema.clone());
1162
1163        let expected = col("x")
1164            .gt(lit(-3_i64))
1165            .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1166
1167        let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1168
1169        assert_eq!(expr, expected);
1170
1171        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1172
1173        let batch = RecordBatch::try_new(
1174            schema,
1175            vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1176        )
1177        .unwrap();
1178        let predicates = physical_expr.evaluate(&batch).unwrap();
1179        assert_eq!(
1180            predicates.into_array(0).unwrap().as_ref(),
1181            &BooleanArray::from(vec![
1182                false, false, false, true, true, true, true, false, false, false
1183            ])
1184        );
1185    }
1186
1187    #[test]
1188    fn test_negative_array_expressions() {
1189        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1190
1191        let planner = Planner::new(schema);
1192
1193        let expected = Expr::Literal(
1194            ScalarValue::List(Arc::new(
1195                ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1196                    [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1197                )]),
1198            )),
1199            None,
1200        );
1201
1202        let expr = planner
1203            .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1204            .unwrap();
1205
1206        assert_eq!(expr, expected);
1207    }
1208
1209    #[test]
1210    fn test_sql_like() {
1211        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1212
1213        let planner = Planner::new(schema.clone());
1214
1215        let expected = col("s").like(lit("str-4"));
1216        // single quote
1217        let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1218        assert_eq!(expr, expected);
1219        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1220
1221        let batch = RecordBatch::try_new(
1222            schema,
1223            vec![Arc::new(StringArray::from_iter_values(
1224                (0..10).map(|v| format!("str-{}", v)),
1225            ))],
1226        )
1227        .unwrap();
1228        let predicates = physical_expr.evaluate(&batch).unwrap();
1229        assert_eq!(
1230            predicates.into_array(0).unwrap().as_ref(),
1231            &BooleanArray::from(vec![
1232                false, false, false, false, true, false, false, false, false, false
1233            ])
1234        );
1235    }
1236
1237    #[test]
1238    fn test_not_like() {
1239        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1240
1241        let planner = Planner::new(schema.clone());
1242
1243        let expected = col("s").not_like(lit("str-4"));
1244        // single quote
1245        let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1246        assert_eq!(expr, expected);
1247        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1248
1249        let batch = RecordBatch::try_new(
1250            schema,
1251            vec![Arc::new(StringArray::from_iter_values(
1252                (0..10).map(|v| format!("str-{}", v)),
1253            ))],
1254        )
1255        .unwrap();
1256        let predicates = physical_expr.evaluate(&batch).unwrap();
1257        assert_eq!(
1258            predicates.into_array(0).unwrap().as_ref(),
1259            &BooleanArray::from(vec![
1260                true, true, true, true, false, true, true, true, true, true
1261            ])
1262        );
1263    }
1264
1265    #[test]
1266    fn test_sql_is_in() {
1267        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1268
1269        let planner = Planner::new(schema.clone());
1270
1271        let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1272        // single quote
1273        let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1274        assert_eq!(expr, expected);
1275        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1276
1277        let batch = RecordBatch::try_new(
1278            schema,
1279            vec![Arc::new(StringArray::from_iter_values(
1280                (0..10).map(|v| format!("str-{}", v)),
1281            ))],
1282        )
1283        .unwrap();
1284        let predicates = physical_expr.evaluate(&batch).unwrap();
1285        assert_eq!(
1286            predicates.into_array(0).unwrap().as_ref(),
1287            &BooleanArray::from(vec![
1288                false, false, false, false, true, true, false, false, false, false
1289            ])
1290        );
1291    }
1292
1293    #[test]
1294    fn test_sql_is_null() {
1295        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1296
1297        let planner = Planner::new(schema.clone());
1298
1299        let expected = col("s").is_null();
1300        let expr = planner.parse_filter("s IS NULL").unwrap();
1301        assert_eq!(expr, expected);
1302        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1303
1304        let batch = RecordBatch::try_new(
1305            schema,
1306            vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1307                if v % 3 == 0 {
1308                    Some(format!("str-{}", v))
1309                } else {
1310                    None
1311                }
1312            })))],
1313        )
1314        .unwrap();
1315        let predicates = physical_expr.evaluate(&batch).unwrap();
1316        assert_eq!(
1317            predicates.into_array(0).unwrap().as_ref(),
1318            &BooleanArray::from(vec![
1319                false, true, true, false, true, true, false, true, true, false
1320            ])
1321        );
1322
1323        let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1324        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1325        let predicates = physical_expr.evaluate(&batch).unwrap();
1326        assert_eq!(
1327            predicates.into_array(0).unwrap().as_ref(),
1328            &BooleanArray::from(vec![
1329                true, false, false, true, false, false, true, false, false, true,
1330            ])
1331        );
1332    }
1333
1334    #[test]
1335    fn test_sql_invert() {
1336        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1337
1338        let planner = Planner::new(schema.clone());
1339
1340        let expr = planner.parse_filter("NOT s").unwrap();
1341        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1342
1343        let batch = RecordBatch::try_new(
1344            schema,
1345            vec![Arc::new(BooleanArray::from_iter(
1346                (0..10).map(|v| Some(v % 3 == 0)),
1347            ))],
1348        )
1349        .unwrap();
1350        let predicates = physical_expr.evaluate(&batch).unwrap();
1351        assert_eq!(
1352            predicates.into_array(0).unwrap().as_ref(),
1353            &BooleanArray::from(vec![
1354                false, true, true, false, true, true, false, true, true, false
1355            ])
1356        );
1357    }
1358
1359    #[test]
1360    fn test_sql_cast() {
1361        let cases = &[
1362            (
1363                "x = cast('2021-01-01 00:00:00' as timestamp)",
1364                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1365            ),
1366            (
1367                "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1368                ArrowDataType::Timestamp(TimeUnit::Second, None),
1369            ),
1370            (
1371                "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1372                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1373            ),
1374            (
1375                "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1376                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1377            ),
1378            ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1379            (
1380                "x = cast('1.238' as decimal(9,3))",
1381                ArrowDataType::Decimal128(9, 3),
1382            ),
1383            ("x = cast(1 as float)", ArrowDataType::Float32),
1384            ("x = cast(1 as double)", ArrowDataType::Float64),
1385            ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1386            ("x = cast(1 as smallint)", ArrowDataType::Int16),
1387            ("x = cast(1 as int)", ArrowDataType::Int32),
1388            ("x = cast(1 as integer)", ArrowDataType::Int32),
1389            ("x = cast(1 as bigint)", ArrowDataType::Int64),
1390            ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1391            ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1392            ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1393            ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1394            ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1395            ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1396            ("x = cast(1 as string)", ArrowDataType::Utf8),
1397        ];
1398
1399        for (sql, expected_data_type) in cases {
1400            let schema = Arc::new(Schema::new(vec![Field::new(
1401                "x",
1402                expected_data_type.clone(),
1403                true,
1404            )]));
1405            let planner = Planner::new(schema.clone());
1406            let expr = planner.parse_filter(sql).unwrap();
1407
1408            // Get the thing after 'cast(` but before ' as'.
1409            let expected_value_str = sql
1410                .split("cast(")
1411                .nth(1)
1412                .unwrap()
1413                .split(" as")
1414                .next()
1415                .unwrap();
1416            // Remove any quote marks
1417            let expected_value_str = expected_value_str.trim_matches('\'');
1418
1419            match expr {
1420                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1421                    Expr::Cast(Cast { expr, data_type }) => {
1422                        match expr.as_ref() {
1423                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1424                                assert_eq!(value_str, expected_value_str);
1425                            }
1426                            Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1427                                assert_eq!(*value, 1);
1428                            }
1429                            _ => panic!("Expected cast to be applied to literal"),
1430                        }
1431                        assert_eq!(data_type, expected_data_type);
1432                    }
1433                    _ => panic!("Expected right to be a cast"),
1434                },
1435                _ => panic!("Expected binary expression"),
1436            }
1437        }
1438    }
1439
1440    #[test]
1441    fn test_sql_literals() {
1442        let cases = &[
1443            (
1444                "x = timestamp '2021-01-01 00:00:00'",
1445                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1446            ),
1447            (
1448                "x = timestamp(0) '2021-01-01 00:00:00'",
1449                ArrowDataType::Timestamp(TimeUnit::Second, None),
1450            ),
1451            (
1452                "x = timestamp(9) '2021-01-01 00:00:00.123'",
1453                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1454            ),
1455            ("x = date '2021-01-01'", ArrowDataType::Date32),
1456            ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1457        ];
1458
1459        for (sql, expected_data_type) in cases {
1460            let schema = Arc::new(Schema::new(vec![Field::new(
1461                "x",
1462                expected_data_type.clone(),
1463                true,
1464            )]));
1465            let planner = Planner::new(schema.clone());
1466            let expr = planner.parse_filter(sql).unwrap();
1467
1468            let expected_value_str = sql.split('\'').nth(1).unwrap();
1469
1470            match expr {
1471                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1472                    Expr::Cast(Cast { expr, data_type }) => {
1473                        match expr.as_ref() {
1474                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1475                                assert_eq!(value_str, expected_value_str);
1476                            }
1477                            _ => panic!("Expected cast to be applied to literal"),
1478                        }
1479                        assert_eq!(data_type, expected_data_type);
1480                    }
1481                    _ => panic!("Expected right to be a cast"),
1482                },
1483                _ => panic!("Expected binary expression"),
1484            }
1485        }
1486    }
1487
1488    #[test]
1489    fn test_sql_array_literals() {
1490        let cases = [
1491            (
1492                "x = [1, 2, 3]",
1493                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1494            ),
1495            (
1496                "x = [1, 2, 3]",
1497                ArrowDataType::FixedSizeList(
1498                    Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1499                    3,
1500                ),
1501            ),
1502        ];
1503
1504        for (sql, expected_data_type) in cases {
1505            let schema = Arc::new(Schema::new(vec![Field::new(
1506                "x",
1507                expected_data_type.clone(),
1508                true,
1509            )]));
1510            let planner = Planner::new(schema.clone());
1511            let expr = planner.parse_filter(sql).unwrap();
1512            let expr = planner.optimize_expr(expr).unwrap();
1513
1514            match expr {
1515                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1516                    Expr::Literal(value, _) => {
1517                        assert_eq!(&value.data_type(), &expected_data_type);
1518                    }
1519                    _ => panic!("Expected right to be a literal"),
1520                },
1521                _ => panic!("Expected binary expression"),
1522            }
1523        }
1524    }
1525
1526    #[test]
1527    fn test_sql_between() {
1528        use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1529        use arrow_schema::{DataType, Field, Schema, TimeUnit};
1530        use std::sync::Arc;
1531
1532        let schema = Arc::new(Schema::new(vec![
1533            Field::new("x", DataType::Int32, false),
1534            Field::new("y", DataType::Float64, false),
1535            Field::new(
1536                "ts",
1537                DataType::Timestamp(TimeUnit::Microsecond, None),
1538                false,
1539            ),
1540        ]));
1541
1542        let planner = Planner::new(schema.clone());
1543
1544        // Test integer BETWEEN
1545        let expr = planner
1546            .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1547            .unwrap();
1548        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1549
1550        // Create timestamp array with values representing:
1551        // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
1552        let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
1553        let ts_array = TimestampMicrosecondArray::from_iter_values(
1554            (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
1555        );
1556
1557        let batch = RecordBatch::try_new(
1558            schema,
1559            vec![
1560                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1561                Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1562                Arc::new(ts_array),
1563            ],
1564        )
1565        .unwrap();
1566
1567        let predicates = physical_expr.evaluate(&batch).unwrap();
1568        assert_eq!(
1569            predicates.into_array(0).unwrap().as_ref(),
1570            &BooleanArray::from(vec![
1571                false, false, false, true, true, true, true, true, false, false
1572            ])
1573        );
1574
1575        // Test NOT BETWEEN
1576        let expr = planner
1577            .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1578            .unwrap();
1579        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1580
1581        let predicates = physical_expr.evaluate(&batch).unwrap();
1582        assert_eq!(
1583            predicates.into_array(0).unwrap().as_ref(),
1584            &BooleanArray::from(vec![
1585                true, true, true, false, false, false, false, false, true, true
1586            ])
1587        );
1588
1589        // Test floating point BETWEEN
1590        let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1591        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1592
1593        let predicates = physical_expr.evaluate(&batch).unwrap();
1594        assert_eq!(
1595            predicates.into_array(0).unwrap().as_ref(),
1596            &BooleanArray::from(vec![
1597                false, false, false, true, true, true, true, false, false, false
1598            ])
1599        );
1600
1601        // Test timestamp BETWEEN
1602        let expr = planner
1603            .parse_filter(
1604                "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1605            )
1606            .unwrap();
1607        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1608
1609        let predicates = physical_expr.evaluate(&batch).unwrap();
1610        assert_eq!(
1611            predicates.into_array(0).unwrap().as_ref(),
1612            &BooleanArray::from(vec![
1613                false, false, false, true, true, true, true, true, false, false
1614            ])
1615        );
1616    }
1617
1618    #[test]
1619    fn test_sql_comparison() {
1620        // Create a batch with all data types
1621        let batch: Vec<(&str, ArrayRef)> = vec![
1622            (
1623                "timestamp_s",
1624                Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1625            ),
1626            (
1627                "timestamp_ms",
1628                Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1629            ),
1630            (
1631                "timestamp_us",
1632                Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1633            ),
1634            (
1635                "timestamp_ns",
1636                Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1637            ),
1638        ];
1639        let batch = RecordBatch::try_from_iter(batch).unwrap();
1640
1641        let planner = Planner::new(batch.schema());
1642
1643        // Each expression is meant to select the final 5 rows
1644        let expressions = &[
1645            "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1646            "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1647            "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1648            "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1649        ];
1650
1651        let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1652            std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1653        ));
1654        for expression in expressions {
1655            // convert to physical expression
1656            let logical_expr = planner.parse_filter(expression).unwrap();
1657            let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1658            let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1659
1660            // Evaluate and assert they have correct results
1661            let result = physical_expr.evaluate(&batch).unwrap();
1662            let result = result.into_array(batch.num_rows()).unwrap();
1663            assert_eq!(&expected, &result, "unexpected result for {}", expression);
1664        }
1665    }
1666
1667    #[test]
1668    fn test_columns_in_expr() {
1669        let expr = col("s0").gt(lit("value")).and(
1670            col("st")
1671                .field("st")
1672                .field("s2")
1673                .eq(lit("value"))
1674                .or(col("st")
1675                    .field("s1")
1676                    .in_list(vec![lit("value 1"), lit("value 2")], false)),
1677        );
1678
1679        let columns = Planner::column_names_in_expr(&expr);
1680        assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1681    }
1682
1683    #[test]
1684    fn test_parse_binary_expr() {
1685        let bin_str = "x'616263'";
1686
1687        let schema = Arc::new(Schema::new(vec![Field::new(
1688            "binary",
1689            DataType::Binary,
1690            true,
1691        )]));
1692        let planner = Planner::new(schema);
1693        let expr = planner.parse_expr(bin_str).unwrap();
1694        assert_eq!(
1695            expr,
1696            Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1697        );
1698    }
1699
1700    #[test]
1701    fn test_lance_context_provider_expr_planners() {
1702        let ctx_provider = LanceContextProvider::default();
1703        assert!(!ctx_provider.get_expr_planners().is_empty());
1704    }
1705}