1use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::SessionState;
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30 Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
34use datafusion::sql::sqlparser::ast::{
35 AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36 Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
37 ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
38};
39use datafusion::{
40 common::Column,
41 logical_expr::{col, Between, BinaryExpr, Like, Operator},
42 physical_expr::execution_props::ExecutionProps,
43 physical_plan::PhysicalExpr,
44 prelude::Expr,
45 scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use lance_core::error::LanceOptionExt;
51use snafu::location;
52
53use lance_core::{Error, Result};
54
55#[derive(Debug, Clone)]
56struct CastListF16Udf {
57 signature: Signature,
58}
59
60impl CastListF16Udf {
61 pub fn new() -> Self {
62 Self {
63 signature: Signature::any(1, Volatility::Immutable),
64 }
65 }
66}
67
68impl ScalarUDFImpl for CastListF16Udf {
69 fn as_any(&self) -> &dyn std::any::Any {
70 self
71 }
72
73 fn name(&self) -> &str {
74 "_cast_list_f16"
75 }
76
77 fn signature(&self) -> &Signature {
78 &self.signature
79 }
80
81 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
82 let input = &arg_types[0];
83 match input {
84 ArrowDataType::FixedSizeList(field, size) => {
85 if field.data_type() != &ArrowDataType::Float32
86 && field.data_type() != &ArrowDataType::Float16
87 {
88 return Err(datafusion::error::DataFusionError::Execution(
89 "cast_list_f16 only supports list of float32 or float16".to_string(),
90 ));
91 }
92 Ok(ArrowDataType::FixedSizeList(
93 Arc::new(Field::new(
94 field.name(),
95 ArrowDataType::Float16,
96 field.is_nullable(),
97 )),
98 *size,
99 ))
100 }
101 ArrowDataType::List(field) => {
102 if field.data_type() != &ArrowDataType::Float32
103 && field.data_type() != &ArrowDataType::Float16
104 {
105 return Err(datafusion::error::DataFusionError::Execution(
106 "cast_list_f16 only supports list of float32 or float16".to_string(),
107 ));
108 }
109 Ok(ArrowDataType::List(Arc::new(Field::new(
110 field.name(),
111 ArrowDataType::Float16,
112 field.is_nullable(),
113 ))))
114 }
115 _ => Err(datafusion::error::DataFusionError::Execution(
116 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
117 )),
118 }
119 }
120
121 fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
122 let ColumnarValue::Array(arr) = &func_args.args[0] else {
123 return Err(datafusion::error::DataFusionError::Execution(
124 "cast_list_f16 only supports array arguments".to_string(),
125 ));
126 };
127
128 let to_type = match arr.data_type() {
129 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
130 Arc::new(Field::new(
131 field.name(),
132 ArrowDataType::Float16,
133 field.is_nullable(),
134 )),
135 *size,
136 ),
137 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
138 field.name(),
139 ArrowDataType::Float16,
140 field.is_nullable(),
141 ))),
142 _ => {
143 return Err(datafusion::error::DataFusionError::Execution(
144 "cast_list_f16 only supports array arguments".to_string(),
145 ));
146 }
147 };
148
149 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
150 Ok(ColumnarValue::Array(res))
151 }
152}
153
154struct LanceContextProvider {
156 options: datafusion::config::ConfigOptions,
157 state: SessionState,
158 expr_planners: Vec<Arc<dyn ExprPlanner>>,
159}
160
161impl Default for LanceContextProvider {
162 fn default() -> Self {
163 let config = SessionConfig::new();
164 let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
165 let mut state_builder = SessionStateBuilder::new()
166 .with_config(config)
167 .with_runtime_env(runtime)
168 .with_default_features();
169
170 let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
175
176 Self {
177 options: ConfigOptions::default(),
178 state: state_builder.build(),
179 expr_planners,
180 }
181 }
182}
183
184impl ContextProvider for LanceContextProvider {
185 fn get_table_source(
186 &self,
187 name: datafusion::sql::TableReference,
188 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
189 Err(datafusion::error::DataFusionError::NotImplemented(format!(
190 "Attempt to reference inner table {} not supported",
191 name
192 )))
193 }
194
195 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
196 self.state.aggregate_functions().get(name).cloned()
197 }
198
199 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
200 self.state.window_functions().get(name).cloned()
201 }
202
203 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
204 match f {
205 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
208 _ => self.state.scalar_functions().get(f).cloned(),
209 }
210 }
211
212 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
213 None
215 }
216
217 fn options(&self) -> &datafusion::config::ConfigOptions {
218 &self.options
219 }
220
221 fn udf_names(&self) -> Vec<String> {
222 self.state.scalar_functions().keys().cloned().collect()
223 }
224
225 fn udaf_names(&self) -> Vec<String> {
226 self.state.aggregate_functions().keys().cloned().collect()
227 }
228
229 fn udwf_names(&self) -> Vec<String> {
230 self.state.window_functions().keys().cloned().collect()
231 }
232
233 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
234 &self.expr_planners
235 }
236}
237
238pub struct Planner {
239 schema: SchemaRef,
240 context_provider: LanceContextProvider,
241 enable_relations: bool,
242}
243
244impl Planner {
245 pub fn new(schema: SchemaRef) -> Self {
246 Self {
247 schema,
248 context_provider: LanceContextProvider::default(),
249 enable_relations: false,
250 }
251 }
252
253 pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
259 self.enable_relations = enable_relations;
260 self
261 }
262
263 fn column(&self, idents: &[Ident]) -> Expr {
264 fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
265 for ident in idents {
266 *expr = Expr::ScalarFunction(ScalarFunction {
267 args: vec![
268 std::mem::take(expr),
269 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
270 ],
271 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
272 });
273 }
274 }
275
276 if self.enable_relations && idents.len() > 1 {
277 let relation = &idents[0].value;
279 let column_name = &idents[1].value;
280 let column = Expr::Column(Column::new(Some(relation.clone()), column_name.clone()));
281 let mut result = column;
282 handle_remaining_idents(&mut result, &idents[2..]);
283 result
284 } else {
285 let mut column = col(&idents[0].value);
287 handle_remaining_idents(&mut column, &idents[1..]);
288 column
289 }
290 }
291
292 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
293 Ok(match op {
294 BinaryOperator::Plus => Operator::Plus,
295 BinaryOperator::Minus => Operator::Minus,
296 BinaryOperator::Multiply => Operator::Multiply,
297 BinaryOperator::Divide => Operator::Divide,
298 BinaryOperator::Modulo => Operator::Modulo,
299 BinaryOperator::StringConcat => Operator::StringConcat,
300 BinaryOperator::Gt => Operator::Gt,
301 BinaryOperator::Lt => Operator::Lt,
302 BinaryOperator::GtEq => Operator::GtEq,
303 BinaryOperator::LtEq => Operator::LtEq,
304 BinaryOperator::Eq => Operator::Eq,
305 BinaryOperator::NotEq => Operator::NotEq,
306 BinaryOperator::And => Operator::And,
307 BinaryOperator::Or => Operator::Or,
308 _ => {
309 return Err(Error::invalid_input(
310 format!("Operator {op} is not supported"),
311 location!(),
312 ));
313 }
314 })
315 }
316
317 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
318 Ok(Expr::BinaryExpr(BinaryExpr::new(
319 Box::new(self.parse_sql_expr(left)?),
320 self.binary_op(op)?,
321 Box::new(self.parse_sql_expr(right)?),
322 )))
323 }
324
325 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
326 Ok(match op {
327 UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
328 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
329 }
330
331 UnaryOperator::Minus => {
332 use datafusion::logical_expr::lit;
333 match expr {
334 SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
335 Ok(n) => lit(-n),
336 Err(_) => lit(-n
337 .parse::<f64>()
338 .map_err(|_e| {
339 Error::invalid_input(
340 format!("negative operator can be only applied to integer and float operands, got: {n}"),
341 location!(),
342 )
343 })?),
344 },
345 _ => {
346 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
347 }
348 }
349 }
350
351 _ => {
352 return Err(Error::invalid_input(
353 format!("Unary operator '{:?}' is not supported", op),
354 location!(),
355 ));
356 }
357 })
358 }
359
360 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
362 use datafusion::logical_expr::lit;
363 let value: Cow<str> = if negative {
364 Cow::Owned(format!("-{}", value))
365 } else {
366 Cow::Borrowed(value)
367 };
368 if let Ok(n) = value.parse::<i64>() {
369 Ok(lit(n))
370 } else {
371 value.parse::<f64>().map(lit).map_err(|_| {
372 Error::invalid_input(
373 format!("'{value}' is not supported number value."),
374 location!(),
375 )
376 })
377 }
378 }
379
380 fn value(&self, value: &Value) -> Result<Expr> {
381 Ok(match value {
382 Value::Number(v, _) => self.number(v.as_str(), false)?,
383 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
384 Value::HexStringLiteral(hsl) => {
385 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
386 }
387 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
388 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
389 Value::Null => Expr::Literal(ScalarValue::Null, None),
390 _ => todo!(),
391 })
392 }
393
394 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
395 match func_args {
396 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
397 _ => Err(Error::invalid_input(
398 format!("Unsupported function args: {:?}", func_args),
399 location!(),
400 )),
401 }
402 }
403
404 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
411 match &func.args {
412 FunctionArguments::List(args) => {
413 if func.name.0.len() != 1 {
414 return Err(Error::invalid_input(
415 format!("Function name must have 1 part, got: {:?}", func.name.0),
416 location!(),
417 ));
418 }
419 Ok(Expr::IsNotNull(Box::new(
420 self.parse_function_args(&args.args[0])?,
421 )))
422 }
423 _ => Err(Error::invalid_input(
424 format!("Unsupported function args: {:?}", &func.args),
425 location!(),
426 )),
427 }
428 }
429
430 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
431 if let SQLExpr::Function(function) = &function {
432 if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
433 if &name.value == "is_valid" {
434 return self.legacy_parse_function(function);
435 }
436 }
437 }
438 let sql_to_rel = SqlToRel::new_with_options(
439 &self.context_provider,
440 ParserOptions {
441 parse_float_as_decimal: false,
442 enable_ident_normalization: false,
443 support_varchar_with_length: false,
444 enable_options_value_normalization: false,
445 collect_spans: false,
446 map_varchar_to_utf8view: false,
447 },
448 );
449
450 let mut planner_context = PlannerContext::default();
451 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
452 sql_to_rel
453 .sql_to_expr(function, &schema, &mut planner_context)
454 .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
455 }
456
457 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
458 const SUPPORTED_TYPES: [&str; 13] = [
459 "int [unsigned]",
460 "tinyint [unsigned]",
461 "smallint [unsigned]",
462 "bigint [unsigned]",
463 "float",
464 "double",
465 "string",
466 "binary",
467 "date",
468 "timestamp(precision)",
469 "datetime(precision)",
470 "decimal(precision,scale)",
471 "boolean",
472 ];
473 match data_type {
474 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
475 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
476 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
477 SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
478 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
479 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
480 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
481 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
482 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
483 SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
484 SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
485 SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
486 Ok(ArrowDataType::UInt32)
487 }
488 SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
489 SQLDataType::Date => Ok(ArrowDataType::Date32),
490 SQLDataType::Timestamp(resolution, tz) => {
491 match tz {
492 TimezoneInfo::None => {}
493 _ => {
494 return Err(Error::invalid_input(
495 "Timezone not supported in timestamp".to_string(),
496 location!(),
497 ));
498 }
499 };
500 let time_unit = match resolution {
501 None => TimeUnit::Microsecond,
503 Some(0) => TimeUnit::Second,
504 Some(3) => TimeUnit::Millisecond,
505 Some(6) => TimeUnit::Microsecond,
506 Some(9) => TimeUnit::Nanosecond,
507 _ => {
508 return Err(Error::invalid_input(
509 format!("Unsupported datetime resolution: {:?}", resolution),
510 location!(),
511 ));
512 }
513 };
514 Ok(ArrowDataType::Timestamp(time_unit, None))
515 }
516 SQLDataType::Datetime(resolution) => {
517 let time_unit = match resolution {
518 None => TimeUnit::Microsecond,
519 Some(0) => TimeUnit::Second,
520 Some(3) => TimeUnit::Millisecond,
521 Some(6) => TimeUnit::Microsecond,
522 Some(9) => TimeUnit::Nanosecond,
523 _ => {
524 return Err(Error::invalid_input(
525 format!("Unsupported datetime resolution: {:?}", resolution),
526 location!(),
527 ));
528 }
529 };
530 Ok(ArrowDataType::Timestamp(time_unit, None))
531 }
532 SQLDataType::Decimal(number_info) => match number_info {
533 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
534 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
535 }
536 _ => Err(Error::invalid_input(
537 format!(
538 "Must provide precision and scale for decimal: {:?}",
539 number_info
540 ),
541 location!(),
542 )),
543 },
544 _ => Err(Error::invalid_input(
545 format!(
546 "Unsupported data type: {:?}. Supported types: {:?}",
547 data_type, SUPPORTED_TYPES
548 ),
549 location!(),
550 )),
551 }
552 }
553
554 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
555 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
556 for planner in self.context_provider.get_expr_planners() {
557 match planner.plan_field_access(field_access_expr, &df_schema)? {
558 PlannerResult::Planned(expr) => return Ok(expr),
559 PlannerResult::Original(expr) => {
560 field_access_expr = expr;
561 }
562 }
563 }
564 Err(Error::invalid_input(
565 "Field access could not be planned",
566 location!(),
567 ))
568 }
569
570 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
571 match expr {
572 SQLExpr::Identifier(id) => {
573 if id.quote_style == Some('"') {
576 Ok(Expr::Literal(
577 ScalarValue::Utf8(Some(id.value.clone())),
578 None,
579 ))
580 } else if id.quote_style == Some('`') {
583 Ok(Expr::Column(Column::from_name(id.value.clone())))
584 } else {
585 Ok(self.column(vec![id.clone()].as_slice()))
586 }
587 }
588 SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
589 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
590 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
591 SQLExpr::Value(value) => self.value(&value.value),
592 SQLExpr::Array(SQLArray { elem, .. }) => {
593 let mut values = vec![];
594
595 let array_literal_error = |pos: usize, value: &_| {
596 Err(Error::invalid_input(
597 format!(
598 "Expected a literal value in array, instead got {} at position {}",
599 value, pos
600 ),
601 location!(),
602 ))
603 };
604
605 for (pos, expr) in elem.iter().enumerate() {
606 match expr {
607 SQLExpr::Value(value) => {
608 if let Expr::Literal(value, _) = self.value(&value.value)? {
609 values.push(value);
610 } else {
611 return array_literal_error(pos, expr);
612 }
613 }
614 SQLExpr::UnaryOp {
615 op: UnaryOperator::Minus,
616 expr,
617 } => {
618 if let SQLExpr::Value(ValueWithSpan {
619 value: Value::Number(number, _),
620 ..
621 }) = expr.as_ref()
622 {
623 if let Expr::Literal(value, _) = self.number(number, true)? {
624 values.push(value);
625 } else {
626 return array_literal_error(pos, expr);
627 }
628 } else {
629 return array_literal_error(pos, expr);
630 }
631 }
632 _ => {
633 return array_literal_error(pos, expr);
634 }
635 }
636 }
637
638 let field = if !values.is_empty() {
639 let data_type = values[0].data_type();
640
641 for value in &mut values {
642 if value.data_type() != data_type {
643 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
644 format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
645 location!()
646 ))?;
647 }
648 }
649 Field::new("item", data_type, true)
650 } else {
651 Field::new("item", ArrowDataType::Null, true)
652 };
653
654 let values = values
655 .into_iter()
656 .map(|v| v.to_array().map_err(Error::from))
657 .collect::<Result<Vec<_>>>()?;
658 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
659 let values = concat(&array_refs)?;
660 let values = ListArray::try_new(
661 field.into(),
662 OffsetBuffer::from_lengths([values.len()]),
663 values,
664 None,
665 )?;
666
667 Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
668 }
669 SQLExpr::TypedString { data_type, value } => {
671 let value = value.clone().into_string().expect_ok()?;
672 Ok(Expr::Cast(datafusion::logical_expr::Cast {
673 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
674 data_type: self.parse_type(data_type)?,
675 }))
676 }
677 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
678 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
679 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
680 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
681 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
682 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
683 SQLExpr::InList {
684 expr,
685 list,
686 negated,
687 } => {
688 let value_expr = self.parse_sql_expr(expr)?;
689 let list_exprs = list
690 .iter()
691 .map(|e| self.parse_sql_expr(e))
692 .collect::<Result<Vec<_>>>()?;
693 Ok(value_expr.in_list(list_exprs, *negated))
694 }
695 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
696 SQLExpr::Function(_) => self.parse_function(expr.clone()),
697 SQLExpr::ILike {
698 negated,
699 expr,
700 pattern,
701 escape_char,
702 any: _,
703 } => Ok(Expr::Like(Like::new(
704 *negated,
705 Box::new(self.parse_sql_expr(expr)?),
706 Box::new(self.parse_sql_expr(pattern)?),
707 escape_char.as_ref().and_then(|c| c.chars().next()),
708 true,
709 ))),
710 SQLExpr::Like {
711 negated,
712 expr,
713 pattern,
714 escape_char,
715 any: _,
716 } => Ok(Expr::Like(Like::new(
717 *negated,
718 Box::new(self.parse_sql_expr(expr)?),
719 Box::new(self.parse_sql_expr(pattern)?),
720 escape_char.as_ref().and_then(|c| c.chars().next()),
721 false,
722 ))),
723 SQLExpr::Cast {
724 expr,
725 data_type,
726 kind,
727 ..
728 } => match kind {
729 datafusion::sql::sqlparser::ast::CastKind::TryCast
730 | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
731 Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
732 expr: Box::new(self.parse_sql_expr(expr)?),
733 data_type: self.parse_type(data_type)?,
734 }))
735 }
736 _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
737 expr: Box::new(self.parse_sql_expr(expr)?),
738 data_type: self.parse_type(data_type)?,
739 })),
740 },
741 SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
742 "JSON access is not supported",
743 location!(),
744 )),
745 SQLExpr::CompoundFieldAccess { root, access_chain } => {
746 let mut expr = self.parse_sql_expr(root)?;
747
748 for access in access_chain {
749 let field_access = match access {
750 AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
752 | AccessExpr::Subscript(Subscript::Index {
753 index:
754 SQLExpr::Value(ValueWithSpan {
755 value:
756 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
757 ..
758 }),
759 }) => GetFieldAccess::NamedStructField {
760 name: ScalarValue::from(s.as_str()),
761 },
762 AccessExpr::Subscript(Subscript::Index { index }) => {
763 let key = Box::new(self.parse_sql_expr(index)?);
764 GetFieldAccess::ListIndex { key }
765 }
766 AccessExpr::Subscript(Subscript::Slice { .. }) => {
767 return Err(Error::invalid_input(
768 "Slice subscript is not supported",
769 location!(),
770 ));
771 }
772 _ => {
773 return Err(Error::invalid_input(
776 "Only dot notation or index access is supported for field access",
777 location!(),
778 ));
779 }
780 };
781
782 let field_access_expr = RawFieldAccessExpr { expr, field_access };
783 expr = self.plan_field_access(field_access_expr)?;
784 }
785
786 Ok(expr)
787 }
788 SQLExpr::Between {
789 expr,
790 negated,
791 low,
792 high,
793 } => {
794 let expr = self.parse_sql_expr(expr)?;
796 let low = self.parse_sql_expr(low)?;
797 let high = self.parse_sql_expr(high)?;
798
799 let between = Expr::Between(Between::new(
800 Box::new(expr),
801 *negated,
802 Box::new(low),
803 Box::new(high),
804 ));
805 Ok(between)
806 }
807 _ => Err(Error::invalid_input(
808 format!("Expression '{expr}' is not supported SQL in lance"),
809 location!(),
810 )),
811 }
812 }
813
814 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
819 let ast_expr = parse_sql_filter(filter)?;
821 let expr = self.parse_sql_expr(&ast_expr)?;
822 let schema = Schema::try_from(self.schema.as_ref())?;
823 let resolved = resolve_expr(&expr, &schema).map_err(|e| {
824 Error::invalid_input(
825 format!("Error resolving filter expression {filter}: {e}"),
826 location!(),
827 )
828 })?;
829
830 Ok(coerce_filter_type_to_boolean(resolved))
831 }
832
833 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
838 let ast_expr = parse_sql_expr(expr)?;
839 let expr = self.parse_sql_expr(&ast_expr)?;
840 let schema = Schema::try_from(self.schema.as_ref())?;
841 let resolved = resolve_expr(&expr, &schema)?;
842 Ok(resolved)
843 }
844
845 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
851 let hex_bytes = s.as_bytes();
852 let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
853
854 let start_idx = hex_bytes.len() % 2;
855 if start_idx > 0 {
856 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
858 }
859
860 for i in (start_idx..hex_bytes.len()).step_by(2) {
861 let high = Self::try_decode_hex_char(hex_bytes[i])?;
862 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
863 decoded_bytes.push((high << 4) | low);
864 }
865
866 Some(decoded_bytes)
867 }
868
869 const fn try_decode_hex_char(c: u8) -> Option<u8> {
873 match c {
874 b'A'..=b'F' => Some(c - b'A' + 10),
875 b'a'..=b'f' => Some(c - b'a' + 10),
876 b'0'..=b'9' => Some(c - b'0'),
877 _ => None,
878 }
879 }
880
881 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
883 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
884
885 let props = ExecutionProps::default();
888 let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
889 let simplifier =
890 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
891
892 let expr = simplifier.simplify(expr)?;
893 let expr = simplifier.coerce(expr, &df_schema)?;
894
895 Ok(expr)
896 }
897
898 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
900 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
901 Ok(datafusion::physical_expr::create_physical_expr(
902 expr,
903 df_schema.as_ref(),
904 &Default::default(),
905 )?)
906 }
907
908 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
915 let mut visitor = ColumnCapturingVisitor {
916 current_path: VecDeque::new(),
917 columns: BTreeSet::new(),
918 };
919 expr.visit(&mut visitor).unwrap();
920 visitor.columns.into_iter().collect()
921 }
922}
923
924struct ColumnCapturingVisitor {
925 current_path: VecDeque<String>,
927 columns: BTreeSet<String>,
928}
929
930impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
931 type Node = Expr;
932
933 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
934 match node {
935 Expr::Column(Column { name, .. }) => {
936 let mut path = name.clone();
937 for part in self.current_path.drain(..) {
938 path.push('.');
939 path.push_str(&part);
940 }
941 self.columns.insert(path);
942 self.current_path.clear();
943 }
944 Expr::ScalarFunction(udf) => {
945 if udf.name() == GetFieldFunc::default().name() {
946 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
947 self.current_path.push_front(name.to_string())
948 } else {
949 self.current_path.clear();
950 }
951 } else {
952 self.current_path.clear();
953 }
954 }
955 _ => {
956 self.current_path.clear();
957 }
958 }
959
960 Ok(TreeNodeRecursion::Continue)
961 }
962}
963
964#[cfg(test)]
965mod tests {
966
967 use crate::logical_expr::ExprExt;
968
969 use super::*;
970
971 use arrow::datatypes::Float64Type;
972 use arrow_array::{
973 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
974 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
975 TimestampNanosecondArray, TimestampSecondArray,
976 };
977 use arrow_schema::{DataType, Fields, Schema};
978 use datafusion::{
979 logical_expr::{lit, Cast},
980 prelude::{array_element, get_field},
981 };
982 use datafusion_functions::core::expr_ext::FieldAccessor;
983
984 #[test]
985 fn test_parse_filter_simple() {
986 let schema = Arc::new(Schema::new(vec![
987 Field::new("i", DataType::Int32, false),
988 Field::new("s", DataType::Utf8, true),
989 Field::new(
990 "st",
991 DataType::Struct(Fields::from(vec![
992 Field::new("x", DataType::Float32, false),
993 Field::new("y", DataType::Float32, false),
994 ])),
995 true,
996 ),
997 ]));
998
999 let planner = Planner::new(schema.clone());
1000
1001 let expected = col("i")
1002 .gt(lit(3_i32))
1003 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1004 .and(
1005 col("s")
1006 .eq(lit("str-4"))
1007 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1008 );
1009
1010 let expr = planner
1012 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1013 .unwrap();
1014 assert_eq!(expr, expected);
1015
1016 let expr = planner
1018 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1019 .unwrap();
1020
1021 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1022
1023 let batch = RecordBatch::try_new(
1024 schema,
1025 vec![
1026 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1027 Arc::new(StringArray::from_iter_values(
1028 (0..10).map(|v| format!("str-{}", v)),
1029 )),
1030 Arc::new(StructArray::from(vec![
1031 (
1032 Arc::new(Field::new("x", DataType::Float32, false)),
1033 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1034 as ArrayRef,
1035 ),
1036 (
1037 Arc::new(Field::new("y", DataType::Float32, false)),
1038 Arc::new(Float32Array::from_iter_values(
1039 (0..10).map(|v| (v * 10) as f32),
1040 )),
1041 ),
1042 ])),
1043 ],
1044 )
1045 .unwrap();
1046 let predicates = physical_expr.evaluate(&batch).unwrap();
1047 assert_eq!(
1048 predicates.into_array(0).unwrap().as_ref(),
1049 &BooleanArray::from(vec![
1050 false, false, false, false, true, true, false, false, false, false
1051 ])
1052 );
1053 }
1054
1055 #[test]
1056 fn test_nested_col_refs() {
1057 let schema = Arc::new(Schema::new(vec![
1058 Field::new("s0", DataType::Utf8, true),
1059 Field::new(
1060 "st",
1061 DataType::Struct(Fields::from(vec![
1062 Field::new("s1", DataType::Utf8, true),
1063 Field::new(
1064 "st",
1065 DataType::Struct(Fields::from(vec![Field::new(
1066 "s2",
1067 DataType::Utf8,
1068 true,
1069 )])),
1070 true,
1071 ),
1072 ])),
1073 true,
1074 ),
1075 ]));
1076
1077 let planner = Planner::new(schema);
1078
1079 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1080 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1081 assert!(matches!(
1082 expr,
1083 Expr::BinaryExpr(BinaryExpr {
1084 left: _,
1085 op: Operator::Eq,
1086 right: _
1087 })
1088 ));
1089 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1090 assert_eq!(left.as_ref(), expected);
1091 }
1092 }
1093
1094 let expected = Expr::Column(Column::new_unqualified("s0"));
1095 assert_column_eq(&planner, "s0", &expected);
1096 assert_column_eq(&planner, "`s0`", &expected);
1097
1098 let expected = Expr::ScalarFunction(ScalarFunction {
1099 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1100 args: vec![
1101 Expr::Column(Column::new_unqualified("st")),
1102 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1103 ],
1104 });
1105 assert_column_eq(&planner, "st.s1", &expected);
1106 assert_column_eq(&planner, "`st`.`s1`", &expected);
1107 assert_column_eq(&planner, "st.`s1`", &expected);
1108
1109 let expected = Expr::ScalarFunction(ScalarFunction {
1110 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1111 args: vec![
1112 Expr::ScalarFunction(ScalarFunction {
1113 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1114 args: vec![
1115 Expr::Column(Column::new_unqualified("st")),
1116 Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1117 ],
1118 }),
1119 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1120 ],
1121 });
1122
1123 assert_column_eq(&planner, "st.st.s2", &expected);
1124 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1125 assert_column_eq(&planner, "st.st.`s2`", &expected);
1126 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1127 }
1128
1129 #[test]
1130 fn test_nested_list_refs() {
1131 let schema = Arc::new(Schema::new(vec![Field::new(
1132 "l",
1133 DataType::List(Arc::new(Field::new(
1134 "item",
1135 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1136 true,
1137 ))),
1138 true,
1139 )]));
1140
1141 let planner = Planner::new(schema);
1142
1143 let expected = array_element(col("l"), lit(0_i64));
1144 let expr = planner.parse_expr("l[0]").unwrap();
1145 assert_eq!(expr, expected);
1146
1147 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1148 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1149 assert_eq!(expr, expected);
1150
1151 }
1156
1157 #[test]
1158 fn test_negative_expressions() {
1159 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1160
1161 let planner = Planner::new(schema.clone());
1162
1163 let expected = col("x")
1164 .gt(lit(-3_i64))
1165 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1166
1167 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1168
1169 assert_eq!(expr, expected);
1170
1171 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1172
1173 let batch = RecordBatch::try_new(
1174 schema,
1175 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1176 )
1177 .unwrap();
1178 let predicates = physical_expr.evaluate(&batch).unwrap();
1179 assert_eq!(
1180 predicates.into_array(0).unwrap().as_ref(),
1181 &BooleanArray::from(vec![
1182 false, false, false, true, true, true, true, false, false, false
1183 ])
1184 );
1185 }
1186
1187 #[test]
1188 fn test_negative_array_expressions() {
1189 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1190
1191 let planner = Planner::new(schema);
1192
1193 let expected = Expr::Literal(
1194 ScalarValue::List(Arc::new(
1195 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1196 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1197 )]),
1198 )),
1199 None,
1200 );
1201
1202 let expr = planner
1203 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1204 .unwrap();
1205
1206 assert_eq!(expr, expected);
1207 }
1208
1209 #[test]
1210 fn test_sql_like() {
1211 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1212
1213 let planner = Planner::new(schema.clone());
1214
1215 let expected = col("s").like(lit("str-4"));
1216 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1218 assert_eq!(expr, expected);
1219 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1220
1221 let batch = RecordBatch::try_new(
1222 schema,
1223 vec![Arc::new(StringArray::from_iter_values(
1224 (0..10).map(|v| format!("str-{}", v)),
1225 ))],
1226 )
1227 .unwrap();
1228 let predicates = physical_expr.evaluate(&batch).unwrap();
1229 assert_eq!(
1230 predicates.into_array(0).unwrap().as_ref(),
1231 &BooleanArray::from(vec![
1232 false, false, false, false, true, false, false, false, false, false
1233 ])
1234 );
1235 }
1236
1237 #[test]
1238 fn test_not_like() {
1239 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1240
1241 let planner = Planner::new(schema.clone());
1242
1243 let expected = col("s").not_like(lit("str-4"));
1244 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1246 assert_eq!(expr, expected);
1247 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1248
1249 let batch = RecordBatch::try_new(
1250 schema,
1251 vec![Arc::new(StringArray::from_iter_values(
1252 (0..10).map(|v| format!("str-{}", v)),
1253 ))],
1254 )
1255 .unwrap();
1256 let predicates = physical_expr.evaluate(&batch).unwrap();
1257 assert_eq!(
1258 predicates.into_array(0).unwrap().as_ref(),
1259 &BooleanArray::from(vec![
1260 true, true, true, true, false, true, true, true, true, true
1261 ])
1262 );
1263 }
1264
1265 #[test]
1266 fn test_sql_is_in() {
1267 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1268
1269 let planner = Planner::new(schema.clone());
1270
1271 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1272 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1274 assert_eq!(expr, expected);
1275 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1276
1277 let batch = RecordBatch::try_new(
1278 schema,
1279 vec![Arc::new(StringArray::from_iter_values(
1280 (0..10).map(|v| format!("str-{}", v)),
1281 ))],
1282 )
1283 .unwrap();
1284 let predicates = physical_expr.evaluate(&batch).unwrap();
1285 assert_eq!(
1286 predicates.into_array(0).unwrap().as_ref(),
1287 &BooleanArray::from(vec![
1288 false, false, false, false, true, true, false, false, false, false
1289 ])
1290 );
1291 }
1292
1293 #[test]
1294 fn test_sql_is_null() {
1295 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1296
1297 let planner = Planner::new(schema.clone());
1298
1299 let expected = col("s").is_null();
1300 let expr = planner.parse_filter("s IS NULL").unwrap();
1301 assert_eq!(expr, expected);
1302 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1303
1304 let batch = RecordBatch::try_new(
1305 schema,
1306 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1307 if v % 3 == 0 {
1308 Some(format!("str-{}", v))
1309 } else {
1310 None
1311 }
1312 })))],
1313 )
1314 .unwrap();
1315 let predicates = physical_expr.evaluate(&batch).unwrap();
1316 assert_eq!(
1317 predicates.into_array(0).unwrap().as_ref(),
1318 &BooleanArray::from(vec![
1319 false, true, true, false, true, true, false, true, true, false
1320 ])
1321 );
1322
1323 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1324 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1325 let predicates = physical_expr.evaluate(&batch).unwrap();
1326 assert_eq!(
1327 predicates.into_array(0).unwrap().as_ref(),
1328 &BooleanArray::from(vec![
1329 true, false, false, true, false, false, true, false, false, true,
1330 ])
1331 );
1332 }
1333
1334 #[test]
1335 fn test_sql_invert() {
1336 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1337
1338 let planner = Planner::new(schema.clone());
1339
1340 let expr = planner.parse_filter("NOT s").unwrap();
1341 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1342
1343 let batch = RecordBatch::try_new(
1344 schema,
1345 vec![Arc::new(BooleanArray::from_iter(
1346 (0..10).map(|v| Some(v % 3 == 0)),
1347 ))],
1348 )
1349 .unwrap();
1350 let predicates = physical_expr.evaluate(&batch).unwrap();
1351 assert_eq!(
1352 predicates.into_array(0).unwrap().as_ref(),
1353 &BooleanArray::from(vec![
1354 false, true, true, false, true, true, false, true, true, false
1355 ])
1356 );
1357 }
1358
1359 #[test]
1360 fn test_sql_cast() {
1361 let cases = &[
1362 (
1363 "x = cast('2021-01-01 00:00:00' as timestamp)",
1364 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1365 ),
1366 (
1367 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1368 ArrowDataType::Timestamp(TimeUnit::Second, None),
1369 ),
1370 (
1371 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1372 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1373 ),
1374 (
1375 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1376 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1377 ),
1378 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1379 (
1380 "x = cast('1.238' as decimal(9,3))",
1381 ArrowDataType::Decimal128(9, 3),
1382 ),
1383 ("x = cast(1 as float)", ArrowDataType::Float32),
1384 ("x = cast(1 as double)", ArrowDataType::Float64),
1385 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1386 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1387 ("x = cast(1 as int)", ArrowDataType::Int32),
1388 ("x = cast(1 as integer)", ArrowDataType::Int32),
1389 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1390 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1391 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1392 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1393 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1394 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1395 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1396 ("x = cast(1 as string)", ArrowDataType::Utf8),
1397 ];
1398
1399 for (sql, expected_data_type) in cases {
1400 let schema = Arc::new(Schema::new(vec![Field::new(
1401 "x",
1402 expected_data_type.clone(),
1403 true,
1404 )]));
1405 let planner = Planner::new(schema.clone());
1406 let expr = planner.parse_filter(sql).unwrap();
1407
1408 let expected_value_str = sql
1410 .split("cast(")
1411 .nth(1)
1412 .unwrap()
1413 .split(" as")
1414 .next()
1415 .unwrap();
1416 let expected_value_str = expected_value_str.trim_matches('\'');
1418
1419 match expr {
1420 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1421 Expr::Cast(Cast { expr, data_type }) => {
1422 match expr.as_ref() {
1423 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1424 assert_eq!(value_str, expected_value_str);
1425 }
1426 Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1427 assert_eq!(*value, 1);
1428 }
1429 _ => panic!("Expected cast to be applied to literal"),
1430 }
1431 assert_eq!(data_type, expected_data_type);
1432 }
1433 _ => panic!("Expected right to be a cast"),
1434 },
1435 _ => panic!("Expected binary expression"),
1436 }
1437 }
1438 }
1439
1440 #[test]
1441 fn test_sql_literals() {
1442 let cases = &[
1443 (
1444 "x = timestamp '2021-01-01 00:00:00'",
1445 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1446 ),
1447 (
1448 "x = timestamp(0) '2021-01-01 00:00:00'",
1449 ArrowDataType::Timestamp(TimeUnit::Second, None),
1450 ),
1451 (
1452 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1453 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1454 ),
1455 ("x = date '2021-01-01'", ArrowDataType::Date32),
1456 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1457 ];
1458
1459 for (sql, expected_data_type) in cases {
1460 let schema = Arc::new(Schema::new(vec![Field::new(
1461 "x",
1462 expected_data_type.clone(),
1463 true,
1464 )]));
1465 let planner = Planner::new(schema.clone());
1466 let expr = planner.parse_filter(sql).unwrap();
1467
1468 let expected_value_str = sql.split('\'').nth(1).unwrap();
1469
1470 match expr {
1471 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1472 Expr::Cast(Cast { expr, data_type }) => {
1473 match expr.as_ref() {
1474 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1475 assert_eq!(value_str, expected_value_str);
1476 }
1477 _ => panic!("Expected cast to be applied to literal"),
1478 }
1479 assert_eq!(data_type, expected_data_type);
1480 }
1481 _ => panic!("Expected right to be a cast"),
1482 },
1483 _ => panic!("Expected binary expression"),
1484 }
1485 }
1486 }
1487
1488 #[test]
1489 fn test_sql_array_literals() {
1490 let cases = [
1491 (
1492 "x = [1, 2, 3]",
1493 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1494 ),
1495 (
1496 "x = [1, 2, 3]",
1497 ArrowDataType::FixedSizeList(
1498 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1499 3,
1500 ),
1501 ),
1502 ];
1503
1504 for (sql, expected_data_type) in cases {
1505 let schema = Arc::new(Schema::new(vec![Field::new(
1506 "x",
1507 expected_data_type.clone(),
1508 true,
1509 )]));
1510 let planner = Planner::new(schema.clone());
1511 let expr = planner.parse_filter(sql).unwrap();
1512 let expr = planner.optimize_expr(expr).unwrap();
1513
1514 match expr {
1515 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1516 Expr::Literal(value, _) => {
1517 assert_eq!(&value.data_type(), &expected_data_type);
1518 }
1519 _ => panic!("Expected right to be a literal"),
1520 },
1521 _ => panic!("Expected binary expression"),
1522 }
1523 }
1524 }
1525
1526 #[test]
1527 fn test_sql_between() {
1528 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1529 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1530 use std::sync::Arc;
1531
1532 let schema = Arc::new(Schema::new(vec![
1533 Field::new("x", DataType::Int32, false),
1534 Field::new("y", DataType::Float64, false),
1535 Field::new(
1536 "ts",
1537 DataType::Timestamp(TimeUnit::Microsecond, None),
1538 false,
1539 ),
1540 ]));
1541
1542 let planner = Planner::new(schema.clone());
1543
1544 let expr = planner
1546 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1547 .unwrap();
1548 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1549
1550 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1554 (0..10).map(|i| base_ts + i * 1_000_000), );
1556
1557 let batch = RecordBatch::try_new(
1558 schema,
1559 vec![
1560 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1561 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1562 Arc::new(ts_array),
1563 ],
1564 )
1565 .unwrap();
1566
1567 let predicates = physical_expr.evaluate(&batch).unwrap();
1568 assert_eq!(
1569 predicates.into_array(0).unwrap().as_ref(),
1570 &BooleanArray::from(vec![
1571 false, false, false, true, true, true, true, true, false, false
1572 ])
1573 );
1574
1575 let expr = planner
1577 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1578 .unwrap();
1579 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1580
1581 let predicates = physical_expr.evaluate(&batch).unwrap();
1582 assert_eq!(
1583 predicates.into_array(0).unwrap().as_ref(),
1584 &BooleanArray::from(vec![
1585 true, true, true, false, false, false, false, false, true, true
1586 ])
1587 );
1588
1589 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1591 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1592
1593 let predicates = physical_expr.evaluate(&batch).unwrap();
1594 assert_eq!(
1595 predicates.into_array(0).unwrap().as_ref(),
1596 &BooleanArray::from(vec![
1597 false, false, false, true, true, true, true, false, false, false
1598 ])
1599 );
1600
1601 let expr = planner
1603 .parse_filter(
1604 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1605 )
1606 .unwrap();
1607 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1608
1609 let predicates = physical_expr.evaluate(&batch).unwrap();
1610 assert_eq!(
1611 predicates.into_array(0).unwrap().as_ref(),
1612 &BooleanArray::from(vec![
1613 false, false, false, true, true, true, true, true, false, false
1614 ])
1615 );
1616 }
1617
1618 #[test]
1619 fn test_sql_comparison() {
1620 let batch: Vec<(&str, ArrayRef)> = vec![
1622 (
1623 "timestamp_s",
1624 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1625 ),
1626 (
1627 "timestamp_ms",
1628 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1629 ),
1630 (
1631 "timestamp_us",
1632 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1633 ),
1634 (
1635 "timestamp_ns",
1636 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1637 ),
1638 ];
1639 let batch = RecordBatch::try_from_iter(batch).unwrap();
1640
1641 let planner = Planner::new(batch.schema());
1642
1643 let expressions = &[
1645 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1646 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1647 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1648 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1649 ];
1650
1651 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1652 std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1653 ));
1654 for expression in expressions {
1655 let logical_expr = planner.parse_filter(expression).unwrap();
1657 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1658 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1659
1660 let result = physical_expr.evaluate(&batch).unwrap();
1662 let result = result.into_array(batch.num_rows()).unwrap();
1663 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1664 }
1665 }
1666
1667 #[test]
1668 fn test_columns_in_expr() {
1669 let expr = col("s0").gt(lit("value")).and(
1670 col("st")
1671 .field("st")
1672 .field("s2")
1673 .eq(lit("value"))
1674 .or(col("st")
1675 .field("s1")
1676 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1677 );
1678
1679 let columns = Planner::column_names_in_expr(&expr);
1680 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1681 }
1682
1683 #[test]
1684 fn test_parse_binary_expr() {
1685 let bin_str = "x'616263'";
1686
1687 let schema = Arc::new(Schema::new(vec![Field::new(
1688 "binary",
1689 DataType::Binary,
1690 true,
1691 )]));
1692 let planner = Planner::new(schema);
1693 let expr = planner.parse_expr(bin_str).unwrap();
1694 assert_eq!(
1695 expr,
1696 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1697 );
1698 }
1699
1700 #[test]
1701 fn test_lance_context_provider_expr_planners() {
1702 let ctx_provider = LanceContextProvider::default();
1703 assert!(!ctx_provider.get_expr_planners().is_empty());
1704 }
1705}