1use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::{SessionContext, SessionState};
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30 Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
34use datafusion::sql::sqlparser::ast::{
35 AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36 Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
37 ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
38};
39use datafusion::{
40 common::Column,
41 logical_expr::{col, Between, BinaryExpr, Like, Operator},
42 physical_expr::execution_props::ExecutionProps,
43 physical_plan::PhysicalExpr,
44 prelude::Expr,
45 scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use lance_core::error::LanceOptionExt;
51use snafu::location;
52
53use lance_core::{Error, Result};
54
55#[derive(Debug, Clone)]
56struct CastListF16Udf {
57 signature: Signature,
58}
59
60impl CastListF16Udf {
61 pub fn new() -> Self {
62 Self {
63 signature: Signature::any(1, Volatility::Immutable),
64 }
65 }
66}
67
68impl ScalarUDFImpl for CastListF16Udf {
69 fn as_any(&self) -> &dyn std::any::Any {
70 self
71 }
72
73 fn name(&self) -> &str {
74 "_cast_list_f16"
75 }
76
77 fn signature(&self) -> &Signature {
78 &self.signature
79 }
80
81 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
82 let input = &arg_types[0];
83 match input {
84 ArrowDataType::FixedSizeList(field, size) => {
85 if field.data_type() != &ArrowDataType::Float32
86 && field.data_type() != &ArrowDataType::Float16
87 {
88 return Err(datafusion::error::DataFusionError::Execution(
89 "cast_list_f16 only supports list of float32 or float16".to_string(),
90 ));
91 }
92 Ok(ArrowDataType::FixedSizeList(
93 Arc::new(Field::new(
94 field.name(),
95 ArrowDataType::Float16,
96 field.is_nullable(),
97 )),
98 *size,
99 ))
100 }
101 ArrowDataType::List(field) => {
102 if field.data_type() != &ArrowDataType::Float32
103 && field.data_type() != &ArrowDataType::Float16
104 {
105 return Err(datafusion::error::DataFusionError::Execution(
106 "cast_list_f16 only supports list of float32 or float16".to_string(),
107 ));
108 }
109 Ok(ArrowDataType::List(Arc::new(Field::new(
110 field.name(),
111 ArrowDataType::Float16,
112 field.is_nullable(),
113 ))))
114 }
115 _ => Err(datafusion::error::DataFusionError::Execution(
116 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
117 )),
118 }
119 }
120
121 fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
122 let ColumnarValue::Array(arr) = &func_args.args[0] else {
123 return Err(datafusion::error::DataFusionError::Execution(
124 "cast_list_f16 only supports array arguments".to_string(),
125 ));
126 };
127
128 let to_type = match arr.data_type() {
129 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
130 Arc::new(Field::new(
131 field.name(),
132 ArrowDataType::Float16,
133 field.is_nullable(),
134 )),
135 *size,
136 ),
137 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
138 field.name(),
139 ArrowDataType::Float16,
140 field.is_nullable(),
141 ))),
142 _ => {
143 return Err(datafusion::error::DataFusionError::Execution(
144 "cast_list_f16 only supports array arguments".to_string(),
145 ));
146 }
147 };
148
149 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
150 Ok(ColumnarValue::Array(res))
151 }
152}
153
154struct LanceContextProvider {
156 options: datafusion::config::ConfigOptions,
157 state: SessionState,
158 expr_planners: Vec<Arc<dyn ExprPlanner>>,
159}
160
161impl Default for LanceContextProvider {
162 fn default() -> Self {
163 let config = SessionConfig::new();
164 let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
165
166 let ctx = SessionContext::new_with_config_rt(config.clone(), runtime.clone());
167 crate::udf::register_functions(&ctx);
168
169 let state = ctx.state();
170
171 let mut state_builder = SessionStateBuilder::new()
173 .with_config(config)
174 .with_runtime_env(runtime)
175 .with_default_features();
176
177 let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
179
180 Self {
181 options: ConfigOptions::default(),
182 state,
183 expr_planners,
184 }
185 }
186}
187
188impl ContextProvider for LanceContextProvider {
189 fn get_table_source(
190 &self,
191 name: datafusion::sql::TableReference,
192 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
193 Err(datafusion::error::DataFusionError::NotImplemented(format!(
194 "Attempt to reference inner table {} not supported",
195 name
196 )))
197 }
198
199 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
200 self.state.aggregate_functions().get(name).cloned()
201 }
202
203 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
204 self.state.window_functions().get(name).cloned()
205 }
206
207 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
208 match f {
209 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
212 _ => self.state.scalar_functions().get(f).cloned(),
213 }
214 }
215
216 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
217 None
219 }
220
221 fn options(&self) -> &datafusion::config::ConfigOptions {
222 &self.options
223 }
224
225 fn udf_names(&self) -> Vec<String> {
226 self.state.scalar_functions().keys().cloned().collect()
227 }
228
229 fn udaf_names(&self) -> Vec<String> {
230 self.state.aggregate_functions().keys().cloned().collect()
231 }
232
233 fn udwf_names(&self) -> Vec<String> {
234 self.state.window_functions().keys().cloned().collect()
235 }
236
237 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
238 &self.expr_planners
239 }
240}
241
242pub struct Planner {
243 schema: SchemaRef,
244 context_provider: LanceContextProvider,
245 enable_relations: bool,
246}
247
248impl Planner {
249 pub fn new(schema: SchemaRef) -> Self {
250 Self {
251 schema,
252 context_provider: LanceContextProvider::default(),
253 enable_relations: false,
254 }
255 }
256
257 pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
263 self.enable_relations = enable_relations;
264 self
265 }
266
267 fn column(&self, idents: &[Ident]) -> Expr {
268 fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
269 for ident in idents {
270 *expr = Expr::ScalarFunction(ScalarFunction {
271 args: vec![
272 std::mem::take(expr),
273 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
274 ],
275 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
276 });
277 }
278 }
279
280 if self.enable_relations && idents.len() > 1 {
281 let relation = &idents[0].value;
283 let column_name = &idents[1].value;
284 let column = Expr::Column(Column::new(Some(relation.clone()), column_name.clone()));
285 let mut result = column;
286 handle_remaining_idents(&mut result, &idents[2..]);
287 result
288 } else {
289 let mut column = col(&idents[0].value);
291 handle_remaining_idents(&mut column, &idents[1..]);
292 column
293 }
294 }
295
296 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
297 Ok(match op {
298 BinaryOperator::Plus => Operator::Plus,
299 BinaryOperator::Minus => Operator::Minus,
300 BinaryOperator::Multiply => Operator::Multiply,
301 BinaryOperator::Divide => Operator::Divide,
302 BinaryOperator::Modulo => Operator::Modulo,
303 BinaryOperator::StringConcat => Operator::StringConcat,
304 BinaryOperator::Gt => Operator::Gt,
305 BinaryOperator::Lt => Operator::Lt,
306 BinaryOperator::GtEq => Operator::GtEq,
307 BinaryOperator::LtEq => Operator::LtEq,
308 BinaryOperator::Eq => Operator::Eq,
309 BinaryOperator::NotEq => Operator::NotEq,
310 BinaryOperator::And => Operator::And,
311 BinaryOperator::Or => Operator::Or,
312 _ => {
313 return Err(Error::invalid_input(
314 format!("Operator {op} is not supported"),
315 location!(),
316 ));
317 }
318 })
319 }
320
321 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
322 Ok(Expr::BinaryExpr(BinaryExpr::new(
323 Box::new(self.parse_sql_expr(left)?),
324 self.binary_op(op)?,
325 Box::new(self.parse_sql_expr(right)?),
326 )))
327 }
328
329 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
330 Ok(match op {
331 UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
332 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
333 }
334
335 UnaryOperator::Minus => {
336 use datafusion::logical_expr::lit;
337 match expr {
338 SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
339 Ok(n) => lit(-n),
340 Err(_) => lit(-n
341 .parse::<f64>()
342 .map_err(|_e| {
343 Error::invalid_input(
344 format!("negative operator can be only applied to integer and float operands, got: {n}"),
345 location!(),
346 )
347 })?),
348 },
349 _ => {
350 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
351 }
352 }
353 }
354
355 _ => {
356 return Err(Error::invalid_input(
357 format!("Unary operator '{:?}' is not supported", op),
358 location!(),
359 ));
360 }
361 })
362 }
363
364 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
366 use datafusion::logical_expr::lit;
367 let value: Cow<str> = if negative {
368 Cow::Owned(format!("-{}", value))
369 } else {
370 Cow::Borrowed(value)
371 };
372 if let Ok(n) = value.parse::<i64>() {
373 Ok(lit(n))
374 } else {
375 value.parse::<f64>().map(lit).map_err(|_| {
376 Error::invalid_input(
377 format!("'{value}' is not supported number value."),
378 location!(),
379 )
380 })
381 }
382 }
383
384 fn value(&self, value: &Value) -> Result<Expr> {
385 Ok(match value {
386 Value::Number(v, _) => self.number(v.as_str(), false)?,
387 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
388 Value::HexStringLiteral(hsl) => {
389 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
390 }
391 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
392 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
393 Value::Null => Expr::Literal(ScalarValue::Null, None),
394 _ => todo!(),
395 })
396 }
397
398 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
399 match func_args {
400 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
401 _ => Err(Error::invalid_input(
402 format!("Unsupported function args: {:?}", func_args),
403 location!(),
404 )),
405 }
406 }
407
408 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
415 match &func.args {
416 FunctionArguments::List(args) => {
417 if func.name.0.len() != 1 {
418 return Err(Error::invalid_input(
419 format!("Function name must have 1 part, got: {:?}", func.name.0),
420 location!(),
421 ));
422 }
423 Ok(Expr::IsNotNull(Box::new(
424 self.parse_function_args(&args.args[0])?,
425 )))
426 }
427 _ => Err(Error::invalid_input(
428 format!("Unsupported function args: {:?}", &func.args),
429 location!(),
430 )),
431 }
432 }
433
434 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
435 if let SQLExpr::Function(function) = &function {
436 if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
437 if &name.value == "is_valid" {
438 return self.legacy_parse_function(function);
439 }
440 }
441 }
442 let sql_to_rel = SqlToRel::new_with_options(
443 &self.context_provider,
444 ParserOptions {
445 parse_float_as_decimal: false,
446 enable_ident_normalization: false,
447 support_varchar_with_length: false,
448 enable_options_value_normalization: false,
449 collect_spans: false,
450 map_varchar_to_utf8view: false,
451 },
452 );
453
454 let mut planner_context = PlannerContext::default();
455 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
456 sql_to_rel
457 .sql_to_expr(function, &schema, &mut planner_context)
458 .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
459 }
460
461 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
462 const SUPPORTED_TYPES: [&str; 13] = [
463 "int [unsigned]",
464 "tinyint [unsigned]",
465 "smallint [unsigned]",
466 "bigint [unsigned]",
467 "float",
468 "double",
469 "string",
470 "binary",
471 "date",
472 "timestamp(precision)",
473 "datetime(precision)",
474 "decimal(precision,scale)",
475 "boolean",
476 ];
477 match data_type {
478 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
479 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
480 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
481 SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
482 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
483 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
484 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
485 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
486 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
487 SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
488 SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
489 SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
490 Ok(ArrowDataType::UInt32)
491 }
492 SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
493 SQLDataType::Date => Ok(ArrowDataType::Date32),
494 SQLDataType::Timestamp(resolution, tz) => {
495 match tz {
496 TimezoneInfo::None => {}
497 _ => {
498 return Err(Error::invalid_input(
499 "Timezone not supported in timestamp".to_string(),
500 location!(),
501 ));
502 }
503 };
504 let time_unit = match resolution {
505 None => TimeUnit::Microsecond,
507 Some(0) => TimeUnit::Second,
508 Some(3) => TimeUnit::Millisecond,
509 Some(6) => TimeUnit::Microsecond,
510 Some(9) => TimeUnit::Nanosecond,
511 _ => {
512 return Err(Error::invalid_input(
513 format!("Unsupported datetime resolution: {:?}", resolution),
514 location!(),
515 ));
516 }
517 };
518 Ok(ArrowDataType::Timestamp(time_unit, None))
519 }
520 SQLDataType::Datetime(resolution) => {
521 let time_unit = match resolution {
522 None => TimeUnit::Microsecond,
523 Some(0) => TimeUnit::Second,
524 Some(3) => TimeUnit::Millisecond,
525 Some(6) => TimeUnit::Microsecond,
526 Some(9) => TimeUnit::Nanosecond,
527 _ => {
528 return Err(Error::invalid_input(
529 format!("Unsupported datetime resolution: {:?}", resolution),
530 location!(),
531 ));
532 }
533 };
534 Ok(ArrowDataType::Timestamp(time_unit, None))
535 }
536 SQLDataType::Decimal(number_info) => match number_info {
537 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
538 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
539 }
540 _ => Err(Error::invalid_input(
541 format!(
542 "Must provide precision and scale for decimal: {:?}",
543 number_info
544 ),
545 location!(),
546 )),
547 },
548 _ => Err(Error::invalid_input(
549 format!(
550 "Unsupported data type: {:?}. Supported types: {:?}",
551 data_type, SUPPORTED_TYPES
552 ),
553 location!(),
554 )),
555 }
556 }
557
558 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
559 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
560 for planner in self.context_provider.get_expr_planners() {
561 match planner.plan_field_access(field_access_expr, &df_schema)? {
562 PlannerResult::Planned(expr) => return Ok(expr),
563 PlannerResult::Original(expr) => {
564 field_access_expr = expr;
565 }
566 }
567 }
568 Err(Error::invalid_input(
569 "Field access could not be planned",
570 location!(),
571 ))
572 }
573
574 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
575 match expr {
576 SQLExpr::Identifier(id) => {
577 if id.quote_style == Some('"') {
580 Ok(Expr::Literal(
581 ScalarValue::Utf8(Some(id.value.clone())),
582 None,
583 ))
584 } else if id.quote_style == Some('`') {
587 Ok(Expr::Column(Column::from_name(id.value.clone())))
588 } else {
589 Ok(self.column(vec![id.clone()].as_slice()))
590 }
591 }
592 SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
593 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
594 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
595 SQLExpr::Value(value) => self.value(&value.value),
596 SQLExpr::Array(SQLArray { elem, .. }) => {
597 let mut values = vec![];
598
599 let array_literal_error = |pos: usize, value: &_| {
600 Err(Error::invalid_input(
601 format!(
602 "Expected a literal value in array, instead got {} at position {}",
603 value, pos
604 ),
605 location!(),
606 ))
607 };
608
609 for (pos, expr) in elem.iter().enumerate() {
610 match expr {
611 SQLExpr::Value(value) => {
612 if let Expr::Literal(value, _) = self.value(&value.value)? {
613 values.push(value);
614 } else {
615 return array_literal_error(pos, expr);
616 }
617 }
618 SQLExpr::UnaryOp {
619 op: UnaryOperator::Minus,
620 expr,
621 } => {
622 if let SQLExpr::Value(ValueWithSpan {
623 value: Value::Number(number, _),
624 ..
625 }) = expr.as_ref()
626 {
627 if let Expr::Literal(value, _) = self.number(number, true)? {
628 values.push(value);
629 } else {
630 return array_literal_error(pos, expr);
631 }
632 } else {
633 return array_literal_error(pos, expr);
634 }
635 }
636 _ => {
637 return array_literal_error(pos, expr);
638 }
639 }
640 }
641
642 let field = if !values.is_empty() {
643 let data_type = values[0].data_type();
644
645 for value in &mut values {
646 if value.data_type() != data_type {
647 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
648 format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
649 location!()
650 ))?;
651 }
652 }
653 Field::new("item", data_type, true)
654 } else {
655 Field::new("item", ArrowDataType::Null, true)
656 };
657
658 let values = values
659 .into_iter()
660 .map(|v| v.to_array().map_err(Error::from))
661 .collect::<Result<Vec<_>>>()?;
662 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
663 let values = concat(&array_refs)?;
664 let values = ListArray::try_new(
665 field.into(),
666 OffsetBuffer::from_lengths([values.len()]),
667 values,
668 None,
669 )?;
670
671 Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
672 }
673 SQLExpr::TypedString { data_type, value } => {
675 let value = value.clone().into_string().expect_ok()?;
676 Ok(Expr::Cast(datafusion::logical_expr::Cast {
677 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
678 data_type: self.parse_type(data_type)?,
679 }))
680 }
681 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
682 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
683 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
684 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
685 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
686 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
687 SQLExpr::InList {
688 expr,
689 list,
690 negated,
691 } => {
692 let value_expr = self.parse_sql_expr(expr)?;
693 let list_exprs = list
694 .iter()
695 .map(|e| self.parse_sql_expr(e))
696 .collect::<Result<Vec<_>>>()?;
697 Ok(value_expr.in_list(list_exprs, *negated))
698 }
699 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
700 SQLExpr::Function(_) => self.parse_function(expr.clone()),
701 SQLExpr::ILike {
702 negated,
703 expr,
704 pattern,
705 escape_char,
706 any: _,
707 } => Ok(Expr::Like(Like::new(
708 *negated,
709 Box::new(self.parse_sql_expr(expr)?),
710 Box::new(self.parse_sql_expr(pattern)?),
711 escape_char.as_ref().and_then(|c| c.chars().next()),
712 true,
713 ))),
714 SQLExpr::Like {
715 negated,
716 expr,
717 pattern,
718 escape_char,
719 any: _,
720 } => Ok(Expr::Like(Like::new(
721 *negated,
722 Box::new(self.parse_sql_expr(expr)?),
723 Box::new(self.parse_sql_expr(pattern)?),
724 escape_char.as_ref().and_then(|c| c.chars().next()),
725 false,
726 ))),
727 SQLExpr::Cast {
728 expr,
729 data_type,
730 kind,
731 ..
732 } => match kind {
733 datafusion::sql::sqlparser::ast::CastKind::TryCast
734 | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
735 Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
736 expr: Box::new(self.parse_sql_expr(expr)?),
737 data_type: self.parse_type(data_type)?,
738 }))
739 }
740 _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
741 expr: Box::new(self.parse_sql_expr(expr)?),
742 data_type: self.parse_type(data_type)?,
743 })),
744 },
745 SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
746 "JSON access is not supported",
747 location!(),
748 )),
749 SQLExpr::CompoundFieldAccess { root, access_chain } => {
750 let mut expr = self.parse_sql_expr(root)?;
751
752 for access in access_chain {
753 let field_access = match access {
754 AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
756 | AccessExpr::Subscript(Subscript::Index {
757 index:
758 SQLExpr::Value(ValueWithSpan {
759 value:
760 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
761 ..
762 }),
763 }) => GetFieldAccess::NamedStructField {
764 name: ScalarValue::from(s.as_str()),
765 },
766 AccessExpr::Subscript(Subscript::Index { index }) => {
767 let key = Box::new(self.parse_sql_expr(index)?);
768 GetFieldAccess::ListIndex { key }
769 }
770 AccessExpr::Subscript(Subscript::Slice { .. }) => {
771 return Err(Error::invalid_input(
772 "Slice subscript is not supported",
773 location!(),
774 ));
775 }
776 _ => {
777 return Err(Error::invalid_input(
780 "Only dot notation or index access is supported for field access",
781 location!(),
782 ));
783 }
784 };
785
786 let field_access_expr = RawFieldAccessExpr { expr, field_access };
787 expr = self.plan_field_access(field_access_expr)?;
788 }
789
790 Ok(expr)
791 }
792 SQLExpr::Between {
793 expr,
794 negated,
795 low,
796 high,
797 } => {
798 let expr = self.parse_sql_expr(expr)?;
800 let low = self.parse_sql_expr(low)?;
801 let high = self.parse_sql_expr(high)?;
802
803 let between = Expr::Between(Between::new(
804 Box::new(expr),
805 *negated,
806 Box::new(low),
807 Box::new(high),
808 ));
809 Ok(between)
810 }
811 _ => Err(Error::invalid_input(
812 format!("Expression '{expr}' is not supported SQL in lance"),
813 location!(),
814 )),
815 }
816 }
817
818 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
823 let ast_expr = parse_sql_filter(filter)?;
825 let expr = self.parse_sql_expr(&ast_expr)?;
826 let schema = Schema::try_from(self.schema.as_ref())?;
827 let resolved = resolve_expr(&expr, &schema).map_err(|e| {
828 Error::invalid_input(
829 format!("Error resolving filter expression {filter}: {e}"),
830 location!(),
831 )
832 })?;
833
834 Ok(coerce_filter_type_to_boolean(resolved))
835 }
836
837 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
842 let ast_expr = parse_sql_expr(expr)?;
843 let expr = self.parse_sql_expr(&ast_expr)?;
844 let schema = Schema::try_from(self.schema.as_ref())?;
845 let resolved = resolve_expr(&expr, &schema)?;
846 Ok(resolved)
847 }
848
849 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
855 let hex_bytes = s.as_bytes();
856 let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
857
858 let start_idx = hex_bytes.len() % 2;
859 if start_idx > 0 {
860 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
862 }
863
864 for i in (start_idx..hex_bytes.len()).step_by(2) {
865 let high = Self::try_decode_hex_char(hex_bytes[i])?;
866 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
867 decoded_bytes.push((high << 4) | low);
868 }
869
870 Some(decoded_bytes)
871 }
872
873 const fn try_decode_hex_char(c: u8) -> Option<u8> {
877 match c {
878 b'A'..=b'F' => Some(c - b'A' + 10),
879 b'a'..=b'f' => Some(c - b'a' + 10),
880 b'0'..=b'9' => Some(c - b'0'),
881 _ => None,
882 }
883 }
884
885 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
887 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
888
889 let props = ExecutionProps::default();
892 let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
893 let simplifier =
894 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
895
896 let expr = simplifier.simplify(expr)?;
897 let expr = simplifier.coerce(expr, &df_schema)?;
898
899 Ok(expr)
900 }
901
902 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
904 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
905 Ok(datafusion::physical_expr::create_physical_expr(
906 expr,
907 df_schema.as_ref(),
908 &Default::default(),
909 )?)
910 }
911
912 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
919 let mut visitor = ColumnCapturingVisitor {
920 current_path: VecDeque::new(),
921 columns: BTreeSet::new(),
922 };
923 expr.visit(&mut visitor).unwrap();
924 visitor.columns.into_iter().collect()
925 }
926}
927
928struct ColumnCapturingVisitor {
929 current_path: VecDeque<String>,
931 columns: BTreeSet<String>,
932}
933
934impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
935 type Node = Expr;
936
937 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
938 match node {
939 Expr::Column(Column { name, .. }) => {
940 let mut path = name.clone();
941 for part in self.current_path.drain(..) {
942 path.push('.');
943 path.push_str(&part);
944 }
945 self.columns.insert(path);
946 self.current_path.clear();
947 }
948 Expr::ScalarFunction(udf) => {
949 if udf.name() == GetFieldFunc::default().name() {
950 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
951 self.current_path.push_front(name.to_string())
952 } else {
953 self.current_path.clear();
954 }
955 } else {
956 self.current_path.clear();
957 }
958 }
959 _ => {
960 self.current_path.clear();
961 }
962 }
963
964 Ok(TreeNodeRecursion::Continue)
965 }
966}
967
968#[cfg(test)]
969mod tests {
970
971 use crate::logical_expr::ExprExt;
972
973 use super::*;
974
975 use arrow::datatypes::Float64Type;
976 use arrow_array::{
977 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
978 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
979 TimestampNanosecondArray, TimestampSecondArray,
980 };
981 use arrow_schema::{DataType, Fields, Schema};
982 use datafusion::{
983 logical_expr::{lit, Cast},
984 prelude::{array_element, get_field},
985 };
986 use datafusion_functions::core::expr_ext::FieldAccessor;
987
988 #[test]
989 fn test_parse_filter_simple() {
990 let schema = Arc::new(Schema::new(vec![
991 Field::new("i", DataType::Int32, false),
992 Field::new("s", DataType::Utf8, true),
993 Field::new(
994 "st",
995 DataType::Struct(Fields::from(vec![
996 Field::new("x", DataType::Float32, false),
997 Field::new("y", DataType::Float32, false),
998 ])),
999 true,
1000 ),
1001 ]));
1002
1003 let planner = Planner::new(schema.clone());
1004
1005 let expected = col("i")
1006 .gt(lit(3_i32))
1007 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1008 .and(
1009 col("s")
1010 .eq(lit("str-4"))
1011 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1012 );
1013
1014 let expr = planner
1016 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1017 .unwrap();
1018 assert_eq!(expr, expected);
1019
1020 let expr = planner
1022 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1023 .unwrap();
1024
1025 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1026
1027 let batch = RecordBatch::try_new(
1028 schema,
1029 vec![
1030 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1031 Arc::new(StringArray::from_iter_values(
1032 (0..10).map(|v| format!("str-{}", v)),
1033 )),
1034 Arc::new(StructArray::from(vec![
1035 (
1036 Arc::new(Field::new("x", DataType::Float32, false)),
1037 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1038 as ArrayRef,
1039 ),
1040 (
1041 Arc::new(Field::new("y", DataType::Float32, false)),
1042 Arc::new(Float32Array::from_iter_values(
1043 (0..10).map(|v| (v * 10) as f32),
1044 )),
1045 ),
1046 ])),
1047 ],
1048 )
1049 .unwrap();
1050 let predicates = physical_expr.evaluate(&batch).unwrap();
1051 assert_eq!(
1052 predicates.into_array(0).unwrap().as_ref(),
1053 &BooleanArray::from(vec![
1054 false, false, false, false, true, true, false, false, false, false
1055 ])
1056 );
1057 }
1058
1059 #[test]
1060 fn test_nested_col_refs() {
1061 let schema = Arc::new(Schema::new(vec![
1062 Field::new("s0", DataType::Utf8, true),
1063 Field::new(
1064 "st",
1065 DataType::Struct(Fields::from(vec![
1066 Field::new("s1", DataType::Utf8, true),
1067 Field::new(
1068 "st",
1069 DataType::Struct(Fields::from(vec![Field::new(
1070 "s2",
1071 DataType::Utf8,
1072 true,
1073 )])),
1074 true,
1075 ),
1076 ])),
1077 true,
1078 ),
1079 ]));
1080
1081 let planner = Planner::new(schema);
1082
1083 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1084 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1085 assert!(matches!(
1086 expr,
1087 Expr::BinaryExpr(BinaryExpr {
1088 left: _,
1089 op: Operator::Eq,
1090 right: _
1091 })
1092 ));
1093 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1094 assert_eq!(left.as_ref(), expected);
1095 }
1096 }
1097
1098 let expected = Expr::Column(Column::new_unqualified("s0"));
1099 assert_column_eq(&planner, "s0", &expected);
1100 assert_column_eq(&planner, "`s0`", &expected);
1101
1102 let expected = Expr::ScalarFunction(ScalarFunction {
1103 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1104 args: vec![
1105 Expr::Column(Column::new_unqualified("st")),
1106 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1107 ],
1108 });
1109 assert_column_eq(&planner, "st.s1", &expected);
1110 assert_column_eq(&planner, "`st`.`s1`", &expected);
1111 assert_column_eq(&planner, "st.`s1`", &expected);
1112
1113 let expected = Expr::ScalarFunction(ScalarFunction {
1114 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1115 args: vec![
1116 Expr::ScalarFunction(ScalarFunction {
1117 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1118 args: vec![
1119 Expr::Column(Column::new_unqualified("st")),
1120 Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1121 ],
1122 }),
1123 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1124 ],
1125 });
1126
1127 assert_column_eq(&planner, "st.st.s2", &expected);
1128 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1129 assert_column_eq(&planner, "st.st.`s2`", &expected);
1130 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1131 }
1132
1133 #[test]
1134 fn test_nested_list_refs() {
1135 let schema = Arc::new(Schema::new(vec![Field::new(
1136 "l",
1137 DataType::List(Arc::new(Field::new(
1138 "item",
1139 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1140 true,
1141 ))),
1142 true,
1143 )]));
1144
1145 let planner = Planner::new(schema);
1146
1147 let expected = array_element(col("l"), lit(0_i64));
1148 let expr = planner.parse_expr("l[0]").unwrap();
1149 assert_eq!(expr, expected);
1150
1151 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1152 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1153 assert_eq!(expr, expected);
1154
1155 }
1160
1161 #[test]
1162 fn test_negative_expressions() {
1163 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1164
1165 let planner = Planner::new(schema.clone());
1166
1167 let expected = col("x")
1168 .gt(lit(-3_i64))
1169 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1170
1171 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1172
1173 assert_eq!(expr, expected);
1174
1175 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1176
1177 let batch = RecordBatch::try_new(
1178 schema,
1179 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1180 )
1181 .unwrap();
1182 let predicates = physical_expr.evaluate(&batch).unwrap();
1183 assert_eq!(
1184 predicates.into_array(0).unwrap().as_ref(),
1185 &BooleanArray::from(vec![
1186 false, false, false, true, true, true, true, false, false, false
1187 ])
1188 );
1189 }
1190
1191 #[test]
1192 fn test_negative_array_expressions() {
1193 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1194
1195 let planner = Planner::new(schema);
1196
1197 let expected = Expr::Literal(
1198 ScalarValue::List(Arc::new(
1199 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1200 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1201 )]),
1202 )),
1203 None,
1204 );
1205
1206 let expr = planner
1207 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1208 .unwrap();
1209
1210 assert_eq!(expr, expected);
1211 }
1212
1213 #[test]
1214 fn test_sql_like() {
1215 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1216
1217 let planner = Planner::new(schema.clone());
1218
1219 let expected = col("s").like(lit("str-4"));
1220 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1222 assert_eq!(expr, expected);
1223 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1224
1225 let batch = RecordBatch::try_new(
1226 schema,
1227 vec![Arc::new(StringArray::from_iter_values(
1228 (0..10).map(|v| format!("str-{}", v)),
1229 ))],
1230 )
1231 .unwrap();
1232 let predicates = physical_expr.evaluate(&batch).unwrap();
1233 assert_eq!(
1234 predicates.into_array(0).unwrap().as_ref(),
1235 &BooleanArray::from(vec![
1236 false, false, false, false, true, false, false, false, false, false
1237 ])
1238 );
1239 }
1240
1241 #[test]
1242 fn test_not_like() {
1243 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1244
1245 let planner = Planner::new(schema.clone());
1246
1247 let expected = col("s").not_like(lit("str-4"));
1248 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1250 assert_eq!(expr, expected);
1251 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1252
1253 let batch = RecordBatch::try_new(
1254 schema,
1255 vec![Arc::new(StringArray::from_iter_values(
1256 (0..10).map(|v| format!("str-{}", v)),
1257 ))],
1258 )
1259 .unwrap();
1260 let predicates = physical_expr.evaluate(&batch).unwrap();
1261 assert_eq!(
1262 predicates.into_array(0).unwrap().as_ref(),
1263 &BooleanArray::from(vec![
1264 true, true, true, true, false, true, true, true, true, true
1265 ])
1266 );
1267 }
1268
1269 #[test]
1270 fn test_sql_is_in() {
1271 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1272
1273 let planner = Planner::new(schema.clone());
1274
1275 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1276 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1278 assert_eq!(expr, expected);
1279 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1280
1281 let batch = RecordBatch::try_new(
1282 schema,
1283 vec![Arc::new(StringArray::from_iter_values(
1284 (0..10).map(|v| format!("str-{}", v)),
1285 ))],
1286 )
1287 .unwrap();
1288 let predicates = physical_expr.evaluate(&batch).unwrap();
1289 assert_eq!(
1290 predicates.into_array(0).unwrap().as_ref(),
1291 &BooleanArray::from(vec![
1292 false, false, false, false, true, true, false, false, false, false
1293 ])
1294 );
1295 }
1296
1297 #[test]
1298 fn test_sql_is_null() {
1299 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1300
1301 let planner = Planner::new(schema.clone());
1302
1303 let expected = col("s").is_null();
1304 let expr = planner.parse_filter("s IS NULL").unwrap();
1305 assert_eq!(expr, expected);
1306 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1307
1308 let batch = RecordBatch::try_new(
1309 schema,
1310 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1311 if v % 3 == 0 {
1312 Some(format!("str-{}", v))
1313 } else {
1314 None
1315 }
1316 })))],
1317 )
1318 .unwrap();
1319 let predicates = physical_expr.evaluate(&batch).unwrap();
1320 assert_eq!(
1321 predicates.into_array(0).unwrap().as_ref(),
1322 &BooleanArray::from(vec![
1323 false, true, true, false, true, true, false, true, true, false
1324 ])
1325 );
1326
1327 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1328 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1329 let predicates = physical_expr.evaluate(&batch).unwrap();
1330 assert_eq!(
1331 predicates.into_array(0).unwrap().as_ref(),
1332 &BooleanArray::from(vec![
1333 true, false, false, true, false, false, true, false, false, true,
1334 ])
1335 );
1336 }
1337
1338 #[test]
1339 fn test_sql_invert() {
1340 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1341
1342 let planner = Planner::new(schema.clone());
1343
1344 let expr = planner.parse_filter("NOT s").unwrap();
1345 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1346
1347 let batch = RecordBatch::try_new(
1348 schema,
1349 vec![Arc::new(BooleanArray::from_iter(
1350 (0..10).map(|v| Some(v % 3 == 0)),
1351 ))],
1352 )
1353 .unwrap();
1354 let predicates = physical_expr.evaluate(&batch).unwrap();
1355 assert_eq!(
1356 predicates.into_array(0).unwrap().as_ref(),
1357 &BooleanArray::from(vec![
1358 false, true, true, false, true, true, false, true, true, false
1359 ])
1360 );
1361 }
1362
1363 #[test]
1364 fn test_sql_cast() {
1365 let cases = &[
1366 (
1367 "x = cast('2021-01-01 00:00:00' as timestamp)",
1368 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1369 ),
1370 (
1371 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1372 ArrowDataType::Timestamp(TimeUnit::Second, None),
1373 ),
1374 (
1375 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1376 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1377 ),
1378 (
1379 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1380 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1381 ),
1382 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1383 (
1384 "x = cast('1.238' as decimal(9,3))",
1385 ArrowDataType::Decimal128(9, 3),
1386 ),
1387 ("x = cast(1 as float)", ArrowDataType::Float32),
1388 ("x = cast(1 as double)", ArrowDataType::Float64),
1389 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1390 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1391 ("x = cast(1 as int)", ArrowDataType::Int32),
1392 ("x = cast(1 as integer)", ArrowDataType::Int32),
1393 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1394 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1395 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1396 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1397 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1398 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1399 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1400 ("x = cast(1 as string)", ArrowDataType::Utf8),
1401 ];
1402
1403 for (sql, expected_data_type) in cases {
1404 let schema = Arc::new(Schema::new(vec![Field::new(
1405 "x",
1406 expected_data_type.clone(),
1407 true,
1408 )]));
1409 let planner = Planner::new(schema.clone());
1410 let expr = planner.parse_filter(sql).unwrap();
1411
1412 let expected_value_str = sql
1414 .split("cast(")
1415 .nth(1)
1416 .unwrap()
1417 .split(" as")
1418 .next()
1419 .unwrap();
1420 let expected_value_str = expected_value_str.trim_matches('\'');
1422
1423 match expr {
1424 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1425 Expr::Cast(Cast { expr, data_type }) => {
1426 match expr.as_ref() {
1427 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1428 assert_eq!(value_str, expected_value_str);
1429 }
1430 Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1431 assert_eq!(*value, 1);
1432 }
1433 _ => panic!("Expected cast to be applied to literal"),
1434 }
1435 assert_eq!(data_type, expected_data_type);
1436 }
1437 _ => panic!("Expected right to be a cast"),
1438 },
1439 _ => panic!("Expected binary expression"),
1440 }
1441 }
1442 }
1443
1444 #[test]
1445 fn test_sql_literals() {
1446 let cases = &[
1447 (
1448 "x = timestamp '2021-01-01 00:00:00'",
1449 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1450 ),
1451 (
1452 "x = timestamp(0) '2021-01-01 00:00:00'",
1453 ArrowDataType::Timestamp(TimeUnit::Second, None),
1454 ),
1455 (
1456 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1457 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1458 ),
1459 ("x = date '2021-01-01'", ArrowDataType::Date32),
1460 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1461 ];
1462
1463 for (sql, expected_data_type) in cases {
1464 let schema = Arc::new(Schema::new(vec![Field::new(
1465 "x",
1466 expected_data_type.clone(),
1467 true,
1468 )]));
1469 let planner = Planner::new(schema.clone());
1470 let expr = planner.parse_filter(sql).unwrap();
1471
1472 let expected_value_str = sql.split('\'').nth(1).unwrap();
1473
1474 match expr {
1475 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1476 Expr::Cast(Cast { expr, data_type }) => {
1477 match expr.as_ref() {
1478 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1479 assert_eq!(value_str, expected_value_str);
1480 }
1481 _ => panic!("Expected cast to be applied to literal"),
1482 }
1483 assert_eq!(data_type, expected_data_type);
1484 }
1485 _ => panic!("Expected right to be a cast"),
1486 },
1487 _ => panic!("Expected binary expression"),
1488 }
1489 }
1490 }
1491
1492 #[test]
1493 fn test_sql_array_literals() {
1494 let cases = [
1495 (
1496 "x = [1, 2, 3]",
1497 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1498 ),
1499 (
1500 "x = [1, 2, 3]",
1501 ArrowDataType::FixedSizeList(
1502 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1503 3,
1504 ),
1505 ),
1506 ];
1507
1508 for (sql, expected_data_type) in cases {
1509 let schema = Arc::new(Schema::new(vec![Field::new(
1510 "x",
1511 expected_data_type.clone(),
1512 true,
1513 )]));
1514 let planner = Planner::new(schema.clone());
1515 let expr = planner.parse_filter(sql).unwrap();
1516 let expr = planner.optimize_expr(expr).unwrap();
1517
1518 match expr {
1519 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1520 Expr::Literal(value, _) => {
1521 assert_eq!(&value.data_type(), &expected_data_type);
1522 }
1523 _ => panic!("Expected right to be a literal"),
1524 },
1525 _ => panic!("Expected binary expression"),
1526 }
1527 }
1528 }
1529
1530 #[test]
1531 fn test_sql_between() {
1532 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1533 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1534 use std::sync::Arc;
1535
1536 let schema = Arc::new(Schema::new(vec![
1537 Field::new("x", DataType::Int32, false),
1538 Field::new("y", DataType::Float64, false),
1539 Field::new(
1540 "ts",
1541 DataType::Timestamp(TimeUnit::Microsecond, None),
1542 false,
1543 ),
1544 ]));
1545
1546 let planner = Planner::new(schema.clone());
1547
1548 let expr = planner
1550 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1551 .unwrap();
1552 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1553
1554 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1558 (0..10).map(|i| base_ts + i * 1_000_000), );
1560
1561 let batch = RecordBatch::try_new(
1562 schema,
1563 vec![
1564 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1565 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1566 Arc::new(ts_array),
1567 ],
1568 )
1569 .unwrap();
1570
1571 let predicates = physical_expr.evaluate(&batch).unwrap();
1572 assert_eq!(
1573 predicates.into_array(0).unwrap().as_ref(),
1574 &BooleanArray::from(vec![
1575 false, false, false, true, true, true, true, true, false, false
1576 ])
1577 );
1578
1579 let expr = planner
1581 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1582 .unwrap();
1583 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1584
1585 let predicates = physical_expr.evaluate(&batch).unwrap();
1586 assert_eq!(
1587 predicates.into_array(0).unwrap().as_ref(),
1588 &BooleanArray::from(vec![
1589 true, true, true, false, false, false, false, false, true, true
1590 ])
1591 );
1592
1593 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1595 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1596
1597 let predicates = physical_expr.evaluate(&batch).unwrap();
1598 assert_eq!(
1599 predicates.into_array(0).unwrap().as_ref(),
1600 &BooleanArray::from(vec![
1601 false, false, false, true, true, true, true, false, false, false
1602 ])
1603 );
1604
1605 let expr = planner
1607 .parse_filter(
1608 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1609 )
1610 .unwrap();
1611 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1612
1613 let predicates = physical_expr.evaluate(&batch).unwrap();
1614 assert_eq!(
1615 predicates.into_array(0).unwrap().as_ref(),
1616 &BooleanArray::from(vec![
1617 false, false, false, true, true, true, true, true, false, false
1618 ])
1619 );
1620 }
1621
1622 #[test]
1623 fn test_sql_comparison() {
1624 let batch: Vec<(&str, ArrayRef)> = vec![
1626 (
1627 "timestamp_s",
1628 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1629 ),
1630 (
1631 "timestamp_ms",
1632 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1633 ),
1634 (
1635 "timestamp_us",
1636 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1637 ),
1638 (
1639 "timestamp_ns",
1640 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1641 ),
1642 ];
1643 let batch = RecordBatch::try_from_iter(batch).unwrap();
1644
1645 let planner = Planner::new(batch.schema());
1646
1647 let expressions = &[
1649 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1650 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1651 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1652 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1653 ];
1654
1655 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1656 std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1657 ));
1658 for expression in expressions {
1659 let logical_expr = planner.parse_filter(expression).unwrap();
1661 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1662 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1663
1664 let result = physical_expr.evaluate(&batch).unwrap();
1666 let result = result.into_array(batch.num_rows()).unwrap();
1667 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1668 }
1669 }
1670
1671 #[test]
1672 fn test_columns_in_expr() {
1673 let expr = col("s0").gt(lit("value")).and(
1674 col("st")
1675 .field("st")
1676 .field("s2")
1677 .eq(lit("value"))
1678 .or(col("st")
1679 .field("s1")
1680 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1681 );
1682
1683 let columns = Planner::column_names_in_expr(&expr);
1684 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1685 }
1686
1687 #[test]
1688 fn test_parse_binary_expr() {
1689 let bin_str = "x'616263'";
1690
1691 let schema = Arc::new(Schema::new(vec![Field::new(
1692 "binary",
1693 DataType::Binary,
1694 true,
1695 )]));
1696 let planner = Planner::new(schema);
1697 let expr = planner.parse_expr(bin_str).unwrap();
1698 assert_eq!(
1699 expr,
1700 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1701 );
1702 }
1703
1704 #[test]
1705 fn test_lance_context_provider_expr_planners() {
1706 let ctx_provider = LanceContextProvider::default();
1707 assert!(!ctx_provider.get_expr_planners().is_empty());
1708 }
1709}