1use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::SessionState;
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
30 WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
34use datafusion::sql::sqlparser::ast::{
35 AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36 Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Subscript,
37 TimezoneInfo, UnaryOperator, Value,
38};
39use datafusion::{
40 common::Column,
41 logical_expr::{col, Between, BinaryExpr, Like, Operator},
42 physical_expr::execution_props::ExecutionProps,
43 physical_plan::PhysicalExpr,
44 prelude::Expr,
45 scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use snafu::location;
51
52use lance_core::{Error, Result};
53
54#[derive(Debug, Clone)]
55struct CastListF16Udf {
56 signature: Signature,
57}
58
59impl CastListF16Udf {
60 pub fn new() -> Self {
61 Self {
62 signature: Signature::any(1, Volatility::Immutable),
63 }
64 }
65}
66
67impl ScalarUDFImpl for CastListF16Udf {
68 fn as_any(&self) -> &dyn std::any::Any {
69 self
70 }
71
72 fn name(&self) -> &str {
73 "_cast_list_f16"
74 }
75
76 fn signature(&self) -> &Signature {
77 &self.signature
78 }
79
80 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
81 let input = &arg_types[0];
82 match input {
83 ArrowDataType::FixedSizeList(field, size) => {
84 if field.data_type() != &ArrowDataType::Float32
85 && field.data_type() != &ArrowDataType::Float16
86 {
87 return Err(datafusion::error::DataFusionError::Execution(
88 "cast_list_f16 only supports list of float32 or float16".to_string(),
89 ));
90 }
91 Ok(ArrowDataType::FixedSizeList(
92 Arc::new(Field::new(
93 field.name(),
94 ArrowDataType::Float16,
95 field.is_nullable(),
96 )),
97 *size,
98 ))
99 }
100 ArrowDataType::List(field) => {
101 if field.data_type() != &ArrowDataType::Float32
102 && field.data_type() != &ArrowDataType::Float16
103 {
104 return Err(datafusion::error::DataFusionError::Execution(
105 "cast_list_f16 only supports list of float32 or float16".to_string(),
106 ));
107 }
108 Ok(ArrowDataType::List(Arc::new(Field::new(
109 field.name(),
110 ArrowDataType::Float16,
111 field.is_nullable(),
112 ))))
113 }
114 _ => Err(datafusion::error::DataFusionError::Execution(
115 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
116 )),
117 }
118 }
119
120 fn invoke(&self, args: &[ColumnarValue]) -> DFResult<ColumnarValue> {
121 let ColumnarValue::Array(arr) = &args[0] else {
122 return Err(datafusion::error::DataFusionError::Execution(
123 "cast_list_f16 only supports array arguments".to_string(),
124 ));
125 };
126
127 let to_type = match arr.data_type() {
128 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
129 Arc::new(Field::new(
130 field.name(),
131 ArrowDataType::Float16,
132 field.is_nullable(),
133 )),
134 *size,
135 ),
136 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
137 field.name(),
138 ArrowDataType::Float16,
139 field.is_nullable(),
140 ))),
141 _ => {
142 return Err(datafusion::error::DataFusionError::Execution(
143 "cast_list_f16 only supports array arguments".to_string(),
144 ));
145 }
146 };
147
148 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
149 Ok(ColumnarValue::Array(res))
150 }
151}
152
153struct LanceContextProvider {
155 options: datafusion::config::ConfigOptions,
156 state: SessionState,
157 expr_planners: Vec<Arc<dyn ExprPlanner>>,
158}
159
160impl Default for LanceContextProvider {
161 fn default() -> Self {
162 let config = SessionConfig::new();
163 let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
164 let mut state_builder = SessionStateBuilder::new()
165 .with_config(config)
166 .with_runtime_env(runtime)
167 .with_default_features();
168
169 let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
174
175 Self {
176 options: ConfigOptions::default(),
177 state: state_builder.build(),
178 expr_planners,
179 }
180 }
181}
182
183impl ContextProvider for LanceContextProvider {
184 fn get_table_source(
185 &self,
186 name: datafusion::sql::TableReference,
187 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
188 Err(datafusion::error::DataFusionError::NotImplemented(format!(
189 "Attempt to reference inner table {} not supported",
190 name
191 )))
192 }
193
194 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
195 self.state.aggregate_functions().get(name).cloned()
196 }
197
198 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
199 self.state.window_functions().get(name).cloned()
200 }
201
202 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
203 match f {
204 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
207 _ => self.state.scalar_functions().get(f).cloned(),
208 }
209 }
210
211 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
212 None
214 }
215
216 fn options(&self) -> &datafusion::config::ConfigOptions {
217 &self.options
218 }
219
220 fn udf_names(&self) -> Vec<String> {
221 self.state.scalar_functions().keys().cloned().collect()
222 }
223
224 fn udaf_names(&self) -> Vec<String> {
225 self.state.aggregate_functions().keys().cloned().collect()
226 }
227
228 fn udwf_names(&self) -> Vec<String> {
229 self.state.window_functions().keys().cloned().collect()
230 }
231
232 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
233 &self.expr_planners
234 }
235}
236
237pub struct Planner {
238 schema: SchemaRef,
239 context_provider: LanceContextProvider,
240}
241
242impl Planner {
243 pub fn new(schema: SchemaRef) -> Self {
244 Self {
245 schema,
246 context_provider: LanceContextProvider::default(),
247 }
248 }
249
250 fn column(idents: &[Ident]) -> Expr {
251 let mut column = col(&idents[0].value);
252 for ident in &idents[1..] {
253 column = Expr::ScalarFunction(ScalarFunction {
254 args: vec![
255 column,
256 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone()))),
257 ],
258 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
259 });
260 }
261 column
262 }
263
264 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
265 Ok(match op {
266 BinaryOperator::Plus => Operator::Plus,
267 BinaryOperator::Minus => Operator::Minus,
268 BinaryOperator::Multiply => Operator::Multiply,
269 BinaryOperator::Divide => Operator::Divide,
270 BinaryOperator::Modulo => Operator::Modulo,
271 BinaryOperator::StringConcat => Operator::StringConcat,
272 BinaryOperator::Gt => Operator::Gt,
273 BinaryOperator::Lt => Operator::Lt,
274 BinaryOperator::GtEq => Operator::GtEq,
275 BinaryOperator::LtEq => Operator::LtEq,
276 BinaryOperator::Eq => Operator::Eq,
277 BinaryOperator::NotEq => Operator::NotEq,
278 BinaryOperator::And => Operator::And,
279 BinaryOperator::Or => Operator::Or,
280 _ => {
281 return Err(Error::invalid_input(
282 format!("Operator {op} is not supported"),
283 location!(),
284 ));
285 }
286 })
287 }
288
289 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
290 Ok(Expr::BinaryExpr(BinaryExpr::new(
291 Box::new(self.parse_sql_expr(left)?),
292 self.binary_op(op)?,
293 Box::new(self.parse_sql_expr(right)?),
294 )))
295 }
296
297 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
298 Ok(match op {
299 UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
300 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
301 }
302
303 UnaryOperator::Minus => {
304 use datafusion::logical_expr::lit;
305 match expr {
306 SQLExpr::Value(Value::Number(n, _)) => match n.parse::<i64>() {
307 Ok(n) => lit(-n),
308 Err(_) => lit(-n
309 .parse::<f64>()
310 .map_err(|_e| {
311 Error::invalid_input(
312 format!("negative operator can be only applied to integer and float operands, got: {n}"),
313 location!(),
314 )
315 })?),
316 },
317 _ => {
318 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
319 }
320 }
321 }
322
323 _ => {
324 return Err(Error::invalid_input(
325 format!("Unary operator '{:?}' is not supported", op),
326 location!(),
327 ));
328 }
329 })
330 }
331
332 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
334 use datafusion::logical_expr::lit;
335 let value: Cow<str> = if negative {
336 Cow::Owned(format!("-{}", value))
337 } else {
338 Cow::Borrowed(value)
339 };
340 if let Ok(n) = value.parse::<i64>() {
341 Ok(lit(n))
342 } else {
343 value.parse::<f64>().map(lit).map_err(|_| {
344 Error::invalid_input(
345 format!("'{value}' is not supported number value."),
346 location!(),
347 )
348 })
349 }
350 }
351
352 fn value(&self, value: &Value) -> Result<Expr> {
353 Ok(match value {
354 Value::Number(v, _) => self.number(v.as_str(), false)?,
355 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
356 Value::HexStringLiteral(hsl) => {
357 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)))
358 }
359 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone()))),
360 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v))),
361 Value::Null => Expr::Literal(ScalarValue::Null),
362 _ => todo!(),
363 })
364 }
365
366 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
367 match func_args {
368 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
369 _ => Err(Error::invalid_input(
370 format!("Unsupported function args: {:?}", func_args),
371 location!(),
372 )),
373 }
374 }
375
376 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
383 match &func.args {
384 FunctionArguments::List(args) => {
385 if func.name.0.len() != 1 {
386 return Err(Error::invalid_input(
387 format!("Function name must have 1 part, got: {:?}", func.name.0),
388 location!(),
389 ));
390 }
391 Ok(Expr::IsNotNull(Box::new(
392 self.parse_function_args(&args.args[0])?,
393 )))
394 }
395 _ => Err(Error::invalid_input(
396 format!("Unsupported function args: {:?}", &func.args),
397 location!(),
398 )),
399 }
400 }
401
402 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
403 if let SQLExpr::Function(function) = &function {
404 if !function.name.0.is_empty() && function.name.0[0].value == "is_valid" {
405 return self.legacy_parse_function(function);
406 }
407 }
408 let sql_to_rel = SqlToRel::new_with_options(
409 &self.context_provider,
410 ParserOptions {
411 parse_float_as_decimal: false,
412 enable_ident_normalization: false,
413 support_varchar_with_length: false,
414 enable_options_value_normalization: false,
415 collect_spans: false,
416 },
417 );
418
419 let mut planner_context = PlannerContext::default();
420 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
421 Ok(sql_to_rel.sql_to_expr(function, &schema, &mut planner_context)?)
422 }
423
424 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
425 const SUPPORTED_TYPES: [&str; 13] = [
426 "int [unsigned]",
427 "tinyint [unsigned]",
428 "smallint [unsigned]",
429 "bigint [unsigned]",
430 "float",
431 "double",
432 "string",
433 "binary",
434 "date",
435 "timestamp(precision)",
436 "datetime(precision)",
437 "decimal(precision,scale)",
438 "boolean",
439 ];
440 match data_type {
441 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
442 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
443 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
444 SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
445 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
446 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
447 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
448 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
449 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
450 SQLDataType::UnsignedTinyInt(_) => Ok(ArrowDataType::UInt8),
451 SQLDataType::UnsignedSmallInt(_) => Ok(ArrowDataType::UInt16),
452 SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => {
453 Ok(ArrowDataType::UInt32)
454 }
455 SQLDataType::UnsignedBigInt(_) => Ok(ArrowDataType::UInt64),
456 SQLDataType::Date => Ok(ArrowDataType::Date32),
457 SQLDataType::Timestamp(resolution, tz) => {
458 match tz {
459 TimezoneInfo::None => {}
460 _ => {
461 return Err(Error::invalid_input(
462 "Timezone not supported in timestamp".to_string(),
463 location!(),
464 ));
465 }
466 };
467 let time_unit = match resolution {
468 None => TimeUnit::Microsecond,
470 Some(0) => TimeUnit::Second,
471 Some(3) => TimeUnit::Millisecond,
472 Some(6) => TimeUnit::Microsecond,
473 Some(9) => TimeUnit::Nanosecond,
474 _ => {
475 return Err(Error::invalid_input(
476 format!("Unsupported datetime resolution: {:?}", resolution),
477 location!(),
478 ));
479 }
480 };
481 Ok(ArrowDataType::Timestamp(time_unit, None))
482 }
483 SQLDataType::Datetime(resolution) => {
484 let time_unit = match resolution {
485 None => TimeUnit::Microsecond,
486 Some(0) => TimeUnit::Second,
487 Some(3) => TimeUnit::Millisecond,
488 Some(6) => TimeUnit::Microsecond,
489 Some(9) => TimeUnit::Nanosecond,
490 _ => {
491 return Err(Error::invalid_input(
492 format!("Unsupported datetime resolution: {:?}", resolution),
493 location!(),
494 ));
495 }
496 };
497 Ok(ArrowDataType::Timestamp(time_unit, None))
498 }
499 SQLDataType::Decimal(number_info) => match number_info {
500 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
501 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
502 }
503 _ => Err(Error::invalid_input(
504 format!(
505 "Must provide precision and scale for decimal: {:?}",
506 number_info
507 ),
508 location!(),
509 )),
510 },
511 _ => Err(Error::invalid_input(
512 format!(
513 "Unsupported data type: {:?}. Supported types: {:?}",
514 data_type, SUPPORTED_TYPES
515 ),
516 location!(),
517 )),
518 }
519 }
520
521 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
522 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
523 for planner in self.context_provider.get_expr_planners() {
524 match planner.plan_field_access(field_access_expr, &df_schema)? {
525 PlannerResult::Planned(expr) => return Ok(expr),
526 PlannerResult::Original(expr) => {
527 field_access_expr = expr;
528 }
529 }
530 }
531 Err(Error::invalid_input(
532 "Field access could not be planned",
533 location!(),
534 ))
535 }
536
537 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
538 match expr {
539 SQLExpr::Identifier(id) => {
540 if id.quote_style == Some('"') {
543 Ok(Expr::Literal(ScalarValue::Utf8(Some(id.value.clone()))))
544 } else if id.quote_style == Some('`') {
547 Ok(Expr::Column(Column::from_name(id.value.clone())))
548 } else {
549 Ok(Self::column(vec![id.clone()].as_slice()))
550 }
551 }
552 SQLExpr::CompoundIdentifier(ids) => Ok(Self::column(ids.as_slice())),
553 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
554 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
555 SQLExpr::Value(value) => self.value(value),
556 SQLExpr::Array(SQLArray { elem, .. }) => {
557 let mut values = vec![];
558
559 let array_literal_error = |pos: usize, value: &_| {
560 Err(Error::invalid_input(
561 format!(
562 "Expected a literal value in array, instead got {} at position {}",
563 value, pos
564 ),
565 location!(),
566 ))
567 };
568
569 for (pos, expr) in elem.iter().enumerate() {
570 match expr {
571 SQLExpr::Value(value) => {
572 if let Expr::Literal(value) = self.value(value)? {
573 values.push(value);
574 } else {
575 return array_literal_error(pos, expr);
576 }
577 }
578 SQLExpr::UnaryOp {
579 op: UnaryOperator::Minus,
580 expr,
581 } => {
582 if let SQLExpr::Value(Value::Number(number, _)) = expr.as_ref() {
583 if let Expr::Literal(value) = self.number(number, true)? {
584 values.push(value);
585 } else {
586 return array_literal_error(pos, expr);
587 }
588 } else {
589 return array_literal_error(pos, expr);
590 }
591 }
592 _ => {
593 return array_literal_error(pos, expr);
594 }
595 }
596 }
597
598 let field = if !values.is_empty() {
599 let data_type = values[0].data_type();
600
601 for value in &mut values {
602 if value.data_type() != data_type {
603 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
604 format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
605 location!()
606 ))?;
607 }
608 }
609 Field::new("item", data_type, true)
610 } else {
611 Field::new("item", ArrowDataType::Null, true)
612 };
613
614 let values = values
615 .into_iter()
616 .map(|v| v.to_array().map_err(Error::from))
617 .collect::<Result<Vec<_>>>()?;
618 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
619 let values = concat(&array_refs)?;
620 let values = ListArray::try_new(
621 field.into(),
622 OffsetBuffer::from_lengths([values.len()]),
623 values,
624 None,
625 )?;
626
627 Ok(Expr::Literal(ScalarValue::List(Arc::new(values))))
628 }
629 SQLExpr::TypedString { data_type, value } => {
631 Ok(Expr::Cast(datafusion::logical_expr::Cast {
632 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value.clone())))),
633 data_type: self.parse_type(data_type)?,
634 }))
635 }
636 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
637 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
638 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
639 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
640 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
641 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
642 SQLExpr::InList {
643 expr,
644 list,
645 negated,
646 } => {
647 let value_expr = self.parse_sql_expr(expr)?;
648 let list_exprs = list
649 .iter()
650 .map(|e| self.parse_sql_expr(e))
651 .collect::<Result<Vec<_>>>()?;
652 Ok(value_expr.in_list(list_exprs, *negated))
653 }
654 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
655 SQLExpr::Function(_) => self.parse_function(expr.clone()),
656 SQLExpr::ILike {
657 negated,
658 expr,
659 pattern,
660 escape_char,
661 any: _,
662 } => Ok(Expr::Like(Like::new(
663 *negated,
664 Box::new(self.parse_sql_expr(expr)?),
665 Box::new(self.parse_sql_expr(pattern)?),
666 escape_char.as_ref().and_then(|c| c.chars().next()),
667 true,
668 ))),
669 SQLExpr::Like {
670 negated,
671 expr,
672 pattern,
673 escape_char,
674 any: _,
675 } => Ok(Expr::Like(Like::new(
676 *negated,
677 Box::new(self.parse_sql_expr(expr)?),
678 Box::new(self.parse_sql_expr(pattern)?),
679 escape_char.as_ref().and_then(|c| c.chars().next()),
680 false,
681 ))),
682 SQLExpr::Cast {
683 expr, data_type, ..
684 } => Ok(Expr::Cast(datafusion::logical_expr::Cast {
685 expr: Box::new(self.parse_sql_expr(expr)?),
686 data_type: self.parse_type(data_type)?,
687 })),
688 SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
689 "JSON access is not supported",
690 location!(),
691 )),
692 SQLExpr::CompoundFieldAccess { root, access_chain } => {
693 let mut expr = self.parse_sql_expr(root)?;
694
695 for access in access_chain {
696 let field_access = match access {
697 AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
699 | AccessExpr::Subscript(Subscript::Index {
700 index:
701 SQLExpr::Value(
702 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
703 ),
704 }) => GetFieldAccess::NamedStructField {
705 name: ScalarValue::from(s.as_str()),
706 },
707 AccessExpr::Subscript(Subscript::Index { index }) => {
708 let key = Box::new(self.parse_sql_expr(index)?);
709 GetFieldAccess::ListIndex { key }
710 }
711 AccessExpr::Subscript(Subscript::Slice { .. }) => {
712 return Err(Error::invalid_input(
713 "Slice subscript is not supported",
714 location!(),
715 ));
716 }
717 _ => {
718 return Err(Error::invalid_input(
721 "Only dot notation or index access is supported for field access",
722 location!(),
723 ));
724 }
725 };
726
727 let field_access_expr = RawFieldAccessExpr { expr, field_access };
728 expr = self.plan_field_access(field_access_expr)?;
729 }
730
731 Ok(expr)
732 }
733 SQLExpr::Between {
734 expr,
735 negated,
736 low,
737 high,
738 } => {
739 let expr = self.parse_sql_expr(expr)?;
741 let low = self.parse_sql_expr(low)?;
742 let high = self.parse_sql_expr(high)?;
743
744 let between = Expr::Between(Between::new(
745 Box::new(expr),
746 *negated,
747 Box::new(low),
748 Box::new(high),
749 ));
750 Ok(between)
751 }
752 _ => Err(Error::invalid_input(
753 format!("Expression '{expr}' is not supported SQL in lance"),
754 location!(),
755 )),
756 }
757 }
758
759 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
764 let ast_expr = parse_sql_filter(filter)?;
766 let expr = self.parse_sql_expr(&ast_expr)?;
767 let schema = Schema::try_from(self.schema.as_ref())?;
768 let resolved = resolve_expr(&expr, &schema)?;
769 coerce_filter_type_to_boolean(resolved)
770 }
771
772 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
777 let ast_expr = parse_sql_expr(expr)?;
778 let expr = self.parse_sql_expr(&ast_expr)?;
779 let schema = Schema::try_from(self.schema.as_ref())?;
780 let resolved = resolve_expr(&expr, &schema)?;
781 Ok(resolved)
782 }
783
784 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
790 let hex_bytes = s.as_bytes();
791 let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
792
793 let start_idx = hex_bytes.len() % 2;
794 if start_idx > 0 {
795 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
797 }
798
799 for i in (start_idx..hex_bytes.len()).step_by(2) {
800 let high = Self::try_decode_hex_char(hex_bytes[i])?;
801 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
802 decoded_bytes.push((high << 4) | low);
803 }
804
805 Some(decoded_bytes)
806 }
807
808 const fn try_decode_hex_char(c: u8) -> Option<u8> {
812 match c {
813 b'A'..=b'F' => Some(c - b'A' + 10),
814 b'a'..=b'f' => Some(c - b'a' + 10),
815 b'0'..=b'9' => Some(c - b'0'),
816 _ => None,
817 }
818 }
819
820 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
822 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
823
824 let props = ExecutionProps::default();
827 let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
828 let simplifier =
829 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
830
831 let expr = simplifier.simplify(expr)?;
832 let expr = simplifier.coerce(expr, &df_schema)?;
833
834 Ok(expr)
835 }
836
837 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
839 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
840
841 Ok(datafusion::physical_expr::create_physical_expr(
842 expr,
843 df_schema.as_ref(),
844 &Default::default(),
845 )?)
846 }
847
848 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
852 let mut visitor = ColumnCapturingVisitor {
853 current_path: VecDeque::new(),
854 columns: BTreeSet::new(),
855 };
856 expr.visit(&mut visitor).unwrap();
857 visitor.columns.into_iter().collect()
858 }
859}
860
861struct ColumnCapturingVisitor {
862 current_path: VecDeque<String>,
864 columns: BTreeSet<String>,
865}
866
867impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
868 type Node = Expr;
869
870 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
871 match node {
872 Expr::Column(Column { name, .. }) => {
873 let mut path = name.clone();
874 for part in self.current_path.drain(..) {
875 path.push('.');
876 path.push_str(&part);
877 }
878 self.columns.insert(path);
879 self.current_path.clear();
880 }
881 Expr::ScalarFunction(udf) => {
882 if udf.name() == GetFieldFunc::default().name() {
883 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
884 self.current_path.push_front(name.to_string())
885 } else {
886 self.current_path.clear();
887 }
888 } else {
889 self.current_path.clear();
890 }
891 }
892 _ => {
893 self.current_path.clear();
894 }
895 }
896
897 Ok(TreeNodeRecursion::Continue)
898 }
899}
900
901#[cfg(test)]
902mod tests {
903
904 use crate::logical_expr::ExprExt;
905
906 use super::*;
907
908 use arrow::datatypes::Float64Type;
909 use arrow_array::{
910 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
911 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
912 TimestampNanosecondArray, TimestampSecondArray,
913 };
914 use arrow_schema::{DataType, Fields, Schema};
915 use datafusion::{
916 logical_expr::{lit, Cast},
917 prelude::{array_element, get_field},
918 };
919 use datafusion_functions::core::expr_ext::FieldAccessor;
920
921 #[test]
922 fn test_parse_filter_simple() {
923 let schema = Arc::new(Schema::new(vec![
924 Field::new("i", DataType::Int32, false),
925 Field::new("s", DataType::Utf8, true),
926 Field::new(
927 "st",
928 DataType::Struct(Fields::from(vec![
929 Field::new("x", DataType::Float32, false),
930 Field::new("y", DataType::Float32, false),
931 ])),
932 true,
933 ),
934 ]));
935
936 let planner = Planner::new(schema.clone());
937
938 let expected = col("i")
939 .gt(lit(3_i32))
940 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
941 .and(
942 col("s")
943 .eq(lit("str-4"))
944 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
945 );
946
947 let expr = planner
949 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
950 .unwrap();
951 assert_eq!(expr, expected);
952
953 let expr = planner
955 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
956 .unwrap();
957
958 let physical_expr = planner.create_physical_expr(&expr).unwrap();
959
960 let batch = RecordBatch::try_new(
961 schema,
962 vec![
963 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
964 Arc::new(StringArray::from_iter_values(
965 (0..10).map(|v| format!("str-{}", v)),
966 )),
967 Arc::new(StructArray::from(vec![
968 (
969 Arc::new(Field::new("x", DataType::Float32, false)),
970 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
971 as ArrayRef,
972 ),
973 (
974 Arc::new(Field::new("y", DataType::Float32, false)),
975 Arc::new(Float32Array::from_iter_values(
976 (0..10).map(|v| (v * 10) as f32),
977 )),
978 ),
979 ])),
980 ],
981 )
982 .unwrap();
983 let predicates = physical_expr.evaluate(&batch).unwrap();
984 assert_eq!(
985 predicates.into_array(0).unwrap().as_ref(),
986 &BooleanArray::from(vec![
987 false, false, false, false, true, true, false, false, false, false
988 ])
989 );
990 }
991
992 #[test]
993 fn test_nested_col_refs() {
994 let schema = Arc::new(Schema::new(vec![
995 Field::new("s0", DataType::Utf8, true),
996 Field::new(
997 "st",
998 DataType::Struct(Fields::from(vec![
999 Field::new("s1", DataType::Utf8, true),
1000 Field::new(
1001 "st",
1002 DataType::Struct(Fields::from(vec![Field::new(
1003 "s2",
1004 DataType::Utf8,
1005 true,
1006 )])),
1007 true,
1008 ),
1009 ])),
1010 true,
1011 ),
1012 ]));
1013
1014 let planner = Planner::new(schema);
1015
1016 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1017 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1018 assert!(matches!(
1019 expr,
1020 Expr::BinaryExpr(BinaryExpr {
1021 left: _,
1022 op: Operator::Eq,
1023 right: _
1024 })
1025 ));
1026 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1027 assert_eq!(left.as_ref(), expected);
1028 }
1029 }
1030
1031 let expected = Expr::Column(Column::new_unqualified("s0"));
1032 assert_column_eq(&planner, "s0", &expected);
1033 assert_column_eq(&planner, "`s0`", &expected);
1034
1035 let expected = Expr::ScalarFunction(ScalarFunction {
1036 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1037 args: vec![
1038 Expr::Column(Column::new_unqualified("st")),
1039 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string()))),
1040 ],
1041 });
1042 assert_column_eq(&planner, "st.s1", &expected);
1043 assert_column_eq(&planner, "`st`.`s1`", &expected);
1044 assert_column_eq(&planner, "st.`s1`", &expected);
1045
1046 let expected = Expr::ScalarFunction(ScalarFunction {
1047 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1048 args: vec![
1049 Expr::ScalarFunction(ScalarFunction {
1050 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1051 args: vec![
1052 Expr::Column(Column::new_unqualified("st")),
1053 Expr::Literal(ScalarValue::Utf8(Some("st".to_string()))),
1054 ],
1055 }),
1056 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string()))),
1057 ],
1058 });
1059
1060 assert_column_eq(&planner, "st.st.s2", &expected);
1061 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1062 assert_column_eq(&planner, "st.st.`s2`", &expected);
1063 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1064 }
1065
1066 #[test]
1067 fn test_nested_list_refs() {
1068 let schema = Arc::new(Schema::new(vec![Field::new(
1069 "l",
1070 DataType::List(Arc::new(Field::new(
1071 "item",
1072 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1073 true,
1074 ))),
1075 true,
1076 )]));
1077
1078 let planner = Planner::new(schema);
1079
1080 let expected = array_element(col("l"), lit(0_i64));
1081 let expr = planner.parse_expr("l[0]").unwrap();
1082 assert_eq!(expr, expected);
1083
1084 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1085 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1086 assert_eq!(expr, expected);
1087
1088 }
1093
1094 #[test]
1095 fn test_negative_expressions() {
1096 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1097
1098 let planner = Planner::new(schema.clone());
1099
1100 let expected = col("x")
1101 .gt(lit(-3_i64))
1102 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1103
1104 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1105
1106 assert_eq!(expr, expected);
1107
1108 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1109
1110 let batch = RecordBatch::try_new(
1111 schema,
1112 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1113 )
1114 .unwrap();
1115 let predicates = physical_expr.evaluate(&batch).unwrap();
1116 assert_eq!(
1117 predicates.into_array(0).unwrap().as_ref(),
1118 &BooleanArray::from(vec![
1119 false, false, false, true, true, true, true, false, false, false
1120 ])
1121 );
1122 }
1123
1124 #[test]
1125 fn test_negative_array_expressions() {
1126 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1127
1128 let planner = Planner::new(schema);
1129
1130 let expected = Expr::Literal(ScalarValue::List(Arc::new(
1131 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1132 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1133 )]),
1134 )));
1135
1136 let expr = planner
1137 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1138 .unwrap();
1139
1140 assert_eq!(expr, expected);
1141 }
1142
1143 #[test]
1144 fn test_sql_like() {
1145 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1146
1147 let planner = Planner::new(schema.clone());
1148
1149 let expected = col("s").like(lit("str-4"));
1150 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1152 assert_eq!(expr, expected);
1153 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1154
1155 let batch = RecordBatch::try_new(
1156 schema,
1157 vec![Arc::new(StringArray::from_iter_values(
1158 (0..10).map(|v| format!("str-{}", v)),
1159 ))],
1160 )
1161 .unwrap();
1162 let predicates = physical_expr.evaluate(&batch).unwrap();
1163 assert_eq!(
1164 predicates.into_array(0).unwrap().as_ref(),
1165 &BooleanArray::from(vec![
1166 false, false, false, false, true, false, false, false, false, false
1167 ])
1168 );
1169 }
1170
1171 #[test]
1172 fn test_not_like() {
1173 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1174
1175 let planner = Planner::new(schema.clone());
1176
1177 let expected = col("s").not_like(lit("str-4"));
1178 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1180 assert_eq!(expr, expected);
1181 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1182
1183 let batch = RecordBatch::try_new(
1184 schema,
1185 vec![Arc::new(StringArray::from_iter_values(
1186 (0..10).map(|v| format!("str-{}", v)),
1187 ))],
1188 )
1189 .unwrap();
1190 let predicates = physical_expr.evaluate(&batch).unwrap();
1191 assert_eq!(
1192 predicates.into_array(0).unwrap().as_ref(),
1193 &BooleanArray::from(vec![
1194 true, true, true, true, false, true, true, true, true, true
1195 ])
1196 );
1197 }
1198
1199 #[test]
1200 fn test_sql_is_in() {
1201 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1202
1203 let planner = Planner::new(schema.clone());
1204
1205 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1206 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1208 assert_eq!(expr, expected);
1209 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1210
1211 let batch = RecordBatch::try_new(
1212 schema,
1213 vec![Arc::new(StringArray::from_iter_values(
1214 (0..10).map(|v| format!("str-{}", v)),
1215 ))],
1216 )
1217 .unwrap();
1218 let predicates = physical_expr.evaluate(&batch).unwrap();
1219 assert_eq!(
1220 predicates.into_array(0).unwrap().as_ref(),
1221 &BooleanArray::from(vec![
1222 false, false, false, false, true, true, false, false, false, false
1223 ])
1224 );
1225 }
1226
1227 #[test]
1228 fn test_sql_is_null() {
1229 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1230
1231 let planner = Planner::new(schema.clone());
1232
1233 let expected = col("s").is_null();
1234 let expr = planner.parse_filter("s IS NULL").unwrap();
1235 assert_eq!(expr, expected);
1236 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1237
1238 let batch = RecordBatch::try_new(
1239 schema,
1240 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1241 if v % 3 == 0 {
1242 Some(format!("str-{}", v))
1243 } else {
1244 None
1245 }
1246 })))],
1247 )
1248 .unwrap();
1249 let predicates = physical_expr.evaluate(&batch).unwrap();
1250 assert_eq!(
1251 predicates.into_array(0).unwrap().as_ref(),
1252 &BooleanArray::from(vec![
1253 false, true, true, false, true, true, false, true, true, false
1254 ])
1255 );
1256
1257 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1258 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1259 let predicates = physical_expr.evaluate(&batch).unwrap();
1260 assert_eq!(
1261 predicates.into_array(0).unwrap().as_ref(),
1262 &BooleanArray::from(vec![
1263 true, false, false, true, false, false, true, false, false, true,
1264 ])
1265 );
1266 }
1267
1268 #[test]
1269 fn test_sql_invert() {
1270 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1271
1272 let planner = Planner::new(schema.clone());
1273
1274 let expr = planner.parse_filter("NOT s").unwrap();
1275 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1276
1277 let batch = RecordBatch::try_new(
1278 schema,
1279 vec![Arc::new(BooleanArray::from_iter(
1280 (0..10).map(|v| Some(v % 3 == 0)),
1281 ))],
1282 )
1283 .unwrap();
1284 let predicates = physical_expr.evaluate(&batch).unwrap();
1285 assert_eq!(
1286 predicates.into_array(0).unwrap().as_ref(),
1287 &BooleanArray::from(vec![
1288 false, true, true, false, true, true, false, true, true, false
1289 ])
1290 );
1291 }
1292
1293 #[test]
1294 fn test_sql_cast() {
1295 let cases = &[
1296 (
1297 "x = cast('2021-01-01 00:00:00' as timestamp)",
1298 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1299 ),
1300 (
1301 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1302 ArrowDataType::Timestamp(TimeUnit::Second, None),
1303 ),
1304 (
1305 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1306 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1307 ),
1308 (
1309 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1310 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1311 ),
1312 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1313 (
1314 "x = cast('1.238' as decimal(9,3))",
1315 ArrowDataType::Decimal128(9, 3),
1316 ),
1317 ("x = cast(1 as float)", ArrowDataType::Float32),
1318 ("x = cast(1 as double)", ArrowDataType::Float64),
1319 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1320 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1321 ("x = cast(1 as int)", ArrowDataType::Int32),
1322 ("x = cast(1 as integer)", ArrowDataType::Int32),
1323 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1324 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1325 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1326 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1327 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1328 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1329 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1330 ("x = cast(1 as string)", ArrowDataType::Utf8),
1331 ];
1332
1333 for (sql, expected_data_type) in cases {
1334 let schema = Arc::new(Schema::new(vec![Field::new(
1335 "x",
1336 expected_data_type.clone(),
1337 true,
1338 )]));
1339 let planner = Planner::new(schema.clone());
1340 let expr = planner.parse_filter(sql).unwrap();
1341
1342 let expected_value_str = sql
1344 .split("cast(")
1345 .nth(1)
1346 .unwrap()
1347 .split(" as")
1348 .next()
1349 .unwrap();
1350 let expected_value_str = expected_value_str.trim_matches('\'');
1352
1353 match expr {
1354 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1355 Expr::Cast(Cast { expr, data_type }) => {
1356 match expr.as_ref() {
1357 Expr::Literal(ScalarValue::Utf8(Some(value_str))) => {
1358 assert_eq!(value_str, expected_value_str);
1359 }
1360 Expr::Literal(ScalarValue::Int64(Some(value))) => {
1361 assert_eq!(*value, 1);
1362 }
1363 _ => panic!("Expected cast to be applied to literal"),
1364 }
1365 assert_eq!(data_type, expected_data_type);
1366 }
1367 _ => panic!("Expected right to be a cast"),
1368 },
1369 _ => panic!("Expected binary expression"),
1370 }
1371 }
1372 }
1373
1374 #[test]
1375 fn test_sql_literals() {
1376 let cases = &[
1377 (
1378 "x = timestamp '2021-01-01 00:00:00'",
1379 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1380 ),
1381 (
1382 "x = timestamp(0) '2021-01-01 00:00:00'",
1383 ArrowDataType::Timestamp(TimeUnit::Second, None),
1384 ),
1385 (
1386 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1387 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1388 ),
1389 ("x = date '2021-01-01'", ArrowDataType::Date32),
1390 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1391 ];
1392
1393 for (sql, expected_data_type) in cases {
1394 let schema = Arc::new(Schema::new(vec![Field::new(
1395 "x",
1396 expected_data_type.clone(),
1397 true,
1398 )]));
1399 let planner = Planner::new(schema.clone());
1400 let expr = planner.parse_filter(sql).unwrap();
1401
1402 let expected_value_str = sql.split('\'').nth(1).unwrap();
1403
1404 match expr {
1405 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1406 Expr::Cast(Cast { expr, data_type }) => {
1407 match expr.as_ref() {
1408 Expr::Literal(ScalarValue::Utf8(Some(value_str))) => {
1409 assert_eq!(value_str, expected_value_str);
1410 }
1411 _ => panic!("Expected cast to be applied to literal"),
1412 }
1413 assert_eq!(data_type, expected_data_type);
1414 }
1415 _ => panic!("Expected right to be a cast"),
1416 },
1417 _ => panic!("Expected binary expression"),
1418 }
1419 }
1420 }
1421
1422 #[test]
1423 fn test_sql_array_literals() {
1424 let cases = [
1425 (
1426 "x = [1, 2, 3]",
1427 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1428 ),
1429 (
1430 "x = [1, 2, 3]",
1431 ArrowDataType::FixedSizeList(
1432 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1433 3,
1434 ),
1435 ),
1436 ];
1437
1438 for (sql, expected_data_type) in cases {
1439 let schema = Arc::new(Schema::new(vec![Field::new(
1440 "x",
1441 expected_data_type.clone(),
1442 true,
1443 )]));
1444 let planner = Planner::new(schema.clone());
1445 let expr = planner.parse_filter(sql).unwrap();
1446 let expr = planner.optimize_expr(expr).unwrap();
1447
1448 match expr {
1449 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1450 Expr::Literal(value) => {
1451 assert_eq!(&value.data_type(), &expected_data_type);
1452 }
1453 _ => panic!("Expected right to be a literal"),
1454 },
1455 _ => panic!("Expected binary expression"),
1456 }
1457 }
1458 }
1459
1460 #[test]
1461 fn test_sql_between() {
1462 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1463 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1464 use std::sync::Arc;
1465
1466 let schema = Arc::new(Schema::new(vec![
1467 Field::new("x", DataType::Int32, false),
1468 Field::new("y", DataType::Float64, false),
1469 Field::new(
1470 "ts",
1471 DataType::Timestamp(TimeUnit::Microsecond, None),
1472 false,
1473 ),
1474 ]));
1475
1476 let planner = Planner::new(schema.clone());
1477
1478 let expr = planner
1480 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1481 .unwrap();
1482 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1483
1484 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1488 (0..10).map(|i| base_ts + i * 1_000_000), );
1490
1491 let batch = RecordBatch::try_new(
1492 schema,
1493 vec![
1494 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1495 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1496 Arc::new(ts_array),
1497 ],
1498 )
1499 .unwrap();
1500
1501 let predicates = physical_expr.evaluate(&batch).unwrap();
1502 assert_eq!(
1503 predicates.into_array(0).unwrap().as_ref(),
1504 &BooleanArray::from(vec![
1505 false, false, false, true, true, true, true, true, false, false
1506 ])
1507 );
1508
1509 let expr = planner
1511 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1512 .unwrap();
1513 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1514
1515 let predicates = physical_expr.evaluate(&batch).unwrap();
1516 assert_eq!(
1517 predicates.into_array(0).unwrap().as_ref(),
1518 &BooleanArray::from(vec![
1519 true, true, true, false, false, false, false, false, true, true
1520 ])
1521 );
1522
1523 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1525 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1526
1527 let predicates = physical_expr.evaluate(&batch).unwrap();
1528 assert_eq!(
1529 predicates.into_array(0).unwrap().as_ref(),
1530 &BooleanArray::from(vec![
1531 false, false, false, true, true, true, true, false, false, false
1532 ])
1533 );
1534
1535 let expr = planner
1537 .parse_filter(
1538 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1539 )
1540 .unwrap();
1541 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1542
1543 let predicates = physical_expr.evaluate(&batch).unwrap();
1544 assert_eq!(
1545 predicates.into_array(0).unwrap().as_ref(),
1546 &BooleanArray::from(vec![
1547 false, false, false, true, true, true, true, true, false, false
1548 ])
1549 );
1550 }
1551
1552 #[test]
1553 fn test_sql_comparison() {
1554 let batch: Vec<(&str, ArrayRef)> = vec![
1556 (
1557 "timestamp_s",
1558 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1559 ),
1560 (
1561 "timestamp_ms",
1562 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1563 ),
1564 (
1565 "timestamp_us",
1566 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1567 ),
1568 (
1569 "timestamp_ns",
1570 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1571 ),
1572 ];
1573 let batch = RecordBatch::try_from_iter(batch).unwrap();
1574
1575 let planner = Planner::new(batch.schema());
1576
1577 let expressions = &[
1579 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1580 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1581 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1582 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1583 ];
1584
1585 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1586 std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1587 ));
1588 for expression in expressions {
1589 let logical_expr = planner.parse_filter(expression).unwrap();
1591 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1592 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1593
1594 let result = physical_expr.evaluate(&batch).unwrap();
1596 let result = result.into_array(batch.num_rows()).unwrap();
1597 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1598 }
1599 }
1600
1601 #[test]
1602 fn test_columns_in_expr() {
1603 let expr = col("s0").gt(lit("value")).and(
1604 col("st")
1605 .field("st")
1606 .field("s2")
1607 .eq(lit("value"))
1608 .or(col("st")
1609 .field("s1")
1610 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1611 );
1612
1613 let columns = Planner::column_names_in_expr(&expr);
1614 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1615 }
1616
1617 #[test]
1618 fn test_parse_binary_expr() {
1619 let bin_str = "x'616263'";
1620
1621 let schema = Arc::new(Schema::new(vec![Field::new(
1622 "binary",
1623 DataType::Binary,
1624 true,
1625 )]));
1626 let planner = Planner::new(schema);
1627 let expr = planner.parse_expr(bin_str).unwrap();
1628 assert_eq!(
1629 expr,
1630 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])))
1631 );
1632 }
1633
1634 #[test]
1635 fn test_lance_context_provider_expr_planners() {
1636 let ctx_provider = LanceContextProvider::default();
1637 assert!(!ctx_provider.get_expr_planners().is_empty());
1638 }
1639}