1use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::exec::{LanceExecutionOptions, get_session_context};
11use crate::expr::safe_coerce_scalar;
12use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
13use crate::sql::{parse_sql_expr, parse_sql_filter};
14use arrow::compute::CastOptions;
15use arrow_array::ListArray;
16use arrow_buffer::OffsetBuffer;
17use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
18use arrow_select::concat::concat;
19use datafusion::common::DFSchema;
20use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
21use datafusion::config::ConfigOptions;
22use datafusion::error::Result as DFResult;
23use datafusion::execution::context::SessionState;
24use datafusion::logical_expr::expr::ScalarFunction;
25use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
26use datafusion::logical_expr::{
27 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
28 Signature, Volatility, WindowUDF,
29};
30use datafusion::optimizer::simplify_expressions::SimplifyContext;
31use datafusion::sql::planner::{
32 ContextProvider, NullOrdering, ParserOptions, PlannerContext, SqlToRel,
33};
34use datafusion::sql::sqlparser::ast::{
35 AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36 Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
37 ObjectNamePart, Subscript, TimezoneInfo, TypedString, UnaryOperator, Value, ValueWithSpan,
38};
39use datafusion::{
40 common::Column,
41 logical_expr::{Between, BinaryExpr, Like, Operator},
42 physical_expr::execution_props::ExecutionProps,
43 physical_plan::PhysicalExpr,
44 prelude::Expr,
45 scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use lance_core::error::LanceOptionExt;
51
52use chrono::Utc;
53use lance_core::{Error, Result};
54
55fn encode_jsonb(json_str: &str) -> Result<Expr> {
57 let bytes = lance_arrow::json::encode_json(json_str)
58 .map_err(|e| Error::invalid_input(format!("Failed to encode JSONB: {e}")))?;
59 Ok(Expr::Literal(ScalarValue::LargeBinary(Some(bytes)), None))
60}
61
62#[derive(Debug, Clone, Eq, PartialEq, Hash)]
63struct CastListF16Udf {
64 signature: Signature,
65}
66
67impl CastListF16Udf {
68 pub fn new() -> Self {
69 Self {
70 signature: Signature::any(1, Volatility::Immutable),
71 }
72 }
73}
74
75impl ScalarUDFImpl for CastListF16Udf {
76 fn as_any(&self) -> &dyn std::any::Any {
77 self
78 }
79
80 fn name(&self) -> &str {
81 "_cast_list_f16"
82 }
83
84 fn signature(&self) -> &Signature {
85 &self.signature
86 }
87
88 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
89 let input = &arg_types[0];
90 match input {
91 ArrowDataType::FixedSizeList(field, size) => {
92 if field.data_type() != &ArrowDataType::Float32
93 && field.data_type() != &ArrowDataType::Float16
94 {
95 return Err(datafusion::error::DataFusionError::Execution(
96 "cast_list_f16 only supports list of float32 or float16".to_string(),
97 ));
98 }
99 Ok(ArrowDataType::FixedSizeList(
100 Arc::new(Field::new(
101 field.name(),
102 ArrowDataType::Float16,
103 field.is_nullable(),
104 )),
105 *size,
106 ))
107 }
108 ArrowDataType::List(field) => {
109 if field.data_type() != &ArrowDataType::Float32
110 && field.data_type() != &ArrowDataType::Float16
111 {
112 return Err(datafusion::error::DataFusionError::Execution(
113 "cast_list_f16 only supports list of float32 or float16".to_string(),
114 ));
115 }
116 Ok(ArrowDataType::List(Arc::new(Field::new(
117 field.name(),
118 ArrowDataType::Float16,
119 field.is_nullable(),
120 ))))
121 }
122 _ => Err(datafusion::error::DataFusionError::Execution(
123 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
124 )),
125 }
126 }
127
128 fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
129 let ColumnarValue::Array(arr) = &func_args.args[0] else {
130 return Err(datafusion::error::DataFusionError::Execution(
131 "cast_list_f16 only supports array arguments".to_string(),
132 ));
133 };
134
135 let to_type = match arr.data_type() {
136 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
137 Arc::new(Field::new(
138 field.name(),
139 ArrowDataType::Float16,
140 field.is_nullable(),
141 )),
142 *size,
143 ),
144 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
145 field.name(),
146 ArrowDataType::Float16,
147 field.is_nullable(),
148 ))),
149 _ => {
150 return Err(datafusion::error::DataFusionError::Execution(
151 "cast_list_f16 only supports array arguments".to_string(),
152 ));
153 }
154 };
155
156 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
157 Ok(ColumnarValue::Array(res))
158 }
159}
160
161struct LanceContextProvider {
163 options: datafusion::config::ConfigOptions,
164 state: SessionState,
165 expr_planners: Vec<Arc<dyn ExprPlanner>>,
166}
167
168impl Default for LanceContextProvider {
169 fn default() -> Self {
170 let ctx = get_session_context(&LanceExecutionOptions::default());
171 let state = ctx.state();
172 let expr_planners = state.expr_planners().to_vec();
173
174 Self {
175 options: ConfigOptions::default(),
176 state,
177 expr_planners,
178 }
179 }
180}
181
182impl ContextProvider for LanceContextProvider {
183 fn get_table_source(
184 &self,
185 name: datafusion::sql::TableReference,
186 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
187 Err(datafusion::error::DataFusionError::NotImplemented(format!(
188 "Attempt to reference inner table {} not supported",
189 name
190 )))
191 }
192
193 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
194 self.state.aggregate_functions().get(name).cloned()
195 }
196
197 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
198 self.state.window_functions().get(name).cloned()
199 }
200
201 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
202 match f {
203 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
206 _ => self.state.scalar_functions().get(f).cloned(),
207 }
208 }
209
210 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
211 None
213 }
214
215 fn options(&self) -> &datafusion::config::ConfigOptions {
216 &self.options
217 }
218
219 fn udf_names(&self) -> Vec<String> {
220 self.state.scalar_functions().keys().cloned().collect()
221 }
222
223 fn udaf_names(&self) -> Vec<String> {
224 self.state.aggregate_functions().keys().cloned().collect()
225 }
226
227 fn udwf_names(&self) -> Vec<String> {
228 self.state.window_functions().keys().cloned().collect()
229 }
230
231 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
232 &self.expr_planners
233 }
234}
235
236pub struct Planner {
237 schema: SchemaRef,
238 context_provider: LanceContextProvider,
239 enable_relations: bool,
240}
241
242impl Planner {
243 pub fn new(schema: SchemaRef) -> Self {
244 Self {
245 schema,
246 context_provider: LanceContextProvider::default(),
247 enable_relations: false,
248 }
249 }
250
251 pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
257 self.enable_relations = enable_relations;
258 self
259 }
260
261 fn resolve_column_name(&self, name: &str) -> String {
264 if self.schema.field_with_name(name).is_ok() {
266 return name.to_string();
267 }
268 for field in self.schema.fields() {
270 if field.name().eq_ignore_ascii_case(name) {
271 return field.name().clone();
272 }
273 }
274 name.to_string()
276 }
277
278 fn column(&self, idents: &[Ident]) -> Expr {
279 fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
280 for ident in idents {
281 *expr = Expr::ScalarFunction(ScalarFunction {
282 args: vec![
283 std::mem::take(expr),
284 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
285 ],
286 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
287 });
288 }
289 }
290
291 if self.enable_relations && idents.len() > 1 {
292 let relation = &idents[0].value;
294 let column_name = self.resolve_column_name(&idents[1].value);
295 let column = Expr::Column(Column::new(Some(relation.clone()), column_name));
296 let mut result = column;
297 handle_remaining_idents(&mut result, &idents[2..]);
298 result
299 } else {
300 let resolved_name = self.resolve_column_name(&idents[0].value);
303 let mut column = Expr::Column(Column::from_name(resolved_name));
304 handle_remaining_idents(&mut column, &idents[1..]);
305 column
306 }
307 }
308
309 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
310 Ok(match op {
311 BinaryOperator::Plus => Operator::Plus,
312 BinaryOperator::Minus => Operator::Minus,
313 BinaryOperator::Multiply => Operator::Multiply,
314 BinaryOperator::Divide => Operator::Divide,
315 BinaryOperator::Modulo => Operator::Modulo,
316 BinaryOperator::StringConcat => Operator::StringConcat,
317 BinaryOperator::Gt => Operator::Gt,
318 BinaryOperator::Lt => Operator::Lt,
319 BinaryOperator::GtEq => Operator::GtEq,
320 BinaryOperator::LtEq => Operator::LtEq,
321 BinaryOperator::Eq => Operator::Eq,
322 BinaryOperator::NotEq => Operator::NotEq,
323 BinaryOperator::And => Operator::And,
324 BinaryOperator::Or => Operator::Or,
325 _ => {
326 return Err(Error::invalid_input(format!(
327 "Operator {op} is not supported"
328 )));
329 }
330 })
331 }
332
333 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
334 Ok(Expr::BinaryExpr(BinaryExpr::new(
335 Box::new(self.parse_sql_expr(left)?),
336 self.binary_op(op)?,
337 Box::new(self.parse_sql_expr(right)?),
338 )))
339 }
340
341 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
342 Ok(match op {
343 UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
344 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
345 }
346
347 UnaryOperator::Minus => {
348 use datafusion::logical_expr::lit;
349 match expr {
350 SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
351 Ok(n) => lit(-n),
352 Err(_) => lit(-n
353 .parse::<f64>()
354 .map_err(|_e| {
355 Error::invalid_input(format!("negative operator can be only applied to integer and float operands, got: {n}"))
356 })?),
357 },
358 _ => {
359 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
360 }
361 }
362 }
363
364 _ => {
365 return Err(Error::invalid_input(format!(
366 "Unary operator '{:?}' is not supported",
367 op
368 )));
369 }
370 })
371 }
372
373 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
375 use datafusion::logical_expr::lit;
376 let value: Cow<str> = if negative {
377 Cow::Owned(format!("-{}", value))
378 } else {
379 Cow::Borrowed(value)
380 };
381 if let Ok(n) = value.parse::<i64>() {
382 Ok(lit(n))
383 } else {
384 value.parse::<f64>().map(lit).map_err(|_| {
385 Error::invalid_input(format!("'{value}' is not supported number value."))
386 })
387 }
388 }
389
390 fn value(&self, value: &Value) -> Result<Expr> {
391 Ok(match value {
392 Value::Number(v, _) => self.number(v.as_str(), false)?,
393 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
394 Value::HexStringLiteral(hsl) => {
395 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
396 }
397 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
398 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
399 Value::Null => Expr::Literal(ScalarValue::Null, None),
400 _ => todo!(),
401 })
402 }
403
404 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
405 match func_args {
406 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
407 _ => Err(Error::invalid_input(format!(
408 "Unsupported function args: {:?}",
409 func_args
410 ))),
411 }
412 }
413
414 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
421 match &func.args {
422 FunctionArguments::List(args) => {
423 if func.name.0.len() != 1 {
424 return Err(Error::invalid_input(format!(
425 "Function name must have 1 part, got: {:?}",
426 func.name.0
427 )));
428 }
429 Ok(Expr::IsNotNull(Box::new(
430 self.parse_function_args(&args.args[0])?,
431 )))
432 }
433 _ => Err(Error::invalid_input(format!(
434 "Unsupported function args: {:?}",
435 &func.args
436 ))),
437 }
438 }
439
440 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
441 if let SQLExpr::Function(function) = &function
442 && let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first()
443 && &name.value == "is_valid"
444 {
445 return self.legacy_parse_function(function);
446 }
447 let sql_to_rel = SqlToRel::new_with_options(
448 &self.context_provider,
449 ParserOptions {
450 parse_float_as_decimal: false,
451 enable_ident_normalization: false,
452 support_varchar_with_length: false,
453 enable_options_value_normalization: false,
454 collect_spans: false,
455 map_string_types_to_utf8view: false,
456 default_null_ordering: NullOrdering::NullsMax,
457 },
458 );
459
460 let mut planner_context = PlannerContext::default();
461 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
462 sql_to_rel
463 .sql_to_expr(function, &schema, &mut planner_context)
464 .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}")))
465 }
466
467 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
468 const SUPPORTED_TYPES: [&str; 13] = [
469 "int [unsigned]",
470 "tinyint [unsigned]",
471 "smallint [unsigned]",
472 "bigint [unsigned]",
473 "float",
474 "double",
475 "string",
476 "binary",
477 "date",
478 "timestamp(precision)",
479 "datetime(precision)",
480 "decimal(precision,scale)",
481 "boolean",
482 ];
483 match data_type {
484 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
485 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
486 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
487 SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
488 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
489 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
490 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
491 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
492 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
493 SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
494 SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
495 SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
496 Ok(ArrowDataType::UInt32)
497 }
498 SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
499 SQLDataType::Date => Ok(ArrowDataType::Date32),
500 SQLDataType::Timestamp(resolution, tz) => {
501 match tz {
502 TimezoneInfo::None => {}
503 _ => {
504 return Err(Error::invalid_input(
505 "Timezone not supported in timestamp".to_string(),
506 ));
507 }
508 };
509 let time_unit = match resolution {
510 None => TimeUnit::Microsecond,
512 Some(0) => TimeUnit::Second,
513 Some(3) => TimeUnit::Millisecond,
514 Some(6) => TimeUnit::Microsecond,
515 Some(9) => TimeUnit::Nanosecond,
516 _ => {
517 return Err(Error::invalid_input(format!(
518 "Unsupported datetime resolution: {:?}",
519 resolution
520 )));
521 }
522 };
523 Ok(ArrowDataType::Timestamp(time_unit, None))
524 }
525 SQLDataType::Datetime(resolution) => {
526 let time_unit = match resolution {
527 None => TimeUnit::Microsecond,
528 Some(0) => TimeUnit::Second,
529 Some(3) => TimeUnit::Millisecond,
530 Some(6) => TimeUnit::Microsecond,
531 Some(9) => TimeUnit::Nanosecond,
532 _ => {
533 return Err(Error::invalid_input(format!(
534 "Unsupported datetime resolution: {:?}",
535 resolution
536 )));
537 }
538 };
539 Ok(ArrowDataType::Timestamp(time_unit, None))
540 }
541 SQLDataType::Decimal(number_info) => match number_info {
542 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
543 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
544 }
545 _ => Err(Error::invalid_input(format!(
546 "Must provide precision and scale for decimal: {:?}",
547 number_info
548 ))),
549 },
550 _ => Err(Error::invalid_input(format!(
551 "Unsupported data type: {:?}. Supported types: {:?}",
552 data_type, SUPPORTED_TYPES
553 ))),
554 }
555 }
556
557 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
558 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
559 for planner in self.context_provider.get_expr_planners() {
560 match planner.plan_field_access(field_access_expr, &df_schema)? {
561 PlannerResult::Planned(expr) => return Ok(expr),
562 PlannerResult::Original(expr) => {
563 field_access_expr = expr;
564 }
565 }
566 }
567 Err(Error::invalid_input("Field access could not be planned"))
568 }
569
570 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
571 match expr {
572 SQLExpr::Identifier(id) => {
573 if id.quote_style == Some('"') {
576 Ok(Expr::Literal(
577 ScalarValue::Utf8(Some(id.value.clone())),
578 None,
579 ))
580 } else if id.quote_style == Some('`') {
583 Ok(Expr::Column(Column::from_name(id.value.clone())))
584 } else {
585 Ok(self.column(vec![id.clone()].as_slice()))
586 }
587 }
588 SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
589 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
590 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
591 SQLExpr::Value(value) => self.value(&value.value),
592 SQLExpr::Array(SQLArray { elem, .. }) => {
593 let mut values = vec![];
594
595 let array_literal_error = |pos: usize, value: &_| {
596 Err(Error::invalid_input(format!(
597 "Expected a literal value in array, instead got {} at position {}",
598 value, pos
599 )))
600 };
601
602 for (pos, expr) in elem.iter().enumerate() {
603 match expr {
604 SQLExpr::Value(value) => {
605 if let Expr::Literal(value, _) = self.value(&value.value)? {
606 values.push(value);
607 } else {
608 return array_literal_error(pos, expr);
609 }
610 }
611 SQLExpr::UnaryOp {
612 op: UnaryOperator::Minus,
613 expr,
614 } => {
615 if let SQLExpr::Value(ValueWithSpan {
616 value: Value::Number(number, _),
617 ..
618 }) = expr.as_ref()
619 {
620 if let Expr::Literal(value, _) = self.number(number, true)? {
621 values.push(value);
622 } else {
623 return array_literal_error(pos, expr);
624 }
625 } else {
626 return array_literal_error(pos, expr);
627 }
628 }
629 _ => {
630 return array_literal_error(pos, expr);
631 }
632 }
633 }
634
635 let field = if !values.is_empty() {
636 let data_type = values[0].data_type();
637
638 for value in &mut values {
639 if value.data_type() != data_type {
640 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type())))?;
641 }
642 }
643 Field::new("item", data_type, true)
644 } else {
645 Field::new("item", ArrowDataType::Null, true)
646 };
647
648 let values = values
649 .into_iter()
650 .map(|v| v.to_array().map_err(Error::from))
651 .collect::<Result<Vec<_>>>()?;
652 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
653 let values = concat(&array_refs)?;
654 let values = ListArray::try_new(
655 field.into(),
656 OffsetBuffer::from_lengths([values.len()]),
657 values,
658 None,
659 )?;
660
661 Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
662 }
663 SQLExpr::TypedString(TypedString {
665 data_type: SQLDataType::JSONB,
666 value,
667 ..
668 }) => match &value.value {
669 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => encode_jsonb(s),
670 _ => Err(Error::invalid_input(
671 "Expected a string value for JSONB literal",
672 )),
673 },
674 SQLExpr::TypedString(TypedString {
676 data_type, value, ..
677 }) => {
678 let value = value.clone().into_string().expect_ok()?;
679 Ok(Expr::Cast(datafusion::logical_expr::Cast {
680 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
681 data_type: self.parse_type(data_type)?,
682 }))
683 }
684 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
685 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
686 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
687 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
688 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
689 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
690 SQLExpr::InList {
691 expr,
692 list,
693 negated,
694 } => {
695 let value_expr = self.parse_sql_expr(expr)?;
696 let list_exprs = list
697 .iter()
698 .map(|e| self.parse_sql_expr(e))
699 .collect::<Result<Vec<_>>>()?;
700 Ok(value_expr.in_list(list_exprs, *negated))
701 }
702 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
703 SQLExpr::Function(_) => self.parse_function(expr.clone()),
704 SQLExpr::ILike {
705 negated,
706 expr,
707 pattern,
708 escape_char,
709 any: _,
710 } => Ok(Expr::Like(Like::new(
711 *negated,
712 Box::new(self.parse_sql_expr(expr)?),
713 Box::new(self.parse_sql_expr(pattern)?),
714 match escape_char {
715 Some(Value::SingleQuotedString(char)) => char.chars().next(),
716 Some(value) => {
717 return Err(Error::invalid_input(format!(
718 "Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}",
719 value
720 )));
721 }
722 None => None,
723 },
724 true,
725 ))),
726 SQLExpr::Like {
727 negated,
728 expr,
729 pattern,
730 escape_char,
731 any: _,
732 } => Ok(Expr::Like(Like::new(
733 *negated,
734 Box::new(self.parse_sql_expr(expr)?),
735 Box::new(self.parse_sql_expr(pattern)?),
736 match escape_char {
737 Some(Value::SingleQuotedString(char)) => char.chars().next(),
738 Some(value) => {
739 return Err(Error::invalid_input(format!(
740 "Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}",
741 value
742 )));
743 }
744 None => None,
745 },
746 false,
747 ))),
748 SQLExpr::Cast {
750 data_type: SQLDataType::JSONB,
751 expr: inner,
752 ..
753 } => match inner.as_ref() {
754 SQLExpr::Value(ValueWithSpan {
755 value: Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
756 ..
757 }) => encode_jsonb(s),
758 _ => Err(Error::invalid_input(
759 "CAST to JSONB only supports string literals",
760 )),
761 },
762 SQLExpr::Cast {
763 expr,
764 data_type,
765 kind,
766 ..
767 } => match kind {
768 datafusion::sql::sqlparser::ast::CastKind::TryCast
769 | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
770 Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
771 expr: Box::new(self.parse_sql_expr(expr)?),
772 data_type: self.parse_type(data_type)?,
773 }))
774 }
775 _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
776 expr: Box::new(self.parse_sql_expr(expr)?),
777 data_type: self.parse_type(data_type)?,
778 })),
779 },
780 SQLExpr::JsonAccess { .. } => Err(Error::invalid_input("JSON access is not supported")),
781 SQLExpr::CompoundFieldAccess { root, access_chain } => {
782 let mut expr = self.parse_sql_expr(root)?;
783
784 for access in access_chain {
785 let field_access = match access {
786 AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
788 | AccessExpr::Subscript(Subscript::Index {
789 index:
790 SQLExpr::Value(ValueWithSpan {
791 value:
792 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
793 ..
794 }),
795 }) => GetFieldAccess::NamedStructField {
796 name: ScalarValue::from(s.as_str()),
797 },
798 AccessExpr::Subscript(Subscript::Index { index }) => {
799 let key = Box::new(self.parse_sql_expr(index)?);
800 GetFieldAccess::ListIndex { key }
801 }
802 AccessExpr::Subscript(Subscript::Slice { .. }) => {
803 return Err(Error::invalid_input("Slice subscript is not supported"));
804 }
805 _ => {
806 return Err(Error::invalid_input(
809 "Only dot notation or index access is supported for field access",
810 ));
811 }
812 };
813
814 let field_access_expr = RawFieldAccessExpr { expr, field_access };
815 expr = self.plan_field_access(field_access_expr)?;
816 }
817
818 Ok(expr)
819 }
820 SQLExpr::Between {
821 expr,
822 negated,
823 low,
824 high,
825 } => {
826 let expr = self.parse_sql_expr(expr)?;
828 let low = self.parse_sql_expr(low)?;
829 let high = self.parse_sql_expr(high)?;
830
831 let between = Expr::Between(Between::new(
832 Box::new(expr),
833 *negated,
834 Box::new(low),
835 Box::new(high),
836 ));
837 Ok(between)
838 }
839 _ => Err(Error::invalid_input(format!(
840 "Expression '{expr}' is not supported SQL in lance"
841 ))),
842 }
843 }
844
845 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
850 let ast_expr = parse_sql_filter(filter)?;
852 let expr = self.parse_sql_expr(&ast_expr)?;
853 let schema = Schema::try_from(self.schema.as_ref())?;
854 let resolved = resolve_expr(&expr, &schema).map_err(|e| {
855 Error::invalid_input(format!("Error resolving filter expression {filter}: {e}"))
856 })?;
857
858 Ok(coerce_filter_type_to_boolean(resolved))
859 }
860
861 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
866 let resolved_name = self.resolve_column_name(expr);
869 if self.schema.field_with_name(&resolved_name).is_ok() {
870 return Ok(Expr::Column(Column::from_name(resolved_name)));
871 }
872
873 let ast_expr = parse_sql_expr(expr)?;
875 let expr = self.parse_sql_expr(&ast_expr)?;
876 let schema = Schema::try_from(self.schema.as_ref())?;
877 let resolved = resolve_expr(&expr, &schema)?;
878 Ok(resolved)
879 }
880
881 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
887 let hex_bytes = s.as_bytes();
888 let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
889
890 let start_idx = hex_bytes.len() % 2;
891 if start_idx > 0 {
892 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
894 }
895
896 for i in (start_idx..hex_bytes.len()).step_by(2) {
897 let high = Self::try_decode_hex_char(hex_bytes[i])?;
898 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
899 decoded_bytes.push((high << 4) | low);
900 }
901
902 Some(decoded_bytes)
903 }
904
905 const fn try_decode_hex_char(c: u8) -> Option<u8> {
909 match c {
910 b'A'..=b'F' => Some(c - b'A' + 10),
911 b'a'..=b'f' => Some(c - b'a' + 10),
912 b'0'..=b'9' => Some(c - b'0'),
913 _ => None,
914 }
915 }
916
917 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
919 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
920
921 let props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
924 let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
925 let simplifier =
926 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
927
928 let expr = simplifier.simplify(expr)?;
929 let expr = simplifier.coerce(expr, &df_schema)?;
930
931 Ok(expr)
932 }
933
934 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
936 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
937 Ok(datafusion::physical_expr::create_physical_expr(
938 expr,
939 df_schema.as_ref(),
940 &Default::default(),
941 )?)
942 }
943
944 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
951 let mut visitor = ColumnCapturingVisitor {
952 current_path: VecDeque::new(),
953 columns: BTreeSet::new(),
954 };
955 expr.visit(&mut visitor).unwrap();
956 visitor.columns.into_iter().collect()
957 }
958}
959
960struct ColumnCapturingVisitor {
961 current_path: VecDeque<String>,
963 columns: BTreeSet<String>,
964}
965
966impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
967 type Node = Expr;
968
969 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
970 match node {
971 Expr::Column(Column { name, .. }) => {
972 let mut path = name.clone();
976 for part in self.current_path.drain(..) {
977 path.push('.');
978 if part.contains('.') || part.contains('`') {
980 let escaped = part.replace('`', "``");
982 path.push('`');
983 path.push_str(&escaped);
984 path.push('`');
985 } else {
986 path.push_str(&part);
987 }
988 }
989 self.columns.insert(path);
990 self.current_path.clear();
991 }
992 Expr::ScalarFunction(udf) => {
993 if udf.name() == GetFieldFunc::default().name() {
994 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
995 self.current_path.push_front(name.to_string())
996 } else {
997 self.current_path.clear();
998 }
999 } else {
1000 self.current_path.clear();
1001 }
1002 }
1003 _ => {
1004 self.current_path.clear();
1005 }
1006 }
1007
1008 Ok(TreeNodeRecursion::Continue)
1009 }
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014
1015 use crate::logical_expr::ExprExt;
1016
1017 use super::*;
1018
1019 use arrow::datatypes::Float64Type;
1020 use arrow_array::{
1021 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
1022 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1023 TimestampNanosecondArray, TimestampSecondArray,
1024 };
1025 use arrow_schema::{DataType, Fields, Schema};
1026 use datafusion::{
1027 logical_expr::{Cast, col, lit},
1028 prelude::{array_element, get_field},
1029 };
1030 use datafusion_functions::core::expr_ext::FieldAccessor;
1031
1032 #[test]
1033 fn test_parse_filter_simple() {
1034 let schema = Arc::new(Schema::new(vec![
1035 Field::new("i", DataType::Int32, false),
1036 Field::new("s", DataType::Utf8, true),
1037 Field::new(
1038 "st",
1039 DataType::Struct(Fields::from(vec![
1040 Field::new("x", DataType::Float32, false),
1041 Field::new("y", DataType::Float32, false),
1042 ])),
1043 true,
1044 ),
1045 ]));
1046
1047 let planner = Planner::new(schema.clone());
1048
1049 let expected = col("i")
1050 .gt(lit(3_i32))
1051 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1052 .and(
1053 col("s")
1054 .eq(lit("str-4"))
1055 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1056 );
1057
1058 let expr = planner
1060 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1061 .unwrap();
1062 assert_eq!(expr, expected);
1063
1064 let expr = planner
1066 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1067 .unwrap();
1068
1069 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1070
1071 let batch = RecordBatch::try_new(
1072 schema,
1073 vec![
1074 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1075 Arc::new(StringArray::from_iter_values(
1076 (0..10).map(|v| format!("str-{}", v)),
1077 )),
1078 Arc::new(StructArray::from(vec![
1079 (
1080 Arc::new(Field::new("x", DataType::Float32, false)),
1081 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1082 as ArrayRef,
1083 ),
1084 (
1085 Arc::new(Field::new("y", DataType::Float32, false)),
1086 Arc::new(Float32Array::from_iter_values(
1087 (0..10).map(|v| (v * 10) as f32),
1088 )),
1089 ),
1090 ])),
1091 ],
1092 )
1093 .unwrap();
1094 let predicates = physical_expr.evaluate(&batch).unwrap();
1095 assert_eq!(
1096 predicates.into_array(0).unwrap().as_ref(),
1097 &BooleanArray::from(vec![
1098 false, false, false, false, true, true, false, false, false, false
1099 ])
1100 );
1101 }
1102
1103 #[test]
1104 fn test_nested_col_refs() {
1105 let schema = Arc::new(Schema::new(vec![
1106 Field::new("s0", DataType::Utf8, true),
1107 Field::new(
1108 "st",
1109 DataType::Struct(Fields::from(vec![
1110 Field::new("s1", DataType::Utf8, true),
1111 Field::new(
1112 "st",
1113 DataType::Struct(Fields::from(vec![Field::new(
1114 "s2",
1115 DataType::Utf8,
1116 true,
1117 )])),
1118 true,
1119 ),
1120 ])),
1121 true,
1122 ),
1123 ]));
1124
1125 let planner = Planner::new(schema);
1126
1127 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1128 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1129 assert!(matches!(
1130 expr,
1131 Expr::BinaryExpr(BinaryExpr {
1132 left: _,
1133 op: Operator::Eq,
1134 right: _
1135 })
1136 ));
1137 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1138 assert_eq!(left.as_ref(), expected);
1139 }
1140 }
1141
1142 let expected = Expr::Column(Column::new_unqualified("s0"));
1143 assert_column_eq(&planner, "s0", &expected);
1144 assert_column_eq(&planner, "`s0`", &expected);
1145
1146 let expected = Expr::ScalarFunction(ScalarFunction {
1147 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1148 args: vec![
1149 Expr::Column(Column::new_unqualified("st")),
1150 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1151 ],
1152 });
1153 assert_column_eq(&planner, "st.s1", &expected);
1154 assert_column_eq(&planner, "`st`.`s1`", &expected);
1155 assert_column_eq(&planner, "st.`s1`", &expected);
1156
1157 let expected = Expr::ScalarFunction(ScalarFunction {
1158 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1159 args: vec![
1160 Expr::ScalarFunction(ScalarFunction {
1161 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1162 args: vec![
1163 Expr::Column(Column::new_unqualified("st")),
1164 Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1165 ],
1166 }),
1167 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1168 ],
1169 });
1170
1171 assert_column_eq(&planner, "st.st.s2", &expected);
1172 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1173 assert_column_eq(&planner, "st.st.`s2`", &expected);
1174 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1175 }
1176
1177 #[test]
1178 fn test_nested_list_refs() {
1179 let schema = Arc::new(Schema::new(vec![Field::new(
1180 "l",
1181 DataType::List(Arc::new(Field::new(
1182 "item",
1183 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1184 true,
1185 ))),
1186 true,
1187 )]));
1188
1189 let planner = Planner::new(schema);
1190
1191 let expected = array_element(col("l"), lit(0_i64));
1192 let expr = planner.parse_expr("l[0]").unwrap();
1193 assert_eq!(expr, expected);
1194
1195 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1196 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1197 assert_eq!(expr, expected);
1198
1199 }
1204
1205 #[test]
1206 fn test_negative_expressions() {
1207 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1208
1209 let planner = Planner::new(schema.clone());
1210
1211 let expected = col("x")
1212 .gt(lit(-3_i64))
1213 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1214
1215 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1216
1217 assert_eq!(expr, expected);
1218
1219 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1220
1221 let batch = RecordBatch::try_new(
1222 schema,
1223 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1224 )
1225 .unwrap();
1226 let predicates = physical_expr.evaluate(&batch).unwrap();
1227 assert_eq!(
1228 predicates.into_array(0).unwrap().as_ref(),
1229 &BooleanArray::from(vec![
1230 false, false, false, true, true, true, true, false, false, false
1231 ])
1232 );
1233 }
1234
1235 #[test]
1236 fn test_negative_array_expressions() {
1237 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1238
1239 let planner = Planner::new(schema);
1240
1241 let expected = Expr::Literal(
1242 ScalarValue::List(Arc::new(
1243 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1244 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1245 )]),
1246 )),
1247 None,
1248 );
1249
1250 let expr = planner
1251 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1252 .unwrap();
1253
1254 assert_eq!(expr, expected);
1255 }
1256
1257 #[test]
1258 fn test_sql_like() {
1259 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1260
1261 let planner = Planner::new(schema.clone());
1262
1263 let expected = col("s").like(lit("str-4"));
1264 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1266 assert_eq!(expr, expected);
1267 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1268
1269 let batch = RecordBatch::try_new(
1270 schema,
1271 vec![Arc::new(StringArray::from_iter_values(
1272 (0..10).map(|v| format!("str-{}", v)),
1273 ))],
1274 )
1275 .unwrap();
1276 let predicates = physical_expr.evaluate(&batch).unwrap();
1277 assert_eq!(
1278 predicates.into_array(0).unwrap().as_ref(),
1279 &BooleanArray::from(vec![
1280 false, false, false, false, true, false, false, false, false, false
1281 ])
1282 );
1283 }
1284
1285 #[test]
1286 fn test_not_like() {
1287 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1288
1289 let planner = Planner::new(schema.clone());
1290
1291 let expected = col("s").not_like(lit("str-4"));
1292 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1294 assert_eq!(expr, expected);
1295 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1296
1297 let batch = RecordBatch::try_new(
1298 schema,
1299 vec![Arc::new(StringArray::from_iter_values(
1300 (0..10).map(|v| format!("str-{}", v)),
1301 ))],
1302 )
1303 .unwrap();
1304 let predicates = physical_expr.evaluate(&batch).unwrap();
1305 assert_eq!(
1306 predicates.into_array(0).unwrap().as_ref(),
1307 &BooleanArray::from(vec![
1308 true, true, true, true, false, true, true, true, true, true
1309 ])
1310 );
1311 }
1312
1313 #[test]
1314 fn test_sql_is_in() {
1315 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1316
1317 let planner = Planner::new(schema.clone());
1318
1319 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1320 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1322 assert_eq!(expr, expected);
1323 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1324
1325 let batch = RecordBatch::try_new(
1326 schema,
1327 vec![Arc::new(StringArray::from_iter_values(
1328 (0..10).map(|v| format!("str-{}", v)),
1329 ))],
1330 )
1331 .unwrap();
1332 let predicates = physical_expr.evaluate(&batch).unwrap();
1333 assert_eq!(
1334 predicates.into_array(0).unwrap().as_ref(),
1335 &BooleanArray::from(vec![
1336 false, false, false, false, true, true, false, false, false, false
1337 ])
1338 );
1339 }
1340
1341 #[test]
1342 fn test_sql_is_null() {
1343 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1344
1345 let planner = Planner::new(schema.clone());
1346
1347 let expected = col("s").is_null();
1348 let expr = planner.parse_filter("s IS NULL").unwrap();
1349 assert_eq!(expr, expected);
1350 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1351
1352 let batch = RecordBatch::try_new(
1353 schema,
1354 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1355 if v % 3 == 0 {
1356 Some(format!("str-{}", v))
1357 } else {
1358 None
1359 }
1360 })))],
1361 )
1362 .unwrap();
1363 let predicates = physical_expr.evaluate(&batch).unwrap();
1364 assert_eq!(
1365 predicates.into_array(0).unwrap().as_ref(),
1366 &BooleanArray::from(vec![
1367 false, true, true, false, true, true, false, true, true, false
1368 ])
1369 );
1370
1371 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1372 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1373 let predicates = physical_expr.evaluate(&batch).unwrap();
1374 assert_eq!(
1375 predicates.into_array(0).unwrap().as_ref(),
1376 &BooleanArray::from(vec![
1377 true, false, false, true, false, false, true, false, false, true,
1378 ])
1379 );
1380 }
1381
1382 #[test]
1383 fn test_sql_invert() {
1384 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1385
1386 let planner = Planner::new(schema.clone());
1387
1388 let expr = planner.parse_filter("NOT s").unwrap();
1389 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1390
1391 let batch = RecordBatch::try_new(
1392 schema,
1393 vec![Arc::new(BooleanArray::from_iter(
1394 (0..10).map(|v| Some(v % 3 == 0)),
1395 ))],
1396 )
1397 .unwrap();
1398 let predicates = physical_expr.evaluate(&batch).unwrap();
1399 assert_eq!(
1400 predicates.into_array(0).unwrap().as_ref(),
1401 &BooleanArray::from(vec![
1402 false, true, true, false, true, true, false, true, true, false
1403 ])
1404 );
1405 }
1406
1407 #[test]
1408 fn test_sql_cast() {
1409 let cases = &[
1410 (
1411 "x = cast('2021-01-01 00:00:00' as timestamp)",
1412 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1413 ),
1414 (
1415 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1416 ArrowDataType::Timestamp(TimeUnit::Second, None),
1417 ),
1418 (
1419 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1420 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1421 ),
1422 (
1423 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1424 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1425 ),
1426 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1427 (
1428 "x = cast('1.238' as decimal(9,3))",
1429 ArrowDataType::Decimal128(9, 3),
1430 ),
1431 ("x = cast(1 as float)", ArrowDataType::Float32),
1432 ("x = cast(1 as double)", ArrowDataType::Float64),
1433 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1434 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1435 ("x = cast(1 as int)", ArrowDataType::Int32),
1436 ("x = cast(1 as integer)", ArrowDataType::Int32),
1437 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1438 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1439 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1440 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1441 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1442 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1443 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1444 ("x = cast(1 as string)", ArrowDataType::Utf8),
1445 ];
1446
1447 for (sql, expected_data_type) in cases {
1448 let schema = Arc::new(Schema::new(vec![Field::new(
1449 "x",
1450 expected_data_type.clone(),
1451 true,
1452 )]));
1453 let planner = Planner::new(schema.clone());
1454 let expr = planner.parse_filter(sql).unwrap();
1455
1456 let expected_value_str = sql
1458 .split("cast(")
1459 .nth(1)
1460 .unwrap()
1461 .split(" as")
1462 .next()
1463 .unwrap();
1464 let expected_value_str = expected_value_str.trim_matches('\'');
1466
1467 match expr {
1468 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1469 Expr::Cast(Cast { expr, data_type }) => {
1470 match expr.as_ref() {
1471 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1472 assert_eq!(value_str, expected_value_str);
1473 }
1474 Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1475 assert_eq!(*value, 1);
1476 }
1477 _ => panic!("Expected cast to be applied to literal"),
1478 }
1479 assert_eq!(data_type, expected_data_type);
1480 }
1481 _ => panic!("Expected right to be a cast"),
1482 },
1483 _ => panic!("Expected binary expression"),
1484 }
1485 }
1486 }
1487
1488 #[test]
1489 fn test_sql_literals() {
1490 let cases = &[
1491 (
1492 "x = timestamp '2021-01-01 00:00:00'",
1493 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1494 ),
1495 (
1496 "x = timestamp(0) '2021-01-01 00:00:00'",
1497 ArrowDataType::Timestamp(TimeUnit::Second, None),
1498 ),
1499 (
1500 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1501 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1502 ),
1503 ("x = date '2021-01-01'", ArrowDataType::Date32),
1504 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1505 ];
1506
1507 for (sql, expected_data_type) in cases {
1508 let schema = Arc::new(Schema::new(vec![Field::new(
1509 "x",
1510 expected_data_type.clone(),
1511 true,
1512 )]));
1513 let planner = Planner::new(schema.clone());
1514 let expr = planner.parse_filter(sql).unwrap();
1515
1516 let expected_value_str = sql.split('\'').nth(1).unwrap();
1517
1518 match expr {
1519 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1520 Expr::Cast(Cast { expr, data_type }) => {
1521 match expr.as_ref() {
1522 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1523 assert_eq!(value_str, expected_value_str);
1524 }
1525 _ => panic!("Expected cast to be applied to literal"),
1526 }
1527 assert_eq!(data_type, expected_data_type);
1528 }
1529 _ => panic!("Expected right to be a cast"),
1530 },
1531 _ => panic!("Expected binary expression"),
1532 }
1533 }
1534 }
1535
1536 #[test]
1537 fn test_sql_array_literals() {
1538 let cases = [
1539 (
1540 "x = [1, 2, 3]",
1541 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1542 ),
1543 (
1544 "x = [1, 2, 3]",
1545 ArrowDataType::FixedSizeList(
1546 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1547 3,
1548 ),
1549 ),
1550 ];
1551
1552 for (sql, expected_data_type) in cases {
1553 let schema = Arc::new(Schema::new(vec![Field::new(
1554 "x",
1555 expected_data_type.clone(),
1556 true,
1557 )]));
1558 let planner = Planner::new(schema.clone());
1559 let expr = planner.parse_filter(sql).unwrap();
1560 let expr = planner.optimize_expr(expr).unwrap();
1561
1562 match expr {
1563 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1564 Expr::Literal(value, _) => {
1565 assert_eq!(&value.data_type(), &expected_data_type);
1566 }
1567 _ => panic!("Expected right to be a literal"),
1568 },
1569 _ => panic!("Expected binary expression"),
1570 }
1571 }
1572 }
1573
1574 #[test]
1575 fn test_sql_between() {
1576 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1577 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1578 use std::sync::Arc;
1579
1580 let schema = Arc::new(Schema::new(vec![
1581 Field::new("x", DataType::Int32, false),
1582 Field::new("y", DataType::Float64, false),
1583 Field::new(
1584 "ts",
1585 DataType::Timestamp(TimeUnit::Microsecond, None),
1586 false,
1587 ),
1588 ]));
1589
1590 let planner = Planner::new(schema.clone());
1591
1592 let expr = planner
1594 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1595 .unwrap();
1596 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1597
1598 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1602 (0..10).map(|i| base_ts + i * 1_000_000), );
1604
1605 let batch = RecordBatch::try_new(
1606 schema,
1607 vec![
1608 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1609 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1610 Arc::new(ts_array),
1611 ],
1612 )
1613 .unwrap();
1614
1615 let predicates = physical_expr.evaluate(&batch).unwrap();
1616 assert_eq!(
1617 predicates.into_array(0).unwrap().as_ref(),
1618 &BooleanArray::from(vec![
1619 false, false, false, true, true, true, true, true, false, false
1620 ])
1621 );
1622
1623 let expr = planner
1625 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1626 .unwrap();
1627 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1628
1629 let predicates = physical_expr.evaluate(&batch).unwrap();
1630 assert_eq!(
1631 predicates.into_array(0).unwrap().as_ref(),
1632 &BooleanArray::from(vec![
1633 true, true, true, false, false, false, false, false, true, true
1634 ])
1635 );
1636
1637 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1639 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1640
1641 let predicates = physical_expr.evaluate(&batch).unwrap();
1642 assert_eq!(
1643 predicates.into_array(0).unwrap().as_ref(),
1644 &BooleanArray::from(vec![
1645 false, false, false, true, true, true, true, false, false, false
1646 ])
1647 );
1648
1649 let expr = planner
1651 .parse_filter(
1652 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1653 )
1654 .unwrap();
1655 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1656
1657 let predicates = physical_expr.evaluate(&batch).unwrap();
1658 assert_eq!(
1659 predicates.into_array(0).unwrap().as_ref(),
1660 &BooleanArray::from(vec![
1661 false, false, false, true, true, true, true, true, false, false
1662 ])
1663 );
1664 }
1665
1666 #[test]
1667 fn test_sql_comparison() {
1668 let batch: Vec<(&str, ArrayRef)> = vec![
1670 (
1671 "timestamp_s",
1672 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1673 ),
1674 (
1675 "timestamp_ms",
1676 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1677 ),
1678 (
1679 "timestamp_us",
1680 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1681 ),
1682 (
1683 "timestamp_ns",
1684 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1685 ),
1686 ];
1687 let batch = RecordBatch::try_from_iter(batch).unwrap();
1688
1689 let planner = Planner::new(batch.schema());
1690
1691 let expressions = &[
1693 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1694 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1695 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1696 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1697 ];
1698
1699 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1700 std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1701 ));
1702 for expression in expressions {
1703 let logical_expr = planner.parse_filter(expression).unwrap();
1705 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1706 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1707
1708 let result = physical_expr.evaluate(&batch).unwrap();
1710 let result = result.into_array(batch.num_rows()).unwrap();
1711 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1712 }
1713 }
1714
1715 #[test]
1716 fn test_columns_in_expr() {
1717 let expr = col("s0").gt(lit("value")).and(
1718 col("st")
1719 .field("st")
1720 .field("s2")
1721 .eq(lit("value"))
1722 .or(col("st")
1723 .field("s1")
1724 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1725 );
1726
1727 let columns = Planner::column_names_in_expr(&expr);
1728 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1729 }
1730
1731 #[test]
1732 fn test_parse_binary_expr() {
1733 let bin_str = "x'616263'";
1734
1735 let schema = Arc::new(Schema::new(vec![Field::new(
1736 "binary",
1737 DataType::Binary,
1738 true,
1739 )]));
1740 let planner = Planner::new(schema);
1741 let expr = planner.parse_expr(bin_str).unwrap();
1742 assert_eq!(
1743 expr,
1744 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1745 );
1746 }
1747
1748 #[test]
1749 fn test_lance_context_provider_expr_planners() {
1750 let ctx_provider = LanceContextProvider::default();
1751 assert!(!ctx_provider.get_expr_planners().is_empty());
1752 }
1753
1754 #[test]
1755 fn test_regexp_match_and_non_empty_captions() {
1756 let schema = Arc::new(Schema::new(vec![
1759 Field::new("keywords", DataType::Utf8, true),
1760 Field::new("natural_caption", DataType::Utf8, true),
1761 Field::new("poetic_caption", DataType::Utf8, true),
1762 ]));
1763
1764 let planner = Planner::new(schema.clone());
1765
1766 let expr = planner
1767 .parse_filter(
1768 "regexp_match(keywords, 'Liberty|revolution') AND \
1769 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1770 poetic_caption IS NOT NULL AND poetic_caption <> '')",
1771 )
1772 .unwrap();
1773
1774 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1775
1776 let batch = RecordBatch::try_new(
1777 schema,
1778 vec![
1779 Arc::new(StringArray::from(vec![
1780 Some("Liberty for all"),
1781 Some("peace"),
1782 Some("revolution now"),
1783 Some("Liberty"),
1784 Some("revolutionary"),
1785 Some("none"),
1786 ])) as ArrayRef,
1787 Arc::new(StringArray::from(vec![
1788 Some("a"),
1789 Some("b"),
1790 None,
1791 Some(""),
1792 Some("c"),
1793 Some("d"),
1794 ])) as ArrayRef,
1795 Arc::new(StringArray::from(vec![
1796 Some("x"),
1797 Some(""),
1798 Some("y"),
1799 Some("z"),
1800 None,
1801 Some("w"),
1802 ])) as ArrayRef,
1803 ],
1804 )
1805 .unwrap();
1806
1807 let result = physical_expr.evaluate(&batch).unwrap();
1808 assert_eq!(
1809 result.into_array(0).unwrap().as_ref(),
1810 &BooleanArray::from(vec![true, false, false, false, false, false])
1811 );
1812 }
1813
1814 #[test]
1815 fn test_regexp_match_infer_error_without_boolean_coercion() {
1816 let schema = Arc::new(Schema::new(vec![
1819 Field::new("keywords", DataType::Utf8, true),
1820 Field::new("natural_caption", DataType::Utf8, true),
1821 Field::new("poetic_caption", DataType::Utf8, true),
1822 ]));
1823
1824 let planner = Planner::new(schema);
1825
1826 let expr = planner
1827 .parse_filter(
1828 "regexp_match(keywords, 'Liberty|revolution') AND \
1829 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1830 poetic_caption IS NOT NULL AND poetic_caption <> '')",
1831 )
1832 .unwrap();
1833
1834 let _physical = planner.create_physical_expr(&expr).unwrap();
1836 }
1837
1838 #[test]
1839 fn test_jsonb_literals() {
1840 let schema = Arc::new(Schema::new(vec![Field::new(
1841 "j",
1842 DataType::LargeBinary,
1843 true,
1844 )]));
1845 let planner = Planner::new(schema);
1846
1847 let cases = [
1848 ("jsonb '{\"key\": \"value\"}'", r#"{"key":"value"}"#),
1849 ("cast('{\"a\": 1}' as jsonb)", r#"{"a":1}"#),
1850 ("'{\"a\": 1}'::jsonb", r#"{"a":1}"#),
1851 ];
1852 for (sql, expected) in cases {
1853 let ast = parse_sql_expr(sql).unwrap();
1854 let expr = planner.parse_sql_expr(&ast).unwrap();
1855 match expr {
1856 Expr::Literal(ScalarValue::LargeBinary(Some(bytes)), _) => {
1857 assert_eq!(
1858 lance_arrow::json::decode_json(&bytes),
1859 expected,
1860 "failed for: {sql}"
1861 );
1862 }
1863 other => panic!("Expected LargeBinary literal for '{sql}', got: {other:?}"),
1864 }
1865 }
1866 }
1867
1868 #[test]
1869 fn test_jsonb_literal_errors() {
1870 let schema = Arc::new(Schema::new(vec![Field::new(
1871 "j",
1872 DataType::LargeBinary,
1873 true,
1874 )]));
1875 let planner = Planner::new(schema);
1876
1877 let ast = parse_sql_expr("jsonb 'not valid json'").unwrap();
1879 let err = planner.parse_sql_expr(&ast).unwrap_err();
1880 assert!(
1881 err.to_string().contains("Failed to encode JSONB"),
1882 "expected JSONB encoding error, got: {err}"
1883 );
1884
1885 let ast = parse_sql_expr("cast(j as jsonb)").unwrap();
1887 let err = planner.parse_sql_expr(&ast).unwrap_err();
1888 assert!(
1889 err.to_string()
1890 .contains("CAST to JSONB only supports string literals"),
1891 "got: {err}"
1892 );
1893 }
1894}