1use std::borrow::Cow;
9use std::collections::{BTreeSet, VecDeque};
10use std::sync::Arc;
11
12use crate::expr::safe_coerce_scalar;
13use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
14use crate::sql::{parse_sql_expr, parse_sql_filter};
15use arrow::compute::CastOptions;
16use arrow_array::ListArray;
17use arrow_buffer::OffsetBuffer;
18use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
19use arrow_select::concat::concat;
20use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
21use datafusion::common::DFSchema;
22use datafusion::config::ConfigOptions;
23use datafusion::error::Result as DFResult;
24use datafusion::execution::config::SessionConfig;
25use datafusion::execution::context::SessionState;
26use datafusion::execution::runtime_env::RuntimeEnvBuilder;
27use datafusion::execution::session_state::SessionStateBuilder;
28use datafusion::logical_expr::expr::ScalarFunction;
29use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
30use datafusion::logical_expr::{
31 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
32 WindowUDF,
33};
34use datafusion::optimizer::simplify_expressions::SimplifyContext;
35use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
36use datafusion::sql::sqlparser::ast::{
37 AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
38 Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Subscript,
39 TimezoneInfo, UnaryOperator, Value,
40};
41use datafusion::{
42 common::Column,
43 logical_expr::{col, Between, BinaryExpr, Like, Operator},
44 physical_expr::execution_props::ExecutionProps,
45 physical_plan::PhysicalExpr,
46 prelude::Expr,
47 scalar::ScalarValue,
48};
49use datafusion_functions::core::getfield::GetFieldFunc;
50use lance_arrow::cast::cast_with_options;
51use lance_core::datatypes::Schema;
52use snafu::location;
53
54use lance_core::{Error, Result};
55
56#[derive(Debug, Clone)]
57struct CastListF16Udf {
58 signature: Signature,
59}
60
61impl CastListF16Udf {
62 pub fn new() -> Self {
63 Self {
64 signature: Signature::any(1, Volatility::Immutable),
65 }
66 }
67}
68
69impl ScalarUDFImpl for CastListF16Udf {
70 fn as_any(&self) -> &dyn std::any::Any {
71 self
72 }
73
74 fn name(&self) -> &str {
75 "_cast_list_f16"
76 }
77
78 fn signature(&self) -> &Signature {
79 &self.signature
80 }
81
82 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
83 let input = &arg_types[0];
84 match input {
85 ArrowDataType::FixedSizeList(field, size) => {
86 if field.data_type() != &ArrowDataType::Float32
87 && field.data_type() != &ArrowDataType::Float16
88 {
89 return Err(datafusion::error::DataFusionError::Execution(
90 "cast_list_f16 only supports list of float32 or float16".to_string(),
91 ));
92 }
93 Ok(ArrowDataType::FixedSizeList(
94 Arc::new(Field::new(
95 field.name(),
96 ArrowDataType::Float16,
97 field.is_nullable(),
98 )),
99 *size,
100 ))
101 }
102 ArrowDataType::List(field) => {
103 if field.data_type() != &ArrowDataType::Float32
104 && field.data_type() != &ArrowDataType::Float16
105 {
106 return Err(datafusion::error::DataFusionError::Execution(
107 "cast_list_f16 only supports list of float32 or float16".to_string(),
108 ));
109 }
110 Ok(ArrowDataType::List(Arc::new(Field::new(
111 field.name(),
112 ArrowDataType::Float16,
113 field.is_nullable(),
114 ))))
115 }
116 _ => Err(datafusion::error::DataFusionError::Execution(
117 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
118 )),
119 }
120 }
121
122 fn invoke(&self, args: &[ColumnarValue]) -> DFResult<ColumnarValue> {
123 let ColumnarValue::Array(arr) = &args[0] else {
124 return Err(datafusion::error::DataFusionError::Execution(
125 "cast_list_f16 only supports array arguments".to_string(),
126 ));
127 };
128
129 let to_type = match arr.data_type() {
130 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
131 Arc::new(Field::new(
132 field.name(),
133 ArrowDataType::Float16,
134 field.is_nullable(),
135 )),
136 *size,
137 ),
138 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
139 field.name(),
140 ArrowDataType::Float16,
141 field.is_nullable(),
142 ))),
143 _ => {
144 return Err(datafusion::error::DataFusionError::Execution(
145 "cast_list_f16 only supports array arguments".to_string(),
146 ));
147 }
148 };
149
150 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
151 Ok(ColumnarValue::Array(res))
152 }
153}
154
155struct LanceContextProvider {
157 options: datafusion::config::ConfigOptions,
158 state: SessionState,
159 expr_planners: Vec<Arc<dyn ExprPlanner>>,
160}
161
162impl Default for LanceContextProvider {
163 fn default() -> Self {
164 let config = SessionConfig::new();
165 let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
166 let mut state_builder = SessionStateBuilder::new()
167 .with_config(config)
168 .with_runtime_env(runtime)
169 .with_default_features();
170
171 let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
176
177 Self {
178 options: ConfigOptions::default(),
179 state: state_builder.build(),
180 expr_planners,
181 }
182 }
183}
184
185impl ContextProvider for LanceContextProvider {
186 fn get_table_source(
187 &self,
188 name: datafusion::sql::TableReference,
189 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
190 Err(datafusion::error::DataFusionError::NotImplemented(format!(
191 "Attempt to reference inner table {} not supported",
192 name
193 )))
194 }
195
196 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
197 self.state.aggregate_functions().get(name).cloned()
198 }
199
200 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
201 self.state.window_functions().get(name).cloned()
202 }
203
204 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
205 match f {
206 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
209 _ => self.state.scalar_functions().get(f).cloned(),
210 }
211 }
212
213 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
214 None
216 }
217
218 fn options(&self) -> &datafusion::config::ConfigOptions {
219 &self.options
220 }
221
222 fn udf_names(&self) -> Vec<String> {
223 self.state.scalar_functions().keys().cloned().collect()
224 }
225
226 fn udaf_names(&self) -> Vec<String> {
227 self.state.aggregate_functions().keys().cloned().collect()
228 }
229
230 fn udwf_names(&self) -> Vec<String> {
231 self.state.window_functions().keys().cloned().collect()
232 }
233
234 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
235 &self.expr_planners
236 }
237}
238
239pub struct Planner {
240 schema: SchemaRef,
241 context_provider: LanceContextProvider,
242}
243
244impl Planner {
245 pub fn new(schema: SchemaRef) -> Self {
246 Self {
247 schema,
248 context_provider: LanceContextProvider::default(),
249 }
250 }
251
252 fn column(idents: &[Ident]) -> Expr {
253 let mut column = col(&idents[0].value);
254 for ident in &idents[1..] {
255 column = Expr::ScalarFunction(ScalarFunction {
256 args: vec![
257 column,
258 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone()))),
259 ],
260 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
261 });
262 }
263 column
264 }
265
266 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
267 Ok(match op {
268 BinaryOperator::Plus => Operator::Plus,
269 BinaryOperator::Minus => Operator::Minus,
270 BinaryOperator::Multiply => Operator::Multiply,
271 BinaryOperator::Divide => Operator::Divide,
272 BinaryOperator::Modulo => Operator::Modulo,
273 BinaryOperator::StringConcat => Operator::StringConcat,
274 BinaryOperator::Gt => Operator::Gt,
275 BinaryOperator::Lt => Operator::Lt,
276 BinaryOperator::GtEq => Operator::GtEq,
277 BinaryOperator::LtEq => Operator::LtEq,
278 BinaryOperator::Eq => Operator::Eq,
279 BinaryOperator::NotEq => Operator::NotEq,
280 BinaryOperator::And => Operator::And,
281 BinaryOperator::Or => Operator::Or,
282 _ => {
283 return Err(Error::invalid_input(
284 format!("Operator {op} is not supported"),
285 location!(),
286 ));
287 }
288 })
289 }
290
291 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
292 Ok(Expr::BinaryExpr(BinaryExpr::new(
293 Box::new(self.parse_sql_expr(left)?),
294 self.binary_op(op)?,
295 Box::new(self.parse_sql_expr(right)?),
296 )))
297 }
298
299 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
300 Ok(match op {
301 UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
302 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
303 }
304
305 UnaryOperator::Minus => {
306 use datafusion::logical_expr::lit;
307 match expr {
308 SQLExpr::Value(Value::Number(n, _)) => match n.parse::<i64>() {
309 Ok(n) => lit(-n),
310 Err(_) => lit(-n
311 .parse::<f64>()
312 .map_err(|_e| {
313 Error::invalid_input(
314 format!("negative operator can be only applied to integer and float operands, got: {n}"),
315 location!(),
316 )
317 })?),
318 },
319 _ => {
320 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
321 }
322 }
323 }
324
325 _ => {
326 return Err(Error::invalid_input(
327 format!("Unary operator '{:?}' is not supported", op),
328 location!(),
329 ));
330 }
331 })
332 }
333
334 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
336 use datafusion::logical_expr::lit;
337 let value: Cow<str> = if negative {
338 Cow::Owned(format!("-{}", value))
339 } else {
340 Cow::Borrowed(value)
341 };
342 if let Ok(n) = value.parse::<i64>() {
343 Ok(lit(n))
344 } else {
345 value.parse::<f64>().map(lit).map_err(|_| {
346 Error::invalid_input(
347 format!("'{value}' is not supported number value."),
348 location!(),
349 )
350 })
351 }
352 }
353
354 fn value(&self, value: &Value) -> Result<Expr> {
355 Ok(match value {
356 Value::Number(v, _) => self.number(v.as_str(), false)?,
357 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
358 Value::HexStringLiteral(hsl) => {
359 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)))
360 }
361 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
362 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v))),
363 Value::Null => Expr::Literal(ScalarValue::Null),
364 _ => todo!(),
365 })
366 }
367
368 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
369 match func_args {
370 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
371 _ => Err(Error::invalid_input(
372 format!("Unsupported function args: {:?}", func_args),
373 location!(),
374 )),
375 }
376 }
377
378 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
385 match &func.args {
386 FunctionArguments::List(args) => {
387 if func.name.0.len() != 1 {
388 return Err(Error::invalid_input(
389 format!("Function name must have 1 part, got: {:?}", func.name.0),
390 location!(),
391 ));
392 }
393 Ok(Expr::IsNotNull(Box::new(
394 self.parse_function_args(&args.args[0])?,
395 )))
396 }
397 _ => Err(Error::invalid_input(
398 format!("Unsupported function args: {:?}", &func.args),
399 location!(),
400 )),
401 }
402 }
403
404 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
405 if let SQLExpr::Function(function) = &function {
406 if !function.name.0.is_empty() && function.name.0[0].value == "is_valid" {
407 return self.legacy_parse_function(function);
408 }
409 }
410 let sql_to_rel = SqlToRel::new_with_options(
411 &self.context_provider,
412 ParserOptions {
413 parse_float_as_decimal: false,
414 enable_ident_normalization: false,
415 support_varchar_with_length: false,
416 enable_options_value_normalization: false,
417 collect_spans: false,
418 },
419 );
420
421 let mut planner_context = PlannerContext::default();
422 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
423 Ok(sql_to_rel.sql_to_expr(function, &schema, &mut planner_context)?)
424 }
425
426 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
427 const SUPPORTED_TYPES: [&str; 13] = [
428 "int [unsigned]",
429 "tinyint [unsigned]",
430 "smallint [unsigned]",
431 "bigint [unsigned]",
432 "float",
433 "double",
434 "string",
435 "binary",
436 "date",
437 "timestamp(precision)",
438 "datetime(precision)",
439 "decimal(precision,scale)",
440 "boolean",
441 ];
442 match data_type {
443 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
444 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
445 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
446 SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
447 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
448 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
449 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
450 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
451 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
452 SQLDataType::UnsignedTinyInt(_) => Ok(ArrowDataType::UInt8),
453 SQLDataType::UnsignedSmallInt(_) => Ok(ArrowDataType::UInt16),
454 SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => {
455 Ok(ArrowDataType::UInt32)
456 }
457 SQLDataType::UnsignedBigInt(_) => Ok(ArrowDataType::UInt64),
458 SQLDataType::Date => Ok(ArrowDataType::Date32),
459 SQLDataType::Timestamp(resolution, tz) => {
460 match tz {
461 TimezoneInfo::None => {}
462 _ => {
463 return Err(Error::invalid_input(
464 "Timezone not supported in timestamp".to_string(),
465 location!(),
466 ));
467 }
468 };
469 let time_unit = match resolution {
470 None => TimeUnit::Microsecond,
472 Some(0) => TimeUnit::Second,
473 Some(3) => TimeUnit::Millisecond,
474 Some(6) => TimeUnit::Microsecond,
475 Some(9) => TimeUnit::Nanosecond,
476 _ => {
477 return Err(Error::invalid_input(
478 format!("Unsupported datetime resolution: {:?}", resolution),
479 location!(),
480 ));
481 }
482 };
483 Ok(ArrowDataType::Timestamp(time_unit, None))
484 }
485 SQLDataType::Datetime(resolution) => {
486 let time_unit = match resolution {
487 None => TimeUnit::Microsecond,
488 Some(0) => TimeUnit::Second,
489 Some(3) => TimeUnit::Millisecond,
490 Some(6) => TimeUnit::Microsecond,
491 Some(9) => TimeUnit::Nanosecond,
492 _ => {
493 return Err(Error::invalid_input(
494 format!("Unsupported datetime resolution: {:?}", resolution),
495 location!(),
496 ));
497 }
498 };
499 Ok(ArrowDataType::Timestamp(time_unit, None))
500 }
501 SQLDataType::Decimal(number_info) => match number_info {
502 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
503 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
504 }
505 _ => Err(Error::invalid_input(
506 format!(
507 "Must provide precision and scale for decimal: {:?}",
508 number_info
509 ),
510 location!(),
511 )),
512 },
513 _ => Err(Error::invalid_input(
514 format!(
515 "Unsupported data type: {:?}. Supported types: {:?}",
516 data_type, SUPPORTED_TYPES
517 ),
518 location!(),
519 )),
520 }
521 }
522
523 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
524 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
525 for planner in self.context_provider.get_expr_planners() {
526 match planner.plan_field_access(field_access_expr, &df_schema)? {
527 PlannerResult::Planned(expr) => return Ok(expr),
528 PlannerResult::Original(expr) => {
529 field_access_expr = expr;
530 }
531 }
532 }
533 Err(Error::invalid_input(
534 "Field access could not be planned",
535 location!(),
536 ))
537 }
538
539 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
540 match expr {
541 SQLExpr::Identifier(id) => {
542 if id.quote_style == Some('"') {
545 Ok(Expr::Literal(ScalarValue::Utf8(Some(id.value.clone()))))
546 } else if id.quote_style == Some('`') {
549 Ok(Expr::Column(Column::from_name(id.value.clone())))
550 } else {
551 Ok(Self::column(vec![id.clone()].as_slice()))
552 }
553 }
554 SQLExpr::CompoundIdentifier(ids) => Ok(Self::column(ids.as_slice())),
555 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
556 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
557 SQLExpr::Value(value) => self.value(value),
558 SQLExpr::Array(SQLArray { elem, .. }) => {
559 let mut values = vec![];
560
561 let array_literal_error = |pos: usize, value: &_| {
562 Err(Error::invalid_input(
563 format!(
564 "Expected a literal value in array, instead got {} at position {}",
565 value, pos
566 ),
567 location!(),
568 ))
569 };
570
571 for (pos, expr) in elem.iter().enumerate() {
572 match expr {
573 SQLExpr::Value(value) => {
574 if let Expr::Literal(value) = self.value(value)? {
575 values.push(value);
576 } else {
577 return array_literal_error(pos, expr);
578 }
579 }
580 SQLExpr::UnaryOp {
581 op: UnaryOperator::Minus,
582 expr,
583 } => {
584 if let SQLExpr::Value(Value::Number(number, _)) = expr.as_ref() {
585 if let Expr::Literal(value) = self.number(number, true)? {
586 values.push(value);
587 } else {
588 return array_literal_error(pos, expr);
589 }
590 } else {
591 return array_literal_error(pos, expr);
592 }
593 }
594 _ => {
595 return array_literal_error(pos, expr);
596 }
597 }
598 }
599
600 let field = if !values.is_empty() {
601 let data_type = values[0].data_type();
602
603 for value in &mut values {
604 if value.data_type() != data_type {
605 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
606 format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
607 location!()
608 ))?;
609 }
610 }
611 Field::new("item", data_type, true)
612 } else {
613 Field::new("item", ArrowDataType::Null, true)
614 };
615
616 let values = values
617 .into_iter()
618 .map(|v| v.to_array().map_err(Error::from))
619 .collect::<Result<Vec<_>>>()?;
620 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
621 let values = concat(&array_refs)?;
622 let values = ListArray::try_new(
623 field.into(),
624 OffsetBuffer::from_lengths([values.len()]),
625 values,
626 None,
627 )?;
628
629 Ok(Expr::Literal(ScalarValue::List(Arc::new(values))))
630 }
631 SQLExpr::TypedString { data_type, value } => {
633 Ok(Expr::Cast(datafusion::logical_expr::Cast {
634 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value.clone())))),
635 data_type: self.parse_type(data_type)?,
636 }))
637 }
638 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
639 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
640 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
641 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
642 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
643 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
644 SQLExpr::InList {
645 expr,
646 list,
647 negated,
648 } => {
649 let value_expr = self.parse_sql_expr(expr)?;
650 let list_exprs = list
651 .iter()
652 .map(|e| self.parse_sql_expr(e))
653 .collect::<Result<Vec<_>>>()?;
654 Ok(value_expr.in_list(list_exprs, *negated))
655 }
656 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
657 SQLExpr::Function(_) => self.parse_function(expr.clone()),
658 SQLExpr::ILike {
659 negated,
660 expr,
661 pattern,
662 escape_char,
663 any: _,
664 } => Ok(Expr::Like(Like::new(
665 *negated,
666 Box::new(self.parse_sql_expr(expr)?),
667 Box::new(self.parse_sql_expr(pattern)?),
668 escape_char.as_ref().and_then(|c| c.chars().next()),
669 true,
670 ))),
671 SQLExpr::Like {
672 negated,
673 expr,
674 pattern,
675 escape_char,
676 any: _,
677 } => Ok(Expr::Like(Like::new(
678 *negated,
679 Box::new(self.parse_sql_expr(expr)?),
680 Box::new(self.parse_sql_expr(pattern)?),
681 escape_char.as_ref().and_then(|c| c.chars().next()),
682 false,
683 ))),
684 SQLExpr::Cast {
685 expr, data_type, ..
686 } => Ok(Expr::Cast(datafusion::logical_expr::Cast {
687 expr: Box::new(self.parse_sql_expr(expr)?),
688 data_type: self.parse_type(data_type)?,
689 })),
690 SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
691 "JSON access is not supported",
692 location!(),
693 )),
694 SQLExpr::CompoundFieldAccess { root, access_chain } => {
695 let mut expr = self.parse_sql_expr(root)?;
696
697 for access in access_chain {
698 let field_access = match access {
699 AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
701 | AccessExpr::Subscript(Subscript::Index {
702 index:
703 SQLExpr::Value(
704 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
705 ),
706 }) => GetFieldAccess::NamedStructField {
707 name: ScalarValue::from(s.as_str()),
708 },
709 AccessExpr::Subscript(Subscript::Index { index }) => {
710 let key = Box::new(self.parse_sql_expr(index)?);
711 GetFieldAccess::ListIndex { key }
712 }
713 AccessExpr::Subscript(Subscript::Slice { .. }) => {
714 return Err(Error::invalid_input(
715 "Slice subscript is not supported",
716 location!(),
717 ));
718 }
719 _ => {
720 return Err(Error::invalid_input(
723 "Only dot notation or index access is supported for field access",
724 location!(),
725 ));
726 }
727 };
728
729 let field_access_expr = RawFieldAccessExpr { expr, field_access };
730 expr = self.plan_field_access(field_access_expr)?;
731 }
732
733 Ok(expr)
734 }
735 SQLExpr::Between {
736 expr,
737 negated,
738 low,
739 high,
740 } => {
741 let expr = self.parse_sql_expr(expr)?;
743 let low = self.parse_sql_expr(low)?;
744 let high = self.parse_sql_expr(high)?;
745
746 let between = Expr::Between(Between::new(
747 Box::new(expr),
748 *negated,
749 Box::new(low),
750 Box::new(high),
751 ));
752 Ok(between)
753 }
754 _ => Err(Error::invalid_input(
755 format!("Expression '{expr}' is not supported SQL in lance"),
756 location!(),
757 )),
758 }
759 }
760
761 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
766 let ast_expr = parse_sql_filter(filter)?;
768 let expr = self.parse_sql_expr(&ast_expr)?;
769 let schema = Schema::try_from(self.schema.as_ref())?;
770 let resolved = resolve_expr(&expr, &schema)?;
771 coerce_filter_type_to_boolean(resolved)
772 }
773
774 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
779 let ast_expr = parse_sql_expr(expr)?;
780 let expr = self.parse_sql_expr(&ast_expr)?;
781 let schema = Schema::try_from(self.schema.as_ref())?;
782 let resolved = resolve_expr(&expr, &schema)?;
783 Ok(resolved)
784 }
785
786 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
792 let hex_bytes = s.as_bytes();
793 let mut decoded_bytes = Vec::with_capacity((hex_bytes.len() + 1) / 2);
794
795 let start_idx = hex_bytes.len() % 2;
796 if start_idx > 0 {
797 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
799 }
800
801 for i in (start_idx..hex_bytes.len()).step_by(2) {
802 let high = Self::try_decode_hex_char(hex_bytes[i])?;
803 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
804 decoded_bytes.push((high << 4) | low);
805 }
806
807 Some(decoded_bytes)
808 }
809
810 const fn try_decode_hex_char(c: u8) -> Option<u8> {
814 match c {
815 b'A'..=b'F' => Some(c - b'A' + 10),
816 b'a'..=b'f' => Some(c - b'a' + 10),
817 b'0'..=b'9' => Some(c - b'0'),
818 _ => None,
819 }
820 }
821
822 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
824 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
825
826 let props = ExecutionProps::default();
829 let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
830 let simplifier =
831 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
832
833 let expr = simplifier.simplify(expr)?;
834 let expr = simplifier.coerce(expr, &df_schema)?;
835
836 Ok(expr)
837 }
838
839 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
841 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
842
843 Ok(datafusion::physical_expr::create_physical_expr(
844 expr,
845 df_schema.as_ref(),
846 &Default::default(),
847 )?)
848 }
849
850 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
854 let mut visitor = ColumnCapturingVisitor {
855 current_path: VecDeque::new(),
856 columns: BTreeSet::new(),
857 };
858 expr.visit(&mut visitor).unwrap();
859 visitor.columns.into_iter().collect()
860 }
861}
862
863struct ColumnCapturingVisitor {
864 current_path: VecDeque<String>,
866 columns: BTreeSet<String>,
867}
868
869impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
870 type Node = Expr;
871
872 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
873 match node {
874 Expr::Column(Column { name, .. }) => {
875 let mut path = name.clone();
876 for part in self.current_path.drain(..) {
877 path.push('.');
878 path.push_str(&part);
879 }
880 self.columns.insert(path);
881 self.current_path.clear();
882 }
883 Expr::ScalarFunction(udf) => {
884 if udf.name() == GetFieldFunc::default().name() {
885 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
886 self.current_path.push_front(name.to_string())
887 } else {
888 self.current_path.clear();
889 }
890 } else {
891 self.current_path.clear();
892 }
893 }
894 _ => {
895 self.current_path.clear();
896 }
897 }
898
899 Ok(TreeNodeRecursion::Continue)
900 }
901}
902
903#[cfg(test)]
904mod tests {
905
906 use crate::logical_expr::ExprExt;
907
908 use super::*;
909
910 use arrow::datatypes::Float64Type;
911 use arrow_array::{
912 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
913 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
914 TimestampNanosecondArray, TimestampSecondArray,
915 };
916 use arrow_schema::{DataType, Fields, Schema};
917 use datafusion::{
918 logical_expr::{lit, Cast},
919 prelude::{array_element, get_field},
920 };
921 use datafusion_functions::core::expr_ext::FieldAccessor;
922
923 #[test]
924 fn test_parse_filter_simple() {
925 let schema = Arc::new(Schema::new(vec![
926 Field::new("i", DataType::Int32, false),
927 Field::new("s", DataType::Utf8, true),
928 Field::new(
929 "st",
930 DataType::Struct(Fields::from(vec![
931 Field::new("x", DataType::Float32, false),
932 Field::new("y", DataType::Float32, false),
933 ])),
934 true,
935 ),
936 ]));
937
938 let planner = Planner::new(schema.clone());
939
940 let expected = col("i")
941 .gt(lit(3_i32))
942 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
943 .and(
944 col("s")
945 .eq(lit("str-4"))
946 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
947 );
948
949 let expr = planner
951 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
952 .unwrap();
953 assert_eq!(expr, expected);
954
955 let expr = planner
957 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
958 .unwrap();
959
960 let physical_expr = planner.create_physical_expr(&expr).unwrap();
961
962 let batch = RecordBatch::try_new(
963 schema,
964 vec![
965 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
966 Arc::new(StringArray::from_iter_values(
967 (0..10).map(|v| format!("str-{}", v)),
968 )),
969 Arc::new(StructArray::from(vec![
970 (
971 Arc::new(Field::new("x", DataType::Float32, false)),
972 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
973 as ArrayRef,
974 ),
975 (
976 Arc::new(Field::new("y", DataType::Float32, false)),
977 Arc::new(Float32Array::from_iter_values(
978 (0..10).map(|v| (v * 10) as f32),
979 )),
980 ),
981 ])),
982 ],
983 )
984 .unwrap();
985 let predicates = physical_expr.evaluate(&batch).unwrap();
986 assert_eq!(
987 predicates.into_array(0).unwrap().as_ref(),
988 &BooleanArray::from(vec![
989 false, false, false, false, true, true, false, false, false, false
990 ])
991 );
992 }
993
994 #[test]
995 fn test_nested_col_refs() {
996 let schema = Arc::new(Schema::new(vec![
997 Field::new("s0", DataType::Utf8, true),
998 Field::new(
999 "st",
1000 DataType::Struct(Fields::from(vec![
1001 Field::new("s1", DataType::Utf8, true),
1002 Field::new(
1003 "st",
1004 DataType::Struct(Fields::from(vec![Field::new(
1005 "s2",
1006 DataType::Utf8,
1007 true,
1008 )])),
1009 true,
1010 ),
1011 ])),
1012 true,
1013 ),
1014 ]));
1015
1016 let planner = Planner::new(schema);
1017
1018 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1019 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1020 assert!(matches!(
1021 expr,
1022 Expr::BinaryExpr(BinaryExpr {
1023 left: _,
1024 op: Operator::Eq,
1025 right: _
1026 })
1027 ));
1028 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1029 assert_eq!(left.as_ref(), expected);
1030 }
1031 }
1032
1033 let expected = Expr::Column(Column::new_unqualified("s0"));
1034 assert_column_eq(&planner, "s0", &expected);
1035 assert_column_eq(&planner, "`s0`", &expected);
1036
1037 let expected = Expr::ScalarFunction(ScalarFunction {
1038 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1039 args: vec![
1040 Expr::Column(Column::new_unqualified("st")),
1041 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string()))),
1042 ],
1043 });
1044 assert_column_eq(&planner, "st.s1", &expected);
1045 assert_column_eq(&planner, "`st`.`s1`", &expected);
1046 assert_column_eq(&planner, "st.`s1`", &expected);
1047
1048 let expected = Expr::ScalarFunction(ScalarFunction {
1049 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1050 args: vec![
1051 Expr::ScalarFunction(ScalarFunction {
1052 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1053 args: vec![
1054 Expr::Column(Column::new_unqualified("st")),
1055 Expr::Literal(ScalarValue::Utf8(Some("st".to_string()))),
1056 ],
1057 }),
1058 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string()))),
1059 ],
1060 });
1061
1062 assert_column_eq(&planner, "st.st.s2", &expected);
1063 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1064 assert_column_eq(&planner, "st.st.`s2`", &expected);
1065 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1066 }
1067
1068 #[test]
1069 fn test_nested_list_refs() {
1070 let schema = Arc::new(Schema::new(vec![Field::new(
1071 "l",
1072 DataType::List(Arc::new(Field::new(
1073 "item",
1074 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1075 true,
1076 ))),
1077 true,
1078 )]));
1079
1080 let planner = Planner::new(schema);
1081
1082 let expected = array_element(col("l"), lit(0_i64));
1083 let expr = planner.parse_expr("l[0]").unwrap();
1084 assert_eq!(expr, expected);
1085
1086 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1087 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1088 assert_eq!(expr, expected);
1089
1090 }
1095
1096 #[test]
1097 fn test_negative_expressions() {
1098 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1099
1100 let planner = Planner::new(schema.clone());
1101
1102 let expected = col("x")
1103 .gt(lit(-3_i64))
1104 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1105
1106 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1107
1108 assert_eq!(expr, expected);
1109
1110 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1111
1112 let batch = RecordBatch::try_new(
1113 schema,
1114 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1115 )
1116 .unwrap();
1117 let predicates = physical_expr.evaluate(&batch).unwrap();
1118 assert_eq!(
1119 predicates.into_array(0).unwrap().as_ref(),
1120 &BooleanArray::from(vec![
1121 false, false, false, true, true, true, true, false, false, false
1122 ])
1123 );
1124 }
1125
1126 #[test]
1127 fn test_negative_array_expressions() {
1128 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1129
1130 let planner = Planner::new(schema);
1131
1132 let expected = Expr::Literal(ScalarValue::List(Arc::new(
1133 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1134 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1135 )]),
1136 )));
1137
1138 let expr = planner
1139 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1140 .unwrap();
1141
1142 assert_eq!(expr, expected);
1143 }
1144
1145 #[test]
1146 fn test_sql_like() {
1147 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1148
1149 let planner = Planner::new(schema.clone());
1150
1151 let expected = col("s").like(lit("str-4"));
1152 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1154 assert_eq!(expr, expected);
1155 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1156
1157 let batch = RecordBatch::try_new(
1158 schema,
1159 vec![Arc::new(StringArray::from_iter_values(
1160 (0..10).map(|v| format!("str-{}", v)),
1161 ))],
1162 )
1163 .unwrap();
1164 let predicates = physical_expr.evaluate(&batch).unwrap();
1165 assert_eq!(
1166 predicates.into_array(0).unwrap().as_ref(),
1167 &BooleanArray::from(vec![
1168 false, false, false, false, true, false, false, false, false, false
1169 ])
1170 );
1171 }
1172
1173 #[test]
1174 fn test_not_like() {
1175 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1176
1177 let planner = Planner::new(schema.clone());
1178
1179 let expected = col("s").not_like(lit("str-4"));
1180 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1182 assert_eq!(expr, expected);
1183 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1184
1185 let batch = RecordBatch::try_new(
1186 schema,
1187 vec![Arc::new(StringArray::from_iter_values(
1188 (0..10).map(|v| format!("str-{}", v)),
1189 ))],
1190 )
1191 .unwrap();
1192 let predicates = physical_expr.evaluate(&batch).unwrap();
1193 assert_eq!(
1194 predicates.into_array(0).unwrap().as_ref(),
1195 &BooleanArray::from(vec![
1196 true, true, true, true, false, true, true, true, true, true
1197 ])
1198 );
1199 }
1200
1201 #[test]
1202 fn test_sql_is_in() {
1203 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1204
1205 let planner = Planner::new(schema.clone());
1206
1207 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1208 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1210 assert_eq!(expr, expected);
1211 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1212
1213 let batch = RecordBatch::try_new(
1214 schema,
1215 vec![Arc::new(StringArray::from_iter_values(
1216 (0..10).map(|v| format!("str-{}", v)),
1217 ))],
1218 )
1219 .unwrap();
1220 let predicates = physical_expr.evaluate(&batch).unwrap();
1221 assert_eq!(
1222 predicates.into_array(0).unwrap().as_ref(),
1223 &BooleanArray::from(vec![
1224 false, false, false, false, true, true, false, false, false, false
1225 ])
1226 );
1227 }
1228
1229 #[test]
1230 fn test_sql_is_null() {
1231 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1232
1233 let planner = Planner::new(schema.clone());
1234
1235 let expected = col("s").is_null();
1236 let expr = planner.parse_filter("s IS NULL").unwrap();
1237 assert_eq!(expr, expected);
1238 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1239
1240 let batch = RecordBatch::try_new(
1241 schema,
1242 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1243 if v % 3 == 0 {
1244 Some(format!("str-{}", v))
1245 } else {
1246 None
1247 }
1248 })))],
1249 )
1250 .unwrap();
1251 let predicates = physical_expr.evaluate(&batch).unwrap();
1252 assert_eq!(
1253 predicates.into_array(0).unwrap().as_ref(),
1254 &BooleanArray::from(vec![
1255 false, true, true, false, true, true, false, true, true, false
1256 ])
1257 );
1258
1259 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1260 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1261 let predicates = physical_expr.evaluate(&batch).unwrap();
1262 assert_eq!(
1263 predicates.into_array(0).unwrap().as_ref(),
1264 &BooleanArray::from(vec![
1265 true, false, false, true, false, false, true, false, false, true,
1266 ])
1267 );
1268 }
1269
1270 #[test]
1271 fn test_sql_invert() {
1272 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1273
1274 let planner = Planner::new(schema.clone());
1275
1276 let expr = planner.parse_filter("NOT s").unwrap();
1277 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1278
1279 let batch = RecordBatch::try_new(
1280 schema,
1281 vec![Arc::new(BooleanArray::from_iter(
1282 (0..10).map(|v| Some(v % 3 == 0)),
1283 ))],
1284 )
1285 .unwrap();
1286 let predicates = physical_expr.evaluate(&batch).unwrap();
1287 assert_eq!(
1288 predicates.into_array(0).unwrap().as_ref(),
1289 &BooleanArray::from(vec![
1290 false, true, true, false, true, true, false, true, true, false
1291 ])
1292 );
1293 }
1294
1295 #[test]
1296 fn test_sql_cast() {
1297 let cases = &[
1298 (
1299 "x = cast('2021-01-01 00:00:00' as timestamp)",
1300 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1301 ),
1302 (
1303 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1304 ArrowDataType::Timestamp(TimeUnit::Second, None),
1305 ),
1306 (
1307 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1308 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1309 ),
1310 (
1311 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1312 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1313 ),
1314 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1315 (
1316 "x = cast('1.238' as decimal(9,3))",
1317 ArrowDataType::Decimal128(9, 3),
1318 ),
1319 ("x = cast(1 as float)", ArrowDataType::Float32),
1320 ("x = cast(1 as double)", ArrowDataType::Float64),
1321 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1322 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1323 ("x = cast(1 as int)", ArrowDataType::Int32),
1324 ("x = cast(1 as integer)", ArrowDataType::Int32),
1325 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1326 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1327 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1328 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1329 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1330 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1331 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1332 ("x = cast(1 as string)", ArrowDataType::Utf8),
1333 ];
1334
1335 for (sql, expected_data_type) in cases {
1336 let schema = Arc::new(Schema::new(vec![Field::new(
1337 "x",
1338 expected_data_type.clone(),
1339 true,
1340 )]));
1341 let planner = Planner::new(schema.clone());
1342 let expr = planner.parse_filter(sql).unwrap();
1343
1344 let expected_value_str = sql
1346 .split("cast(")
1347 .nth(1)
1348 .unwrap()
1349 .split(" as")
1350 .next()
1351 .unwrap();
1352 let expected_value_str = expected_value_str.trim_matches('\'');
1354
1355 match expr {
1356 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1357 Expr::Cast(Cast { expr, data_type }) => {
1358 match expr.as_ref() {
1359 Expr::Literal(ScalarValue::Utf8(Some(value_str))) => {
1360 assert_eq!(value_str, expected_value_str);
1361 }
1362 Expr::Literal(ScalarValue::Int64(Some(value))) => {
1363 assert_eq!(*value, 1);
1364 }
1365 _ => panic!("Expected cast to be applied to literal"),
1366 }
1367 assert_eq!(data_type, expected_data_type);
1368 }
1369 _ => panic!("Expected right to be a cast"),
1370 },
1371 _ => panic!("Expected binary expression"),
1372 }
1373 }
1374 }
1375
1376 #[test]
1377 fn test_sql_literals() {
1378 let cases = &[
1379 (
1380 "x = timestamp '2021-01-01 00:00:00'",
1381 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1382 ),
1383 (
1384 "x = timestamp(0) '2021-01-01 00:00:00'",
1385 ArrowDataType::Timestamp(TimeUnit::Second, None),
1386 ),
1387 (
1388 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1389 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1390 ),
1391 ("x = date '2021-01-01'", ArrowDataType::Date32),
1392 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1393 ];
1394
1395 for (sql, expected_data_type) in cases {
1396 let schema = Arc::new(Schema::new(vec![Field::new(
1397 "x",
1398 expected_data_type.clone(),
1399 true,
1400 )]));
1401 let planner = Planner::new(schema.clone());
1402 let expr = planner.parse_filter(sql).unwrap();
1403
1404 let expected_value_str = sql.split('\'').nth(1).unwrap();
1405
1406 match expr {
1407 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1408 Expr::Cast(Cast { expr, data_type }) => {
1409 match expr.as_ref() {
1410 Expr::Literal(ScalarValue::Utf8(Some(value_str))) => {
1411 assert_eq!(value_str, expected_value_str);
1412 }
1413 _ => panic!("Expected cast to be applied to literal"),
1414 }
1415 assert_eq!(data_type, expected_data_type);
1416 }
1417 _ => panic!("Expected right to be a cast"),
1418 },
1419 _ => panic!("Expected binary expression"),
1420 }
1421 }
1422 }
1423
1424 #[test]
1425 fn test_sql_array_literals() {
1426 let cases = [
1427 (
1428 "x = [1, 2, 3]",
1429 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1430 ),
1431 (
1432 "x = [1, 2, 3]",
1433 ArrowDataType::FixedSizeList(
1434 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1435 3,
1436 ),
1437 ),
1438 ];
1439
1440 for (sql, expected_data_type) in cases {
1441 let schema = Arc::new(Schema::new(vec![Field::new(
1442 "x",
1443 expected_data_type.clone(),
1444 true,
1445 )]));
1446 let planner = Planner::new(schema.clone());
1447 let expr = planner.parse_filter(sql).unwrap();
1448 let expr = planner.optimize_expr(expr).unwrap();
1449
1450 match expr {
1451 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1452 Expr::Literal(value) => {
1453 assert_eq!(&value.data_type(), &expected_data_type);
1454 }
1455 _ => panic!("Expected right to be a literal"),
1456 },
1457 _ => panic!("Expected binary expression"),
1458 }
1459 }
1460 }
1461
1462 #[test]
1463 fn test_sql_between() {
1464 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1465 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1466 use std::sync::Arc;
1467
1468 let schema = Arc::new(Schema::new(vec![
1469 Field::new("x", DataType::Int32, false),
1470 Field::new("y", DataType::Float64, false),
1471 Field::new(
1472 "ts",
1473 DataType::Timestamp(TimeUnit::Microsecond, None),
1474 false,
1475 ),
1476 ]));
1477
1478 let planner = Planner::new(schema.clone());
1479
1480 let expr = planner
1482 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1483 .unwrap();
1484 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1485
1486 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1490 (0..10).map(|i| base_ts + i * 1_000_000), );
1492
1493 let batch = RecordBatch::try_new(
1494 schema,
1495 vec![
1496 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1497 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1498 Arc::new(ts_array),
1499 ],
1500 )
1501 .unwrap();
1502
1503 let predicates = physical_expr.evaluate(&batch).unwrap();
1504 assert_eq!(
1505 predicates.into_array(0).unwrap().as_ref(),
1506 &BooleanArray::from(vec![
1507 false, false, false, true, true, true, true, true, false, false
1508 ])
1509 );
1510
1511 let expr = planner
1513 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1514 .unwrap();
1515 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1516
1517 let predicates = physical_expr.evaluate(&batch).unwrap();
1518 assert_eq!(
1519 predicates.into_array(0).unwrap().as_ref(),
1520 &BooleanArray::from(vec![
1521 true, true, true, false, false, false, false, false, true, true
1522 ])
1523 );
1524
1525 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1527 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1528
1529 let predicates = physical_expr.evaluate(&batch).unwrap();
1530 assert_eq!(
1531 predicates.into_array(0).unwrap().as_ref(),
1532 &BooleanArray::from(vec![
1533 false, false, false, true, true, true, true, false, false, false
1534 ])
1535 );
1536
1537 let expr = planner
1539 .parse_filter(
1540 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1541 )
1542 .unwrap();
1543 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1544
1545 let predicates = physical_expr.evaluate(&batch).unwrap();
1546 assert_eq!(
1547 predicates.into_array(0).unwrap().as_ref(),
1548 &BooleanArray::from(vec![
1549 false, false, false, true, true, true, true, true, false, false
1550 ])
1551 );
1552 }
1553
1554 #[test]
1555 fn test_sql_comparison() {
1556 let batch: Vec<(&str, ArrayRef)> = vec![
1558 (
1559 "timestamp_s",
1560 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1561 ),
1562 (
1563 "timestamp_ms",
1564 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1565 ),
1566 (
1567 "timestamp_us",
1568 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1569 ),
1570 (
1571 "timestamp_ns",
1572 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1573 ),
1574 ];
1575 let batch = RecordBatch::try_from_iter(batch).unwrap();
1576
1577 let planner = Planner::new(batch.schema());
1578
1579 let expressions = &[
1581 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1582 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1583 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1584 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1585 ];
1586
1587 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1588 std::iter::repeat(Some(false))
1589 .take(5)
1590 .chain(std::iter::repeat(Some(true)).take(5)),
1591 ));
1592 for expression in expressions {
1593 let logical_expr = planner.parse_filter(expression).unwrap();
1595 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1596 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1597
1598 let result = physical_expr.evaluate(&batch).unwrap();
1600 let result = result.into_array(batch.num_rows()).unwrap();
1601 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1602 }
1603 }
1604
1605 #[test]
1606 fn test_columns_in_expr() {
1607 let expr = col("s0").gt(lit("value")).and(
1608 col("st")
1609 .field("st")
1610 .field("s2")
1611 .eq(lit("value"))
1612 .or(col("st")
1613 .field("s1")
1614 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1615 );
1616
1617 let columns = Planner::column_names_in_expr(&expr);
1618 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1619 }
1620
1621 #[test]
1622 fn test_parse_binary_expr() {
1623 let bin_str = "x'616263'";
1624
1625 let schema = Arc::new(Schema::new(vec![Field::new(
1626 "binary",
1627 DataType::Binary,
1628 true,
1629 )]));
1630 let planner = Planner::new(schema);
1631 let expr = planner.parse_expr(bin_str).unwrap();
1632 assert_eq!(
1633 expr,
1634 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])))
1635 );
1636 }
1637
1638 #[test]
1639 fn test_lance_context_provider_expr_planners() {
1640 let ctx_provider = LanceContextProvider::default();
1641 assert!(!ctx_provider.get_expr_planners().is_empty());
1642 }
1643}