1use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::exec::{get_session_context, LanceExecutionOptions};
11use crate::expr::safe_coerce_scalar;
12use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
13use crate::sql::{parse_sql_expr, parse_sql_filter};
14use arrow::compute::CastOptions;
15use arrow_array::ListArray;
16use arrow_buffer::OffsetBuffer;
17use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
18use arrow_select::concat::concat;
19use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
20use datafusion::common::DFSchema;
21use datafusion::config::ConfigOptions;
22use datafusion::error::Result as DFResult;
23use datafusion::execution::context::SessionState;
24use datafusion::logical_expr::expr::ScalarFunction;
25use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
26use datafusion::logical_expr::{
27 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
28 Signature, Volatility, WindowUDF,
29};
30use datafusion::optimizer::simplify_expressions::SimplifyContext;
31use datafusion::sql::planner::{
32 ContextProvider, NullOrdering, ParserOptions, PlannerContext, SqlToRel,
33};
34use datafusion::sql::sqlparser::ast::{
35 AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36 Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
37 ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
38};
39use datafusion::{
40 common::Column,
41 logical_expr::{Between, BinaryExpr, Like, Operator},
42 physical_expr::execution_props::ExecutionProps,
43 physical_plan::PhysicalExpr,
44 prelude::Expr,
45 scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use lance_core::error::LanceOptionExt;
51use snafu::location;
52
53use chrono::Utc;
54use lance_core::{Error, Result};
55
56#[derive(Debug, Clone, Eq, PartialEq, Hash)]
57struct CastListF16Udf {
58 signature: Signature,
59}
60
61impl CastListF16Udf {
62 pub fn new() -> Self {
63 Self {
64 signature: Signature::any(1, Volatility::Immutable),
65 }
66 }
67}
68
69impl ScalarUDFImpl for CastListF16Udf {
70 fn as_any(&self) -> &dyn std::any::Any {
71 self
72 }
73
74 fn name(&self) -> &str {
75 "_cast_list_f16"
76 }
77
78 fn signature(&self) -> &Signature {
79 &self.signature
80 }
81
82 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
83 let input = &arg_types[0];
84 match input {
85 ArrowDataType::FixedSizeList(field, size) => {
86 if field.data_type() != &ArrowDataType::Float32
87 && field.data_type() != &ArrowDataType::Float16
88 {
89 return Err(datafusion::error::DataFusionError::Execution(
90 "cast_list_f16 only supports list of float32 or float16".to_string(),
91 ));
92 }
93 Ok(ArrowDataType::FixedSizeList(
94 Arc::new(Field::new(
95 field.name(),
96 ArrowDataType::Float16,
97 field.is_nullable(),
98 )),
99 *size,
100 ))
101 }
102 ArrowDataType::List(field) => {
103 if field.data_type() != &ArrowDataType::Float32
104 && field.data_type() != &ArrowDataType::Float16
105 {
106 return Err(datafusion::error::DataFusionError::Execution(
107 "cast_list_f16 only supports list of float32 or float16".to_string(),
108 ));
109 }
110 Ok(ArrowDataType::List(Arc::new(Field::new(
111 field.name(),
112 ArrowDataType::Float16,
113 field.is_nullable(),
114 ))))
115 }
116 _ => Err(datafusion::error::DataFusionError::Execution(
117 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
118 )),
119 }
120 }
121
122 fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
123 let ColumnarValue::Array(arr) = &func_args.args[0] else {
124 return Err(datafusion::error::DataFusionError::Execution(
125 "cast_list_f16 only supports array arguments".to_string(),
126 ));
127 };
128
129 let to_type = match arr.data_type() {
130 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
131 Arc::new(Field::new(
132 field.name(),
133 ArrowDataType::Float16,
134 field.is_nullable(),
135 )),
136 *size,
137 ),
138 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
139 field.name(),
140 ArrowDataType::Float16,
141 field.is_nullable(),
142 ))),
143 _ => {
144 return Err(datafusion::error::DataFusionError::Execution(
145 "cast_list_f16 only supports array arguments".to_string(),
146 ));
147 }
148 };
149
150 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
151 Ok(ColumnarValue::Array(res))
152 }
153}
154
155struct LanceContextProvider {
157 options: datafusion::config::ConfigOptions,
158 state: SessionState,
159 expr_planners: Vec<Arc<dyn ExprPlanner>>,
160}
161
162impl Default for LanceContextProvider {
163 fn default() -> Self {
164 let ctx = get_session_context(&LanceExecutionOptions::default());
165 let state = ctx.state();
166 let expr_planners = state.expr_planners().to_vec();
167
168 Self {
169 options: ConfigOptions::default(),
170 state,
171 expr_planners,
172 }
173 }
174}
175
176impl ContextProvider for LanceContextProvider {
177 fn get_table_source(
178 &self,
179 name: datafusion::sql::TableReference,
180 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
181 Err(datafusion::error::DataFusionError::NotImplemented(format!(
182 "Attempt to reference inner table {} not supported",
183 name
184 )))
185 }
186
187 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
188 self.state.aggregate_functions().get(name).cloned()
189 }
190
191 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
192 self.state.window_functions().get(name).cloned()
193 }
194
195 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
196 match f {
197 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
200 _ => self.state.scalar_functions().get(f).cloned(),
201 }
202 }
203
204 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
205 None
207 }
208
209 fn options(&self) -> &datafusion::config::ConfigOptions {
210 &self.options
211 }
212
213 fn udf_names(&self) -> Vec<String> {
214 self.state.scalar_functions().keys().cloned().collect()
215 }
216
217 fn udaf_names(&self) -> Vec<String> {
218 self.state.aggregate_functions().keys().cloned().collect()
219 }
220
221 fn udwf_names(&self) -> Vec<String> {
222 self.state.window_functions().keys().cloned().collect()
223 }
224
225 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
226 &self.expr_planners
227 }
228}
229
230pub struct Planner {
231 schema: SchemaRef,
232 context_provider: LanceContextProvider,
233 enable_relations: bool,
234}
235
236impl Planner {
237 pub fn new(schema: SchemaRef) -> Self {
238 Self {
239 schema,
240 context_provider: LanceContextProvider::default(),
241 enable_relations: false,
242 }
243 }
244
245 pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
251 self.enable_relations = enable_relations;
252 self
253 }
254
255 fn resolve_column_name(&self, name: &str) -> String {
258 if self.schema.field_with_name(name).is_ok() {
260 return name.to_string();
261 }
262 for field in self.schema.fields() {
264 if field.name().eq_ignore_ascii_case(name) {
265 return field.name().clone();
266 }
267 }
268 name.to_string()
270 }
271
272 fn column(&self, idents: &[Ident]) -> Expr {
273 fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
274 for ident in idents {
275 *expr = Expr::ScalarFunction(ScalarFunction {
276 args: vec![
277 std::mem::take(expr),
278 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
279 ],
280 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
281 });
282 }
283 }
284
285 if self.enable_relations && idents.len() > 1 {
286 let relation = &idents[0].value;
288 let column_name = self.resolve_column_name(&idents[1].value);
289 let column = Expr::Column(Column::new(Some(relation.clone()), column_name));
290 let mut result = column;
291 handle_remaining_idents(&mut result, &idents[2..]);
292 result
293 } else {
294 let resolved_name = self.resolve_column_name(&idents[0].value);
297 let mut column = Expr::Column(Column::from_name(resolved_name));
298 handle_remaining_idents(&mut column, &idents[1..]);
299 column
300 }
301 }
302
303 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
304 Ok(match op {
305 BinaryOperator::Plus => Operator::Plus,
306 BinaryOperator::Minus => Operator::Minus,
307 BinaryOperator::Multiply => Operator::Multiply,
308 BinaryOperator::Divide => Operator::Divide,
309 BinaryOperator::Modulo => Operator::Modulo,
310 BinaryOperator::StringConcat => Operator::StringConcat,
311 BinaryOperator::Gt => Operator::Gt,
312 BinaryOperator::Lt => Operator::Lt,
313 BinaryOperator::GtEq => Operator::GtEq,
314 BinaryOperator::LtEq => Operator::LtEq,
315 BinaryOperator::Eq => Operator::Eq,
316 BinaryOperator::NotEq => Operator::NotEq,
317 BinaryOperator::And => Operator::And,
318 BinaryOperator::Or => Operator::Or,
319 _ => {
320 return Err(Error::invalid_input(
321 format!("Operator {op} is not supported"),
322 location!(),
323 ));
324 }
325 })
326 }
327
328 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
329 Ok(Expr::BinaryExpr(BinaryExpr::new(
330 Box::new(self.parse_sql_expr(left)?),
331 self.binary_op(op)?,
332 Box::new(self.parse_sql_expr(right)?),
333 )))
334 }
335
336 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
337 Ok(match op {
338 UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
339 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
340 }
341
342 UnaryOperator::Minus => {
343 use datafusion::logical_expr::lit;
344 match expr {
345 SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
346 Ok(n) => lit(-n),
347 Err(_) => lit(-n
348 .parse::<f64>()
349 .map_err(|_e| {
350 Error::invalid_input(
351 format!("negative operator can be only applied to integer and float operands, got: {n}"),
352 location!(),
353 )
354 })?),
355 },
356 _ => {
357 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
358 }
359 }
360 }
361
362 _ => {
363 return Err(Error::invalid_input(
364 format!("Unary operator '{:?}' is not supported", op),
365 location!(),
366 ));
367 }
368 })
369 }
370
371 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
373 use datafusion::logical_expr::lit;
374 let value: Cow<str> = if negative {
375 Cow::Owned(format!("-{}", value))
376 } else {
377 Cow::Borrowed(value)
378 };
379 if let Ok(n) = value.parse::<i64>() {
380 Ok(lit(n))
381 } else {
382 value.parse::<f64>().map(lit).map_err(|_| {
383 Error::invalid_input(
384 format!("'{value}' is not supported number value."),
385 location!(),
386 )
387 })
388 }
389 }
390
391 fn value(&self, value: &Value) -> Result<Expr> {
392 Ok(match value {
393 Value::Number(v, _) => self.number(v.as_str(), false)?,
394 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
395 Value::HexStringLiteral(hsl) => {
396 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
397 }
398 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
399 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
400 Value::Null => Expr::Literal(ScalarValue::Null, None),
401 _ => todo!(),
402 })
403 }
404
405 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
406 match func_args {
407 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
408 _ => Err(Error::invalid_input(
409 format!("Unsupported function args: {:?}", func_args),
410 location!(),
411 )),
412 }
413 }
414
415 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
422 match &func.args {
423 FunctionArguments::List(args) => {
424 if func.name.0.len() != 1 {
425 return Err(Error::invalid_input(
426 format!("Function name must have 1 part, got: {:?}", func.name.0),
427 location!(),
428 ));
429 }
430 Ok(Expr::IsNotNull(Box::new(
431 self.parse_function_args(&args.args[0])?,
432 )))
433 }
434 _ => Err(Error::invalid_input(
435 format!("Unsupported function args: {:?}", &func.args),
436 location!(),
437 )),
438 }
439 }
440
441 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
442 if let SQLExpr::Function(function) = &function {
443 if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
444 if &name.value == "is_valid" {
445 return self.legacy_parse_function(function);
446 }
447 }
448 }
449 let sql_to_rel = SqlToRel::new_with_options(
450 &self.context_provider,
451 ParserOptions {
452 parse_float_as_decimal: false,
453 enable_ident_normalization: false,
454 support_varchar_with_length: false,
455 enable_options_value_normalization: false,
456 collect_spans: false,
457 map_string_types_to_utf8view: false,
458 default_null_ordering: NullOrdering::NullsMax,
459 },
460 );
461
462 let mut planner_context = PlannerContext::default();
463 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
464 sql_to_rel
465 .sql_to_expr(function, &schema, &mut planner_context)
466 .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
467 }
468
469 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
470 const SUPPORTED_TYPES: [&str; 13] = [
471 "int [unsigned]",
472 "tinyint [unsigned]",
473 "smallint [unsigned]",
474 "bigint [unsigned]",
475 "float",
476 "double",
477 "string",
478 "binary",
479 "date",
480 "timestamp(precision)",
481 "datetime(precision)",
482 "decimal(precision,scale)",
483 "boolean",
484 ];
485 match data_type {
486 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
487 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
488 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
489 SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
490 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
491 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
492 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
493 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
494 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
495 SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
496 SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
497 SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
498 Ok(ArrowDataType::UInt32)
499 }
500 SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
501 SQLDataType::Date => Ok(ArrowDataType::Date32),
502 SQLDataType::Timestamp(resolution, tz) => {
503 match tz {
504 TimezoneInfo::None => {}
505 _ => {
506 return Err(Error::invalid_input(
507 "Timezone not supported in timestamp".to_string(),
508 location!(),
509 ));
510 }
511 };
512 let time_unit = match resolution {
513 None => TimeUnit::Microsecond,
515 Some(0) => TimeUnit::Second,
516 Some(3) => TimeUnit::Millisecond,
517 Some(6) => TimeUnit::Microsecond,
518 Some(9) => TimeUnit::Nanosecond,
519 _ => {
520 return Err(Error::invalid_input(
521 format!("Unsupported datetime resolution: {:?}", resolution),
522 location!(),
523 ));
524 }
525 };
526 Ok(ArrowDataType::Timestamp(time_unit, None))
527 }
528 SQLDataType::Datetime(resolution) => {
529 let time_unit = match resolution {
530 None => TimeUnit::Microsecond,
531 Some(0) => TimeUnit::Second,
532 Some(3) => TimeUnit::Millisecond,
533 Some(6) => TimeUnit::Microsecond,
534 Some(9) => TimeUnit::Nanosecond,
535 _ => {
536 return Err(Error::invalid_input(
537 format!("Unsupported datetime resolution: {:?}", resolution),
538 location!(),
539 ));
540 }
541 };
542 Ok(ArrowDataType::Timestamp(time_unit, None))
543 }
544 SQLDataType::Decimal(number_info) => match number_info {
545 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
546 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
547 }
548 _ => Err(Error::invalid_input(
549 format!(
550 "Must provide precision and scale for decimal: {:?}",
551 number_info
552 ),
553 location!(),
554 )),
555 },
556 _ => Err(Error::invalid_input(
557 format!(
558 "Unsupported data type: {:?}. Supported types: {:?}",
559 data_type, SUPPORTED_TYPES
560 ),
561 location!(),
562 )),
563 }
564 }
565
566 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
567 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
568 for planner in self.context_provider.get_expr_planners() {
569 match planner.plan_field_access(field_access_expr, &df_schema)? {
570 PlannerResult::Planned(expr) => return Ok(expr),
571 PlannerResult::Original(expr) => {
572 field_access_expr = expr;
573 }
574 }
575 }
576 Err(Error::invalid_input(
577 "Field access could not be planned",
578 location!(),
579 ))
580 }
581
582 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
583 match expr {
584 SQLExpr::Identifier(id) => {
585 if id.quote_style == Some('"') {
588 Ok(Expr::Literal(
589 ScalarValue::Utf8(Some(id.value.clone())),
590 None,
591 ))
592 } else if id.quote_style == Some('`') {
595 Ok(Expr::Column(Column::from_name(id.value.clone())))
596 } else {
597 Ok(self.column(vec![id.clone()].as_slice()))
598 }
599 }
600 SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
601 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
602 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
603 SQLExpr::Value(value) => self.value(&value.value),
604 SQLExpr::Array(SQLArray { elem, .. }) => {
605 let mut values = vec![];
606
607 let array_literal_error = |pos: usize, value: &_| {
608 Err(Error::invalid_input(
609 format!(
610 "Expected a literal value in array, instead got {} at position {}",
611 value, pos
612 ),
613 location!(),
614 ))
615 };
616
617 for (pos, expr) in elem.iter().enumerate() {
618 match expr {
619 SQLExpr::Value(value) => {
620 if let Expr::Literal(value, _) = self.value(&value.value)? {
621 values.push(value);
622 } else {
623 return array_literal_error(pos, expr);
624 }
625 }
626 SQLExpr::UnaryOp {
627 op: UnaryOperator::Minus,
628 expr,
629 } => {
630 if let SQLExpr::Value(ValueWithSpan {
631 value: Value::Number(number, _),
632 ..
633 }) = expr.as_ref()
634 {
635 if let Expr::Literal(value, _) = self.number(number, true)? {
636 values.push(value);
637 } else {
638 return array_literal_error(pos, expr);
639 }
640 } else {
641 return array_literal_error(pos, expr);
642 }
643 }
644 _ => {
645 return array_literal_error(pos, expr);
646 }
647 }
648 }
649
650 let field = if !values.is_empty() {
651 let data_type = values[0].data_type();
652
653 for value in &mut values {
654 if value.data_type() != data_type {
655 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
656 format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
657 location!()
658 ))?;
659 }
660 }
661 Field::new("item", data_type, true)
662 } else {
663 Field::new("item", ArrowDataType::Null, true)
664 };
665
666 let values = values
667 .into_iter()
668 .map(|v| v.to_array().map_err(Error::from))
669 .collect::<Result<Vec<_>>>()?;
670 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
671 let values = concat(&array_refs)?;
672 let values = ListArray::try_new(
673 field.into(),
674 OffsetBuffer::from_lengths([values.len()]),
675 values,
676 None,
677 )?;
678
679 Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
680 }
681 SQLExpr::TypedString { data_type, value } => {
683 let value = value.clone().into_string().expect_ok()?;
684 Ok(Expr::Cast(datafusion::logical_expr::Cast {
685 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
686 data_type: self.parse_type(data_type)?,
687 }))
688 }
689 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
690 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
691 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
692 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
693 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
694 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
695 SQLExpr::InList {
696 expr,
697 list,
698 negated,
699 } => {
700 let value_expr = self.parse_sql_expr(expr)?;
701 let list_exprs = list
702 .iter()
703 .map(|e| self.parse_sql_expr(e))
704 .collect::<Result<Vec<_>>>()?;
705 Ok(value_expr.in_list(list_exprs, *negated))
706 }
707 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
708 SQLExpr::Function(_) => self.parse_function(expr.clone()),
709 SQLExpr::ILike {
710 negated,
711 expr,
712 pattern,
713 escape_char,
714 any: _,
715 } => Ok(Expr::Like(Like::new(
716 *negated,
717 Box::new(self.parse_sql_expr(expr)?),
718 Box::new(self.parse_sql_expr(pattern)?),
719 match escape_char {
720 Some(Value::SingleQuotedString(char)) => char.chars().next(),
721 Some(value) => return Err(Error::invalid_input(
722 format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
723 location!()
724 )),
725 None => None,
726 },
727 true,
728 ))),
729 SQLExpr::Like {
730 negated,
731 expr,
732 pattern,
733 escape_char,
734 any: _,
735 } => Ok(Expr::Like(Like::new(
736 *negated,
737 Box::new(self.parse_sql_expr(expr)?),
738 Box::new(self.parse_sql_expr(pattern)?),
739 match escape_char {
740 Some(Value::SingleQuotedString(char)) => char.chars().next(),
741 Some(value) => return Err(Error::invalid_input(
742 format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
743 location!()
744 )),
745 None => None,
746 },
747 false,
748 ))),
749 SQLExpr::Cast {
750 expr,
751 data_type,
752 kind,
753 ..
754 } => match kind {
755 datafusion::sql::sqlparser::ast::CastKind::TryCast
756 | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
757 Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
758 expr: Box::new(self.parse_sql_expr(expr)?),
759 data_type: self.parse_type(data_type)?,
760 }))
761 }
762 _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
763 expr: Box::new(self.parse_sql_expr(expr)?),
764 data_type: self.parse_type(data_type)?,
765 })),
766 },
767 SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
768 "JSON access is not supported",
769 location!(),
770 )),
771 SQLExpr::CompoundFieldAccess { root, access_chain } => {
772 let mut expr = self.parse_sql_expr(root)?;
773
774 for access in access_chain {
775 let field_access = match access {
776 AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
778 | AccessExpr::Subscript(Subscript::Index {
779 index:
780 SQLExpr::Value(ValueWithSpan {
781 value:
782 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
783 ..
784 }),
785 }) => GetFieldAccess::NamedStructField {
786 name: ScalarValue::from(s.as_str()),
787 },
788 AccessExpr::Subscript(Subscript::Index { index }) => {
789 let key = Box::new(self.parse_sql_expr(index)?);
790 GetFieldAccess::ListIndex { key }
791 }
792 AccessExpr::Subscript(Subscript::Slice { .. }) => {
793 return Err(Error::invalid_input(
794 "Slice subscript is not supported",
795 location!(),
796 ));
797 }
798 _ => {
799 return Err(Error::invalid_input(
802 "Only dot notation or index access is supported for field access",
803 location!(),
804 ));
805 }
806 };
807
808 let field_access_expr = RawFieldAccessExpr { expr, field_access };
809 expr = self.plan_field_access(field_access_expr)?;
810 }
811
812 Ok(expr)
813 }
814 SQLExpr::Between {
815 expr,
816 negated,
817 low,
818 high,
819 } => {
820 let expr = self.parse_sql_expr(expr)?;
822 let low = self.parse_sql_expr(low)?;
823 let high = self.parse_sql_expr(high)?;
824
825 let between = Expr::Between(Between::new(
826 Box::new(expr),
827 *negated,
828 Box::new(low),
829 Box::new(high),
830 ));
831 Ok(between)
832 }
833 _ => Err(Error::invalid_input(
834 format!("Expression '{expr}' is not supported SQL in lance"),
835 location!(),
836 )),
837 }
838 }
839
840 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
845 let ast_expr = parse_sql_filter(filter)?;
847 let expr = self.parse_sql_expr(&ast_expr)?;
848 let schema = Schema::try_from(self.schema.as_ref())?;
849 let resolved = resolve_expr(&expr, &schema).map_err(|e| {
850 Error::invalid_input(
851 format!("Error resolving filter expression {filter}: {e}"),
852 location!(),
853 )
854 })?;
855
856 Ok(coerce_filter_type_to_boolean(resolved))
857 }
858
859 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
864 let resolved_name = self.resolve_column_name(expr);
867 if self.schema.field_with_name(&resolved_name).is_ok() {
868 return Ok(Expr::Column(Column::from_name(resolved_name)));
869 }
870
871 let ast_expr = parse_sql_expr(expr)?;
873 let expr = self.parse_sql_expr(&ast_expr)?;
874 let schema = Schema::try_from(self.schema.as_ref())?;
875 let resolved = resolve_expr(&expr, &schema)?;
876 Ok(resolved)
877 }
878
879 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
885 let hex_bytes = s.as_bytes();
886 let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
887
888 let start_idx = hex_bytes.len() % 2;
889 if start_idx > 0 {
890 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
892 }
893
894 for i in (start_idx..hex_bytes.len()).step_by(2) {
895 let high = Self::try_decode_hex_char(hex_bytes[i])?;
896 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
897 decoded_bytes.push((high << 4) | low);
898 }
899
900 Some(decoded_bytes)
901 }
902
903 const fn try_decode_hex_char(c: u8) -> Option<u8> {
907 match c {
908 b'A'..=b'F' => Some(c - b'A' + 10),
909 b'a'..=b'f' => Some(c - b'a' + 10),
910 b'0'..=b'9' => Some(c - b'0'),
911 _ => None,
912 }
913 }
914
915 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
917 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
918
919 let props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
922 let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
923 let simplifier =
924 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
925
926 let expr = simplifier.simplify(expr)?;
927 let expr = simplifier.coerce(expr, &df_schema)?;
928
929 Ok(expr)
930 }
931
932 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
934 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
935 Ok(datafusion::physical_expr::create_physical_expr(
936 expr,
937 df_schema.as_ref(),
938 &Default::default(),
939 )?)
940 }
941
942 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
949 let mut visitor = ColumnCapturingVisitor {
950 current_path: VecDeque::new(),
951 columns: BTreeSet::new(),
952 };
953 expr.visit(&mut visitor).unwrap();
954 visitor.columns.into_iter().collect()
955 }
956}
957
958struct ColumnCapturingVisitor {
959 current_path: VecDeque<String>,
961 columns: BTreeSet<String>,
962}
963
964impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
965 type Node = Expr;
966
967 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
968 match node {
969 Expr::Column(Column { name, .. }) => {
970 let mut path = name.clone();
974 for part in self.current_path.drain(..) {
975 path.push('.');
976 if part.contains('.') || part.contains('`') {
978 let escaped = part.replace('`', "``");
980 path.push('`');
981 path.push_str(&escaped);
982 path.push('`');
983 } else {
984 path.push_str(&part);
985 }
986 }
987 self.columns.insert(path);
988 self.current_path.clear();
989 }
990 Expr::ScalarFunction(udf) => {
991 if udf.name() == GetFieldFunc::default().name() {
992 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
993 self.current_path.push_front(name.to_string())
994 } else {
995 self.current_path.clear();
996 }
997 } else {
998 self.current_path.clear();
999 }
1000 }
1001 _ => {
1002 self.current_path.clear();
1003 }
1004 }
1005
1006 Ok(TreeNodeRecursion::Continue)
1007 }
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012
1013 use crate::logical_expr::ExprExt;
1014
1015 use super::*;
1016
1017 use arrow::datatypes::Float64Type;
1018 use arrow_array::{
1019 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
1020 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1021 TimestampNanosecondArray, TimestampSecondArray,
1022 };
1023 use arrow_schema::{DataType, Fields, Schema};
1024 use datafusion::{
1025 logical_expr::{col, lit, Cast},
1026 prelude::{array_element, get_field},
1027 };
1028 use datafusion_functions::core::expr_ext::FieldAccessor;
1029
1030 #[test]
1031 fn test_parse_filter_simple() {
1032 let schema = Arc::new(Schema::new(vec![
1033 Field::new("i", DataType::Int32, false),
1034 Field::new("s", DataType::Utf8, true),
1035 Field::new(
1036 "st",
1037 DataType::Struct(Fields::from(vec![
1038 Field::new("x", DataType::Float32, false),
1039 Field::new("y", DataType::Float32, false),
1040 ])),
1041 true,
1042 ),
1043 ]));
1044
1045 let planner = Planner::new(schema.clone());
1046
1047 let expected = col("i")
1048 .gt(lit(3_i32))
1049 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1050 .and(
1051 col("s")
1052 .eq(lit("str-4"))
1053 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1054 );
1055
1056 let expr = planner
1058 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1059 .unwrap();
1060 assert_eq!(expr, expected);
1061
1062 let expr = planner
1064 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1065 .unwrap();
1066
1067 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1068
1069 let batch = RecordBatch::try_new(
1070 schema,
1071 vec![
1072 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1073 Arc::new(StringArray::from_iter_values(
1074 (0..10).map(|v| format!("str-{}", v)),
1075 )),
1076 Arc::new(StructArray::from(vec![
1077 (
1078 Arc::new(Field::new("x", DataType::Float32, false)),
1079 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1080 as ArrayRef,
1081 ),
1082 (
1083 Arc::new(Field::new("y", DataType::Float32, false)),
1084 Arc::new(Float32Array::from_iter_values(
1085 (0..10).map(|v| (v * 10) as f32),
1086 )),
1087 ),
1088 ])),
1089 ],
1090 )
1091 .unwrap();
1092 let predicates = physical_expr.evaluate(&batch).unwrap();
1093 assert_eq!(
1094 predicates.into_array(0).unwrap().as_ref(),
1095 &BooleanArray::from(vec![
1096 false, false, false, false, true, true, false, false, false, false
1097 ])
1098 );
1099 }
1100
1101 #[test]
1102 fn test_nested_col_refs() {
1103 let schema = Arc::new(Schema::new(vec![
1104 Field::new("s0", DataType::Utf8, true),
1105 Field::new(
1106 "st",
1107 DataType::Struct(Fields::from(vec![
1108 Field::new("s1", DataType::Utf8, true),
1109 Field::new(
1110 "st",
1111 DataType::Struct(Fields::from(vec![Field::new(
1112 "s2",
1113 DataType::Utf8,
1114 true,
1115 )])),
1116 true,
1117 ),
1118 ])),
1119 true,
1120 ),
1121 ]));
1122
1123 let planner = Planner::new(schema);
1124
1125 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1126 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1127 assert!(matches!(
1128 expr,
1129 Expr::BinaryExpr(BinaryExpr {
1130 left: _,
1131 op: Operator::Eq,
1132 right: _
1133 })
1134 ));
1135 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1136 assert_eq!(left.as_ref(), expected);
1137 }
1138 }
1139
1140 let expected = Expr::Column(Column::new_unqualified("s0"));
1141 assert_column_eq(&planner, "s0", &expected);
1142 assert_column_eq(&planner, "`s0`", &expected);
1143
1144 let expected = Expr::ScalarFunction(ScalarFunction {
1145 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1146 args: vec![
1147 Expr::Column(Column::new_unqualified("st")),
1148 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1149 ],
1150 });
1151 assert_column_eq(&planner, "st.s1", &expected);
1152 assert_column_eq(&planner, "`st`.`s1`", &expected);
1153 assert_column_eq(&planner, "st.`s1`", &expected);
1154
1155 let expected = Expr::ScalarFunction(ScalarFunction {
1156 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1157 args: vec![
1158 Expr::ScalarFunction(ScalarFunction {
1159 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1160 args: vec![
1161 Expr::Column(Column::new_unqualified("st")),
1162 Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1163 ],
1164 }),
1165 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1166 ],
1167 });
1168
1169 assert_column_eq(&planner, "st.st.s2", &expected);
1170 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1171 assert_column_eq(&planner, "st.st.`s2`", &expected);
1172 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1173 }
1174
1175 #[test]
1176 fn test_nested_list_refs() {
1177 let schema = Arc::new(Schema::new(vec![Field::new(
1178 "l",
1179 DataType::List(Arc::new(Field::new(
1180 "item",
1181 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1182 true,
1183 ))),
1184 true,
1185 )]));
1186
1187 let planner = Planner::new(schema);
1188
1189 let expected = array_element(col("l"), lit(0_i64));
1190 let expr = planner.parse_expr("l[0]").unwrap();
1191 assert_eq!(expr, expected);
1192
1193 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1194 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1195 assert_eq!(expr, expected);
1196
1197 }
1202
1203 #[test]
1204 fn test_negative_expressions() {
1205 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1206
1207 let planner = Planner::new(schema.clone());
1208
1209 let expected = col("x")
1210 .gt(lit(-3_i64))
1211 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1212
1213 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1214
1215 assert_eq!(expr, expected);
1216
1217 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1218
1219 let batch = RecordBatch::try_new(
1220 schema,
1221 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1222 )
1223 .unwrap();
1224 let predicates = physical_expr.evaluate(&batch).unwrap();
1225 assert_eq!(
1226 predicates.into_array(0).unwrap().as_ref(),
1227 &BooleanArray::from(vec![
1228 false, false, false, true, true, true, true, false, false, false
1229 ])
1230 );
1231 }
1232
1233 #[test]
1234 fn test_negative_array_expressions() {
1235 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1236
1237 let planner = Planner::new(schema);
1238
1239 let expected = Expr::Literal(
1240 ScalarValue::List(Arc::new(
1241 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1242 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1243 )]),
1244 )),
1245 None,
1246 );
1247
1248 let expr = planner
1249 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1250 .unwrap();
1251
1252 assert_eq!(expr, expected);
1253 }
1254
1255 #[test]
1256 fn test_sql_like() {
1257 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1258
1259 let planner = Planner::new(schema.clone());
1260
1261 let expected = col("s").like(lit("str-4"));
1262 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1264 assert_eq!(expr, expected);
1265 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1266
1267 let batch = RecordBatch::try_new(
1268 schema,
1269 vec![Arc::new(StringArray::from_iter_values(
1270 (0..10).map(|v| format!("str-{}", v)),
1271 ))],
1272 )
1273 .unwrap();
1274 let predicates = physical_expr.evaluate(&batch).unwrap();
1275 assert_eq!(
1276 predicates.into_array(0).unwrap().as_ref(),
1277 &BooleanArray::from(vec![
1278 false, false, false, false, true, false, false, false, false, false
1279 ])
1280 );
1281 }
1282
1283 #[test]
1284 fn test_not_like() {
1285 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1286
1287 let planner = Planner::new(schema.clone());
1288
1289 let expected = col("s").not_like(lit("str-4"));
1290 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1292 assert_eq!(expr, expected);
1293 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1294
1295 let batch = RecordBatch::try_new(
1296 schema,
1297 vec![Arc::new(StringArray::from_iter_values(
1298 (0..10).map(|v| format!("str-{}", v)),
1299 ))],
1300 )
1301 .unwrap();
1302 let predicates = physical_expr.evaluate(&batch).unwrap();
1303 assert_eq!(
1304 predicates.into_array(0).unwrap().as_ref(),
1305 &BooleanArray::from(vec![
1306 true, true, true, true, false, true, true, true, true, true
1307 ])
1308 );
1309 }
1310
1311 #[test]
1312 fn test_sql_is_in() {
1313 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1314
1315 let planner = Planner::new(schema.clone());
1316
1317 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1318 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1320 assert_eq!(expr, expected);
1321 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1322
1323 let batch = RecordBatch::try_new(
1324 schema,
1325 vec![Arc::new(StringArray::from_iter_values(
1326 (0..10).map(|v| format!("str-{}", v)),
1327 ))],
1328 )
1329 .unwrap();
1330 let predicates = physical_expr.evaluate(&batch).unwrap();
1331 assert_eq!(
1332 predicates.into_array(0).unwrap().as_ref(),
1333 &BooleanArray::from(vec![
1334 false, false, false, false, true, true, false, false, false, false
1335 ])
1336 );
1337 }
1338
1339 #[test]
1340 fn test_sql_is_null() {
1341 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1342
1343 let planner = Planner::new(schema.clone());
1344
1345 let expected = col("s").is_null();
1346 let expr = planner.parse_filter("s IS NULL").unwrap();
1347 assert_eq!(expr, expected);
1348 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1349
1350 let batch = RecordBatch::try_new(
1351 schema,
1352 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1353 if v % 3 == 0 {
1354 Some(format!("str-{}", v))
1355 } else {
1356 None
1357 }
1358 })))],
1359 )
1360 .unwrap();
1361 let predicates = physical_expr.evaluate(&batch).unwrap();
1362 assert_eq!(
1363 predicates.into_array(0).unwrap().as_ref(),
1364 &BooleanArray::from(vec![
1365 false, true, true, false, true, true, false, true, true, false
1366 ])
1367 );
1368
1369 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1370 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1371 let predicates = physical_expr.evaluate(&batch).unwrap();
1372 assert_eq!(
1373 predicates.into_array(0).unwrap().as_ref(),
1374 &BooleanArray::from(vec![
1375 true, false, false, true, false, false, true, false, false, true,
1376 ])
1377 );
1378 }
1379
1380 #[test]
1381 fn test_sql_invert() {
1382 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1383
1384 let planner = Planner::new(schema.clone());
1385
1386 let expr = planner.parse_filter("NOT s").unwrap();
1387 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1388
1389 let batch = RecordBatch::try_new(
1390 schema,
1391 vec![Arc::new(BooleanArray::from_iter(
1392 (0..10).map(|v| Some(v % 3 == 0)),
1393 ))],
1394 )
1395 .unwrap();
1396 let predicates = physical_expr.evaluate(&batch).unwrap();
1397 assert_eq!(
1398 predicates.into_array(0).unwrap().as_ref(),
1399 &BooleanArray::from(vec![
1400 false, true, true, false, true, true, false, true, true, false
1401 ])
1402 );
1403 }
1404
1405 #[test]
1406 fn test_sql_cast() {
1407 let cases = &[
1408 (
1409 "x = cast('2021-01-01 00:00:00' as timestamp)",
1410 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1411 ),
1412 (
1413 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1414 ArrowDataType::Timestamp(TimeUnit::Second, None),
1415 ),
1416 (
1417 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1418 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1419 ),
1420 (
1421 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1422 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1423 ),
1424 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1425 (
1426 "x = cast('1.238' as decimal(9,3))",
1427 ArrowDataType::Decimal128(9, 3),
1428 ),
1429 ("x = cast(1 as float)", ArrowDataType::Float32),
1430 ("x = cast(1 as double)", ArrowDataType::Float64),
1431 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1432 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1433 ("x = cast(1 as int)", ArrowDataType::Int32),
1434 ("x = cast(1 as integer)", ArrowDataType::Int32),
1435 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1436 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1437 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1438 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1439 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1440 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1441 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1442 ("x = cast(1 as string)", ArrowDataType::Utf8),
1443 ];
1444
1445 for (sql, expected_data_type) in cases {
1446 let schema = Arc::new(Schema::new(vec![Field::new(
1447 "x",
1448 expected_data_type.clone(),
1449 true,
1450 )]));
1451 let planner = Planner::new(schema.clone());
1452 let expr = planner.parse_filter(sql).unwrap();
1453
1454 let expected_value_str = sql
1456 .split("cast(")
1457 .nth(1)
1458 .unwrap()
1459 .split(" as")
1460 .next()
1461 .unwrap();
1462 let expected_value_str = expected_value_str.trim_matches('\'');
1464
1465 match expr {
1466 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1467 Expr::Cast(Cast { expr, data_type }) => {
1468 match expr.as_ref() {
1469 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1470 assert_eq!(value_str, expected_value_str);
1471 }
1472 Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1473 assert_eq!(*value, 1);
1474 }
1475 _ => panic!("Expected cast to be applied to literal"),
1476 }
1477 assert_eq!(data_type, expected_data_type);
1478 }
1479 _ => panic!("Expected right to be a cast"),
1480 },
1481 _ => panic!("Expected binary expression"),
1482 }
1483 }
1484 }
1485
1486 #[test]
1487 fn test_sql_literals() {
1488 let cases = &[
1489 (
1490 "x = timestamp '2021-01-01 00:00:00'",
1491 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1492 ),
1493 (
1494 "x = timestamp(0) '2021-01-01 00:00:00'",
1495 ArrowDataType::Timestamp(TimeUnit::Second, None),
1496 ),
1497 (
1498 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1499 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1500 ),
1501 ("x = date '2021-01-01'", ArrowDataType::Date32),
1502 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1503 ];
1504
1505 for (sql, expected_data_type) in cases {
1506 let schema = Arc::new(Schema::new(vec![Field::new(
1507 "x",
1508 expected_data_type.clone(),
1509 true,
1510 )]));
1511 let planner = Planner::new(schema.clone());
1512 let expr = planner.parse_filter(sql).unwrap();
1513
1514 let expected_value_str = sql.split('\'').nth(1).unwrap();
1515
1516 match expr {
1517 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1518 Expr::Cast(Cast { expr, data_type }) => {
1519 match expr.as_ref() {
1520 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1521 assert_eq!(value_str, expected_value_str);
1522 }
1523 _ => panic!("Expected cast to be applied to literal"),
1524 }
1525 assert_eq!(data_type, expected_data_type);
1526 }
1527 _ => panic!("Expected right to be a cast"),
1528 },
1529 _ => panic!("Expected binary expression"),
1530 }
1531 }
1532 }
1533
1534 #[test]
1535 fn test_sql_array_literals() {
1536 let cases = [
1537 (
1538 "x = [1, 2, 3]",
1539 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1540 ),
1541 (
1542 "x = [1, 2, 3]",
1543 ArrowDataType::FixedSizeList(
1544 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1545 3,
1546 ),
1547 ),
1548 ];
1549
1550 for (sql, expected_data_type) in cases {
1551 let schema = Arc::new(Schema::new(vec![Field::new(
1552 "x",
1553 expected_data_type.clone(),
1554 true,
1555 )]));
1556 let planner = Planner::new(schema.clone());
1557 let expr = planner.parse_filter(sql).unwrap();
1558 let expr = planner.optimize_expr(expr).unwrap();
1559
1560 match expr {
1561 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1562 Expr::Literal(value, _) => {
1563 assert_eq!(&value.data_type(), &expected_data_type);
1564 }
1565 _ => panic!("Expected right to be a literal"),
1566 },
1567 _ => panic!("Expected binary expression"),
1568 }
1569 }
1570 }
1571
1572 #[test]
1573 fn test_sql_between() {
1574 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1575 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1576 use std::sync::Arc;
1577
1578 let schema = Arc::new(Schema::new(vec![
1579 Field::new("x", DataType::Int32, false),
1580 Field::new("y", DataType::Float64, false),
1581 Field::new(
1582 "ts",
1583 DataType::Timestamp(TimeUnit::Microsecond, None),
1584 false,
1585 ),
1586 ]));
1587
1588 let planner = Planner::new(schema.clone());
1589
1590 let expr = planner
1592 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1593 .unwrap();
1594 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1595
1596 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1600 (0..10).map(|i| base_ts + i * 1_000_000), );
1602
1603 let batch = RecordBatch::try_new(
1604 schema,
1605 vec![
1606 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1607 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1608 Arc::new(ts_array),
1609 ],
1610 )
1611 .unwrap();
1612
1613 let predicates = physical_expr.evaluate(&batch).unwrap();
1614 assert_eq!(
1615 predicates.into_array(0).unwrap().as_ref(),
1616 &BooleanArray::from(vec![
1617 false, false, false, true, true, true, true, true, false, false
1618 ])
1619 );
1620
1621 let expr = planner
1623 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1624 .unwrap();
1625 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1626
1627 let predicates = physical_expr.evaluate(&batch).unwrap();
1628 assert_eq!(
1629 predicates.into_array(0).unwrap().as_ref(),
1630 &BooleanArray::from(vec![
1631 true, true, true, false, false, false, false, false, true, true
1632 ])
1633 );
1634
1635 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1637 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1638
1639 let predicates = physical_expr.evaluate(&batch).unwrap();
1640 assert_eq!(
1641 predicates.into_array(0).unwrap().as_ref(),
1642 &BooleanArray::from(vec![
1643 false, false, false, true, true, true, true, false, false, false
1644 ])
1645 );
1646
1647 let expr = planner
1649 .parse_filter(
1650 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1651 )
1652 .unwrap();
1653 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1654
1655 let predicates = physical_expr.evaluate(&batch).unwrap();
1656 assert_eq!(
1657 predicates.into_array(0).unwrap().as_ref(),
1658 &BooleanArray::from(vec![
1659 false, false, false, true, true, true, true, true, false, false
1660 ])
1661 );
1662 }
1663
1664 #[test]
1665 fn test_sql_comparison() {
1666 let batch: Vec<(&str, ArrayRef)> = vec![
1668 (
1669 "timestamp_s",
1670 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1671 ),
1672 (
1673 "timestamp_ms",
1674 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1675 ),
1676 (
1677 "timestamp_us",
1678 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1679 ),
1680 (
1681 "timestamp_ns",
1682 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1683 ),
1684 ];
1685 let batch = RecordBatch::try_from_iter(batch).unwrap();
1686
1687 let planner = Planner::new(batch.schema());
1688
1689 let expressions = &[
1691 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1692 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1693 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1694 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1695 ];
1696
1697 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1698 std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1699 ));
1700 for expression in expressions {
1701 let logical_expr = planner.parse_filter(expression).unwrap();
1703 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1704 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1705
1706 let result = physical_expr.evaluate(&batch).unwrap();
1708 let result = result.into_array(batch.num_rows()).unwrap();
1709 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1710 }
1711 }
1712
1713 #[test]
1714 fn test_columns_in_expr() {
1715 let expr = col("s0").gt(lit("value")).and(
1716 col("st")
1717 .field("st")
1718 .field("s2")
1719 .eq(lit("value"))
1720 .or(col("st")
1721 .field("s1")
1722 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1723 );
1724
1725 let columns = Planner::column_names_in_expr(&expr);
1726 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1727 }
1728
1729 #[test]
1730 fn test_parse_binary_expr() {
1731 let bin_str = "x'616263'";
1732
1733 let schema = Arc::new(Schema::new(vec![Field::new(
1734 "binary",
1735 DataType::Binary,
1736 true,
1737 )]));
1738 let planner = Planner::new(schema);
1739 let expr = planner.parse_expr(bin_str).unwrap();
1740 assert_eq!(
1741 expr,
1742 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1743 );
1744 }
1745
1746 #[test]
1747 fn test_lance_context_provider_expr_planners() {
1748 let ctx_provider = LanceContextProvider::default();
1749 assert!(!ctx_provider.get_expr_planners().is_empty());
1750 }
1751
1752 #[test]
1753 fn test_regexp_match_and_non_empty_captions() {
1754 let schema = Arc::new(Schema::new(vec![
1757 Field::new("keywords", DataType::Utf8, true),
1758 Field::new("natural_caption", DataType::Utf8, true),
1759 Field::new("poetic_caption", DataType::Utf8, true),
1760 ]));
1761
1762 let planner = Planner::new(schema.clone());
1763
1764 let expr = planner
1765 .parse_filter(
1766 "regexp_match(keywords, 'Liberty|revolution') AND \
1767 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1768 poetic_caption IS NOT NULL AND poetic_caption <> '')",
1769 )
1770 .unwrap();
1771
1772 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1773
1774 let batch = RecordBatch::try_new(
1775 schema,
1776 vec![
1777 Arc::new(StringArray::from(vec![
1778 Some("Liberty for all"),
1779 Some("peace"),
1780 Some("revolution now"),
1781 Some("Liberty"),
1782 Some("revolutionary"),
1783 Some("none"),
1784 ])) as ArrayRef,
1785 Arc::new(StringArray::from(vec![
1786 Some("a"),
1787 Some("b"),
1788 None,
1789 Some(""),
1790 Some("c"),
1791 Some("d"),
1792 ])) as ArrayRef,
1793 Arc::new(StringArray::from(vec![
1794 Some("x"),
1795 Some(""),
1796 Some("y"),
1797 Some("z"),
1798 None,
1799 Some("w"),
1800 ])) as ArrayRef,
1801 ],
1802 )
1803 .unwrap();
1804
1805 let result = physical_expr.evaluate(&batch).unwrap();
1806 assert_eq!(
1807 result.into_array(0).unwrap().as_ref(),
1808 &BooleanArray::from(vec![true, false, false, false, false, false])
1809 );
1810 }
1811
1812 #[test]
1813 fn test_regexp_match_infer_error_without_boolean_coercion() {
1814 let schema = Arc::new(Schema::new(vec![
1817 Field::new("keywords", DataType::Utf8, true),
1818 Field::new("natural_caption", DataType::Utf8, true),
1819 Field::new("poetic_caption", DataType::Utf8, true),
1820 ]));
1821
1822 let planner = Planner::new(schema);
1823
1824 let expr = planner
1825 .parse_filter(
1826 "regexp_match(keywords, 'Liberty|revolution') AND \
1827 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1828 poetic_caption IS NOT NULL AND poetic_caption <> '')",
1829 )
1830 .unwrap();
1831
1832 let _physical = planner.create_physical_expr(&expr).unwrap();
1834 }
1835}