1use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::exec::{LanceExecutionOptions, get_session_context};
11use crate::expr::safe_coerce_scalar;
12use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
13use crate::sql::{parse_sql_expr, parse_sql_filter};
14use arrow::compute::CastOptions;
15use arrow_array::ListArray;
16use arrow_buffer::OffsetBuffer;
17use arrow_cast::cast_with_options;
18use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
19use arrow_select::concat::concat;
20use datafusion::common::DFSchema;
21use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
22use datafusion::config::ConfigOptions;
23use datafusion::error::Result as DFResult;
24use datafusion::execution::context::SessionState;
25use datafusion::logical_expr::expr::ScalarFunction;
26use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
27use datafusion::logical_expr::{
28 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
29 Signature, Volatility, WindowUDF,
30};
31use datafusion::optimizer::simplify_expressions::SimplifyContext;
32use datafusion::sql::planner::{
33 ContextProvider, NullOrdering, ParserOptions, PlannerContext, SqlToRel,
34};
35use datafusion::sql::sqlparser::ast::{
36 AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
37 Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
38 ObjectNamePart, Subscript, TimezoneInfo, TypedString, UnaryOperator, Value, ValueWithSpan,
39};
40use datafusion::{
41 common::Column,
42 logical_expr::{Between, BinaryExpr, Like, Operator},
43 physical_plan::PhysicalExpr,
44 prelude::Expr,
45 scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_core::datatypes::Schema;
49use lance_core::error::LanceOptionExt;
50
51use chrono::Utc;
52use lance_core::{Error, Result};
53
54fn encode_jsonb(json_str: &str) -> Result<Expr> {
56 let bytes = lance_arrow::json::encode_json(json_str)
57 .map_err(|e| Error::invalid_input(format!("Failed to encode JSONB: {e}")))?;
58 Ok(Expr::Literal(ScalarValue::LargeBinary(Some(bytes)), None))
59}
60
61#[derive(Debug, Clone, Eq, PartialEq, Hash)]
62struct CastListF16Udf {
63 signature: Signature,
64}
65
66impl CastListF16Udf {
67 pub fn new() -> Self {
68 Self {
69 signature: Signature::any(1, Volatility::Immutable),
70 }
71 }
72}
73
74impl ScalarUDFImpl for CastListF16Udf {
75 fn as_any(&self) -> &dyn std::any::Any {
76 self
77 }
78
79 fn name(&self) -> &str {
80 "_cast_list_f16"
81 }
82
83 fn signature(&self) -> &Signature {
84 &self.signature
85 }
86
87 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
88 let input = &arg_types[0];
89 match input {
90 ArrowDataType::FixedSizeList(field, size) => {
91 if field.data_type() != &ArrowDataType::Float32
92 && field.data_type() != &ArrowDataType::Float16
93 {
94 return Err(datafusion::error::DataFusionError::Execution(
95 "cast_list_f16 only supports list of float32 or float16".to_string(),
96 ));
97 }
98 Ok(ArrowDataType::FixedSizeList(
99 Arc::new(Field::new(
100 field.name(),
101 ArrowDataType::Float16,
102 field.is_nullable(),
103 )),
104 *size,
105 ))
106 }
107 ArrowDataType::List(field) => {
108 if field.data_type() != &ArrowDataType::Float32
109 && field.data_type() != &ArrowDataType::Float16
110 {
111 return Err(datafusion::error::DataFusionError::Execution(
112 "cast_list_f16 only supports list of float32 or float16".to_string(),
113 ));
114 }
115 Ok(ArrowDataType::List(Arc::new(Field::new(
116 field.name(),
117 ArrowDataType::Float16,
118 field.is_nullable(),
119 ))))
120 }
121 _ => Err(datafusion::error::DataFusionError::Execution(
122 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
123 )),
124 }
125 }
126
127 fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
128 let ColumnarValue::Array(arr) = &func_args.args[0] else {
129 return Err(datafusion::error::DataFusionError::Execution(
130 "cast_list_f16 only supports array arguments".to_string(),
131 ));
132 };
133
134 let to_type = match arr.data_type() {
135 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
136 Arc::new(Field::new(
137 field.name(),
138 ArrowDataType::Float16,
139 field.is_nullable(),
140 )),
141 *size,
142 ),
143 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
144 field.name(),
145 ArrowDataType::Float16,
146 field.is_nullable(),
147 ))),
148 _ => {
149 return Err(datafusion::error::DataFusionError::Execution(
150 "cast_list_f16 only supports array arguments".to_string(),
151 ));
152 }
153 };
154
155 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
156 Ok(ColumnarValue::Array(res))
157 }
158}
159
160struct LanceContextProvider {
162 options: datafusion::config::ConfigOptions,
163 state: SessionState,
164 expr_planners: Vec<Arc<dyn ExprPlanner>>,
165}
166
167impl Default for LanceContextProvider {
168 fn default() -> Self {
169 let ctx = get_session_context(&LanceExecutionOptions::default());
170 let state = ctx.state();
171 let expr_planners = state.expr_planners().to_vec();
172
173 Self {
174 options: ConfigOptions::default(),
175 state,
176 expr_planners,
177 }
178 }
179}
180
181impl ContextProvider for LanceContextProvider {
182 fn get_table_source(
183 &self,
184 name: datafusion::sql::TableReference,
185 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
186 Err(datafusion::error::DataFusionError::NotImplemented(format!(
187 "Attempt to reference inner table {} not supported",
188 name
189 )))
190 }
191
192 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
193 self.state.aggregate_functions().get(name).cloned()
194 }
195
196 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
197 self.state.window_functions().get(name).cloned()
198 }
199
200 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
201 match f {
202 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
205 _ => self.state.scalar_functions().get(f).cloned(),
206 }
207 }
208
209 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
210 None
212 }
213
214 fn options(&self) -> &datafusion::config::ConfigOptions {
215 &self.options
216 }
217
218 fn udf_names(&self) -> Vec<String> {
219 self.state.scalar_functions().keys().cloned().collect()
220 }
221
222 fn udaf_names(&self) -> Vec<String> {
223 self.state.aggregate_functions().keys().cloned().collect()
224 }
225
226 fn udwf_names(&self) -> Vec<String> {
227 self.state.window_functions().keys().cloned().collect()
228 }
229
230 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
231 &self.expr_planners
232 }
233}
234
235pub struct Planner {
236 schema: SchemaRef,
237 context_provider: LanceContextProvider,
238 enable_relations: bool,
239}
240
241impl Planner {
242 pub fn new(schema: SchemaRef) -> Self {
243 Self {
244 schema,
245 context_provider: LanceContextProvider::default(),
246 enable_relations: false,
247 }
248 }
249
250 pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
256 self.enable_relations = enable_relations;
257 self
258 }
259
260 fn resolve_column_name(&self, name: &str) -> String {
263 if self.schema.field_with_name(name).is_ok() {
265 return name.to_string();
266 }
267 for field in self.schema.fields() {
269 if field.name().eq_ignore_ascii_case(name) {
270 return field.name().clone();
271 }
272 }
273 name.to_string()
275 }
276
277 fn column(&self, idents: &[Ident]) -> Expr {
278 fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
279 for ident in idents {
280 *expr = Expr::ScalarFunction(ScalarFunction {
281 args: vec![
282 std::mem::take(expr),
283 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
284 ],
285 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
286 });
287 }
288 }
289
290 if self.enable_relations && idents.len() > 1 {
291 let relation = &idents[0].value;
293 let column_name = self.resolve_column_name(&idents[1].value);
294 let column = Expr::Column(Column::new(Some(relation.clone()), column_name));
295 let mut result = column;
296 handle_remaining_idents(&mut result, &idents[2..]);
297 result
298 } else {
299 let resolved_name = self.resolve_column_name(&idents[0].value);
302 let mut column = Expr::Column(Column::from_name(resolved_name));
303 handle_remaining_idents(&mut column, &idents[1..]);
304 column
305 }
306 }
307
308 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
309 Ok(match op {
310 BinaryOperator::Plus => Operator::Plus,
311 BinaryOperator::Minus => Operator::Minus,
312 BinaryOperator::Multiply => Operator::Multiply,
313 BinaryOperator::Divide => Operator::Divide,
314 BinaryOperator::Modulo => Operator::Modulo,
315 BinaryOperator::StringConcat => Operator::StringConcat,
316 BinaryOperator::Gt => Operator::Gt,
317 BinaryOperator::Lt => Operator::Lt,
318 BinaryOperator::GtEq => Operator::GtEq,
319 BinaryOperator::LtEq => Operator::LtEq,
320 BinaryOperator::Eq => Operator::Eq,
321 BinaryOperator::NotEq => Operator::NotEq,
322 BinaryOperator::And => Operator::And,
323 BinaryOperator::Or => Operator::Or,
324 _ => {
325 return Err(Error::invalid_input(format!(
326 "Operator {op} is not supported"
327 )));
328 }
329 })
330 }
331
332 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
333 Ok(Expr::BinaryExpr(BinaryExpr::new(
334 Box::new(self.parse_sql_expr(left)?),
335 self.binary_op(op)?,
336 Box::new(self.parse_sql_expr(right)?),
337 )))
338 }
339
340 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
341 Ok(match op {
342 UnaryOperator::Not | UnaryOperator::BitwiseNot => {
343 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
344 }
345
346 UnaryOperator::Minus => {
347 use datafusion::logical_expr::lit;
348 match expr {
349 SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
350 Ok(n) => lit(-n),
351 Err(_) => lit(-n
352 .parse::<f64>()
353 .map_err(|_e| {
354 Error::invalid_input(format!("negative operator can be only applied to integer and float operands, got: {n}"))
355 })?),
356 },
357 _ => {
358 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
359 }
360 }
361 }
362
363 _ => {
364 return Err(Error::invalid_input(format!(
365 "Unary operator '{:?}' is not supported",
366 op
367 )));
368 }
369 })
370 }
371
372 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
374 use datafusion::logical_expr::lit;
375 let value: Cow<str> = if negative {
376 Cow::Owned(format!("-{}", value))
377 } else {
378 Cow::Borrowed(value)
379 };
380 if let Ok(n) = value.parse::<i64>() {
381 Ok(lit(n))
382 } else {
383 value.parse::<f64>().map(lit).map_err(|_| {
384 Error::invalid_input(format!("'{value}' is not supported number value."))
385 })
386 }
387 }
388
389 fn value(&self, value: &Value) -> Result<Expr> {
390 Ok(match value {
391 Value::Number(v, _) => self.number(v.as_str(), false)?,
392 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
393 Value::HexStringLiteral(hsl) => {
394 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
395 }
396 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
397 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
398 Value::Null => Expr::Literal(ScalarValue::Null, None),
399 _ => todo!(),
400 })
401 }
402
403 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
404 match func_args {
405 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
406 _ => Err(Error::invalid_input(format!(
407 "Unsupported function args: {:?}",
408 func_args
409 ))),
410 }
411 }
412
413 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
420 match &func.args {
421 FunctionArguments::List(args) => {
422 if func.name.0.len() != 1 {
423 return Err(Error::invalid_input(format!(
424 "Function name must have 1 part, got: {:?}",
425 func.name.0
426 )));
427 }
428 Ok(Expr::IsNotNull(Box::new(
429 self.parse_function_args(&args.args[0])?,
430 )))
431 }
432 _ => Err(Error::invalid_input(format!(
433 "Unsupported function args: {:?}",
434 &func.args
435 ))),
436 }
437 }
438
439 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
440 if let SQLExpr::Function(function) = &function
441 && let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first()
442 && &name.value == "is_valid"
443 {
444 return self.legacy_parse_function(function);
445 }
446 let sql_to_rel = SqlToRel::new_with_options(
447 &self.context_provider,
448 ParserOptions {
449 parse_float_as_decimal: false,
450 enable_ident_normalization: false,
451 support_varchar_with_length: false,
452 enable_options_value_normalization: false,
453 collect_spans: false,
454 map_string_types_to_utf8view: false,
455 default_null_ordering: NullOrdering::NullsMax,
456 },
457 );
458
459 let mut planner_context = PlannerContext::default();
460 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
461 sql_to_rel
462 .sql_to_expr(function, &schema, &mut planner_context)
463 .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}")))
464 }
465
466 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
467 const SUPPORTED_TYPES: [&str; 13] = [
468 "int [unsigned]",
469 "tinyint [unsigned]",
470 "smallint [unsigned]",
471 "bigint [unsigned]",
472 "float",
473 "double",
474 "string",
475 "binary",
476 "date",
477 "timestamp(precision)",
478 "datetime(precision)",
479 "decimal(precision,scale)",
480 "boolean",
481 ];
482 match data_type {
483 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
484 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
485 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
486 SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
487 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
488 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
489 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
490 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
491 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
492 SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
493 SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
494 SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
495 Ok(ArrowDataType::UInt32)
496 }
497 SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
498 SQLDataType::Date => Ok(ArrowDataType::Date32),
499 SQLDataType::Timestamp(resolution, tz) => {
500 match tz {
501 TimezoneInfo::None => {}
502 _ => {
503 return Err(Error::invalid_input(
504 "Timezone not supported in timestamp".to_string(),
505 ));
506 }
507 };
508 let time_unit = match resolution {
509 None => TimeUnit::Microsecond,
511 Some(0) => TimeUnit::Second,
512 Some(3) => TimeUnit::Millisecond,
513 Some(6) => TimeUnit::Microsecond,
514 Some(9) => TimeUnit::Nanosecond,
515 _ => {
516 return Err(Error::invalid_input(format!(
517 "Unsupported datetime resolution: {:?}",
518 resolution
519 )));
520 }
521 };
522 Ok(ArrowDataType::Timestamp(time_unit, None))
523 }
524 SQLDataType::Datetime(resolution) => {
525 let time_unit = match resolution {
526 None => TimeUnit::Microsecond,
527 Some(0) => TimeUnit::Second,
528 Some(3) => TimeUnit::Millisecond,
529 Some(6) => TimeUnit::Microsecond,
530 Some(9) => TimeUnit::Nanosecond,
531 _ => {
532 return Err(Error::invalid_input(format!(
533 "Unsupported datetime resolution: {:?}",
534 resolution
535 )));
536 }
537 };
538 Ok(ArrowDataType::Timestamp(time_unit, None))
539 }
540 SQLDataType::Decimal(number_info) => match number_info {
541 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
542 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
543 }
544 _ => Err(Error::invalid_input(format!(
545 "Must provide precision and scale for decimal: {:?}",
546 number_info
547 ))),
548 },
549 _ => Err(Error::invalid_input(format!(
550 "Unsupported data type: {:?}. Supported types: {:?}",
551 data_type, SUPPORTED_TYPES
552 ))),
553 }
554 }
555
556 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
557 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
558 for planner in self.context_provider.get_expr_planners() {
559 match planner.plan_field_access(field_access_expr, &df_schema)? {
560 PlannerResult::Planned(expr) => return Ok(expr),
561 PlannerResult::Original(expr) => {
562 field_access_expr = expr;
563 }
564 }
565 }
566 Err(Error::invalid_input("Field access could not be planned"))
567 }
568
569 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
570 match expr {
571 SQLExpr::Identifier(id) => {
572 if id.quote_style == Some('"') {
575 Ok(Expr::Literal(
576 ScalarValue::Utf8(Some(id.value.clone())),
577 None,
578 ))
579 } else if id.quote_style == Some('`') {
582 Ok(Expr::Column(Column::from_name(id.value.clone())))
583 } else {
584 Ok(self.column(vec![id.clone()].as_slice()))
585 }
586 }
587 SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
588 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
589 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
590 SQLExpr::Value(value) => self.value(&value.value),
591 SQLExpr::Array(SQLArray { elem, .. }) => {
592 let mut values = vec![];
593
594 let array_literal_error = |pos: usize, value: &_| {
595 Err(Error::invalid_input(format!(
596 "Expected a literal value in array, instead got {} at position {}",
597 value, pos
598 )))
599 };
600
601 for (pos, expr) in elem.iter().enumerate() {
602 match expr {
603 SQLExpr::Value(value) => {
604 if let Expr::Literal(value, _) = self.value(&value.value)? {
605 values.push(value);
606 } else {
607 return array_literal_error(pos, expr);
608 }
609 }
610 SQLExpr::UnaryOp {
611 op: UnaryOperator::Minus,
612 expr,
613 } => {
614 if let SQLExpr::Value(ValueWithSpan {
615 value: Value::Number(number, _),
616 ..
617 }) = expr.as_ref()
618 {
619 if let Expr::Literal(value, _) = self.number(number, true)? {
620 values.push(value);
621 } else {
622 return array_literal_error(pos, expr);
623 }
624 } else {
625 return array_literal_error(pos, expr);
626 }
627 }
628 _ => {
629 return array_literal_error(pos, expr);
630 }
631 }
632 }
633
634 let field = if !values.is_empty() {
635 let data_type = values[0].data_type();
636
637 for value in &mut values {
638 if value.data_type() != data_type {
639 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type())))?;
640 }
641 }
642 Field::new("item", data_type, true)
643 } else {
644 Field::new("item", ArrowDataType::Null, true)
645 };
646
647 let values = values
648 .into_iter()
649 .map(|v| v.to_array().map_err(Error::from))
650 .collect::<Result<Vec<_>>>()?;
651 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
652 let values = concat(&array_refs)?;
653 let values = ListArray::try_new(
654 field.into(),
655 OffsetBuffer::from_lengths([values.len()]),
656 values,
657 None,
658 )?;
659
660 Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
661 }
662 SQLExpr::TypedString(TypedString {
664 data_type: SQLDataType::JSONB,
665 value,
666 ..
667 }) => match &value.value {
668 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => encode_jsonb(s),
669 _ => Err(Error::invalid_input(
670 "Expected a string value for JSONB literal",
671 )),
672 },
673 SQLExpr::TypedString(TypedString {
675 data_type, value, ..
676 }) => {
677 let value = value.clone().into_string().expect_ok()?;
678 Ok(Expr::Cast(datafusion::logical_expr::Cast {
679 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
680 data_type: self.parse_type(data_type)?,
681 }))
682 }
683 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
684 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
685 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
686 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
687 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
688 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
689 SQLExpr::InList {
690 expr,
691 list,
692 negated,
693 } => {
694 let value_expr = self.parse_sql_expr(expr)?;
695 let list_exprs = list
696 .iter()
697 .map(|e| self.parse_sql_expr(e))
698 .collect::<Result<Vec<_>>>()?;
699 Ok(value_expr.in_list(list_exprs, *negated))
700 }
701 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
702 SQLExpr::Function(_) => self.parse_function(expr.clone()),
703 SQLExpr::ILike {
704 negated,
705 expr,
706 pattern,
707 escape_char,
708 any: _,
709 } => Ok(Expr::Like(Like::new(
710 *negated,
711 Box::new(self.parse_sql_expr(expr)?),
712 Box::new(self.parse_sql_expr(pattern)?),
713 match escape_char {
714 Some(Value::SingleQuotedString(char)) => char.chars().next(),
715 Some(value) => {
716 return Err(Error::invalid_input(format!(
717 "Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}",
718 value
719 )));
720 }
721 None => None,
722 },
723 true,
724 ))),
725 SQLExpr::Like {
726 negated,
727 expr,
728 pattern,
729 escape_char,
730 any: _,
731 } => Ok(Expr::Like(Like::new(
732 *negated,
733 Box::new(self.parse_sql_expr(expr)?),
734 Box::new(self.parse_sql_expr(pattern)?),
735 match escape_char {
736 Some(Value::SingleQuotedString(char)) => char.chars().next(),
737 Some(value) => {
738 return Err(Error::invalid_input(format!(
739 "Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}",
740 value
741 )));
742 }
743 None => None,
744 },
745 false,
746 ))),
747 SQLExpr::Cast {
749 data_type: SQLDataType::JSONB,
750 expr: inner,
751 ..
752 } => match inner.as_ref() {
753 SQLExpr::Value(ValueWithSpan {
754 value: Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
755 ..
756 }) => encode_jsonb(s),
757 _ => Err(Error::invalid_input(
758 "CAST to JSONB only supports string literals",
759 )),
760 },
761 SQLExpr::Cast {
762 expr,
763 data_type,
764 kind,
765 ..
766 } => match kind {
767 datafusion::sql::sqlparser::ast::CastKind::TryCast
768 | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
769 Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
770 expr: Box::new(self.parse_sql_expr(expr)?),
771 data_type: self.parse_type(data_type)?,
772 }))
773 }
774 _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
775 expr: Box::new(self.parse_sql_expr(expr)?),
776 data_type: self.parse_type(data_type)?,
777 })),
778 },
779 SQLExpr::JsonAccess { .. } => Err(Error::invalid_input("JSON access is not supported")),
780 SQLExpr::CompoundFieldAccess { root, access_chain } => {
781 let mut expr = self.parse_sql_expr(root)?;
782
783 for access in access_chain {
784 let field_access = match access {
785 AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
787 | AccessExpr::Subscript(Subscript::Index {
788 index:
789 SQLExpr::Value(ValueWithSpan {
790 value:
791 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
792 ..
793 }),
794 }) => GetFieldAccess::NamedStructField {
795 name: ScalarValue::from(s.as_str()),
796 },
797 AccessExpr::Subscript(Subscript::Index { index }) => {
798 let key = Box::new(self.parse_sql_expr(index)?);
799 GetFieldAccess::ListIndex { key }
800 }
801 AccessExpr::Subscript(Subscript::Slice { .. }) => {
802 return Err(Error::invalid_input("Slice subscript is not supported"));
803 }
804 _ => {
805 return Err(Error::invalid_input(
808 "Only dot notation or index access is supported for field access",
809 ));
810 }
811 };
812
813 let field_access_expr = RawFieldAccessExpr { expr, field_access };
814 expr = self.plan_field_access(field_access_expr)?;
815 }
816
817 Ok(expr)
818 }
819 SQLExpr::Between {
820 expr,
821 negated,
822 low,
823 high,
824 } => {
825 let expr = self.parse_sql_expr(expr)?;
827 let low = self.parse_sql_expr(low)?;
828 let high = self.parse_sql_expr(high)?;
829
830 let between = Expr::Between(Between::new(
831 Box::new(expr),
832 *negated,
833 Box::new(low),
834 Box::new(high),
835 ));
836 Ok(between)
837 }
838 _ => Err(Error::invalid_input(format!(
839 "Expression '{expr}' is not supported SQL in lance"
840 ))),
841 }
842 }
843
844 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
849 let ast_expr = parse_sql_filter(filter)?;
851 let expr = self.parse_sql_expr(&ast_expr)?;
852 let schema = Schema::try_from(self.schema.as_ref())?;
853 let resolved = resolve_expr(&expr, &schema).map_err(|e| {
854 Error::invalid_input(format!("Error resolving filter expression {filter}: {e}"))
855 })?;
856
857 Ok(coerce_filter_type_to_boolean(resolved))
858 }
859
860 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
865 let resolved_name = self.resolve_column_name(expr);
868 if self.schema.field_with_name(&resolved_name).is_ok() {
869 return Ok(Expr::Column(Column::from_name(resolved_name)));
870 }
871
872 let ast_expr = parse_sql_expr(expr)?;
874 let expr = self.parse_sql_expr(&ast_expr)?;
875 let schema = Schema::try_from(self.schema.as_ref())?;
876 let resolved = resolve_expr(&expr, &schema)?;
877 Ok(resolved)
878 }
879
880 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
886 let hex_bytes = s.as_bytes();
887 let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
888
889 let start_idx = hex_bytes.len() % 2;
890 if start_idx > 0 {
891 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
893 }
894
895 for i in (start_idx..hex_bytes.len()).step_by(2) {
896 let high = Self::try_decode_hex_char(hex_bytes[i])?;
897 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
898 decoded_bytes.push((high << 4) | low);
899 }
900
901 Some(decoded_bytes)
902 }
903
904 const fn try_decode_hex_char(c: u8) -> Option<u8> {
908 match c {
909 b'A'..=b'F' => Some(c - b'A' + 10),
910 b'a'..=b'f' => Some(c - b'a' + 10),
911 b'0'..=b'9' => Some(c - b'0'),
912 _ => None,
913 }
914 }
915
916 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
918 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
919
920 let simplify_context = SimplifyContext::default()
923 .with_schema(df_schema.clone())
924 .with_query_execution_start_time(Some(Utc::now()));
925 let simplifier =
926 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
927
928 let expr = simplifier.coerce(expr, &df_schema)?;
930 let expr = simplifier.simplify(expr)?;
931
932 Ok(expr)
933 }
934
935 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
937 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
938 Ok(datafusion::physical_expr::create_physical_expr(
939 expr,
940 df_schema.as_ref(),
941 &Default::default(),
942 )?)
943 }
944
945 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
952 let mut visitor = ColumnCapturingVisitor {
953 current_path: VecDeque::new(),
954 columns: BTreeSet::new(),
955 };
956 expr.visit(&mut visitor).unwrap();
957 visitor.columns.into_iter().collect()
958 }
959}
960
961struct ColumnCapturingVisitor {
962 current_path: VecDeque<String>,
964 columns: BTreeSet<String>,
965}
966
967impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
968 type Node = Expr;
969
970 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
971 match node {
972 Expr::Column(Column { name, .. }) => {
973 let mut path = name.clone();
977 for part in self.current_path.drain(..) {
978 path.push('.');
979 if part.contains('.') || part.contains('`') {
981 let escaped = part.replace('`', "``");
983 path.push('`');
984 path.push_str(&escaped);
985 path.push('`');
986 } else {
987 path.push_str(&part);
988 }
989 }
990 self.columns.insert(path);
991 self.current_path.clear();
992 }
993 Expr::ScalarFunction(udf) => {
994 if udf.name() == GetFieldFunc::default().name() {
995 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
996 self.current_path.push_front(name.to_string())
997 } else {
998 self.current_path.clear();
999 }
1000 } else {
1001 self.current_path.clear();
1002 }
1003 }
1004 _ => {
1005 self.current_path.clear();
1006 }
1007 }
1008
1009 Ok(TreeNodeRecursion::Continue)
1010 }
1011}
1012
1013#[cfg(test)]
1014mod tests {
1015 use std::any::Any;
1016
1017 use crate::logical_expr::ExprExt;
1018
1019 use super::*;
1020
1021 use arrow::datatypes::Float64Type;
1022 use arrow_array::{
1023 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
1024 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1025 TimestampNanosecondArray, TimestampSecondArray,
1026 };
1027 use arrow_schema::{DataType, Fields, Schema};
1028 use datafusion::{
1029 logical_expr::{Cast, col, lit},
1030 prelude::{array_element, get_field},
1031 };
1032 use datafusion_functions::core::expr_ext::FieldAccessor;
1033
1034 #[test]
1035 fn test_parse_filter_simple() {
1036 let schema = Arc::new(Schema::new(vec![
1037 Field::new("i", DataType::Int32, false),
1038 Field::new("s", DataType::Utf8, true),
1039 Field::new(
1040 "st",
1041 DataType::Struct(Fields::from(vec![
1042 Field::new("x", DataType::Float32, false),
1043 Field::new("y", DataType::Float32, false),
1044 ])),
1045 true,
1046 ),
1047 ]));
1048
1049 let planner = Planner::new(schema.clone());
1050
1051 let expected = col("i")
1052 .gt(lit(3_i32))
1053 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1054 .and(
1055 col("s")
1056 .eq(lit("str-4"))
1057 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1058 );
1059
1060 let expr = planner
1062 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1063 .unwrap();
1064 assert_eq!(expr, expected);
1065
1066 let expr = planner
1068 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1069 .unwrap();
1070
1071 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1072
1073 let batch = RecordBatch::try_new(
1074 schema,
1075 vec![
1076 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1077 Arc::new(StringArray::from_iter_values(
1078 (0..10).map(|v| format!("str-{}", v)),
1079 )),
1080 Arc::new(StructArray::from(vec![
1081 (
1082 Arc::new(Field::new("x", DataType::Float32, false)),
1083 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1084 as ArrayRef,
1085 ),
1086 (
1087 Arc::new(Field::new("y", DataType::Float32, false)),
1088 Arc::new(Float32Array::from_iter_values(
1089 (0..10).map(|v| (v * 10) as f32),
1090 )),
1091 ),
1092 ])),
1093 ],
1094 )
1095 .unwrap();
1096 let predicates = physical_expr.evaluate(&batch).unwrap();
1097 assert_eq!(
1098 predicates.into_array(0).unwrap().as_ref(),
1099 &BooleanArray::from(vec![
1100 false, false, false, false, true, true, false, false, false, false
1101 ])
1102 );
1103 }
1104
1105 #[derive(Debug, Eq, PartialEq, Hash)]
1106 struct StrictFloat64Udf {
1107 signature: Signature,
1108 }
1109
1110 impl StrictFloat64Udf {
1111 fn new() -> Self {
1112 Self {
1113 signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
1114 }
1115 }
1116 }
1117
1118 impl ScalarUDFImpl for StrictFloat64Udf {
1119 fn as_any(&self) -> &dyn Any {
1120 self
1121 }
1122
1123 fn name(&self) -> &str {
1124 "strict_float64"
1125 }
1126
1127 fn signature(&self) -> &Signature {
1128 &self.signature
1129 }
1130
1131 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1132 Ok(DataType::Float64)
1133 }
1134
1135 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1136 let data_type = args.args[0].data_type();
1137 assert_eq!(
1138 data_type,
1139 DataType::Float64,
1140 "strict_float64 expected Float64, got {data_type}"
1141 );
1142 Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(0.0))))
1143 }
1144 }
1145
1146 #[test]
1147 fn test_coerce_before_simplify() {
1148 let planner = Planner::new(Arc::new(Schema::empty()));
1149 let strict_float64 = Arc::new(ScalarUDF::new_from_impl(StrictFloat64Udf::new()));
1150 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(strict_float64, vec![lit(0_i64)]))
1151 .eq(lit(0.0_f64));
1152
1153 let optimized = planner.optimize_expr(expr).unwrap();
1154
1155 planner.create_physical_expr(&optimized).unwrap();
1156 }
1157
1158 #[test]
1159 fn test_nested_col_refs() {
1160 let schema = Arc::new(Schema::new(vec![
1161 Field::new("s0", DataType::Utf8, true),
1162 Field::new(
1163 "st",
1164 DataType::Struct(Fields::from(vec![
1165 Field::new("s1", DataType::Utf8, true),
1166 Field::new(
1167 "st",
1168 DataType::Struct(Fields::from(vec![Field::new(
1169 "s2",
1170 DataType::Utf8,
1171 true,
1172 )])),
1173 true,
1174 ),
1175 ])),
1176 true,
1177 ),
1178 ]));
1179
1180 let planner = Planner::new(schema);
1181
1182 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1183 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1184 assert!(matches!(
1185 expr,
1186 Expr::BinaryExpr(BinaryExpr {
1187 left: _,
1188 op: Operator::Eq,
1189 right: _
1190 })
1191 ));
1192 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1193 assert_eq!(left.as_ref(), expected);
1194 }
1195 }
1196
1197 let expected = Expr::Column(Column::new_unqualified("s0"));
1198 assert_column_eq(&planner, "s0", &expected);
1199 assert_column_eq(&planner, "`s0`", &expected);
1200
1201 let expected = Expr::ScalarFunction(ScalarFunction {
1202 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1203 args: vec![
1204 Expr::Column(Column::new_unqualified("st")),
1205 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1206 ],
1207 });
1208 assert_column_eq(&planner, "st.s1", &expected);
1209 assert_column_eq(&planner, "`st`.`s1`", &expected);
1210 assert_column_eq(&planner, "st.`s1`", &expected);
1211
1212 let expected = Expr::ScalarFunction(ScalarFunction {
1213 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1214 args: vec![
1215 Expr::ScalarFunction(ScalarFunction {
1216 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1217 args: vec![
1218 Expr::Column(Column::new_unqualified("st")),
1219 Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1220 ],
1221 }),
1222 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1223 ],
1224 });
1225
1226 assert_column_eq(&planner, "st.st.s2", &expected);
1227 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1228 assert_column_eq(&planner, "st.st.`s2`", &expected);
1229 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1230 }
1231
1232 #[test]
1233 fn test_nested_list_refs() {
1234 let schema = Arc::new(Schema::new(vec![Field::new(
1235 "l",
1236 DataType::List(Arc::new(Field::new(
1237 "item",
1238 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1239 true,
1240 ))),
1241 true,
1242 )]));
1243
1244 let planner = Planner::new(schema);
1245
1246 let expected = array_element(col("l"), lit(0_i64));
1247 let expr = planner.parse_expr("l[0]").unwrap();
1248 assert_eq!(expr, expected);
1249
1250 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1251 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1252 assert_eq!(expr, expected);
1253
1254 }
1259
1260 #[test]
1261 fn test_negative_expressions() {
1262 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1263
1264 let planner = Planner::new(schema.clone());
1265
1266 let expected = col("x")
1267 .gt(lit(-3_i64))
1268 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1269
1270 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1271
1272 assert_eq!(expr, expected);
1273
1274 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1275
1276 let batch = RecordBatch::try_new(
1277 schema,
1278 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1279 )
1280 .unwrap();
1281 let predicates = physical_expr.evaluate(&batch).unwrap();
1282 assert_eq!(
1283 predicates.into_array(0).unwrap().as_ref(),
1284 &BooleanArray::from(vec![
1285 false, false, false, true, true, true, true, false, false, false
1286 ])
1287 );
1288 }
1289
1290 #[test]
1291 fn test_negative_array_expressions() {
1292 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1293
1294 let planner = Planner::new(schema);
1295
1296 let expected = Expr::Literal(
1297 ScalarValue::List(Arc::new(
1298 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1299 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1300 )]),
1301 )),
1302 None,
1303 );
1304
1305 let expr = planner
1306 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1307 .unwrap();
1308
1309 assert_eq!(expr, expected);
1310 }
1311
1312 #[test]
1313 fn test_sql_like() {
1314 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1315
1316 let planner = Planner::new(schema.clone());
1317
1318 let expected = col("s").like(lit("str-4"));
1319 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1321 assert_eq!(expr, expected);
1322 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1323
1324 let batch = RecordBatch::try_new(
1325 schema,
1326 vec![Arc::new(StringArray::from_iter_values(
1327 (0..10).map(|v| format!("str-{}", v)),
1328 ))],
1329 )
1330 .unwrap();
1331 let predicates = physical_expr.evaluate(&batch).unwrap();
1332 assert_eq!(
1333 predicates.into_array(0).unwrap().as_ref(),
1334 &BooleanArray::from(vec![
1335 false, false, false, false, true, false, false, false, false, false
1336 ])
1337 );
1338 }
1339
1340 #[test]
1341 fn test_not_like() {
1342 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1343
1344 let planner = Planner::new(schema.clone());
1345
1346 let expected = col("s").not_like(lit("str-4"));
1347 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1349 assert_eq!(expr, expected);
1350 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1351
1352 let batch = RecordBatch::try_new(
1353 schema,
1354 vec![Arc::new(StringArray::from_iter_values(
1355 (0..10).map(|v| format!("str-{}", v)),
1356 ))],
1357 )
1358 .unwrap();
1359 let predicates = physical_expr.evaluate(&batch).unwrap();
1360 assert_eq!(
1361 predicates.into_array(0).unwrap().as_ref(),
1362 &BooleanArray::from(vec![
1363 true, true, true, true, false, true, true, true, true, true
1364 ])
1365 );
1366 }
1367
1368 #[test]
1369 fn test_sql_is_in() {
1370 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1371
1372 let planner = Planner::new(schema.clone());
1373
1374 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1375 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1377 assert_eq!(expr, expected);
1378 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1379
1380 let batch = RecordBatch::try_new(
1381 schema,
1382 vec![Arc::new(StringArray::from_iter_values(
1383 (0..10).map(|v| format!("str-{}", v)),
1384 ))],
1385 )
1386 .unwrap();
1387 let predicates = physical_expr.evaluate(&batch).unwrap();
1388 assert_eq!(
1389 predicates.into_array(0).unwrap().as_ref(),
1390 &BooleanArray::from(vec![
1391 false, false, false, false, true, true, false, false, false, false
1392 ])
1393 );
1394 }
1395
1396 #[test]
1397 fn test_sql_is_null() {
1398 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1399
1400 let planner = Planner::new(schema.clone());
1401
1402 let expected = col("s").is_null();
1403 let expr = planner.parse_filter("s IS NULL").unwrap();
1404 assert_eq!(expr, expected);
1405 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1406
1407 let batch = RecordBatch::try_new(
1408 schema,
1409 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1410 if v % 3 == 0 {
1411 Some(format!("str-{}", v))
1412 } else {
1413 None
1414 }
1415 })))],
1416 )
1417 .unwrap();
1418 let predicates = physical_expr.evaluate(&batch).unwrap();
1419 assert_eq!(
1420 predicates.into_array(0).unwrap().as_ref(),
1421 &BooleanArray::from(vec![
1422 false, true, true, false, true, true, false, true, true, false
1423 ])
1424 );
1425
1426 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1427 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1428 let predicates = physical_expr.evaluate(&batch).unwrap();
1429 assert_eq!(
1430 predicates.into_array(0).unwrap().as_ref(),
1431 &BooleanArray::from(vec![
1432 true, false, false, true, false, false, true, false, false, true,
1433 ])
1434 );
1435 }
1436
1437 #[test]
1438 fn test_sql_invert() {
1439 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1440
1441 let planner = Planner::new(schema.clone());
1442
1443 let expr = planner.parse_filter("NOT s").unwrap();
1444 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1445
1446 let batch = RecordBatch::try_new(
1447 schema,
1448 vec![Arc::new(BooleanArray::from_iter(
1449 (0..10).map(|v| Some(v % 3 == 0)),
1450 ))],
1451 )
1452 .unwrap();
1453 let predicates = physical_expr.evaluate(&batch).unwrap();
1454 assert_eq!(
1455 predicates.into_array(0).unwrap().as_ref(),
1456 &BooleanArray::from(vec![
1457 false, true, true, false, true, true, false, true, true, false
1458 ])
1459 );
1460 }
1461
1462 #[test]
1463 fn test_sql_cast() {
1464 let cases = &[
1465 (
1466 "x = cast('2021-01-01 00:00:00' as timestamp)",
1467 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1468 ),
1469 (
1470 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1471 ArrowDataType::Timestamp(TimeUnit::Second, None),
1472 ),
1473 (
1474 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1475 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1476 ),
1477 (
1478 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1479 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1480 ),
1481 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1482 (
1483 "x = cast('1.238' as decimal(9,3))",
1484 ArrowDataType::Decimal128(9, 3),
1485 ),
1486 ("x = cast(1 as float)", ArrowDataType::Float32),
1487 ("x = cast(1 as double)", ArrowDataType::Float64),
1488 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1489 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1490 ("x = cast(1 as int)", ArrowDataType::Int32),
1491 ("x = cast(1 as integer)", ArrowDataType::Int32),
1492 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1493 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1494 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1495 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1496 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1497 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1498 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1499 ("x = cast(1 as string)", ArrowDataType::Utf8),
1500 ];
1501
1502 for (sql, expected_data_type) in cases {
1503 let schema = Arc::new(Schema::new(vec![Field::new(
1504 "x",
1505 expected_data_type.clone(),
1506 true,
1507 )]));
1508 let planner = Planner::new(schema.clone());
1509 let expr = planner.parse_filter(sql).unwrap();
1510
1511 let expected_value_str = sql
1513 .split("cast(")
1514 .nth(1)
1515 .unwrap()
1516 .split(" as")
1517 .next()
1518 .unwrap();
1519 let expected_value_str = expected_value_str.trim_matches('\'');
1521
1522 match expr {
1523 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1524 Expr::Cast(Cast { expr, data_type }) => {
1525 match expr.as_ref() {
1526 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1527 assert_eq!(value_str, expected_value_str);
1528 }
1529 Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1530 assert_eq!(*value, 1);
1531 }
1532 _ => panic!("Expected cast to be applied to literal"),
1533 }
1534 assert_eq!(data_type, expected_data_type);
1535 }
1536 _ => panic!("Expected right to be a cast"),
1537 },
1538 _ => panic!("Expected binary expression"),
1539 }
1540 }
1541 }
1542
1543 #[test]
1544 fn test_sql_literals() {
1545 let cases = &[
1546 (
1547 "x = timestamp '2021-01-01 00:00:00'",
1548 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1549 ),
1550 (
1551 "x = timestamp(0) '2021-01-01 00:00:00'",
1552 ArrowDataType::Timestamp(TimeUnit::Second, None),
1553 ),
1554 (
1555 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1556 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1557 ),
1558 ("x = date '2021-01-01'", ArrowDataType::Date32),
1559 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1560 ];
1561
1562 for (sql, expected_data_type) in cases {
1563 let schema = Arc::new(Schema::new(vec![Field::new(
1564 "x",
1565 expected_data_type.clone(),
1566 true,
1567 )]));
1568 let planner = Planner::new(schema.clone());
1569 let expr = planner.parse_filter(sql).unwrap();
1570
1571 let expected_value_str = sql.split('\'').nth(1).unwrap();
1572
1573 match expr {
1574 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1575 Expr::Cast(Cast { expr, data_type }) => {
1576 match expr.as_ref() {
1577 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1578 assert_eq!(value_str, expected_value_str);
1579 }
1580 _ => panic!("Expected cast to be applied to literal"),
1581 }
1582 assert_eq!(data_type, expected_data_type);
1583 }
1584 _ => panic!("Expected right to be a cast"),
1585 },
1586 _ => panic!("Expected binary expression"),
1587 }
1588 }
1589 }
1590
1591 #[test]
1592 fn test_sql_array_literals() {
1593 let cases = [
1594 (
1595 "x = [1, 2, 3]",
1596 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1597 ),
1598 (
1599 "x = [1, 2, 3]",
1600 ArrowDataType::FixedSizeList(
1601 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1602 3,
1603 ),
1604 ),
1605 ];
1606
1607 for (sql, expected_data_type) in cases {
1608 let schema = Arc::new(Schema::new(vec![Field::new(
1609 "x",
1610 expected_data_type.clone(),
1611 true,
1612 )]));
1613 let planner = Planner::new(schema.clone());
1614 let expr = planner.parse_filter(sql).unwrap();
1615 let expr = planner.optimize_expr(expr).unwrap();
1616
1617 match expr {
1618 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1619 Expr::Literal(value, _) => {
1620 assert_eq!(&value.data_type(), &expected_data_type);
1621 }
1622 _ => panic!("Expected right to be a literal"),
1623 },
1624 _ => panic!("Expected binary expression"),
1625 }
1626 }
1627 }
1628
1629 #[test]
1630 fn test_sql_between() {
1631 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1632 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1633 use std::sync::Arc;
1634
1635 let schema = Arc::new(Schema::new(vec![
1636 Field::new("x", DataType::Int32, false),
1637 Field::new("y", DataType::Float64, false),
1638 Field::new(
1639 "ts",
1640 DataType::Timestamp(TimeUnit::Microsecond, None),
1641 false,
1642 ),
1643 ]));
1644
1645 let planner = Planner::new(schema.clone());
1646
1647 let expr = planner
1649 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1650 .unwrap();
1651 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1652
1653 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1657 (0..10).map(|i| base_ts + i * 1_000_000), );
1659
1660 let batch = RecordBatch::try_new(
1661 schema,
1662 vec![
1663 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1664 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1665 Arc::new(ts_array),
1666 ],
1667 )
1668 .unwrap();
1669
1670 let predicates = physical_expr.evaluate(&batch).unwrap();
1671 assert_eq!(
1672 predicates.into_array(0).unwrap().as_ref(),
1673 &BooleanArray::from(vec![
1674 false, false, false, true, true, true, true, true, false, false
1675 ])
1676 );
1677
1678 let expr = planner
1680 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1681 .unwrap();
1682 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1683
1684 let predicates = physical_expr.evaluate(&batch).unwrap();
1685 assert_eq!(
1686 predicates.into_array(0).unwrap().as_ref(),
1687 &BooleanArray::from(vec![
1688 true, true, true, false, false, false, false, false, true, true
1689 ])
1690 );
1691
1692 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1694 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1695
1696 let predicates = physical_expr.evaluate(&batch).unwrap();
1697 assert_eq!(
1698 predicates.into_array(0).unwrap().as_ref(),
1699 &BooleanArray::from(vec![
1700 false, false, false, true, true, true, true, false, false, false
1701 ])
1702 );
1703
1704 let expr = planner
1706 .parse_filter(
1707 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1708 )
1709 .unwrap();
1710 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1711
1712 let predicates = physical_expr.evaluate(&batch).unwrap();
1713 assert_eq!(
1714 predicates.into_array(0).unwrap().as_ref(),
1715 &BooleanArray::from(vec![
1716 false, false, false, true, true, true, true, true, false, false
1717 ])
1718 );
1719 }
1720
1721 #[test]
1722 fn test_sql_comparison() {
1723 let batch: Vec<(&str, ArrayRef)> = vec![
1725 (
1726 "timestamp_s",
1727 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1728 ),
1729 (
1730 "timestamp_ms",
1731 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1732 ),
1733 (
1734 "timestamp_us",
1735 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1736 ),
1737 (
1738 "timestamp_ns",
1739 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1740 ),
1741 ];
1742 let batch = RecordBatch::try_from_iter(batch).unwrap();
1743
1744 let planner = Planner::new(batch.schema());
1745
1746 let expressions = &[
1748 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1749 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1750 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1751 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1752 ];
1753
1754 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1755 std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1756 ));
1757 for expression in expressions {
1758 let logical_expr = planner.parse_filter(expression).unwrap();
1760 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1761 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1762
1763 let result = physical_expr.evaluate(&batch).unwrap();
1765 let result = result.into_array(batch.num_rows()).unwrap();
1766 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1767 }
1768 }
1769
1770 #[test]
1771 fn test_columns_in_expr() {
1772 let expr = col("s0").gt(lit("value")).and(
1773 col("st")
1774 .field("st")
1775 .field("s2")
1776 .eq(lit("value"))
1777 .or(col("st")
1778 .field("s1")
1779 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1780 );
1781
1782 let columns = Planner::column_names_in_expr(&expr);
1783 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1784 }
1785
1786 #[test]
1787 fn test_parse_binary_expr() {
1788 let bin_str = "x'616263'";
1789
1790 let schema = Arc::new(Schema::new(vec![Field::new(
1791 "binary",
1792 DataType::Binary,
1793 true,
1794 )]));
1795 let planner = Planner::new(schema);
1796 let expr = planner.parse_expr(bin_str).unwrap();
1797 assert_eq!(
1798 expr,
1799 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1800 );
1801 }
1802
1803 #[test]
1804 fn test_lance_context_provider_expr_planners() {
1805 let ctx_provider = LanceContextProvider::default();
1806 assert!(!ctx_provider.get_expr_planners().is_empty());
1807 }
1808
1809 #[test]
1810 fn test_regexp_match_and_non_empty_captions() {
1811 let schema = Arc::new(Schema::new(vec![
1814 Field::new("keywords", DataType::Utf8, true),
1815 Field::new("natural_caption", DataType::Utf8, true),
1816 Field::new("poetic_caption", DataType::Utf8, true),
1817 ]));
1818
1819 let planner = Planner::new(schema.clone());
1820
1821 let expr = planner
1822 .parse_filter(
1823 "regexp_match(keywords, 'Liberty|revolution') AND \
1824 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1825 poetic_caption IS NOT NULL AND poetic_caption <> '')",
1826 )
1827 .unwrap();
1828
1829 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1830
1831 let batch = RecordBatch::try_new(
1832 schema,
1833 vec![
1834 Arc::new(StringArray::from(vec![
1835 Some("Liberty for all"),
1836 Some("peace"),
1837 Some("revolution now"),
1838 Some("Liberty"),
1839 Some("revolutionary"),
1840 Some("none"),
1841 ])) as ArrayRef,
1842 Arc::new(StringArray::from(vec![
1843 Some("a"),
1844 Some("b"),
1845 None,
1846 Some(""),
1847 Some("c"),
1848 Some("d"),
1849 ])) as ArrayRef,
1850 Arc::new(StringArray::from(vec![
1851 Some("x"),
1852 Some(""),
1853 Some("y"),
1854 Some("z"),
1855 None,
1856 Some("w"),
1857 ])) as ArrayRef,
1858 ],
1859 )
1860 .unwrap();
1861
1862 let result = physical_expr.evaluate(&batch).unwrap();
1863 assert_eq!(
1864 result.into_array(0).unwrap().as_ref(),
1865 &BooleanArray::from(vec![true, false, false, false, false, false])
1866 );
1867 }
1868
1869 #[test]
1870 fn test_regexp_match_infer_error_without_boolean_coercion() {
1871 let schema = Arc::new(Schema::new(vec![
1874 Field::new("keywords", DataType::Utf8, true),
1875 Field::new("natural_caption", DataType::Utf8, true),
1876 Field::new("poetic_caption", DataType::Utf8, true),
1877 ]));
1878
1879 let planner = Planner::new(schema);
1880
1881 let expr = planner
1882 .parse_filter(
1883 "regexp_match(keywords, 'Liberty|revolution') AND \
1884 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1885 poetic_caption IS NOT NULL AND poetic_caption <> '')",
1886 )
1887 .unwrap();
1888
1889 let _physical = planner.create_physical_expr(&expr).unwrap();
1891 }
1892
1893 #[test]
1894 fn test_jsonb_literals() {
1895 let schema = Arc::new(Schema::new(vec![Field::new(
1896 "j",
1897 DataType::LargeBinary,
1898 true,
1899 )]));
1900 let planner = Planner::new(schema);
1901
1902 let cases = [
1903 ("jsonb '{\"key\": \"value\"}'", r#"{"key":"value"}"#),
1904 ("cast('{\"a\": 1}' as jsonb)", r#"{"a":1}"#),
1905 ("'{\"a\": 1}'::jsonb", r#"{"a":1}"#),
1906 ];
1907 for (sql, expected) in cases {
1908 let ast = parse_sql_expr(sql).unwrap();
1909 let expr = planner.parse_sql_expr(&ast).unwrap();
1910 match expr {
1911 Expr::Literal(ScalarValue::LargeBinary(Some(bytes)), _) => {
1912 assert_eq!(
1913 lance_arrow::json::decode_json(&bytes),
1914 expected,
1915 "failed for: {sql}"
1916 );
1917 }
1918 other => panic!("Expected LargeBinary literal for '{sql}', got: {other:?}"),
1919 }
1920 }
1921 }
1922
1923 #[test]
1924 fn test_jsonb_literal_errors() {
1925 let schema = Arc::new(Schema::new(vec![Field::new(
1926 "j",
1927 DataType::LargeBinary,
1928 true,
1929 )]));
1930 let planner = Planner::new(schema);
1931
1932 let ast = parse_sql_expr("jsonb 'not valid json'").unwrap();
1934 let err = planner.parse_sql_expr(&ast).unwrap_err();
1935 assert!(
1936 err.to_string().contains("Failed to encode JSONB"),
1937 "expected JSONB encoding error, got: {err}"
1938 );
1939
1940 let ast = parse_sql_expr("cast(j as jsonb)").unwrap();
1942 let err = planner.parse_sql_expr(&ast).unwrap_err();
1943 assert!(
1944 err.to_string()
1945 .contains("CAST to JSONB only supports string literals"),
1946 "got: {err}"
1947 );
1948 }
1949}