lance_datafusion/
planner.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Exec plan planner
5
6use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::SessionState;
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29    AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30    Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
34use datafusion::sql::sqlparser::ast::{
35    AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36    Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
37    ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
38};
39use datafusion::{
40    common::Column,
41    logical_expr::{col, Between, BinaryExpr, Like, Operator},
42    physical_expr::execution_props::ExecutionProps,
43    physical_plan::PhysicalExpr,
44    prelude::Expr,
45    scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use lance_core::error::LanceOptionExt;
51use snafu::location;
52
53use lance_core::{Error, Result};
54
55#[derive(Debug, Clone)]
56struct CastListF16Udf {
57    signature: Signature,
58}
59
60impl CastListF16Udf {
61    pub fn new() -> Self {
62        Self {
63            signature: Signature::any(1, Volatility::Immutable),
64        }
65    }
66}
67
68impl ScalarUDFImpl for CastListF16Udf {
69    fn as_any(&self) -> &dyn std::any::Any {
70        self
71    }
72
73    fn name(&self) -> &str {
74        "_cast_list_f16"
75    }
76
77    fn signature(&self) -> &Signature {
78        &self.signature
79    }
80
81    fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
82        let input = &arg_types[0];
83        match input {
84            ArrowDataType::FixedSizeList(field, size) => {
85                if field.data_type() != &ArrowDataType::Float32
86                    && field.data_type() != &ArrowDataType::Float16
87                {
88                    return Err(datafusion::error::DataFusionError::Execution(
89                        "cast_list_f16 only supports list of float32 or float16".to_string(),
90                    ));
91                }
92                Ok(ArrowDataType::FixedSizeList(
93                    Arc::new(Field::new(
94                        field.name(),
95                        ArrowDataType::Float16,
96                        field.is_nullable(),
97                    )),
98                    *size,
99                ))
100            }
101            ArrowDataType::List(field) => {
102                if field.data_type() != &ArrowDataType::Float32
103                    && field.data_type() != &ArrowDataType::Float16
104                {
105                    return Err(datafusion::error::DataFusionError::Execution(
106                        "cast_list_f16 only supports list of float32 or float16".to_string(),
107                    ));
108                }
109                Ok(ArrowDataType::List(Arc::new(Field::new(
110                    field.name(),
111                    ArrowDataType::Float16,
112                    field.is_nullable(),
113                ))))
114            }
115            _ => Err(datafusion::error::DataFusionError::Execution(
116                "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
117            )),
118        }
119    }
120
121    fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
122        let ColumnarValue::Array(arr) = &func_args.args[0] else {
123            return Err(datafusion::error::DataFusionError::Execution(
124                "cast_list_f16 only supports array arguments".to_string(),
125            ));
126        };
127
128        let to_type = match arr.data_type() {
129            ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
130                Arc::new(Field::new(
131                    field.name(),
132                    ArrowDataType::Float16,
133                    field.is_nullable(),
134                )),
135                *size,
136            ),
137            ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
138                field.name(),
139                ArrowDataType::Float16,
140                field.is_nullable(),
141            ))),
142            _ => {
143                return Err(datafusion::error::DataFusionError::Execution(
144                    "cast_list_f16 only supports array arguments".to_string(),
145                ));
146            }
147        };
148
149        let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
150        Ok(ColumnarValue::Array(res))
151    }
152}
153
154// Adapter that instructs datafusion how lance expects expressions to be interpreted
155struct LanceContextProvider {
156    options: datafusion::config::ConfigOptions,
157    state: SessionState,
158    expr_planners: Vec<Arc<dyn ExprPlanner>>,
159}
160
161impl Default for LanceContextProvider {
162    fn default() -> Self {
163        let config = SessionConfig::new();
164        let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
165        let mut state_builder = SessionStateBuilder::new()
166            .with_config(config)
167            .with_runtime_env(runtime)
168            .with_default_features();
169
170        // SessionState does not expose expr_planners, so we need to get the default ones from
171        // the builder and store them to return from get_expr_planners
172
173        // unwrap safe because with_default_features sets expr_planners
174        let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
175
176        Self {
177            options: ConfigOptions::default(),
178            state: state_builder.build(),
179            expr_planners,
180        }
181    }
182}
183
184impl ContextProvider for LanceContextProvider {
185    fn get_table_source(
186        &self,
187        name: datafusion::sql::TableReference,
188    ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
189        Err(datafusion::error::DataFusionError::NotImplemented(format!(
190            "Attempt to reference inner table {} not supported",
191            name
192        )))
193    }
194
195    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
196        self.state.aggregate_functions().get(name).cloned()
197    }
198
199    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
200        self.state.window_functions().get(name).cloned()
201    }
202
203    fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
204        match f {
205            // TODO: cast should go thru CAST syntax instead of UDF
206            // Going thru UDF makes it hard for the optimizer to find no-ops
207            "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
208            _ => self.state.scalar_functions().get(f).cloned(),
209        }
210    }
211
212    fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
213        // Variables (things like @@LANGUAGE) not supported
214        None
215    }
216
217    fn options(&self) -> &datafusion::config::ConfigOptions {
218        &self.options
219    }
220
221    fn udf_names(&self) -> Vec<String> {
222        self.state.scalar_functions().keys().cloned().collect()
223    }
224
225    fn udaf_names(&self) -> Vec<String> {
226        self.state.aggregate_functions().keys().cloned().collect()
227    }
228
229    fn udwf_names(&self) -> Vec<String> {
230        self.state.window_functions().keys().cloned().collect()
231    }
232
233    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
234        &self.expr_planners
235    }
236}
237
238pub struct Planner {
239    schema: SchemaRef,
240    context_provider: LanceContextProvider,
241}
242
243impl Planner {
244    pub fn new(schema: SchemaRef) -> Self {
245        Self {
246            schema,
247            context_provider: LanceContextProvider::default(),
248        }
249    }
250
251    fn column(idents: &[Ident]) -> Expr {
252        let mut column = col(&idents[0].value);
253        for ident in &idents[1..] {
254            column = Expr::ScalarFunction(ScalarFunction {
255                args: vec![
256                    column,
257                    Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
258                ],
259                func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
260            });
261        }
262        column
263    }
264
265    fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
266        Ok(match op {
267            BinaryOperator::Plus => Operator::Plus,
268            BinaryOperator::Minus => Operator::Minus,
269            BinaryOperator::Multiply => Operator::Multiply,
270            BinaryOperator::Divide => Operator::Divide,
271            BinaryOperator::Modulo => Operator::Modulo,
272            BinaryOperator::StringConcat => Operator::StringConcat,
273            BinaryOperator::Gt => Operator::Gt,
274            BinaryOperator::Lt => Operator::Lt,
275            BinaryOperator::GtEq => Operator::GtEq,
276            BinaryOperator::LtEq => Operator::LtEq,
277            BinaryOperator::Eq => Operator::Eq,
278            BinaryOperator::NotEq => Operator::NotEq,
279            BinaryOperator::And => Operator::And,
280            BinaryOperator::Or => Operator::Or,
281            _ => {
282                return Err(Error::invalid_input(
283                    format!("Operator {op} is not supported"),
284                    location!(),
285                ));
286            }
287        })
288    }
289
290    fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
291        Ok(Expr::BinaryExpr(BinaryExpr::new(
292            Box::new(self.parse_sql_expr(left)?),
293            self.binary_op(op)?,
294            Box::new(self.parse_sql_expr(right)?),
295        )))
296    }
297
298    fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
299        Ok(match op {
300            UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
301                Expr::Not(Box::new(self.parse_sql_expr(expr)?))
302            }
303
304            UnaryOperator::Minus => {
305                use datafusion::logical_expr::lit;
306                match expr {
307                    SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
308                        Ok(n) => lit(-n),
309                        Err(_) => lit(-n
310                            .parse::<f64>()
311                            .map_err(|_e| {
312                                Error::invalid_input(
313                                    format!("negative operator can be only applied to integer and float operands, got: {n}"),
314                                    location!(),
315                                )
316                            })?),
317                    },
318                    _ => {
319                        Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
320                    }
321                }
322            }
323
324            _ => {
325                return Err(Error::invalid_input(
326                    format!("Unary operator '{:?}' is not supported", op),
327                    location!(),
328                ));
329            }
330        })
331    }
332
333    // See datafusion `sqlToRel::parse_sql_number()`
334    fn number(&self, value: &str, negative: bool) -> Result<Expr> {
335        use datafusion::logical_expr::lit;
336        let value: Cow<str> = if negative {
337            Cow::Owned(format!("-{}", value))
338        } else {
339            Cow::Borrowed(value)
340        };
341        if let Ok(n) = value.parse::<i64>() {
342            Ok(lit(n))
343        } else {
344            value.parse::<f64>().map(lit).map_err(|_| {
345                Error::invalid_input(
346                    format!("'{value}' is not supported number value."),
347                    location!(),
348                )
349            })
350        }
351    }
352
353    fn value(&self, value: &Value) -> Result<Expr> {
354        Ok(match value {
355            Value::Number(v, _) => self.number(v.as_str(), false)?,
356            Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
357            Value::HexStringLiteral(hsl) => {
358                Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
359            }
360            Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
361            Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
362            Value::Null => Expr::Literal(ScalarValue::Null, None),
363            _ => todo!(),
364        })
365    }
366
367    fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
368        match func_args {
369            FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
370            _ => Err(Error::invalid_input(
371                format!("Unsupported function args: {:?}", func_args),
372                location!(),
373            )),
374        }
375    }
376
377    // We now use datafusion to parse functions.  This allows us to use datafusion's
378    // entire collection of functions (previously we had just hard-coded support for two functions).
379    //
380    // Unfortunately, one of those two functions was is_valid and the reason we needed it was because
381    // this is a function that comes from duckdb.  Datafusion does not consider is_valid to be a function
382    // but rather an AST node (Expr::IsNotNull) and so we need to handle this case specially.
383    fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
384        match &func.args {
385            FunctionArguments::List(args) => {
386                if func.name.0.len() != 1 {
387                    return Err(Error::invalid_input(
388                        format!("Function name must have 1 part, got: {:?}", func.name.0),
389                        location!(),
390                    ));
391                }
392                Ok(Expr::IsNotNull(Box::new(
393                    self.parse_function_args(&args.args[0])?,
394                )))
395            }
396            _ => Err(Error::invalid_input(
397                format!("Unsupported function args: {:?}", &func.args),
398                location!(),
399            )),
400        }
401    }
402
403    fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
404        if let SQLExpr::Function(function) = &function {
405            if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
406                if &name.value == "is_valid" {
407                    return self.legacy_parse_function(function);
408                }
409            }
410        }
411        let sql_to_rel = SqlToRel::new_with_options(
412            &self.context_provider,
413            ParserOptions {
414                parse_float_as_decimal: false,
415                enable_ident_normalization: false,
416                support_varchar_with_length: false,
417                enable_options_value_normalization: false,
418                collect_spans: false,
419                map_varchar_to_utf8view: false,
420            },
421        );
422
423        let mut planner_context = PlannerContext::default();
424        let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
425        sql_to_rel
426            .sql_to_expr(function, &schema, &mut planner_context)
427            .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
428    }
429
430    fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
431        const SUPPORTED_TYPES: [&str; 13] = [
432            "int [unsigned]",
433            "tinyint [unsigned]",
434            "smallint [unsigned]",
435            "bigint [unsigned]",
436            "float",
437            "double",
438            "string",
439            "binary",
440            "date",
441            "timestamp(precision)",
442            "datetime(precision)",
443            "decimal(precision,scale)",
444            "boolean",
445        ];
446        match data_type {
447            SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
448            SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
449            SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
450            SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
451            SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
452            SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
453            SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
454            SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
455            SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
456            SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
457            SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
458            SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
459                Ok(ArrowDataType::UInt32)
460            }
461            SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
462            SQLDataType::Date => Ok(ArrowDataType::Date32),
463            SQLDataType::Timestamp(resolution, tz) => {
464                match tz {
465                    TimezoneInfo::None => {}
466                    _ => {
467                        return Err(Error::invalid_input(
468                            "Timezone not supported in timestamp".to_string(),
469                            location!(),
470                        ));
471                    }
472                };
473                let time_unit = match resolution {
474                    // Default to microsecond to match PyArrow
475                    None => TimeUnit::Microsecond,
476                    Some(0) => TimeUnit::Second,
477                    Some(3) => TimeUnit::Millisecond,
478                    Some(6) => TimeUnit::Microsecond,
479                    Some(9) => TimeUnit::Nanosecond,
480                    _ => {
481                        return Err(Error::invalid_input(
482                            format!("Unsupported datetime resolution: {:?}", resolution),
483                            location!(),
484                        ));
485                    }
486                };
487                Ok(ArrowDataType::Timestamp(time_unit, None))
488            }
489            SQLDataType::Datetime(resolution) => {
490                let time_unit = match resolution {
491                    None => TimeUnit::Microsecond,
492                    Some(0) => TimeUnit::Second,
493                    Some(3) => TimeUnit::Millisecond,
494                    Some(6) => TimeUnit::Microsecond,
495                    Some(9) => TimeUnit::Nanosecond,
496                    _ => {
497                        return Err(Error::invalid_input(
498                            format!("Unsupported datetime resolution: {:?}", resolution),
499                            location!(),
500                        ));
501                    }
502                };
503                Ok(ArrowDataType::Timestamp(time_unit, None))
504            }
505            SQLDataType::Decimal(number_info) => match number_info {
506                ExactNumberInfo::PrecisionAndScale(precision, scale) => {
507                    Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
508                }
509                _ => Err(Error::invalid_input(
510                    format!(
511                        "Must provide precision and scale for decimal: {:?}",
512                        number_info
513                    ),
514                    location!(),
515                )),
516            },
517            _ => Err(Error::invalid_input(
518                format!(
519                    "Unsupported data type: {:?}. Supported types: {:?}",
520                    data_type, SUPPORTED_TYPES
521                ),
522                location!(),
523            )),
524        }
525    }
526
527    fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
528        let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
529        for planner in self.context_provider.get_expr_planners() {
530            match planner.plan_field_access(field_access_expr, &df_schema)? {
531                PlannerResult::Planned(expr) => return Ok(expr),
532                PlannerResult::Original(expr) => {
533                    field_access_expr = expr;
534                }
535            }
536        }
537        Err(Error::invalid_input(
538            "Field access could not be planned",
539            location!(),
540        ))
541    }
542
543    fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
544        match expr {
545            SQLExpr::Identifier(id) => {
546                // Users can pass string literals wrapped in `"`.
547                // (Normally SQL only allows single quotes.)
548                if id.quote_style == Some('"') {
549                    Ok(Expr::Literal(
550                        ScalarValue::Utf8(Some(id.value.clone())),
551                        None,
552                    ))
553                // Users can wrap identifiers with ` to reference non-standard
554                // names, such as uppercase or spaces.
555                } else if id.quote_style == Some('`') {
556                    Ok(Expr::Column(Column::from_name(id.value.clone())))
557                } else {
558                    Ok(Self::column(vec![id.clone()].as_slice()))
559                }
560            }
561            SQLExpr::CompoundIdentifier(ids) => Ok(Self::column(ids.as_slice())),
562            SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
563            SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
564            SQLExpr::Value(value) => self.value(&value.value),
565            SQLExpr::Array(SQLArray { elem, .. }) => {
566                let mut values = vec![];
567
568                let array_literal_error = |pos: usize, value: &_| {
569                    Err(Error::invalid_input(
570                        format!(
571                            "Expected a literal value in array, instead got {} at position {}",
572                            value, pos
573                        ),
574                        location!(),
575                    ))
576                };
577
578                for (pos, expr) in elem.iter().enumerate() {
579                    match expr {
580                        SQLExpr::Value(value) => {
581                            if let Expr::Literal(value, _) = self.value(&value.value)? {
582                                values.push(value);
583                            } else {
584                                return array_literal_error(pos, expr);
585                            }
586                        }
587                        SQLExpr::UnaryOp {
588                            op: UnaryOperator::Minus,
589                            expr,
590                        } => {
591                            if let SQLExpr::Value(ValueWithSpan {
592                                value: Value::Number(number, _),
593                                ..
594                            }) = expr.as_ref()
595                            {
596                                if let Expr::Literal(value, _) = self.number(number, true)? {
597                                    values.push(value);
598                                } else {
599                                    return array_literal_error(pos, expr);
600                                }
601                            } else {
602                                return array_literal_error(pos, expr);
603                            }
604                        }
605                        _ => {
606                            return array_literal_error(pos, expr);
607                        }
608                    }
609                }
610
611                let field = if !values.is_empty() {
612                    let data_type = values[0].data_type();
613
614                    for value in &mut values {
615                        if value.data_type() != data_type {
616                            *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
617                                format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
618                                location!()
619                            ))?;
620                        }
621                    }
622                    Field::new("item", data_type, true)
623                } else {
624                    Field::new("item", ArrowDataType::Null, true)
625                };
626
627                let values = values
628                    .into_iter()
629                    .map(|v| v.to_array().map_err(Error::from))
630                    .collect::<Result<Vec<_>>>()?;
631                let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
632                let values = concat(&array_refs)?;
633                let values = ListArray::try_new(
634                    field.into(),
635                    OffsetBuffer::from_lengths([values.len()]),
636                    values,
637                    None,
638                )?;
639
640                Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
641            }
642            // For example, DATE '2020-01-01'
643            SQLExpr::TypedString { data_type, value } => {
644                let value = value.clone().into_string().expect_ok()?;
645                Ok(Expr::Cast(datafusion::logical_expr::Cast {
646                    expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
647                    data_type: self.parse_type(data_type)?,
648                }))
649            }
650            SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
651            SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
652            SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
653            SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
654            SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
655            SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
656            SQLExpr::InList {
657                expr,
658                list,
659                negated,
660            } => {
661                let value_expr = self.parse_sql_expr(expr)?;
662                let list_exprs = list
663                    .iter()
664                    .map(|e| self.parse_sql_expr(e))
665                    .collect::<Result<Vec<_>>>()?;
666                Ok(value_expr.in_list(list_exprs, *negated))
667            }
668            SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
669            SQLExpr::Function(_) => self.parse_function(expr.clone()),
670            SQLExpr::ILike {
671                negated,
672                expr,
673                pattern,
674                escape_char,
675                any: _,
676            } => Ok(Expr::Like(Like::new(
677                *negated,
678                Box::new(self.parse_sql_expr(expr)?),
679                Box::new(self.parse_sql_expr(pattern)?),
680                escape_char.as_ref().and_then(|c| c.chars().next()),
681                true,
682            ))),
683            SQLExpr::Like {
684                negated,
685                expr,
686                pattern,
687                escape_char,
688                any: _,
689            } => Ok(Expr::Like(Like::new(
690                *negated,
691                Box::new(self.parse_sql_expr(expr)?),
692                Box::new(self.parse_sql_expr(pattern)?),
693                escape_char.as_ref().and_then(|c| c.chars().next()),
694                false,
695            ))),
696            SQLExpr::Cast {
697                expr, data_type, ..
698            } => Ok(Expr::Cast(datafusion::logical_expr::Cast {
699                expr: Box::new(self.parse_sql_expr(expr)?),
700                data_type: self.parse_type(data_type)?,
701            })),
702            SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
703                "JSON access is not supported",
704                location!(),
705            )),
706            SQLExpr::CompoundFieldAccess { root, access_chain } => {
707                let mut expr = self.parse_sql_expr(root)?;
708
709                for access in access_chain {
710                    let field_access = match access {
711                        // x.y or x['y']
712                        AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
713                        | AccessExpr::Subscript(Subscript::Index {
714                            index:
715                                SQLExpr::Value(ValueWithSpan {
716                                    value:
717                                        Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
718                                    ..
719                                }),
720                        }) => GetFieldAccess::NamedStructField {
721                            name: ScalarValue::from(s.as_str()),
722                        },
723                        AccessExpr::Subscript(Subscript::Index { index }) => {
724                            let key = Box::new(self.parse_sql_expr(index)?);
725                            GetFieldAccess::ListIndex { key }
726                        }
727                        AccessExpr::Subscript(Subscript::Slice { .. }) => {
728                            return Err(Error::invalid_input(
729                                "Slice subscript is not supported",
730                                location!(),
731                            ));
732                        }
733                        _ => {
734                            // Handle other cases like JSON access
735                            // Note: JSON access is not supported in lance
736                            return Err(Error::invalid_input(
737                                "Only dot notation or index access is supported for field access",
738                                location!(),
739                            ));
740                        }
741                    };
742
743                    let field_access_expr = RawFieldAccessExpr { expr, field_access };
744                    expr = self.plan_field_access(field_access_expr)?;
745                }
746
747                Ok(expr)
748            }
749            SQLExpr::Between {
750                expr,
751                negated,
752                low,
753                high,
754            } => {
755                // Parse the main expression and bounds
756                let expr = self.parse_sql_expr(expr)?;
757                let low = self.parse_sql_expr(low)?;
758                let high = self.parse_sql_expr(high)?;
759
760                let between = Expr::Between(Between::new(
761                    Box::new(expr),
762                    *negated,
763                    Box::new(low),
764                    Box::new(high),
765                ));
766                Ok(between)
767            }
768            _ => Err(Error::invalid_input(
769                format!("Expression '{expr}' is not supported SQL in lance"),
770                location!(),
771            )),
772        }
773    }
774
775    /// Create Logical [Expr] from a SQL filter clause.
776    ///
777    /// Note: the returned expression must be passed through [optimize_expr()]
778    /// before being passed to [create_physical_expr()].
779    pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
780        // Allow sqlparser to parse filter as part of ONE SQL statement.
781        let ast_expr = parse_sql_filter(filter)?;
782        let expr = self.parse_sql_expr(&ast_expr)?;
783        let schema = Schema::try_from(self.schema.as_ref())?;
784        let resolved = resolve_expr(&expr, &schema).map_err(|e| {
785            Error::invalid_input(
786                format!("Error resolving filter expression {filter}: {e}"),
787                location!(),
788            )
789        })?;
790        Ok(coerce_filter_type_to_boolean(resolved))
791    }
792
793    /// Create Logical [Expr] from a SQL expression.
794    ///
795    /// Note: the returned expression must be passed through [optimize_filter()]
796    /// before being passed to [create_physical_expr()].
797    pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
798        let ast_expr = parse_sql_expr(expr)?;
799        let expr = self.parse_sql_expr(&ast_expr)?;
800        let schema = Schema::try_from(self.schema.as_ref())?;
801        let resolved = resolve_expr(&expr, &schema)?;
802        Ok(resolved)
803    }
804
805    /// Try to decode bytes from hex literal string.
806    ///
807    /// Copied from datafusion because this is not public.
808    ///
809    /// TODO: use SqlToRel from Datafusion directly?
810    fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
811        let hex_bytes = s.as_bytes();
812        let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
813
814        let start_idx = hex_bytes.len() % 2;
815        if start_idx > 0 {
816            // The first byte is formed of only one char.
817            decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
818        }
819
820        for i in (start_idx..hex_bytes.len()).step_by(2) {
821            let high = Self::try_decode_hex_char(hex_bytes[i])?;
822            let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
823            decoded_bytes.push((high << 4) | low);
824        }
825
826        Some(decoded_bytes)
827    }
828
829    /// Try to decode a byte from a hex char.
830    ///
831    /// None will be returned if the input char is hex-invalid.
832    const fn try_decode_hex_char(c: u8) -> Option<u8> {
833        match c {
834            b'A'..=b'F' => Some(c - b'A' + 10),
835            b'a'..=b'f' => Some(c - b'a' + 10),
836            b'0'..=b'9' => Some(c - b'0'),
837            _ => None,
838        }
839    }
840
841    /// Optimize the filter expression and coerce data types.
842    pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
843        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
844
845        // DataFusion needs the simplify and coerce passes to be applied before
846        // expressions can be handled by the physical planner.
847        let props = ExecutionProps::default();
848        let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
849        let simplifier =
850            datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
851
852        let expr = simplifier.simplify(expr)?;
853        let expr = simplifier.coerce(expr, &df_schema)?;
854
855        Ok(expr)
856    }
857
858    /// Create the [`PhysicalExpr`] from a logical [`Expr`]
859    pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
860        let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
861
862        Ok(datafusion::physical_expr::create_physical_expr(
863            expr,
864            df_schema.as_ref(),
865            &Default::default(),
866        )?)
867    }
868
869    /// Collect the columns in the expression.
870    ///
871    /// The columns are returned in sorted order.
872    pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
873        let mut visitor = ColumnCapturingVisitor {
874            current_path: VecDeque::new(),
875            columns: BTreeSet::new(),
876        };
877        expr.visit(&mut visitor).unwrap();
878        visitor.columns.into_iter().collect()
879    }
880}
881
882struct ColumnCapturingVisitor {
883    // Current column path. If this is empty, we are not in a column expression.
884    current_path: VecDeque<String>,
885    columns: BTreeSet<String>,
886}
887
888impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
889    type Node = Expr;
890
891    fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
892        match node {
893            Expr::Column(Column { name, .. }) => {
894                let mut path = name.clone();
895                for part in self.current_path.drain(..) {
896                    path.push('.');
897                    path.push_str(&part);
898                }
899                self.columns.insert(path);
900                self.current_path.clear();
901            }
902            Expr::ScalarFunction(udf) => {
903                if udf.name() == GetFieldFunc::default().name() {
904                    if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
905                        self.current_path.push_front(name.to_string())
906                    } else {
907                        self.current_path.clear();
908                    }
909                } else {
910                    self.current_path.clear();
911                }
912            }
913            _ => {
914                self.current_path.clear();
915            }
916        }
917
918        Ok(TreeNodeRecursion::Continue)
919    }
920}
921
922#[cfg(test)]
923mod tests {
924
925    use crate::logical_expr::ExprExt;
926
927    use super::*;
928
929    use arrow::datatypes::Float64Type;
930    use arrow_array::{
931        ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
932        StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
933        TimestampNanosecondArray, TimestampSecondArray,
934    };
935    use arrow_schema::{DataType, Fields, Schema};
936    use datafusion::{
937        logical_expr::{lit, Cast},
938        prelude::{array_element, get_field},
939    };
940    use datafusion_functions::core::expr_ext::FieldAccessor;
941
942    #[test]
943    fn test_parse_filter_simple() {
944        let schema = Arc::new(Schema::new(vec![
945            Field::new("i", DataType::Int32, false),
946            Field::new("s", DataType::Utf8, true),
947            Field::new(
948                "st",
949                DataType::Struct(Fields::from(vec![
950                    Field::new("x", DataType::Float32, false),
951                    Field::new("y", DataType::Float32, false),
952                ])),
953                true,
954            ),
955        ]));
956
957        let planner = Planner::new(schema.clone());
958
959        let expected = col("i")
960            .gt(lit(3_i32))
961            .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
962            .and(
963                col("s")
964                    .eq(lit("str-4"))
965                    .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
966            );
967
968        // double quotes
969        let expr = planner
970            .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
971            .unwrap();
972        assert_eq!(expr, expected);
973
974        // single quote
975        let expr = planner
976            .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
977            .unwrap();
978
979        let physical_expr = planner.create_physical_expr(&expr).unwrap();
980
981        let batch = RecordBatch::try_new(
982            schema,
983            vec![
984                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
985                Arc::new(StringArray::from_iter_values(
986                    (0..10).map(|v| format!("str-{}", v)),
987                )),
988                Arc::new(StructArray::from(vec![
989                    (
990                        Arc::new(Field::new("x", DataType::Float32, false)),
991                        Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
992                            as ArrayRef,
993                    ),
994                    (
995                        Arc::new(Field::new("y", DataType::Float32, false)),
996                        Arc::new(Float32Array::from_iter_values(
997                            (0..10).map(|v| (v * 10) as f32),
998                        )),
999                    ),
1000                ])),
1001            ],
1002        )
1003        .unwrap();
1004        let predicates = physical_expr.evaluate(&batch).unwrap();
1005        assert_eq!(
1006            predicates.into_array(0).unwrap().as_ref(),
1007            &BooleanArray::from(vec![
1008                false, false, false, false, true, true, false, false, false, false
1009            ])
1010        );
1011    }
1012
1013    #[test]
1014    fn test_nested_col_refs() {
1015        let schema = Arc::new(Schema::new(vec![
1016            Field::new("s0", DataType::Utf8, true),
1017            Field::new(
1018                "st",
1019                DataType::Struct(Fields::from(vec![
1020                    Field::new("s1", DataType::Utf8, true),
1021                    Field::new(
1022                        "st",
1023                        DataType::Struct(Fields::from(vec![Field::new(
1024                            "s2",
1025                            DataType::Utf8,
1026                            true,
1027                        )])),
1028                        true,
1029                    ),
1030                ])),
1031                true,
1032            ),
1033        ]));
1034
1035        let planner = Planner::new(schema);
1036
1037        fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1038            let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1039            assert!(matches!(
1040                expr,
1041                Expr::BinaryExpr(BinaryExpr {
1042                    left: _,
1043                    op: Operator::Eq,
1044                    right: _
1045                })
1046            ));
1047            if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1048                assert_eq!(left.as_ref(), expected);
1049            }
1050        }
1051
1052        let expected = Expr::Column(Column::new_unqualified("s0"));
1053        assert_column_eq(&planner, "s0", &expected);
1054        assert_column_eq(&planner, "`s0`", &expected);
1055
1056        let expected = Expr::ScalarFunction(ScalarFunction {
1057            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1058            args: vec![
1059                Expr::Column(Column::new_unqualified("st")),
1060                Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1061            ],
1062        });
1063        assert_column_eq(&planner, "st.s1", &expected);
1064        assert_column_eq(&planner, "`st`.`s1`", &expected);
1065        assert_column_eq(&planner, "st.`s1`", &expected);
1066
1067        let expected = Expr::ScalarFunction(ScalarFunction {
1068            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1069            args: vec![
1070                Expr::ScalarFunction(ScalarFunction {
1071                    func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1072                    args: vec![
1073                        Expr::Column(Column::new_unqualified("st")),
1074                        Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1075                    ],
1076                }),
1077                Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1078            ],
1079        });
1080
1081        assert_column_eq(&planner, "st.st.s2", &expected);
1082        assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1083        assert_column_eq(&planner, "st.st.`s2`", &expected);
1084        assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1085    }
1086
1087    #[test]
1088    fn test_nested_list_refs() {
1089        let schema = Arc::new(Schema::new(vec![Field::new(
1090            "l",
1091            DataType::List(Arc::new(Field::new(
1092                "item",
1093                DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1094                true,
1095            ))),
1096            true,
1097        )]));
1098
1099        let planner = Planner::new(schema);
1100
1101        let expected = array_element(col("l"), lit(0_i64));
1102        let expr = planner.parse_expr("l[0]").unwrap();
1103        assert_eq!(expr, expected);
1104
1105        let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1106        let expr = planner.parse_expr("l[0]['f1']").unwrap();
1107        assert_eq!(expr, expected);
1108
1109        // FIXME: This should work, but sqlparser doesn't recognize anything
1110        // after the period for some reason.
1111        // let expr = planner.parse_expr("l[0].f1").unwrap();
1112        // assert_eq!(expr, expected);
1113    }
1114
1115    #[test]
1116    fn test_negative_expressions() {
1117        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1118
1119        let planner = Planner::new(schema.clone());
1120
1121        let expected = col("x")
1122            .gt(lit(-3_i64))
1123            .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1124
1125        let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1126
1127        assert_eq!(expr, expected);
1128
1129        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1130
1131        let batch = RecordBatch::try_new(
1132            schema,
1133            vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1134        )
1135        .unwrap();
1136        let predicates = physical_expr.evaluate(&batch).unwrap();
1137        assert_eq!(
1138            predicates.into_array(0).unwrap().as_ref(),
1139            &BooleanArray::from(vec![
1140                false, false, false, true, true, true, true, false, false, false
1141            ])
1142        );
1143    }
1144
1145    #[test]
1146    fn test_negative_array_expressions() {
1147        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1148
1149        let planner = Planner::new(schema);
1150
1151        let expected = Expr::Literal(
1152            ScalarValue::List(Arc::new(
1153                ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1154                    [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1155                )]),
1156            )),
1157            None,
1158        );
1159
1160        let expr = planner
1161            .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1162            .unwrap();
1163
1164        assert_eq!(expr, expected);
1165    }
1166
1167    #[test]
1168    fn test_sql_like() {
1169        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1170
1171        let planner = Planner::new(schema.clone());
1172
1173        let expected = col("s").like(lit("str-4"));
1174        // single quote
1175        let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1176        assert_eq!(expr, expected);
1177        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1178
1179        let batch = RecordBatch::try_new(
1180            schema,
1181            vec![Arc::new(StringArray::from_iter_values(
1182                (0..10).map(|v| format!("str-{}", v)),
1183            ))],
1184        )
1185        .unwrap();
1186        let predicates = physical_expr.evaluate(&batch).unwrap();
1187        assert_eq!(
1188            predicates.into_array(0).unwrap().as_ref(),
1189            &BooleanArray::from(vec![
1190                false, false, false, false, true, false, false, false, false, false
1191            ])
1192        );
1193    }
1194
1195    #[test]
1196    fn test_not_like() {
1197        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1198
1199        let planner = Planner::new(schema.clone());
1200
1201        let expected = col("s").not_like(lit("str-4"));
1202        // single quote
1203        let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1204        assert_eq!(expr, expected);
1205        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1206
1207        let batch = RecordBatch::try_new(
1208            schema,
1209            vec![Arc::new(StringArray::from_iter_values(
1210                (0..10).map(|v| format!("str-{}", v)),
1211            ))],
1212        )
1213        .unwrap();
1214        let predicates = physical_expr.evaluate(&batch).unwrap();
1215        assert_eq!(
1216            predicates.into_array(0).unwrap().as_ref(),
1217            &BooleanArray::from(vec![
1218                true, true, true, true, false, true, true, true, true, true
1219            ])
1220        );
1221    }
1222
1223    #[test]
1224    fn test_sql_is_in() {
1225        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1226
1227        let planner = Planner::new(schema.clone());
1228
1229        let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1230        // single quote
1231        let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1232        assert_eq!(expr, expected);
1233        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1234
1235        let batch = RecordBatch::try_new(
1236            schema,
1237            vec![Arc::new(StringArray::from_iter_values(
1238                (0..10).map(|v| format!("str-{}", v)),
1239            ))],
1240        )
1241        .unwrap();
1242        let predicates = physical_expr.evaluate(&batch).unwrap();
1243        assert_eq!(
1244            predicates.into_array(0).unwrap().as_ref(),
1245            &BooleanArray::from(vec![
1246                false, false, false, false, true, true, false, false, false, false
1247            ])
1248        );
1249    }
1250
1251    #[test]
1252    fn test_sql_is_null() {
1253        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1254
1255        let planner = Planner::new(schema.clone());
1256
1257        let expected = col("s").is_null();
1258        let expr = planner.parse_filter("s IS NULL").unwrap();
1259        assert_eq!(expr, expected);
1260        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1261
1262        let batch = RecordBatch::try_new(
1263            schema,
1264            vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1265                if v % 3 == 0 {
1266                    Some(format!("str-{}", v))
1267                } else {
1268                    None
1269                }
1270            })))],
1271        )
1272        .unwrap();
1273        let predicates = physical_expr.evaluate(&batch).unwrap();
1274        assert_eq!(
1275            predicates.into_array(0).unwrap().as_ref(),
1276            &BooleanArray::from(vec![
1277                false, true, true, false, true, true, false, true, true, false
1278            ])
1279        );
1280
1281        let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1282        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1283        let predicates = physical_expr.evaluate(&batch).unwrap();
1284        assert_eq!(
1285            predicates.into_array(0).unwrap().as_ref(),
1286            &BooleanArray::from(vec![
1287                true, false, false, true, false, false, true, false, false, true,
1288            ])
1289        );
1290    }
1291
1292    #[test]
1293    fn test_sql_invert() {
1294        let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1295
1296        let planner = Planner::new(schema.clone());
1297
1298        let expr = planner.parse_filter("NOT s").unwrap();
1299        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1300
1301        let batch = RecordBatch::try_new(
1302            schema,
1303            vec![Arc::new(BooleanArray::from_iter(
1304                (0..10).map(|v| Some(v % 3 == 0)),
1305            ))],
1306        )
1307        .unwrap();
1308        let predicates = physical_expr.evaluate(&batch).unwrap();
1309        assert_eq!(
1310            predicates.into_array(0).unwrap().as_ref(),
1311            &BooleanArray::from(vec![
1312                false, true, true, false, true, true, false, true, true, false
1313            ])
1314        );
1315    }
1316
1317    #[test]
1318    fn test_sql_cast() {
1319        let cases = &[
1320            (
1321                "x = cast('2021-01-01 00:00:00' as timestamp)",
1322                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1323            ),
1324            (
1325                "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1326                ArrowDataType::Timestamp(TimeUnit::Second, None),
1327            ),
1328            (
1329                "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1330                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1331            ),
1332            (
1333                "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1334                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1335            ),
1336            ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1337            (
1338                "x = cast('1.238' as decimal(9,3))",
1339                ArrowDataType::Decimal128(9, 3),
1340            ),
1341            ("x = cast(1 as float)", ArrowDataType::Float32),
1342            ("x = cast(1 as double)", ArrowDataType::Float64),
1343            ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1344            ("x = cast(1 as smallint)", ArrowDataType::Int16),
1345            ("x = cast(1 as int)", ArrowDataType::Int32),
1346            ("x = cast(1 as integer)", ArrowDataType::Int32),
1347            ("x = cast(1 as bigint)", ArrowDataType::Int64),
1348            ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1349            ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1350            ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1351            ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1352            ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1353            ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1354            ("x = cast(1 as string)", ArrowDataType::Utf8),
1355        ];
1356
1357        for (sql, expected_data_type) in cases {
1358            let schema = Arc::new(Schema::new(vec![Field::new(
1359                "x",
1360                expected_data_type.clone(),
1361                true,
1362            )]));
1363            let planner = Planner::new(schema.clone());
1364            let expr = planner.parse_filter(sql).unwrap();
1365
1366            // Get the thing after 'cast(` but before ' as'.
1367            let expected_value_str = sql
1368                .split("cast(")
1369                .nth(1)
1370                .unwrap()
1371                .split(" as")
1372                .next()
1373                .unwrap();
1374            // Remove any quote marks
1375            let expected_value_str = expected_value_str.trim_matches('\'');
1376
1377            match expr {
1378                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1379                    Expr::Cast(Cast { expr, data_type }) => {
1380                        match expr.as_ref() {
1381                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1382                                assert_eq!(value_str, expected_value_str);
1383                            }
1384                            Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1385                                assert_eq!(*value, 1);
1386                            }
1387                            _ => panic!("Expected cast to be applied to literal"),
1388                        }
1389                        assert_eq!(data_type, expected_data_type);
1390                    }
1391                    _ => panic!("Expected right to be a cast"),
1392                },
1393                _ => panic!("Expected binary expression"),
1394            }
1395        }
1396    }
1397
1398    #[test]
1399    fn test_sql_literals() {
1400        let cases = &[
1401            (
1402                "x = timestamp '2021-01-01 00:00:00'",
1403                ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1404            ),
1405            (
1406                "x = timestamp(0) '2021-01-01 00:00:00'",
1407                ArrowDataType::Timestamp(TimeUnit::Second, None),
1408            ),
1409            (
1410                "x = timestamp(9) '2021-01-01 00:00:00.123'",
1411                ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1412            ),
1413            ("x = date '2021-01-01'", ArrowDataType::Date32),
1414            ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1415        ];
1416
1417        for (sql, expected_data_type) in cases {
1418            let schema = Arc::new(Schema::new(vec![Field::new(
1419                "x",
1420                expected_data_type.clone(),
1421                true,
1422            )]));
1423            let planner = Planner::new(schema.clone());
1424            let expr = planner.parse_filter(sql).unwrap();
1425
1426            let expected_value_str = sql.split('\'').nth(1).unwrap();
1427
1428            match expr {
1429                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1430                    Expr::Cast(Cast { expr, data_type }) => {
1431                        match expr.as_ref() {
1432                            Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1433                                assert_eq!(value_str, expected_value_str);
1434                            }
1435                            _ => panic!("Expected cast to be applied to literal"),
1436                        }
1437                        assert_eq!(data_type, expected_data_type);
1438                    }
1439                    _ => panic!("Expected right to be a cast"),
1440                },
1441                _ => panic!("Expected binary expression"),
1442            }
1443        }
1444    }
1445
1446    #[test]
1447    fn test_sql_array_literals() {
1448        let cases = [
1449            (
1450                "x = [1, 2, 3]",
1451                ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1452            ),
1453            (
1454                "x = [1, 2, 3]",
1455                ArrowDataType::FixedSizeList(
1456                    Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1457                    3,
1458                ),
1459            ),
1460        ];
1461
1462        for (sql, expected_data_type) in cases {
1463            let schema = Arc::new(Schema::new(vec![Field::new(
1464                "x",
1465                expected_data_type.clone(),
1466                true,
1467            )]));
1468            let planner = Planner::new(schema.clone());
1469            let expr = planner.parse_filter(sql).unwrap();
1470            let expr = planner.optimize_expr(expr).unwrap();
1471
1472            match expr {
1473                Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1474                    Expr::Literal(value, _) => {
1475                        assert_eq!(&value.data_type(), &expected_data_type);
1476                    }
1477                    _ => panic!("Expected right to be a literal"),
1478                },
1479                _ => panic!("Expected binary expression"),
1480            }
1481        }
1482    }
1483
1484    #[test]
1485    fn test_sql_between() {
1486        use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1487        use arrow_schema::{DataType, Field, Schema, TimeUnit};
1488        use std::sync::Arc;
1489
1490        let schema = Arc::new(Schema::new(vec![
1491            Field::new("x", DataType::Int32, false),
1492            Field::new("y", DataType::Float64, false),
1493            Field::new(
1494                "ts",
1495                DataType::Timestamp(TimeUnit::Microsecond, None),
1496                false,
1497            ),
1498        ]));
1499
1500        let planner = Planner::new(schema.clone());
1501
1502        // Test integer BETWEEN
1503        let expr = planner
1504            .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1505            .unwrap();
1506        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1507
1508        // Create timestamp array with values representing:
1509        // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
1510        let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
1511        let ts_array = TimestampMicrosecondArray::from_iter_values(
1512            (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
1513        );
1514
1515        let batch = RecordBatch::try_new(
1516            schema,
1517            vec![
1518                Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1519                Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1520                Arc::new(ts_array),
1521            ],
1522        )
1523        .unwrap();
1524
1525        let predicates = physical_expr.evaluate(&batch).unwrap();
1526        assert_eq!(
1527            predicates.into_array(0).unwrap().as_ref(),
1528            &BooleanArray::from(vec![
1529                false, false, false, true, true, true, true, true, false, false
1530            ])
1531        );
1532
1533        // Test NOT BETWEEN
1534        let expr = planner
1535            .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1536            .unwrap();
1537        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1538
1539        let predicates = physical_expr.evaluate(&batch).unwrap();
1540        assert_eq!(
1541            predicates.into_array(0).unwrap().as_ref(),
1542            &BooleanArray::from(vec![
1543                true, true, true, false, false, false, false, false, true, true
1544            ])
1545        );
1546
1547        // Test floating point BETWEEN
1548        let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1549        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1550
1551        let predicates = physical_expr.evaluate(&batch).unwrap();
1552        assert_eq!(
1553            predicates.into_array(0).unwrap().as_ref(),
1554            &BooleanArray::from(vec![
1555                false, false, false, true, true, true, true, false, false, false
1556            ])
1557        );
1558
1559        // Test timestamp BETWEEN
1560        let expr = planner
1561            .parse_filter(
1562                "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1563            )
1564            .unwrap();
1565        let physical_expr = planner.create_physical_expr(&expr).unwrap();
1566
1567        let predicates = physical_expr.evaluate(&batch).unwrap();
1568        assert_eq!(
1569            predicates.into_array(0).unwrap().as_ref(),
1570            &BooleanArray::from(vec![
1571                false, false, false, true, true, true, true, true, false, false
1572            ])
1573        );
1574    }
1575
1576    #[test]
1577    fn test_sql_comparison() {
1578        // Create a batch with all data types
1579        let batch: Vec<(&str, ArrayRef)> = vec![
1580            (
1581                "timestamp_s",
1582                Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1583            ),
1584            (
1585                "timestamp_ms",
1586                Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1587            ),
1588            (
1589                "timestamp_us",
1590                Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1591            ),
1592            (
1593                "timestamp_ns",
1594                Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1595            ),
1596        ];
1597        let batch = RecordBatch::try_from_iter(batch).unwrap();
1598
1599        let planner = Planner::new(batch.schema());
1600
1601        // Each expression is meant to select the final 5 rows
1602        let expressions = &[
1603            "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1604            "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1605            "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1606            "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1607        ];
1608
1609        let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1610            std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1611        ));
1612        for expression in expressions {
1613            // convert to physical expression
1614            let logical_expr = planner.parse_filter(expression).unwrap();
1615            let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1616            let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1617
1618            // Evaluate and assert they have correct results
1619            let result = physical_expr.evaluate(&batch).unwrap();
1620            let result = result.into_array(batch.num_rows()).unwrap();
1621            assert_eq!(&expected, &result, "unexpected result for {}", expression);
1622        }
1623    }
1624
1625    #[test]
1626    fn test_columns_in_expr() {
1627        let expr = col("s0").gt(lit("value")).and(
1628            col("st")
1629                .field("st")
1630                .field("s2")
1631                .eq(lit("value"))
1632                .or(col("st")
1633                    .field("s1")
1634                    .in_list(vec![lit("value 1"), lit("value 2")], false)),
1635        );
1636
1637        let columns = Planner::column_names_in_expr(&expr);
1638        assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1639    }
1640
1641    #[test]
1642    fn test_parse_binary_expr() {
1643        let bin_str = "x'616263'";
1644
1645        let schema = Arc::new(Schema::new(vec![Field::new(
1646            "binary",
1647            DataType::Binary,
1648            true,
1649        )]));
1650        let planner = Planner::new(schema);
1651        let expr = planner.parse_expr(bin_str).unwrap();
1652        assert_eq!(
1653            expr,
1654            Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1655        );
1656    }
1657
1658    #[test]
1659    fn test_lance_context_provider_expr_planners() {
1660        let ctx_provider = LanceContextProvider::default();
1661        assert!(!ctx_provider.get_expr_planners().is_empty());
1662    }
1663}