lance_datafusion/
planner.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Exec plan planner
5
6use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::{SessionContext, SessionState};
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29    AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30    Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
34use datafusion::sql::sqlparser::ast::{
35    AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36    Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
37    ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
38};
39use datafusion::{
40    common::Column,
41    logical_expr::{col, Between, BinaryExpr, Like, Operator},
42    physical_expr::execution_props::ExecutionProps,
43    physical_plan::PhysicalExpr,
44    prelude::Expr,
45    scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use lance_core::error::LanceOptionExt;
51use snafu::location;
52
53use lance_core::{Error, Result};
54
55#[derive(Debug, Clone)]
56struct CastListF16Udf {
57    signature: Signature,
58}
59
60impl CastListF16Udf {
61    pub fn new() -> Self {
62        Self {
63            signature: Signature::any(1, Volatility::Immutable),
64        }
65    }
66}
67
68impl ScalarUDFImpl for CastListF16Udf {
69    fn as_any(&self) -> &dyn std::any::Any {
70        self
71    }
72
73    fn name(&self) -> &str {
74        "_cast_list_f16"
75    }
76
77    fn signature(&self) -> &Signature {
78        &self.signature
79    }
80
81    fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
82        let input = &arg_types[0];
83        match input {
84            ArrowDataType::FixedSizeList(field, size) => {
85                if field.data_type() != &ArrowDataType::Float32
86                    && field.data_type() != &ArrowDataType::Float16
87                {
88                    return Err(datafusion::error::DataFusionError::Execution(
89                        "cast_list_f16 only supports list of float32 or float16".to_string(),
90                    ));
91                }
92                Ok(ArrowDataType::FixedSizeList(
93                    Arc::new(Field::new(
94                        field.name(),
95                        ArrowDataType::Float16,
96                        field.is_nullable(),
97                    )),
98                    *size,
99                ))
100            }
101            ArrowDataType::List(field) => {
102                if field.data_type() != &ArrowDataType::Float32
103                    && field.data_type() != &ArrowDataType::Float16
104                {
105                    return Err(datafusion::error::DataFusionError::Execution(
106                        "cast_list_f16 only supports list of float32 or float16".to_string(),
107                    ));
108                }
109                Ok(ArrowDataType::List(Arc::new(Field::new(
110                    field.name(),
111                    ArrowDataType::Float16,
112                    field.is_nullable(),
113                ))))
114            }
115            _ => Err(datafusion::error::DataFusionError::Execution(
116                "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
117            )),
118        }
119    }
120
121    fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
122        let ColumnarValue::Array(arr) = &func_args.args[0] else {
123            return Err(datafusion::error::DataFusionError::Execution(
124                "cast_list_f16 only supports array arguments".to_string(),
125            ));
126        };
127
128        let to_type = match arr.data_type() {
129            ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
130                Arc::new(Field::new(
131                    field.name(),
132                    ArrowDataType::Float16,
133                    field.is_nullable(),
134                )),
135                *size,
136            ),
137            ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
138                field.name(),
139                ArrowDataType::Float16,
140                field.is_nullable(),
141            ))),
142            _ => {
143                return Err(datafusion::error::DataFusionError::Execution(
144                    "cast_list_f16 only supports array arguments".to_string(),
145                ));
146            }
147        };
148
149        let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
150        Ok(ColumnarValue::Array(res))
151    }
152}
153
154// Adapter that instructs datafusion how lance expects expressions to be interpreted
155struct LanceContextProvider {
156    options: datafusion::config::ConfigOptions,
157    state: SessionState,
158    expr_planners: Vec<Arc<dyn ExprPlanner>>,
159}
160
161impl Default for LanceContextProvider {
162    fn default() -> Self {
163        let config = SessionConfig::new();
164        let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
165
166        let ctx = SessionContext::new_with_config_rt(config.clone(), runtime.clone());
167        crate::udf::register_functions(&ctx);
168
169        let state = ctx.state();
170
171        // SessionState does not expose expr_planners, so we need to get them separately
172        let mut state_builder = SessionStateBuilder::new()
173            .with_config(config)
174            .with_runtime_env(runtime)
175            .with_default_features();
176
177        // unwrap safe because with_default_features sets expr_planners
178        let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
179
180        Self {
181            options: ConfigOptions::default(),
182            state,
183            expr_planners,
184        }
185    }
186}
187
188impl ContextProvider for LanceContextProvider {
189    fn get_table_source(
190        &self,
191        name: datafusion::sql::TableReference,
192    ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
193        Err(datafusion::error::DataFusionError::NotImplemented(format!(
194            "Attempt to reference inner table {} not supported",
195            name
196        )))
197    }
198
199    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
200        self.state.aggregate_functions().get(name).cloned()
201    }
202
203    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
204        self.state.window_functions().get(name).cloned()
205    }
206
207    fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
208        match f {
209            // TODO: cast should go thru CAST syntax instead of UDF
210            // Going thru UDF makes it hard for the optimizer to find no-ops
211            "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
212            _ => self.state.scalar_functions().get(f).cloned(),
213        }
214    }
215
216    fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
217        // Variables (things like @@LANGUAGE) not supported
218        None
219    }
220
221    fn options(&self) -> &datafusion::config::ConfigOptions {
222        &self.options
223    }
224
225    fn udf_names(&self) -> Vec<String> {
226        self.state.scalar_functions().keys().cloned().collect()
227    }
228
229    fn udaf_names(&self) -> Vec<String> {
230        self.state.aggregate_functions().keys().cloned().collect()
231    }
232
233    fn udwf_names(&self) -> Vec<String> {
234        self.state.window_functions().keys().cloned().collect()
235    }
236
237    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
238        &self.expr_planners
239    }
240}
241
242pub struct Planner {
243    schema: SchemaRef,
244    context_provider: LanceContextProvider,
245    enable_relations: bool,
246}
247
248impl Planner {
249    pub fn new(schema: SchemaRef) -> Self {
250        Self {
251            schema,
252            context_provider: LanceContextProvider::default(),
253            enable_relations: false,
254        }
255    }
256
257    /// If passed with `true`, then the first identifier in column reference
258    /// is parsed as the relation. For example, `table.field.inner` will be
259    /// read as the nested field `field.inner` (`inner` on struct field `field`)
260    /// on the `table` relation. If `false` (the default), then no relations
261    /// are used and all identifiers are assumed to be a nested column path.
262    pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
263        self.enable_relations = enable_relations;
264        self
265    }
266
267    fn column(&self, idents: &[Ident]) -> Expr {
268        fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
269            for ident in idents {
270                *expr = Expr::ScalarFunction(ScalarFunction {
271                    args: vec![
272                        std::mem::take(expr),
273                        Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
274                    ],
275                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
276                });
277            }
278        }
279
280        if self.enable_relations && idents.len() > 1 {
281            // Create qualified column reference (relation.column)
282            let relation = &idents[0].value;
283            let column_name = &idents[1].value;
284            let column = Expr::Column(Column::new(Some(relation.clone()), column_name.clone()));
285            let mut result = column;
286            handle_remaining_idents(&mut result, &idents[2..]);
287            result
288        } else {
289            // Default behavior - treat as struct field access
290            let mut column = col(&idents[0].value);
291            handle_remaining_idents(&mut column, &idents[1..]);
292            column
293        }
294    }
295
296    fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
297        Ok(match op {
298            BinaryOperator::Plus => Operator::Plus,
299            BinaryOperator::Minus => Operator::Minus,
300            BinaryOperator::Multiply => Operator::Multiply,
301            BinaryOperator::Divide => Operator::Divide,
302            BinaryOperator::Modulo => Operator::Modulo,
303            BinaryOperator::StringConcat => Operator::StringConcat,
304            BinaryOperator::Gt => Operator::Gt,
305            BinaryOperator::Lt => Operator::Lt,
306            BinaryOperator::GtEq => Operator::GtEq,
307            BinaryOperator::LtEq => Operator::LtEq,
308            BinaryOperator::Eq => Operator::Eq,
309            BinaryOperator::NotEq => Operator::NotEq,
310            BinaryOperator::And => Operator::And,
311            BinaryOperator::Or => Operator::Or,
312            _ => {
313                return Err(Error::invalid_input(
314                    format!("Operator {op} is not supported"),
315                    location!(),
316                ));
317            }
318        })
319    }
320
321    fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
322        Ok(Expr::BinaryExpr(BinaryExpr::new(
323            Box::new(self.parse_sql_expr(left)?),
324            self.binary_op(op)?,
325            Box::new(self.parse_sql_expr(right)?),
326        )))
327    }
328
329    fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
330        Ok(match op {
331            UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
332                Expr::Not(Box::new(self.parse_sql_expr(expr)?))
333            }
334
335            UnaryOperator::Minus => {
336                use datafusion::logical_expr::lit;
337                match expr {
338                    SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
339                        Ok(n) => lit(-n),
340                        Err(_) => lit(-n
341                            .parse::<f64>()
342                            .map_err(|_e| {
343                                Error::invalid_input(
344                                    format!("negative operator can be only applied to integer and float operands, got: {n}"),
345                                    location!(),
346                                )
347                            })?),
348                    },
349                    _ => {
350                        Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
351                    }
352                }
353            }
354
355            _ => {
356                return Err(Error::invalid_input(
357                    format!("Unary operator '{:?}' is not supported", op),
358                    location!(),
359                ));
360            }
361        })
362    }
363
364    // See datafusion `sqlToRel::parse_sql_number()`
365    fn number(&self, value: &str, negative: bool) -> Result<Expr> {
366        use datafusion::logical_expr::lit;
367        let value: Cow<str> = if negative {
368            Cow::Owned(format!("-{}", value))
369        } else {
370            Cow::Borrowed(value)
371        };
372        if let Ok(n) = value.parse::<i64>() {
373            Ok(lit(n))
374        } else {
375            value.parse::<f64>().map(lit).map_err(|_| {
376                Error::invalid_input(
377                    format!("'{value}' is not supported number value."),
378                    location!(),
379                )
380            })
381        }
382    }
383
384    fn value(&self, value: &Value) -> Result<Expr> {
385        Ok(match value {
386            Value::Number(v, _) => self.number(v.as_str(), false)?,
387            Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
388            Value::HexStringLiteral(hsl) => {
389                Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
390            }
391            Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
392            Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
393            Value::Null => Expr::Literal(ScalarValue::Null, None),
394            _ => todo!(),
395        })
396    }
397
398    fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
399        match func_args {
400            FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
401            _ => Err(Error::invalid_input(
402                format!("Unsupported function args: {:?}", func_args),
403                location!(),
404            )),
405        }
406    }
407
408    // We now use datafusion to parse functions.  This allows us to use datafusion's
409    // entire collection of functions (previously we had just hard-coded support for two functions).
410    //
411    // Unfortunately, one of those two functions was is_valid and the reason we needed it was because
412    // this is a function that comes from duckdb.  Datafusion does not consider is_valid to be a function
413    // but rather an AST node (Expr::IsNotNull) and so we need to handle this case specially.
414    fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
415        match &func.args {
416            FunctionArguments::List(args) => {
417                if func.name.0.len() != 1 {
418                    return Err(Error::invalid_input(
419                        format!("Function name must have 1 part, got: {:?}", func.name.0),
420                        location!(),
421                    ));
422                }
423                Ok(Expr::IsNotNull(Box::new(
424                    self.parse_function_args(&args.args[0])?,
425                )))
426            }
427            _ => Err(Error::invalid_input(
428                format!("Unsupported function args: {:?}", &func.args),
429                location!(),
430            )),
431        }
432    }
433
434    fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
435        if let SQLExpr::Function(function) = &function {
436            if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
437                if &name.value == "is_valid" {
438                    return self.legacy_parse_function(function);
439                }
440            }
441        }
442        let sql_to_rel = SqlToRel::new_with_options(
443            &self.context_provider,
444            ParserOptions {
445                parse_float_as_decimal: false,
446                enable_ident_normalization: false,
447                support_varchar_with_length: false,
448                enable_options_value_normalization: false,
449                collect_spans: false,
450                map_string_types_to_utf8view: false,
451            },
452        );
453
454        let mut planner_context = PlannerContext::default();
455        let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
456        sql_to_rel
457            .sql_to_expr(function, &schema, &mut planner_context)
458            .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
459    }
460
461    fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
462        const SUPPORTED_TYPES: [&str; 13] = [
463            "int [unsigned]",
464            "tinyint [unsigned]",
465            "smallint [unsigned]",
466            "bigint [unsigned]",
467            "float",
468            "double",
469            "string",
470            "binary",
471            "date",
472            "timestamp(precision)",
473            "datetime(precision)",
474            "decimal(precision,scale)",
475            "boolean",
476        ];
477        match data_type {
478            SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
479            SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
480            SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
481            SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
482            SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
483            SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
484            SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
485            SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
486            SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
487            SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
488            SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
489            SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
490                Ok(ArrowDataType::UInt32)
491            }
492            SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
493            SQLDataType::Date => Ok(ArrowDataType::Date32),
494            SQLDataType::Timestamp(resolution, tz) => {
495                match tz {
496                    TimezoneInfo::None => {}
497                    _ => {
498                        return Err(Error::invalid_input(
499                            "Timezone not supported in timestamp".to_string(),
500                            location!(),
501                        ));
502                    }
503                };
504                let time_unit = match resolution {
505                    // Default to microsecond to match PyArrow
506                    None => TimeUnit::Microsecond,
507                    Some(0) => TimeUnit::Second,
508                    Some(3) => TimeUnit::Millisecond,
509                    Some(6) => TimeUnit::Microsecond,
510                    Some(9) => TimeUnit::Nanosecond,
511                    _ => {
512                        return Err(Error::invalid_input(
513                            format!("Unsupported datetime resolution: {:?}", resolution),
514                            location!(),
515                        ));
516                    }
517                };
518                Ok(ArrowDataType::Timestamp(time_unit, None))
519            }
520            SQLDataType::Datetime(resolution) => {
521                let time_unit = match resolution {
522                    None => TimeUnit::Microsecond,
523                    Some(0) => TimeUnit::Second,
524                    Some(3) => TimeUnit::Millisecond,
525                    Some(6) => TimeUnit::Microsecond,
526                    Some(9) => TimeUnit::Nanosecond,
527                    _ => {
528                        return Err(Error::invalid_input(
529                            format!("Unsupported datetime resolution: {:?}", resolution),
530                            location!(),
531                        ));
532                    }
533                };
534                Ok(ArrowDataType::Timestamp(time_unit, None))
535            }
536            SQLDataType::Decimal(number_info) => match number_info {
537                ExactNumberInfo::PrecisionAndScale(precision, scale) => {
538                    Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
539                }
540                _ => Err(Error::invalid_input(
541                    format!(
542                        "Must provide precision and scale for decimal: {:?}",
543                        number_info
544                    ),
545                    location!(),
546                )),
547            },
548            _ => Err(Error::invalid_input(
549                format!(
550                    "Unsupported data type: {:?}. Supported types: {:?}",
551                    data_type, SUPPORTED_TYPES
552                ),
553                location!(),
554            )),
555        }
556    }
557
558    fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
559        let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
560        for planner in self.context_provider.get_expr_planners() {
561            match planner.plan_field_access(field_access_expr, &df_schema)? {
562                PlannerResult::Planned(expr) => return Ok(expr),
563                PlannerResult::Original(expr) => {
564                    field_access_expr = expr;
565                }
566            }
567        }
568        Err(Error::invalid_input(
569            "Field access could not be planned",
570            location!(),
571        ))
572    }
573
574    fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
575        match expr {
576            SQLExpr::Identifier(id) => {
577                // Users can pass string literals wrapped in `"`.
578                // (Normally SQL only allows single quotes.)
579                if id.quote_style == Some('"') {
580                    Ok(Expr::Literal(
581                        ScalarValue::Utf8(Some(id.value.clone())),
582                        None,
583                    ))
584                // Users can wrap identifiers with ` to reference non-standard
585                // names, such as uppercase or spaces.
586                } else if id.quote_style == Some('`') {
587                    Ok(Expr::Column(Column::from_name(id.value.clone())))
588                } else {
589                    Ok(self.column(vec![id.clone()].as_slice()))
590                }
591            }
592            SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
593            SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
594            SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
595            SQLExpr::Value(value) => self.value(&value.value),
596            SQLExpr::Array(SQLArray { elem, .. }) => {
597                let mut values = vec![];
598
599                let array_literal_error = |pos: usize, value: &_| {
600                    Err(Error::invalid_input(
601                        format!(
602                            "Expected a literal value in array, instead got {} at position {}",
603                            value, pos
604                        ),
605                        location!(),
606                    ))
607                };
608
609                for (pos, expr) in elem.iter().enumerate() {
610                    match expr {
611                        SQLExpr::Value(value) => {
612                            if let Expr::Literal(value, _) = self.value(&value.value)? {
613                                values.push(value);
614                            } else {
615                                return array_literal_error(pos, expr);
616                            }
617                        }
618                        SQLExpr::UnaryOp {
619                            op: UnaryOperator::Minus,
620                            expr,
621                        } => {
622                            if let SQLExpr::Value(ValueWithSpan {
623                                value: Value::Number(number, _),
624                                ..
625                            }) = expr.as_ref()
626                            {
627                                if let Expr::Literal(value, _) = self.number(number, true)? {
628                                    values.push(value);
629                                } else {
630                                    return array_literal_error(pos, expr);
631                                }
632                            } else {
633                                return array_literal_error(pos, expr);
634                            }
635                        }
636                        _ => {
637                            return array_literal_error(pos, expr);
638                        }
639                    }
640                }
641
642                let field = if !values.is_empty() {
643                    let data_type = values[0].data_type();
644
645                    for value in &mut values {
646                        if value.data_type() != data_type {
647                            *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
648                                format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
649                                location!()
650                            ))?;
651                        }
652                    }
653                    Field::new("item", data_type, true)
654                } else {
655                    Field::new("item", ArrowDataType::Null, true)
656                };
657
658                let values = values
659                    .into_iter()
660                    .map(|v| v.to_array().map_err(Error::from))
661                    .collect::<Result<Vec<_>>>()?;
662                let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
663                let values = concat(&array_refs)?;
664                let values = ListArray::try_new(
665                    field.into(),
666                    OffsetBuffer::from_lengths([values.len()]),
667                    values,
668                    None,
669                )?;
670
671                Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
672            }
673            // For example, DATE '2020-01-01'
674            SQLExpr::TypedString { data_type, value } => {
675                let value = value.clone().into_string().expect_ok()?;
676                Ok(Expr::Cast(datafusion::logical_expr::Cast {
677                    expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
678                    data_type: self.parse_type(data_type)?,
679                }))
680            }
681            SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
682            SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
683            SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
684            SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
685            SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
686            SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
687            SQLExpr::InList {
688                expr,
689                list,
690                negated,
691            } => {
692                let value_expr = self.parse_sql_expr(expr)?;
693                let list_exprs = list
694                    .iter()
695                    .map(|e| self.parse_sql_expr(e))
696                    .collect::<Result<Vec<_>>>()?;
697                Ok(value_expr.in_list(list_exprs, *negated))
698            }
699            SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
700            SQLExpr::Function(_) => self.parse_function(expr.clone()),
701            SQLExpr::ILike {
702                negated,
703                expr,
704                pattern,
705                escape_char,
706                any: _,
707            } => Ok(Expr::Like(Like::new(
708                *negated,
709                Box::new(self.parse_sql_expr(expr)?),
710                Box::new(self.parse_sql_expr(pattern)?),
711                escape_char.as_ref().and_then(|c| c.chars().next()),
712                true,
713            ))),
714            SQLExpr::Like {
715                negated,
716                expr,
717                pattern,
718                escape_char,
719                any: _,
720            } => Ok(Expr::Like(Like::new(
721                *negated,
722                Box::new(self.parse_sql_expr(expr)?),
723                Box::new(self.parse_sql_expr(pattern)?),
724                escape_char.as_ref().and_then(|c| c.chars().next()),
725                false,
726            ))),
727            SQLExpr::Cast {
728                expr,
729                data_type,
730                kind,
731                ..
732            } => match kind {
733                datafusion::sql::sqlparser::ast::CastKind::TryCast
734                | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
735                    Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
736                        expr: Box::new(self.parse_sql_expr(expr)?),
737                        data_type: self.parse_type(data_type)?,
738                    }))
739                }
740                _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
741                    expr: Box::new(self.parse_sql_expr(expr)?),
742                    data_type: self.parse_type(data_type)?,
743                })),
744            },
745            SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
746                "JSON access is not supported",
747                location!(),
748            )),
749            SQLExpr::CompoundFieldAccess { root, access_chain } => {
750                let mut expr = self.parse_sql_expr(root)?;
751
752                for access in access_chain {
753                    let field_access = match access {
754                        // x.y or x['y']
755                        AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
756                        | AccessExpr::Subscript(Subscript::Index {
757                            index:
758                                SQLExpr::Value(ValueWithSpan {
759                                    value:
760                                        Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
761                                    ..
762                                }),
763                        }) => GetFieldAccess::NamedStructField {
764                            name: ScalarValue::from(s.as_str()),
765                        },
766                        AccessExpr::Subscript(Subscript::Index { index }) => {
767                            let key = Box::new(self.parse_sql_expr(index)?);
768                            GetFieldAccess::ListIndex { key }
769                        }
770                        AccessExpr::Subscript(Subscript::Slice { .. }) => {
771                            return Err(Error::invalid_input(
772                                "Slice subscript is not supported",
773                                location!(),
774                            ));
775                        }
776                        _ => {
777                            // Handle other cases like JSON access
778                            // Note: JSON access is not supported in lance
779                            return Err(Error::invalid_input(
780                                "Only dot notation or index access is supported for field access",
781                                location!(),
782                            ));
783                        }
784                    };
785
786                    let field_access_expr = RawFieldAccessExpr { expr, field_access };
787                    expr = self.plan_field_access(field_access_expr)?;
788                }
789
790                Ok(expr)
791            }
792            SQLExpr::Between {
793                expr,
794                negated,
795                low,
796                high,
797            } => {
798                // Parse the main expression and bounds
799                let expr = self.parse_sql_expr(expr)?;
800                let low = self.parse_sql_expr(low)?;
801                let high = self.parse_sql_expr(high)?;
802
803                let between = Expr::Between(Between::new(
804                    Box::new(expr),
805                    *negated,
806                    Box::new(low),
807                    Box::new(high),
808                ));
809                Ok(between)
810            }
811            _ => Err(Error::invalid_input(
812                format!("Expression '{expr}' is not supported SQL in lance"),
813                location!(),
814            )),
815        }
816    }
817
818    /// Create Logical [Expr] from a SQL filter clause.
819    ///
820    /// Note: the returned expression must be passed through [optimize_expr()]
821    /// before being passed to [create_physical_expr()].
822    pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
823        // Allow sqlparser to parse filter as part of ONE SQL statement.
824        let ast_expr = parse_sql_filter(filter)?;
825        let expr = self.parse_sql_expr(&ast_expr)?;
826        let schema = Schema::try_from(self.schema.as_ref())?;
827        let resolved = resolve_expr(&expr, &schema).map_err(|e| {
828            Error::invalid_input(
829                format!("Error resolving filter expression {filter}: {e}"),
830                location!(),
831            )
832        })?;
833
834        Ok(coerce_filter_type_to_boolean(resolved))
835    }
836
837    /// Create Logical [Expr] from a SQL expression.
838    ///
839    /// Note: the returned expression must be passed through [optimize_filter()]
840    /// before being passed to [create_physical_expr()].
841    pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
842        let ast_expr = parse_sql_expr(expr)?;
843        let expr = self.parse_sql_expr(&ast_expr)?;
844        let schema = Schema::try_from(self.schema.as_ref())?;
845        let resolved = resolve_expr(&expr, &schema)?;
846        Ok(resolved)
847    }
848
849    /// Try to decode bytes from hex literal string.
850    ///
851    /// Copied from datafusion because this is not public.
852    ///
853    /// TODO: use SqlToRel from Datafusion directly?
854    fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
855        let hex_bytes = s.as_bytes();
856        let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
857
858        let start_idx = hex_bytes.len() % 2;
859        if start_idx > 0 {
860            // The first byte is formed of only one char.
861            decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
862        }
863
864        for i in (start_idx..hex_bytes.len()).step_by(2) {
865            let high = Self::try_decode_hex_char(hex_bytes[i])?;
866            let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
867            decoded_bytes.push((high << 4) | low);
868        }
869
870        Some(decoded_bytes)
871    }
872
873    /// Try to decode a byte from a hex char.
874    ///
875    /// None will be returned if the input char is hex-invalid.
876    const fn try_decode_hex_char(c: u8) -> Option<u8> {
877        match c {
878            b'A'..=b'F' => Some(c - b'A' + 10),
879            b'a'..=b'f' => Some(c - b'a' + 10),
880            b'0'..=b'9' => Some(c - b'0'),
881            _ => None,
882        }
883    }
884
885    /// Optimize the filter expression and coerce data types.
886    pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
887        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
888
889        // DataFusion needs the simplify and coerce passes to be applied before
890        // expressions can be handled by the physical planner.
891        let props = ExecutionProps::default();
892        let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
893        let simplifier =
894            datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
895
896        let expr = simplifier.simplify(expr)?;
897        let expr = simplifier.coerce(expr, &df_schema)?;
898
899        Ok(expr)
900    }
901
902    /// Create the [`PhysicalExpr`] from a logical [`Expr`]
903    pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
904        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
905        Ok(datafusion::physical_expr::create_physical_expr(
906            expr,
907            df_schema.as_ref(),
908            &Default::default(),
909        )?)
910    }
911
912    /// Collect the columns in the expression.
913    ///
914    /// The columns are returned in sorted order.
915    ///
916    /// If the expr refers to nested columns these will be returned
917    /// as dotted paths (x.y.z)
918    pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
919        let mut visitor = ColumnCapturingVisitor {
920            current_path: VecDeque::new(),
921            columns: BTreeSet::new(),
922        };
923        expr.visit(&mut visitor).unwrap();
924        visitor.columns.into_iter().collect()
925    }
926}
927
928struct ColumnCapturingVisitor {
929    // Current column path. If this is empty, we are not in a column expression.
930    current_path: VecDeque<String>,
931    columns: BTreeSet<String>,
932}
933
934impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
935    type Node = Expr;
936
937    fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
938        match node {
939            Expr::Column(Column { name, .. }) => {
940                // Build the field path from the column name and any nested fields
941                // The nested field names from get_field already come as literal strings,
942                // so we just need to concatenate them properly
943                let mut path = name.clone();
944                for part in self.current_path.drain(..) {
945                    path.push('.');
946                    // Check if the part needs quoting (contains dots)
947                    if part.contains('.') || part.contains('`') {
948                        // Quote the field name with backticks and escape any existing backticks
949                        let escaped = part.replace('`', "``");
950                        path.push('`');
951                        path.push_str(&escaped);
952                        path.push('`');
953                    } else {
954                        path.push_str(&part);
955                    }
956                }
957                self.columns.insert(path);
958                self.current_path.clear();
959            }
960            Expr::ScalarFunction(udf) => {
961                if udf.name() == GetFieldFunc::default().name() {
962                    if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
963                        self.current_path.push_front(name.to_string())
964                    } else {
965                        self.current_path.clear();
966                    }
967                } else {
968                    self.current_path.clear();
969                }
970            }
971            _ => {
972                self.current_path.clear();
973            }
974        }
975
976        Ok(TreeNodeRecursion::Continue)
977    }
978}
979
980#[cfg(test)]
981mod tests {
982
983    use crate::logical_expr::ExprExt;
984
985    use super::*;
986
987    use arrow::datatypes::Float64Type;
988    use arrow_array::{
989        ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
990        StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
991        TimestampNanosecondArray, TimestampSecondArray,
992    };
993    use arrow_schema::{DataType, Fields, Schema};
994    use datafusion::{
995        logical_expr::{lit, Cast},
996        prelude::{array_element, get_field},
997    };
998    use datafusion_functions::core::expr_ext::FieldAccessor;
999
1000    #[test]
1001    fn test_parse_filter_simple() {
1002        let schema = Arc::new(Schema::new(vec![
1003            Field::new("i", DataType::Int32, false),
1004            Field::new("s", DataType::Utf8, true),
1005            Field::new(
1006                "st",
1007                DataType::Struct(Fields::from(vec![
1008                    Field::new("x", DataType::Float32, false),
1009                    Field::new("y", DataType::Float32, false),
1010                ])),
1011                true,
1012            ),
1013        ]));
1014
1015        let planner = Planner::new(schema.clone());
1016
1017        let expected = col("i")
1018            .gt(lit(3_i32))
1019            .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1020            .and(
1021                col("s")
1022                    .eq(lit("str-4"))
1023                    .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1024            );
1025
1026        // double quotes
1027        let expr = planner
1028            .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1029            .unwrap();
1030        assert_eq!(expr, expected);
1031
1032        // single quote
1033        let expr = planner
1034            .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1035            .unwrap();
1036
1037        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1038
1039        let batch = RecordBatch::try_new(
1040            schema,
1041            vec![
1042                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1043                Arc::new(StringArray::from_iter_values(
1044                    (0..10).map(|v| format!("str-{}", v)),
1045                )),
1046                Arc::new(StructArray::from(vec![
1047                    (
1048                        Arc::new(Field::new("x", DataType::Float32, false)),
1049                        Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1050                            as ArrayRef,
1051                    ),
1052                    (
1053                        Arc::new(Field::new("y", DataType::Float32, false)),
1054                        Arc::new(Float32Array::from_iter_values(
1055                            (0..10).map(|v| (v * 10) as f32),
1056                        )),
1057                    ),
1058                ])),
1059            ],
1060        )
1061        .unwrap();
1062        let predicates = physical_expr.evaluate(&batch).unwrap();
1063        assert_eq!(
1064            predicates.into_array(0).unwrap().as_ref(),
1065            &BooleanArray::from(vec![
1066                false, false, false, false, true, true, false, false, false, false
1067            ])
1068        );
1069    }
1070
1071    #[test]
1072    fn test_nested_col_refs() {
1073        let schema = Arc::new(Schema::new(vec![
1074            Field::new("s0", DataType::Utf8, true),
1075            Field::new(
1076                "st",
1077                DataType::Struct(Fields::from(vec![
1078                    Field::new("s1", DataType::Utf8, true),
1079                    Field::new(
1080                        "st",
1081                        DataType::Struct(Fields::from(vec![Field::new(
1082                            "s2",
1083                            DataType::Utf8,
1084                            true,
1085                        )])),
1086                        true,
1087                    ),
1088                ])),
1089                true,
1090            ),
1091        ]));
1092
1093        let planner = Planner::new(schema);
1094
1095        fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1096            let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1097            assert!(matches!(
1098                expr,
1099                Expr::BinaryExpr(BinaryExpr {
1100                    left: _,
1101                    op: Operator::Eq,
1102                    right: _
1103                })
1104            ));
1105            if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1106                assert_eq!(left.as_ref(), expected);
1107            }
1108        }
1109
1110        let expected = Expr::Column(Column::new_unqualified("s0"));
1111        assert_column_eq(&planner, "s0", &expected);
1112        assert_column_eq(&planner, "`s0`", &expected);
1113
1114        let expected = Expr::ScalarFunction(ScalarFunction {
1115            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1116            args: vec![
1117                Expr::Column(Column::new_unqualified("st")),
1118                Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1119            ],
1120        });
1121        assert_column_eq(&planner, "st.s1", &expected);
1122        assert_column_eq(&planner, "`st`.`s1`", &expected);
1123        assert_column_eq(&planner, "st.`s1`", &expected);
1124
1125        let expected = Expr::ScalarFunction(ScalarFunction {
1126            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1127            args: vec![
1128                Expr::ScalarFunction(ScalarFunction {
1129                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1130                    args: vec![
1131                        Expr::Column(Column::new_unqualified("st")),
1132                        Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1133                    ],
1134                }),
1135                Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1136            ],
1137        });
1138
1139        assert_column_eq(&planner, "st.st.s2", &expected);
1140        assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1141        assert_column_eq(&planner, "st.st.`s2`", &expected);
1142        assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1143    }
1144
1145    #[test]
1146    fn test_nested_list_refs() {
1147        let schema = Arc::new(Schema::new(vec![Field::new(
1148            "l",
1149            DataType::List(Arc::new(Field::new(
1150                "item",
1151                DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1152                true,
1153            ))),
1154            true,
1155        )]));
1156
1157        let planner = Planner::new(schema);
1158
1159        let expected = array_element(col("l"), lit(0_i64));
1160        let expr = planner.parse_expr("l[0]").unwrap();
1161        assert_eq!(expr, expected);
1162
1163        let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1164        let expr = planner.parse_expr("l[0]['f1']").unwrap();
1165        assert_eq!(expr, expected);
1166
1167        // FIXME: This should work, but sqlparser doesn't recognize anything
1168        // after the period for some reason.
1169        // let expr = planner.parse_expr("l[0].f1").unwrap();
1170        // assert_eq!(expr, expected);
1171    }
1172
1173    #[test]
1174    fn test_negative_expressions() {
1175        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1176
1177        let planner = Planner::new(schema.clone());
1178
1179        let expected = col("x")
1180            .gt(lit(-3_i64))
1181            .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1182
1183        let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1184
1185        assert_eq!(expr, expected);
1186
1187        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1188
1189        let batch = RecordBatch::try_new(
1190            schema,
1191            vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1192        )
1193        .unwrap();
1194        let predicates = physical_expr.evaluate(&batch).unwrap();
1195        assert_eq!(
1196            predicates.into_array(0).unwrap().as_ref(),
1197            &BooleanArray::from(vec![
1198                false, false, false, true, true, true, true, false, false, false
1199            ])
1200        );
1201    }
1202
1203    #[test]
1204    fn test_negative_array_expressions() {
1205        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1206
1207        let planner = Planner::new(schema);
1208
1209        let expected = Expr::Literal(
1210            ScalarValue::List(Arc::new(
1211                ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1212                    [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1213                )]),
1214            )),
1215            None,
1216        );
1217
1218        let expr = planner
1219            .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1220            .unwrap();
1221
1222        assert_eq!(expr, expected);
1223    }
1224
1225    #[test]
1226    fn test_sql_like() {
1227        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1228
1229        let planner = Planner::new(schema.clone());
1230
1231        let expected = col("s").like(lit("str-4"));
1232        // single quote
1233        let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1234        assert_eq!(expr, expected);
1235        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1236
1237        let batch = RecordBatch::try_new(
1238            schema,
1239            vec![Arc::new(StringArray::from_iter_values(
1240                (0..10).map(|v| format!("str-{}", v)),
1241            ))],
1242        )
1243        .unwrap();
1244        let predicates = physical_expr.evaluate(&batch).unwrap();
1245        assert_eq!(
1246            predicates.into_array(0).unwrap().as_ref(),
1247            &BooleanArray::from(vec![
1248                false, false, false, false, true, false, false, false, false, false
1249            ])
1250        );
1251    }
1252
1253    #[test]
1254    fn test_not_like() {
1255        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1256
1257        let planner = Planner::new(schema.clone());
1258
1259        let expected = col("s").not_like(lit("str-4"));
1260        // single quote
1261        let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1262        assert_eq!(expr, expected);
1263        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1264
1265        let batch = RecordBatch::try_new(
1266            schema,
1267            vec![Arc::new(StringArray::from_iter_values(
1268                (0..10).map(|v| format!("str-{}", v)),
1269            ))],
1270        )
1271        .unwrap();
1272        let predicates = physical_expr.evaluate(&batch).unwrap();
1273        assert_eq!(
1274            predicates.into_array(0).unwrap().as_ref(),
1275            &BooleanArray::from(vec![
1276                true, true, true, true, false, true, true, true, true, true
1277            ])
1278        );
1279    }
1280
1281    #[test]
1282    fn test_sql_is_in() {
1283        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1284
1285        let planner = Planner::new(schema.clone());
1286
1287        let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1288        // single quote
1289        let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1290        assert_eq!(expr, expected);
1291        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1292
1293        let batch = RecordBatch::try_new(
1294            schema,
1295            vec![Arc::new(StringArray::from_iter_values(
1296                (0..10).map(|v| format!("str-{}", v)),
1297            ))],
1298        )
1299        .unwrap();
1300        let predicates = physical_expr.evaluate(&batch).unwrap();
1301        assert_eq!(
1302            predicates.into_array(0).unwrap().as_ref(),
1303            &BooleanArray::from(vec![
1304                false, false, false, false, true, true, false, false, false, false
1305            ])
1306        );
1307    }
1308
1309    #[test]
1310    fn test_sql_is_null() {
1311        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1312
1313        let planner = Planner::new(schema.clone());
1314
1315        let expected = col("s").is_null();
1316        let expr = planner.parse_filter("s IS NULL").unwrap();
1317        assert_eq!(expr, expected);
1318        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1319
1320        let batch = RecordBatch::try_new(
1321            schema,
1322            vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1323                if v % 3 == 0 {
1324                    Some(format!("str-{}", v))
1325                } else {
1326                    None
1327                }
1328            })))],
1329        )
1330        .unwrap();
1331        let predicates = physical_expr.evaluate(&batch).unwrap();
1332        assert_eq!(
1333            predicates.into_array(0).unwrap().as_ref(),
1334            &BooleanArray::from(vec![
1335                false, true, true, false, true, true, false, true, true, false
1336            ])
1337        );
1338
1339        let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1340        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1341        let predicates = physical_expr.evaluate(&batch).unwrap();
1342        assert_eq!(
1343            predicates.into_array(0).unwrap().as_ref(),
1344            &BooleanArray::from(vec![
1345                true, false, false, true, false, false, true, false, false, true,
1346            ])
1347        );
1348    }
1349
1350    #[test]
1351    fn test_sql_invert() {
1352        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1353
1354        let planner = Planner::new(schema.clone());
1355
1356        let expr = planner.parse_filter("NOT s").unwrap();
1357        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1358
1359        let batch = RecordBatch::try_new(
1360            schema,
1361            vec![Arc::new(BooleanArray::from_iter(
1362                (0..10).map(|v| Some(v % 3 == 0)),
1363            ))],
1364        )
1365        .unwrap();
1366        let predicates = physical_expr.evaluate(&batch).unwrap();
1367        assert_eq!(
1368            predicates.into_array(0).unwrap().as_ref(),
1369            &BooleanArray::from(vec![
1370                false, true, true, false, true, true, false, true, true, false
1371            ])
1372        );
1373    }
1374
1375    #[test]
1376    fn test_sql_cast() {
1377        let cases = &[
1378            (
1379                "x = cast('2021-01-01 00:00:00' as timestamp)",
1380                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1381            ),
1382            (
1383                "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1384                ArrowDataType::Timestamp(TimeUnit::Second, None),
1385            ),
1386            (
1387                "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1388                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1389            ),
1390            (
1391                "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1392                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1393            ),
1394            ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1395            (
1396                "x = cast('1.238' as decimal(9,3))",
1397                ArrowDataType::Decimal128(9, 3),
1398            ),
1399            ("x = cast(1 as float)", ArrowDataType::Float32),
1400            ("x = cast(1 as double)", ArrowDataType::Float64),
1401            ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1402            ("x = cast(1 as smallint)", ArrowDataType::Int16),
1403            ("x = cast(1 as int)", ArrowDataType::Int32),
1404            ("x = cast(1 as integer)", ArrowDataType::Int32),
1405            ("x = cast(1 as bigint)", ArrowDataType::Int64),
1406            ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1407            ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1408            ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1409            ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1410            ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1411            ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1412            ("x = cast(1 as string)", ArrowDataType::Utf8),
1413        ];
1414
1415        for (sql, expected_data_type) in cases {
1416            let schema = Arc::new(Schema::new(vec![Field::new(
1417                "x",
1418                expected_data_type.clone(),
1419                true,
1420            )]));
1421            let planner = Planner::new(schema.clone());
1422            let expr = planner.parse_filter(sql).unwrap();
1423
1424            // Get the thing after 'cast(` but before ' as'.
1425            let expected_value_str = sql
1426                .split("cast(")
1427                .nth(1)
1428                .unwrap()
1429                .split(" as")
1430                .next()
1431                .unwrap();
1432            // Remove any quote marks
1433            let expected_value_str = expected_value_str.trim_matches('\'');
1434
1435            match expr {
1436                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1437                    Expr::Cast(Cast { expr, data_type }) => {
1438                        match expr.as_ref() {
1439                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1440                                assert_eq!(value_str, expected_value_str);
1441                            }
1442                            Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1443                                assert_eq!(*value, 1);
1444                            }
1445                            _ => panic!("Expected cast to be applied to literal"),
1446                        }
1447                        assert_eq!(data_type, expected_data_type);
1448                    }
1449                    _ => panic!("Expected right to be a cast"),
1450                },
1451                _ => panic!("Expected binary expression"),
1452            }
1453        }
1454    }
1455
1456    #[test]
1457    fn test_sql_literals() {
1458        let cases = &[
1459            (
1460                "x = timestamp '2021-01-01 00:00:00'",
1461                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1462            ),
1463            (
1464                "x = timestamp(0) '2021-01-01 00:00:00'",
1465                ArrowDataType::Timestamp(TimeUnit::Second, None),
1466            ),
1467            (
1468                "x = timestamp(9) '2021-01-01 00:00:00.123'",
1469                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1470            ),
1471            ("x = date '2021-01-01'", ArrowDataType::Date32),
1472            ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1473        ];
1474
1475        for (sql, expected_data_type) in cases {
1476            let schema = Arc::new(Schema::new(vec![Field::new(
1477                "x",
1478                expected_data_type.clone(),
1479                true,
1480            )]));
1481            let planner = Planner::new(schema.clone());
1482            let expr = planner.parse_filter(sql).unwrap();
1483
1484            let expected_value_str = sql.split('\'').nth(1).unwrap();
1485
1486            match expr {
1487                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1488                    Expr::Cast(Cast { expr, data_type }) => {
1489                        match expr.as_ref() {
1490                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1491                                assert_eq!(value_str, expected_value_str);
1492                            }
1493                            _ => panic!("Expected cast to be applied to literal"),
1494                        }
1495                        assert_eq!(data_type, expected_data_type);
1496                    }
1497                    _ => panic!("Expected right to be a cast"),
1498                },
1499                _ => panic!("Expected binary expression"),
1500            }
1501        }
1502    }
1503
1504    #[test]
1505    fn test_sql_array_literals() {
1506        let cases = [
1507            (
1508                "x = [1, 2, 3]",
1509                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1510            ),
1511            (
1512                "x = [1, 2, 3]",
1513                ArrowDataType::FixedSizeList(
1514                    Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1515                    3,
1516                ),
1517            ),
1518        ];
1519
1520        for (sql, expected_data_type) in cases {
1521            let schema = Arc::new(Schema::new(vec![Field::new(
1522                "x",
1523                expected_data_type.clone(),
1524                true,
1525            )]));
1526            let planner = Planner::new(schema.clone());
1527            let expr = planner.parse_filter(sql).unwrap();
1528            let expr = planner.optimize_expr(expr).unwrap();
1529
1530            match expr {
1531                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1532                    Expr::Literal(value, _) => {
1533                        assert_eq!(&value.data_type(), &expected_data_type);
1534                    }
1535                    _ => panic!("Expected right to be a literal"),
1536                },
1537                _ => panic!("Expected binary expression"),
1538            }
1539        }
1540    }
1541
1542    #[test]
1543    fn test_sql_between() {
1544        use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1545        use arrow_schema::{DataType, Field, Schema, TimeUnit};
1546        use std::sync::Arc;
1547
1548        let schema = Arc::new(Schema::new(vec![
1549            Field::new("x", DataType::Int32, false),
1550            Field::new("y", DataType::Float64, false),
1551            Field::new(
1552                "ts",
1553                DataType::Timestamp(TimeUnit::Microsecond, None),
1554                false,
1555            ),
1556        ]));
1557
1558        let planner = Planner::new(schema.clone());
1559
1560        // Test integer BETWEEN
1561        let expr = planner
1562            .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1563            .unwrap();
1564        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1565
1566        // Create timestamp array with values representing:
1567        // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
1568        let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
1569        let ts_array = TimestampMicrosecondArray::from_iter_values(
1570            (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
1571        );
1572
1573        let batch = RecordBatch::try_new(
1574            schema,
1575            vec![
1576                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1577                Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1578                Arc::new(ts_array),
1579            ],
1580        )
1581        .unwrap();
1582
1583        let predicates = physical_expr.evaluate(&batch).unwrap();
1584        assert_eq!(
1585            predicates.into_array(0).unwrap().as_ref(),
1586            &BooleanArray::from(vec![
1587                false, false, false, true, true, true, true, true, false, false
1588            ])
1589        );
1590
1591        // Test NOT BETWEEN
1592        let expr = planner
1593            .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1594            .unwrap();
1595        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1596
1597        let predicates = physical_expr.evaluate(&batch).unwrap();
1598        assert_eq!(
1599            predicates.into_array(0).unwrap().as_ref(),
1600            &BooleanArray::from(vec![
1601                true, true, true, false, false, false, false, false, true, true
1602            ])
1603        );
1604
1605        // Test floating point BETWEEN
1606        let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1607        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1608
1609        let predicates = physical_expr.evaluate(&batch).unwrap();
1610        assert_eq!(
1611            predicates.into_array(0).unwrap().as_ref(),
1612            &BooleanArray::from(vec![
1613                false, false, false, true, true, true, true, false, false, false
1614            ])
1615        );
1616
1617        // Test timestamp BETWEEN
1618        let expr = planner
1619            .parse_filter(
1620                "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1621            )
1622            .unwrap();
1623        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1624
1625        let predicates = physical_expr.evaluate(&batch).unwrap();
1626        assert_eq!(
1627            predicates.into_array(0).unwrap().as_ref(),
1628            &BooleanArray::from(vec![
1629                false, false, false, true, true, true, true, true, false, false
1630            ])
1631        );
1632    }
1633
1634    #[test]
1635    fn test_sql_comparison() {
1636        // Create a batch with all data types
1637        let batch: Vec<(&str, ArrayRef)> = vec![
1638            (
1639                "timestamp_s",
1640                Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1641            ),
1642            (
1643                "timestamp_ms",
1644                Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1645            ),
1646            (
1647                "timestamp_us",
1648                Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1649            ),
1650            (
1651                "timestamp_ns",
1652                Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1653            ),
1654        ];
1655        let batch = RecordBatch::try_from_iter(batch).unwrap();
1656
1657        let planner = Planner::new(batch.schema());
1658
1659        // Each expression is meant to select the final 5 rows
1660        let expressions = &[
1661            "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1662            "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1663            "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1664            "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1665        ];
1666
1667        let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1668            std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1669        ));
1670        for expression in expressions {
1671            // convert to physical expression
1672            let logical_expr = planner.parse_filter(expression).unwrap();
1673            let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1674            let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1675
1676            // Evaluate and assert they have correct results
1677            let result = physical_expr.evaluate(&batch).unwrap();
1678            let result = result.into_array(batch.num_rows()).unwrap();
1679            assert_eq!(&expected, &result, "unexpected result for {}", expression);
1680        }
1681    }
1682
1683    #[test]
1684    fn test_columns_in_expr() {
1685        let expr = col("s0").gt(lit("value")).and(
1686            col("st")
1687                .field("st")
1688                .field("s2")
1689                .eq(lit("value"))
1690                .or(col("st")
1691                    .field("s1")
1692                    .in_list(vec![lit("value 1"), lit("value 2")], false)),
1693        );
1694
1695        let columns = Planner::column_names_in_expr(&expr);
1696        assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1697    }
1698
1699    #[test]
1700    fn test_parse_binary_expr() {
1701        let bin_str = "x'616263'";
1702
1703        let schema = Arc::new(Schema::new(vec![Field::new(
1704            "binary",
1705            DataType::Binary,
1706            true,
1707        )]));
1708        let planner = Planner::new(schema);
1709        let expr = planner.parse_expr(bin_str).unwrap();
1710        assert_eq!(
1711            expr,
1712            Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1713        );
1714    }
1715
1716    #[test]
1717    fn test_lance_context_provider_expr_planners() {
1718        let ctx_provider = LanceContextProvider::default();
1719        assert!(!ctx_provider.get_expr_planners().is_empty());
1720    }
1721}