1use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::{SessionContext, SessionState};
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30 Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
34use datafusion::sql::sqlparser::ast::{
35 AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36 Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
37 ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
38};
39use datafusion::{
40 common::Column,
41 logical_expr::{col, Between, BinaryExpr, Like, Operator},
42 physical_expr::execution_props::ExecutionProps,
43 physical_plan::PhysicalExpr,
44 prelude::Expr,
45 scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use lance_core::error::LanceOptionExt;
51use snafu::location;
52
53use lance_core::{Error, Result};
54
55#[derive(Debug, Clone)]
56struct CastListF16Udf {
57 signature: Signature,
58}
59
60impl CastListF16Udf {
61 pub fn new() -> Self {
62 Self {
63 signature: Signature::any(1, Volatility::Immutable),
64 }
65 }
66}
67
68impl ScalarUDFImpl for CastListF16Udf {
69 fn as_any(&self) -> &dyn std::any::Any {
70 self
71 }
72
73 fn name(&self) -> &str {
74 "_cast_list_f16"
75 }
76
77 fn signature(&self) -> &Signature {
78 &self.signature
79 }
80
81 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
82 let input = &arg_types[0];
83 match input {
84 ArrowDataType::FixedSizeList(field, size) => {
85 if field.data_type() != &ArrowDataType::Float32
86 && field.data_type() != &ArrowDataType::Float16
87 {
88 return Err(datafusion::error::DataFusionError::Execution(
89 "cast_list_f16 only supports list of float32 or float16".to_string(),
90 ));
91 }
92 Ok(ArrowDataType::FixedSizeList(
93 Arc::new(Field::new(
94 field.name(),
95 ArrowDataType::Float16,
96 field.is_nullable(),
97 )),
98 *size,
99 ))
100 }
101 ArrowDataType::List(field) => {
102 if field.data_type() != &ArrowDataType::Float32
103 && field.data_type() != &ArrowDataType::Float16
104 {
105 return Err(datafusion::error::DataFusionError::Execution(
106 "cast_list_f16 only supports list of float32 or float16".to_string(),
107 ));
108 }
109 Ok(ArrowDataType::List(Arc::new(Field::new(
110 field.name(),
111 ArrowDataType::Float16,
112 field.is_nullable(),
113 ))))
114 }
115 _ => Err(datafusion::error::DataFusionError::Execution(
116 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
117 )),
118 }
119 }
120
121 fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
122 let ColumnarValue::Array(arr) = &func_args.args[0] else {
123 return Err(datafusion::error::DataFusionError::Execution(
124 "cast_list_f16 only supports array arguments".to_string(),
125 ));
126 };
127
128 let to_type = match arr.data_type() {
129 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
130 Arc::new(Field::new(
131 field.name(),
132 ArrowDataType::Float16,
133 field.is_nullable(),
134 )),
135 *size,
136 ),
137 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
138 field.name(),
139 ArrowDataType::Float16,
140 field.is_nullable(),
141 ))),
142 _ => {
143 return Err(datafusion::error::DataFusionError::Execution(
144 "cast_list_f16 only supports array arguments".to_string(),
145 ));
146 }
147 };
148
149 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
150 Ok(ColumnarValue::Array(res))
151 }
152}
153
154struct LanceContextProvider {
156 options: datafusion::config::ConfigOptions,
157 state: SessionState,
158 expr_planners: Vec<Arc<dyn ExprPlanner>>,
159}
160
161impl Default for LanceContextProvider {
162 fn default() -> Self {
163 let config = SessionConfig::new();
164 let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
165
166 let ctx = SessionContext::new_with_config_rt(config.clone(), runtime.clone());
167 crate::udf::register_functions(&ctx);
168
169 let state = ctx.state();
170
171 let mut state_builder = SessionStateBuilder::new()
173 .with_config(config)
174 .with_runtime_env(runtime)
175 .with_default_features();
176
177 let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
179
180 Self {
181 options: ConfigOptions::default(),
182 state,
183 expr_planners,
184 }
185 }
186}
187
188impl ContextProvider for LanceContextProvider {
189 fn get_table_source(
190 &self,
191 name: datafusion::sql::TableReference,
192 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
193 Err(datafusion::error::DataFusionError::NotImplemented(format!(
194 "Attempt to reference inner table {} not supported",
195 name
196 )))
197 }
198
199 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
200 self.state.aggregate_functions().get(name).cloned()
201 }
202
203 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
204 self.state.window_functions().get(name).cloned()
205 }
206
207 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
208 match f {
209 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
212 _ => self.state.scalar_functions().get(f).cloned(),
213 }
214 }
215
216 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
217 None
219 }
220
221 fn options(&self) -> &datafusion::config::ConfigOptions {
222 &self.options
223 }
224
225 fn udf_names(&self) -> Vec<String> {
226 self.state.scalar_functions().keys().cloned().collect()
227 }
228
229 fn udaf_names(&self) -> Vec<String> {
230 self.state.aggregate_functions().keys().cloned().collect()
231 }
232
233 fn udwf_names(&self) -> Vec<String> {
234 self.state.window_functions().keys().cloned().collect()
235 }
236
237 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
238 &self.expr_planners
239 }
240}
241
242pub struct Planner {
243 schema: SchemaRef,
244 context_provider: LanceContextProvider,
245 enable_relations: bool,
246}
247
248impl Planner {
249 pub fn new(schema: SchemaRef) -> Self {
250 Self {
251 schema,
252 context_provider: LanceContextProvider::default(),
253 enable_relations: false,
254 }
255 }
256
257 pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
263 self.enable_relations = enable_relations;
264 self
265 }
266
267 fn column(&self, idents: &[Ident]) -> Expr {
268 fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
269 for ident in idents {
270 *expr = Expr::ScalarFunction(ScalarFunction {
271 args: vec![
272 std::mem::take(expr),
273 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
274 ],
275 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
276 });
277 }
278 }
279
280 if self.enable_relations && idents.len() > 1 {
281 let relation = &idents[0].value;
283 let column_name = &idents[1].value;
284 let column = Expr::Column(Column::new(Some(relation.clone()), column_name.clone()));
285 let mut result = column;
286 handle_remaining_idents(&mut result, &idents[2..]);
287 result
288 } else {
289 let mut column = col(&idents[0].value);
291 handle_remaining_idents(&mut column, &idents[1..]);
292 column
293 }
294 }
295
296 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
297 Ok(match op {
298 BinaryOperator::Plus => Operator::Plus,
299 BinaryOperator::Minus => Operator::Minus,
300 BinaryOperator::Multiply => Operator::Multiply,
301 BinaryOperator::Divide => Operator::Divide,
302 BinaryOperator::Modulo => Operator::Modulo,
303 BinaryOperator::StringConcat => Operator::StringConcat,
304 BinaryOperator::Gt => Operator::Gt,
305 BinaryOperator::Lt => Operator::Lt,
306 BinaryOperator::GtEq => Operator::GtEq,
307 BinaryOperator::LtEq => Operator::LtEq,
308 BinaryOperator::Eq => Operator::Eq,
309 BinaryOperator::NotEq => Operator::NotEq,
310 BinaryOperator::And => Operator::And,
311 BinaryOperator::Or => Operator::Or,
312 _ => {
313 return Err(Error::invalid_input(
314 format!("Operator {op} is not supported"),
315 location!(),
316 ));
317 }
318 })
319 }
320
321 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
322 Ok(Expr::BinaryExpr(BinaryExpr::new(
323 Box::new(self.parse_sql_expr(left)?),
324 self.binary_op(op)?,
325 Box::new(self.parse_sql_expr(right)?),
326 )))
327 }
328
329 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
330 Ok(match op {
331 UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
332 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
333 }
334
335 UnaryOperator::Minus => {
336 use datafusion::logical_expr::lit;
337 match expr {
338 SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
339 Ok(n) => lit(-n),
340 Err(_) => lit(-n
341 .parse::<f64>()
342 .map_err(|_e| {
343 Error::invalid_input(
344 format!("negative operator can be only applied to integer and float operands, got: {n}"),
345 location!(),
346 )
347 })?),
348 },
349 _ => {
350 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
351 }
352 }
353 }
354
355 _ => {
356 return Err(Error::invalid_input(
357 format!("Unary operator '{:?}' is not supported", op),
358 location!(),
359 ));
360 }
361 })
362 }
363
364 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
366 use datafusion::logical_expr::lit;
367 let value: Cow<str> = if negative {
368 Cow::Owned(format!("-{}", value))
369 } else {
370 Cow::Borrowed(value)
371 };
372 if let Ok(n) = value.parse::<i64>() {
373 Ok(lit(n))
374 } else {
375 value.parse::<f64>().map(lit).map_err(|_| {
376 Error::invalid_input(
377 format!("'{value}' is not supported number value."),
378 location!(),
379 )
380 })
381 }
382 }
383
384 fn value(&self, value: &Value) -> Result<Expr> {
385 Ok(match value {
386 Value::Number(v, _) => self.number(v.as_str(), false)?,
387 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
388 Value::HexStringLiteral(hsl) => {
389 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
390 }
391 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
392 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
393 Value::Null => Expr::Literal(ScalarValue::Null, None),
394 _ => todo!(),
395 })
396 }
397
398 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
399 match func_args {
400 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
401 _ => Err(Error::invalid_input(
402 format!("Unsupported function args: {:?}", func_args),
403 location!(),
404 )),
405 }
406 }
407
408 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
415 match &func.args {
416 FunctionArguments::List(args) => {
417 if func.name.0.len() != 1 {
418 return Err(Error::invalid_input(
419 format!("Function name must have 1 part, got: {:?}", func.name.0),
420 location!(),
421 ));
422 }
423 Ok(Expr::IsNotNull(Box::new(
424 self.parse_function_args(&args.args[0])?,
425 )))
426 }
427 _ => Err(Error::invalid_input(
428 format!("Unsupported function args: {:?}", &func.args),
429 location!(),
430 )),
431 }
432 }
433
434 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
435 if let SQLExpr::Function(function) = &function {
436 if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
437 if &name.value == "is_valid" {
438 return self.legacy_parse_function(function);
439 }
440 }
441 }
442 let sql_to_rel = SqlToRel::new_with_options(
443 &self.context_provider,
444 ParserOptions {
445 parse_float_as_decimal: false,
446 enable_ident_normalization: false,
447 support_varchar_with_length: false,
448 enable_options_value_normalization: false,
449 collect_spans: false,
450 map_string_types_to_utf8view: false,
451 },
452 );
453
454 let mut planner_context = PlannerContext::default();
455 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
456 sql_to_rel
457 .sql_to_expr(function, &schema, &mut planner_context)
458 .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
459 }
460
461 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
462 const SUPPORTED_TYPES: [&str; 13] = [
463 "int [unsigned]",
464 "tinyint [unsigned]",
465 "smallint [unsigned]",
466 "bigint [unsigned]",
467 "float",
468 "double",
469 "string",
470 "binary",
471 "date",
472 "timestamp(precision)",
473 "datetime(precision)",
474 "decimal(precision,scale)",
475 "boolean",
476 ];
477 match data_type {
478 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
479 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
480 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
481 SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
482 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
483 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
484 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
485 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
486 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
487 SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
488 SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
489 SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
490 Ok(ArrowDataType::UInt32)
491 }
492 SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
493 SQLDataType::Date => Ok(ArrowDataType::Date32),
494 SQLDataType::Timestamp(resolution, tz) => {
495 match tz {
496 TimezoneInfo::None => {}
497 _ => {
498 return Err(Error::invalid_input(
499 "Timezone not supported in timestamp".to_string(),
500 location!(),
501 ));
502 }
503 };
504 let time_unit = match resolution {
505 None => TimeUnit::Microsecond,
507 Some(0) => TimeUnit::Second,
508 Some(3) => TimeUnit::Millisecond,
509 Some(6) => TimeUnit::Microsecond,
510 Some(9) => TimeUnit::Nanosecond,
511 _ => {
512 return Err(Error::invalid_input(
513 format!("Unsupported datetime resolution: {:?}", resolution),
514 location!(),
515 ));
516 }
517 };
518 Ok(ArrowDataType::Timestamp(time_unit, None))
519 }
520 SQLDataType::Datetime(resolution) => {
521 let time_unit = match resolution {
522 None => TimeUnit::Microsecond,
523 Some(0) => TimeUnit::Second,
524 Some(3) => TimeUnit::Millisecond,
525 Some(6) => TimeUnit::Microsecond,
526 Some(9) => TimeUnit::Nanosecond,
527 _ => {
528 return Err(Error::invalid_input(
529 format!("Unsupported datetime resolution: {:?}", resolution),
530 location!(),
531 ));
532 }
533 };
534 Ok(ArrowDataType::Timestamp(time_unit, None))
535 }
536 SQLDataType::Decimal(number_info) => match number_info {
537 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
538 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
539 }
540 _ => Err(Error::invalid_input(
541 format!(
542 "Must provide precision and scale for decimal: {:?}",
543 number_info
544 ),
545 location!(),
546 )),
547 },
548 _ => Err(Error::invalid_input(
549 format!(
550 "Unsupported data type: {:?}. Supported types: {:?}",
551 data_type, SUPPORTED_TYPES
552 ),
553 location!(),
554 )),
555 }
556 }
557
558 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
559 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
560 for planner in self.context_provider.get_expr_planners() {
561 match planner.plan_field_access(field_access_expr, &df_schema)? {
562 PlannerResult::Planned(expr) => return Ok(expr),
563 PlannerResult::Original(expr) => {
564 field_access_expr = expr;
565 }
566 }
567 }
568 Err(Error::invalid_input(
569 "Field access could not be planned",
570 location!(),
571 ))
572 }
573
574 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
575 match expr {
576 SQLExpr::Identifier(id) => {
577 if id.quote_style == Some('"') {
580 Ok(Expr::Literal(
581 ScalarValue::Utf8(Some(id.value.clone())),
582 None,
583 ))
584 } else if id.quote_style == Some('`') {
587 Ok(Expr::Column(Column::from_name(id.value.clone())))
588 } else {
589 Ok(self.column(vec![id.clone()].as_slice()))
590 }
591 }
592 SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
593 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
594 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
595 SQLExpr::Value(value) => self.value(&value.value),
596 SQLExpr::Array(SQLArray { elem, .. }) => {
597 let mut values = vec![];
598
599 let array_literal_error = |pos: usize, value: &_| {
600 Err(Error::invalid_input(
601 format!(
602 "Expected a literal value in array, instead got {} at position {}",
603 value, pos
604 ),
605 location!(),
606 ))
607 };
608
609 for (pos, expr) in elem.iter().enumerate() {
610 match expr {
611 SQLExpr::Value(value) => {
612 if let Expr::Literal(value, _) = self.value(&value.value)? {
613 values.push(value);
614 } else {
615 return array_literal_error(pos, expr);
616 }
617 }
618 SQLExpr::UnaryOp {
619 op: UnaryOperator::Minus,
620 expr,
621 } => {
622 if let SQLExpr::Value(ValueWithSpan {
623 value: Value::Number(number, _),
624 ..
625 }) = expr.as_ref()
626 {
627 if let Expr::Literal(value, _) = self.number(number, true)? {
628 values.push(value);
629 } else {
630 return array_literal_error(pos, expr);
631 }
632 } else {
633 return array_literal_error(pos, expr);
634 }
635 }
636 _ => {
637 return array_literal_error(pos, expr);
638 }
639 }
640 }
641
642 let field = if !values.is_empty() {
643 let data_type = values[0].data_type();
644
645 for value in &mut values {
646 if value.data_type() != data_type {
647 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
648 format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
649 location!()
650 ))?;
651 }
652 }
653 Field::new("item", data_type, true)
654 } else {
655 Field::new("item", ArrowDataType::Null, true)
656 };
657
658 let values = values
659 .into_iter()
660 .map(|v| v.to_array().map_err(Error::from))
661 .collect::<Result<Vec<_>>>()?;
662 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
663 let values = concat(&array_refs)?;
664 let values = ListArray::try_new(
665 field.into(),
666 OffsetBuffer::from_lengths([values.len()]),
667 values,
668 None,
669 )?;
670
671 Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
672 }
673 SQLExpr::TypedString { data_type, value } => {
675 let value = value.clone().into_string().expect_ok()?;
676 Ok(Expr::Cast(datafusion::logical_expr::Cast {
677 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
678 data_type: self.parse_type(data_type)?,
679 }))
680 }
681 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
682 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
683 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
684 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
685 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
686 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
687 SQLExpr::InList {
688 expr,
689 list,
690 negated,
691 } => {
692 let value_expr = self.parse_sql_expr(expr)?;
693 let list_exprs = list
694 .iter()
695 .map(|e| self.parse_sql_expr(e))
696 .collect::<Result<Vec<_>>>()?;
697 Ok(value_expr.in_list(list_exprs, *negated))
698 }
699 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
700 SQLExpr::Function(_) => self.parse_function(expr.clone()),
701 SQLExpr::ILike {
702 negated,
703 expr,
704 pattern,
705 escape_char,
706 any: _,
707 } => Ok(Expr::Like(Like::new(
708 *negated,
709 Box::new(self.parse_sql_expr(expr)?),
710 Box::new(self.parse_sql_expr(pattern)?),
711 escape_char.as_ref().and_then(|c| c.chars().next()),
712 true,
713 ))),
714 SQLExpr::Like {
715 negated,
716 expr,
717 pattern,
718 escape_char,
719 any: _,
720 } => Ok(Expr::Like(Like::new(
721 *negated,
722 Box::new(self.parse_sql_expr(expr)?),
723 Box::new(self.parse_sql_expr(pattern)?),
724 escape_char.as_ref().and_then(|c| c.chars().next()),
725 false,
726 ))),
727 SQLExpr::Cast {
728 expr,
729 data_type,
730 kind,
731 ..
732 } => match kind {
733 datafusion::sql::sqlparser::ast::CastKind::TryCast
734 | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
735 Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
736 expr: Box::new(self.parse_sql_expr(expr)?),
737 data_type: self.parse_type(data_type)?,
738 }))
739 }
740 _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
741 expr: Box::new(self.parse_sql_expr(expr)?),
742 data_type: self.parse_type(data_type)?,
743 })),
744 },
745 SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
746 "JSON access is not supported",
747 location!(),
748 )),
749 SQLExpr::CompoundFieldAccess { root, access_chain } => {
750 let mut expr = self.parse_sql_expr(root)?;
751
752 for access in access_chain {
753 let field_access = match access {
754 AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
756 | AccessExpr::Subscript(Subscript::Index {
757 index:
758 SQLExpr::Value(ValueWithSpan {
759 value:
760 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
761 ..
762 }),
763 }) => GetFieldAccess::NamedStructField {
764 name: ScalarValue::from(s.as_str()),
765 },
766 AccessExpr::Subscript(Subscript::Index { index }) => {
767 let key = Box::new(self.parse_sql_expr(index)?);
768 GetFieldAccess::ListIndex { key }
769 }
770 AccessExpr::Subscript(Subscript::Slice { .. }) => {
771 return Err(Error::invalid_input(
772 "Slice subscript is not supported",
773 location!(),
774 ));
775 }
776 _ => {
777 return Err(Error::invalid_input(
780 "Only dot notation or index access is supported for field access",
781 location!(),
782 ));
783 }
784 };
785
786 let field_access_expr = RawFieldAccessExpr { expr, field_access };
787 expr = self.plan_field_access(field_access_expr)?;
788 }
789
790 Ok(expr)
791 }
792 SQLExpr::Between {
793 expr,
794 negated,
795 low,
796 high,
797 } => {
798 let expr = self.parse_sql_expr(expr)?;
800 let low = self.parse_sql_expr(low)?;
801 let high = self.parse_sql_expr(high)?;
802
803 let between = Expr::Between(Between::new(
804 Box::new(expr),
805 *negated,
806 Box::new(low),
807 Box::new(high),
808 ));
809 Ok(between)
810 }
811 _ => Err(Error::invalid_input(
812 format!("Expression '{expr}' is not supported SQL in lance"),
813 location!(),
814 )),
815 }
816 }
817
818 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
823 let ast_expr = parse_sql_filter(filter)?;
825 let expr = self.parse_sql_expr(&ast_expr)?;
826 let schema = Schema::try_from(self.schema.as_ref())?;
827 let resolved = resolve_expr(&expr, &schema).map_err(|e| {
828 Error::invalid_input(
829 format!("Error resolving filter expression {filter}: {e}"),
830 location!(),
831 )
832 })?;
833
834 Ok(coerce_filter_type_to_boolean(resolved))
835 }
836
837 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
842 let ast_expr = parse_sql_expr(expr)?;
843 let expr = self.parse_sql_expr(&ast_expr)?;
844 let schema = Schema::try_from(self.schema.as_ref())?;
845 let resolved = resolve_expr(&expr, &schema)?;
846 Ok(resolved)
847 }
848
849 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
855 let hex_bytes = s.as_bytes();
856 let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
857
858 let start_idx = hex_bytes.len() % 2;
859 if start_idx > 0 {
860 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
862 }
863
864 for i in (start_idx..hex_bytes.len()).step_by(2) {
865 let high = Self::try_decode_hex_char(hex_bytes[i])?;
866 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
867 decoded_bytes.push((high << 4) | low);
868 }
869
870 Some(decoded_bytes)
871 }
872
873 const fn try_decode_hex_char(c: u8) -> Option<u8> {
877 match c {
878 b'A'..=b'F' => Some(c - b'A' + 10),
879 b'a'..=b'f' => Some(c - b'a' + 10),
880 b'0'..=b'9' => Some(c - b'0'),
881 _ => None,
882 }
883 }
884
885 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
887 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
888
889 let props = ExecutionProps::default();
892 let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
893 let simplifier =
894 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
895
896 let expr = simplifier.simplify(expr)?;
897 let expr = simplifier.coerce(expr, &df_schema)?;
898
899 Ok(expr)
900 }
901
902 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
904 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
905 Ok(datafusion::physical_expr::create_physical_expr(
906 expr,
907 df_schema.as_ref(),
908 &Default::default(),
909 )?)
910 }
911
912 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
919 let mut visitor = ColumnCapturingVisitor {
920 current_path: VecDeque::new(),
921 columns: BTreeSet::new(),
922 };
923 expr.visit(&mut visitor).unwrap();
924 visitor.columns.into_iter().collect()
925 }
926}
927
928struct ColumnCapturingVisitor {
929 current_path: VecDeque<String>,
931 columns: BTreeSet<String>,
932}
933
934impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
935 type Node = Expr;
936
937 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
938 match node {
939 Expr::Column(Column { name, .. }) => {
940 let mut path = name.clone();
944 for part in self.current_path.drain(..) {
945 path.push('.');
946 if part.contains('.') || part.contains('`') {
948 let escaped = part.replace('`', "``");
950 path.push('`');
951 path.push_str(&escaped);
952 path.push('`');
953 } else {
954 path.push_str(&part);
955 }
956 }
957 self.columns.insert(path);
958 self.current_path.clear();
959 }
960 Expr::ScalarFunction(udf) => {
961 if udf.name() == GetFieldFunc::default().name() {
962 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
963 self.current_path.push_front(name.to_string())
964 } else {
965 self.current_path.clear();
966 }
967 } else {
968 self.current_path.clear();
969 }
970 }
971 _ => {
972 self.current_path.clear();
973 }
974 }
975
976 Ok(TreeNodeRecursion::Continue)
977 }
978}
979
980#[cfg(test)]
981mod tests {
982
983 use crate::logical_expr::ExprExt;
984
985 use super::*;
986
987 use arrow::datatypes::Float64Type;
988 use arrow_array::{
989 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
990 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
991 TimestampNanosecondArray, TimestampSecondArray,
992 };
993 use arrow_schema::{DataType, Fields, Schema};
994 use datafusion::{
995 logical_expr::{lit, Cast},
996 prelude::{array_element, get_field},
997 };
998 use datafusion_functions::core::expr_ext::FieldAccessor;
999
1000 #[test]
1001 fn test_parse_filter_simple() {
1002 let schema = Arc::new(Schema::new(vec![
1003 Field::new("i", DataType::Int32, false),
1004 Field::new("s", DataType::Utf8, true),
1005 Field::new(
1006 "st",
1007 DataType::Struct(Fields::from(vec![
1008 Field::new("x", DataType::Float32, false),
1009 Field::new("y", DataType::Float32, false),
1010 ])),
1011 true,
1012 ),
1013 ]));
1014
1015 let planner = Planner::new(schema.clone());
1016
1017 let expected = col("i")
1018 .gt(lit(3_i32))
1019 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1020 .and(
1021 col("s")
1022 .eq(lit("str-4"))
1023 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1024 );
1025
1026 let expr = planner
1028 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1029 .unwrap();
1030 assert_eq!(expr, expected);
1031
1032 let expr = planner
1034 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1035 .unwrap();
1036
1037 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1038
1039 let batch = RecordBatch::try_new(
1040 schema,
1041 vec![
1042 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1043 Arc::new(StringArray::from_iter_values(
1044 (0..10).map(|v| format!("str-{}", v)),
1045 )),
1046 Arc::new(StructArray::from(vec![
1047 (
1048 Arc::new(Field::new("x", DataType::Float32, false)),
1049 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1050 as ArrayRef,
1051 ),
1052 (
1053 Arc::new(Field::new("y", DataType::Float32, false)),
1054 Arc::new(Float32Array::from_iter_values(
1055 (0..10).map(|v| (v * 10) as f32),
1056 )),
1057 ),
1058 ])),
1059 ],
1060 )
1061 .unwrap();
1062 let predicates = physical_expr.evaluate(&batch).unwrap();
1063 assert_eq!(
1064 predicates.into_array(0).unwrap().as_ref(),
1065 &BooleanArray::from(vec![
1066 false, false, false, false, true, true, false, false, false, false
1067 ])
1068 );
1069 }
1070
1071 #[test]
1072 fn test_nested_col_refs() {
1073 let schema = Arc::new(Schema::new(vec![
1074 Field::new("s0", DataType::Utf8, true),
1075 Field::new(
1076 "st",
1077 DataType::Struct(Fields::from(vec![
1078 Field::new("s1", DataType::Utf8, true),
1079 Field::new(
1080 "st",
1081 DataType::Struct(Fields::from(vec![Field::new(
1082 "s2",
1083 DataType::Utf8,
1084 true,
1085 )])),
1086 true,
1087 ),
1088 ])),
1089 true,
1090 ),
1091 ]));
1092
1093 let planner = Planner::new(schema);
1094
1095 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1096 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1097 assert!(matches!(
1098 expr,
1099 Expr::BinaryExpr(BinaryExpr {
1100 left: _,
1101 op: Operator::Eq,
1102 right: _
1103 })
1104 ));
1105 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1106 assert_eq!(left.as_ref(), expected);
1107 }
1108 }
1109
1110 let expected = Expr::Column(Column::new_unqualified("s0"));
1111 assert_column_eq(&planner, "s0", &expected);
1112 assert_column_eq(&planner, "`s0`", &expected);
1113
1114 let expected = Expr::ScalarFunction(ScalarFunction {
1115 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1116 args: vec![
1117 Expr::Column(Column::new_unqualified("st")),
1118 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1119 ],
1120 });
1121 assert_column_eq(&planner, "st.s1", &expected);
1122 assert_column_eq(&planner, "`st`.`s1`", &expected);
1123 assert_column_eq(&planner, "st.`s1`", &expected);
1124
1125 let expected = Expr::ScalarFunction(ScalarFunction {
1126 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1127 args: vec![
1128 Expr::ScalarFunction(ScalarFunction {
1129 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1130 args: vec![
1131 Expr::Column(Column::new_unqualified("st")),
1132 Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1133 ],
1134 }),
1135 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1136 ],
1137 });
1138
1139 assert_column_eq(&planner, "st.st.s2", &expected);
1140 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1141 assert_column_eq(&planner, "st.st.`s2`", &expected);
1142 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1143 }
1144
1145 #[test]
1146 fn test_nested_list_refs() {
1147 let schema = Arc::new(Schema::new(vec![Field::new(
1148 "l",
1149 DataType::List(Arc::new(Field::new(
1150 "item",
1151 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1152 true,
1153 ))),
1154 true,
1155 )]));
1156
1157 let planner = Planner::new(schema);
1158
1159 let expected = array_element(col("l"), lit(0_i64));
1160 let expr = planner.parse_expr("l[0]").unwrap();
1161 assert_eq!(expr, expected);
1162
1163 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1164 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1165 assert_eq!(expr, expected);
1166
1167 }
1172
1173 #[test]
1174 fn test_negative_expressions() {
1175 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1176
1177 let planner = Planner::new(schema.clone());
1178
1179 let expected = col("x")
1180 .gt(lit(-3_i64))
1181 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1182
1183 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1184
1185 assert_eq!(expr, expected);
1186
1187 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1188
1189 let batch = RecordBatch::try_new(
1190 schema,
1191 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1192 )
1193 .unwrap();
1194 let predicates = physical_expr.evaluate(&batch).unwrap();
1195 assert_eq!(
1196 predicates.into_array(0).unwrap().as_ref(),
1197 &BooleanArray::from(vec![
1198 false, false, false, true, true, true, true, false, false, false
1199 ])
1200 );
1201 }
1202
1203 #[test]
1204 fn test_negative_array_expressions() {
1205 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1206
1207 let planner = Planner::new(schema);
1208
1209 let expected = Expr::Literal(
1210 ScalarValue::List(Arc::new(
1211 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1212 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1213 )]),
1214 )),
1215 None,
1216 );
1217
1218 let expr = planner
1219 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1220 .unwrap();
1221
1222 assert_eq!(expr, expected);
1223 }
1224
1225 #[test]
1226 fn test_sql_like() {
1227 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1228
1229 let planner = Planner::new(schema.clone());
1230
1231 let expected = col("s").like(lit("str-4"));
1232 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1234 assert_eq!(expr, expected);
1235 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1236
1237 let batch = RecordBatch::try_new(
1238 schema,
1239 vec![Arc::new(StringArray::from_iter_values(
1240 (0..10).map(|v| format!("str-{}", v)),
1241 ))],
1242 )
1243 .unwrap();
1244 let predicates = physical_expr.evaluate(&batch).unwrap();
1245 assert_eq!(
1246 predicates.into_array(0).unwrap().as_ref(),
1247 &BooleanArray::from(vec![
1248 false, false, false, false, true, false, false, false, false, false
1249 ])
1250 );
1251 }
1252
1253 #[test]
1254 fn test_not_like() {
1255 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1256
1257 let planner = Planner::new(schema.clone());
1258
1259 let expected = col("s").not_like(lit("str-4"));
1260 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1262 assert_eq!(expr, expected);
1263 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1264
1265 let batch = RecordBatch::try_new(
1266 schema,
1267 vec![Arc::new(StringArray::from_iter_values(
1268 (0..10).map(|v| format!("str-{}", v)),
1269 ))],
1270 )
1271 .unwrap();
1272 let predicates = physical_expr.evaluate(&batch).unwrap();
1273 assert_eq!(
1274 predicates.into_array(0).unwrap().as_ref(),
1275 &BooleanArray::from(vec![
1276 true, true, true, true, false, true, true, true, true, true
1277 ])
1278 );
1279 }
1280
1281 #[test]
1282 fn test_sql_is_in() {
1283 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1284
1285 let planner = Planner::new(schema.clone());
1286
1287 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1288 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1290 assert_eq!(expr, expected);
1291 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1292
1293 let batch = RecordBatch::try_new(
1294 schema,
1295 vec![Arc::new(StringArray::from_iter_values(
1296 (0..10).map(|v| format!("str-{}", v)),
1297 ))],
1298 )
1299 .unwrap();
1300 let predicates = physical_expr.evaluate(&batch).unwrap();
1301 assert_eq!(
1302 predicates.into_array(0).unwrap().as_ref(),
1303 &BooleanArray::from(vec![
1304 false, false, false, false, true, true, false, false, false, false
1305 ])
1306 );
1307 }
1308
1309 #[test]
1310 fn test_sql_is_null() {
1311 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1312
1313 let planner = Planner::new(schema.clone());
1314
1315 let expected = col("s").is_null();
1316 let expr = planner.parse_filter("s IS NULL").unwrap();
1317 assert_eq!(expr, expected);
1318 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1319
1320 let batch = RecordBatch::try_new(
1321 schema,
1322 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1323 if v % 3 == 0 {
1324 Some(format!("str-{}", v))
1325 } else {
1326 None
1327 }
1328 })))],
1329 )
1330 .unwrap();
1331 let predicates = physical_expr.evaluate(&batch).unwrap();
1332 assert_eq!(
1333 predicates.into_array(0).unwrap().as_ref(),
1334 &BooleanArray::from(vec![
1335 false, true, true, false, true, true, false, true, true, false
1336 ])
1337 );
1338
1339 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1340 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1341 let predicates = physical_expr.evaluate(&batch).unwrap();
1342 assert_eq!(
1343 predicates.into_array(0).unwrap().as_ref(),
1344 &BooleanArray::from(vec![
1345 true, false, false, true, false, false, true, false, false, true,
1346 ])
1347 );
1348 }
1349
1350 #[test]
1351 fn test_sql_invert() {
1352 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1353
1354 let planner = Planner::new(schema.clone());
1355
1356 let expr = planner.parse_filter("NOT s").unwrap();
1357 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1358
1359 let batch = RecordBatch::try_new(
1360 schema,
1361 vec![Arc::new(BooleanArray::from_iter(
1362 (0..10).map(|v| Some(v % 3 == 0)),
1363 ))],
1364 )
1365 .unwrap();
1366 let predicates = physical_expr.evaluate(&batch).unwrap();
1367 assert_eq!(
1368 predicates.into_array(0).unwrap().as_ref(),
1369 &BooleanArray::from(vec![
1370 false, true, true, false, true, true, false, true, true, false
1371 ])
1372 );
1373 }
1374
1375 #[test]
1376 fn test_sql_cast() {
1377 let cases = &[
1378 (
1379 "x = cast('2021-01-01 00:00:00' as timestamp)",
1380 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1381 ),
1382 (
1383 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1384 ArrowDataType::Timestamp(TimeUnit::Second, None),
1385 ),
1386 (
1387 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1388 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1389 ),
1390 (
1391 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1392 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1393 ),
1394 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1395 (
1396 "x = cast('1.238' as decimal(9,3))",
1397 ArrowDataType::Decimal128(9, 3),
1398 ),
1399 ("x = cast(1 as float)", ArrowDataType::Float32),
1400 ("x = cast(1 as double)", ArrowDataType::Float64),
1401 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1402 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1403 ("x = cast(1 as int)", ArrowDataType::Int32),
1404 ("x = cast(1 as integer)", ArrowDataType::Int32),
1405 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1406 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1407 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1408 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1409 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1410 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1411 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1412 ("x = cast(1 as string)", ArrowDataType::Utf8),
1413 ];
1414
1415 for (sql, expected_data_type) in cases {
1416 let schema = Arc::new(Schema::new(vec![Field::new(
1417 "x",
1418 expected_data_type.clone(),
1419 true,
1420 )]));
1421 let planner = Planner::new(schema.clone());
1422 let expr = planner.parse_filter(sql).unwrap();
1423
1424 let expected_value_str = sql
1426 .split("cast(")
1427 .nth(1)
1428 .unwrap()
1429 .split(" as")
1430 .next()
1431 .unwrap();
1432 let expected_value_str = expected_value_str.trim_matches('\'');
1434
1435 match expr {
1436 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1437 Expr::Cast(Cast { expr, data_type }) => {
1438 match expr.as_ref() {
1439 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1440 assert_eq!(value_str, expected_value_str);
1441 }
1442 Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1443 assert_eq!(*value, 1);
1444 }
1445 _ => panic!("Expected cast to be applied to literal"),
1446 }
1447 assert_eq!(data_type, expected_data_type);
1448 }
1449 _ => panic!("Expected right to be a cast"),
1450 },
1451 _ => panic!("Expected binary expression"),
1452 }
1453 }
1454 }
1455
1456 #[test]
1457 fn test_sql_literals() {
1458 let cases = &[
1459 (
1460 "x = timestamp '2021-01-01 00:00:00'",
1461 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1462 ),
1463 (
1464 "x = timestamp(0) '2021-01-01 00:00:00'",
1465 ArrowDataType::Timestamp(TimeUnit::Second, None),
1466 ),
1467 (
1468 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1469 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1470 ),
1471 ("x = date '2021-01-01'", ArrowDataType::Date32),
1472 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1473 ];
1474
1475 for (sql, expected_data_type) in cases {
1476 let schema = Arc::new(Schema::new(vec![Field::new(
1477 "x",
1478 expected_data_type.clone(),
1479 true,
1480 )]));
1481 let planner = Planner::new(schema.clone());
1482 let expr = planner.parse_filter(sql).unwrap();
1483
1484 let expected_value_str = sql.split('\'').nth(1).unwrap();
1485
1486 match expr {
1487 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1488 Expr::Cast(Cast { expr, data_type }) => {
1489 match expr.as_ref() {
1490 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1491 assert_eq!(value_str, expected_value_str);
1492 }
1493 _ => panic!("Expected cast to be applied to literal"),
1494 }
1495 assert_eq!(data_type, expected_data_type);
1496 }
1497 _ => panic!("Expected right to be a cast"),
1498 },
1499 _ => panic!("Expected binary expression"),
1500 }
1501 }
1502 }
1503
1504 #[test]
1505 fn test_sql_array_literals() {
1506 let cases = [
1507 (
1508 "x = [1, 2, 3]",
1509 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1510 ),
1511 (
1512 "x = [1, 2, 3]",
1513 ArrowDataType::FixedSizeList(
1514 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1515 3,
1516 ),
1517 ),
1518 ];
1519
1520 for (sql, expected_data_type) in cases {
1521 let schema = Arc::new(Schema::new(vec![Field::new(
1522 "x",
1523 expected_data_type.clone(),
1524 true,
1525 )]));
1526 let planner = Planner::new(schema.clone());
1527 let expr = planner.parse_filter(sql).unwrap();
1528 let expr = planner.optimize_expr(expr).unwrap();
1529
1530 match expr {
1531 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1532 Expr::Literal(value, _) => {
1533 assert_eq!(&value.data_type(), &expected_data_type);
1534 }
1535 _ => panic!("Expected right to be a literal"),
1536 },
1537 _ => panic!("Expected binary expression"),
1538 }
1539 }
1540 }
1541
1542 #[test]
1543 fn test_sql_between() {
1544 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1545 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1546 use std::sync::Arc;
1547
1548 let schema = Arc::new(Schema::new(vec![
1549 Field::new("x", DataType::Int32, false),
1550 Field::new("y", DataType::Float64, false),
1551 Field::new(
1552 "ts",
1553 DataType::Timestamp(TimeUnit::Microsecond, None),
1554 false,
1555 ),
1556 ]));
1557
1558 let planner = Planner::new(schema.clone());
1559
1560 let expr = planner
1562 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1563 .unwrap();
1564 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1565
1566 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1570 (0..10).map(|i| base_ts + i * 1_000_000), );
1572
1573 let batch = RecordBatch::try_new(
1574 schema,
1575 vec![
1576 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1577 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1578 Arc::new(ts_array),
1579 ],
1580 )
1581 .unwrap();
1582
1583 let predicates = physical_expr.evaluate(&batch).unwrap();
1584 assert_eq!(
1585 predicates.into_array(0).unwrap().as_ref(),
1586 &BooleanArray::from(vec![
1587 false, false, false, true, true, true, true, true, false, false
1588 ])
1589 );
1590
1591 let expr = planner
1593 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1594 .unwrap();
1595 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1596
1597 let predicates = physical_expr.evaluate(&batch).unwrap();
1598 assert_eq!(
1599 predicates.into_array(0).unwrap().as_ref(),
1600 &BooleanArray::from(vec![
1601 true, true, true, false, false, false, false, false, true, true
1602 ])
1603 );
1604
1605 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1607 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1608
1609 let predicates = physical_expr.evaluate(&batch).unwrap();
1610 assert_eq!(
1611 predicates.into_array(0).unwrap().as_ref(),
1612 &BooleanArray::from(vec![
1613 false, false, false, true, true, true, true, false, false, false
1614 ])
1615 );
1616
1617 let expr = planner
1619 .parse_filter(
1620 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1621 )
1622 .unwrap();
1623 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1624
1625 let predicates = physical_expr.evaluate(&batch).unwrap();
1626 assert_eq!(
1627 predicates.into_array(0).unwrap().as_ref(),
1628 &BooleanArray::from(vec![
1629 false, false, false, true, true, true, true, true, false, false
1630 ])
1631 );
1632 }
1633
1634 #[test]
1635 fn test_sql_comparison() {
1636 let batch: Vec<(&str, ArrayRef)> = vec![
1638 (
1639 "timestamp_s",
1640 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1641 ),
1642 (
1643 "timestamp_ms",
1644 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1645 ),
1646 (
1647 "timestamp_us",
1648 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1649 ),
1650 (
1651 "timestamp_ns",
1652 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1653 ),
1654 ];
1655 let batch = RecordBatch::try_from_iter(batch).unwrap();
1656
1657 let planner = Planner::new(batch.schema());
1658
1659 let expressions = &[
1661 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1662 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1663 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1664 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1665 ];
1666
1667 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1668 std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1669 ));
1670 for expression in expressions {
1671 let logical_expr = planner.parse_filter(expression).unwrap();
1673 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1674 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1675
1676 let result = physical_expr.evaluate(&batch).unwrap();
1678 let result = result.into_array(batch.num_rows()).unwrap();
1679 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1680 }
1681 }
1682
1683 #[test]
1684 fn test_columns_in_expr() {
1685 let expr = col("s0").gt(lit("value")).and(
1686 col("st")
1687 .field("st")
1688 .field("s2")
1689 .eq(lit("value"))
1690 .or(col("st")
1691 .field("s1")
1692 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1693 );
1694
1695 let columns = Planner::column_names_in_expr(&expr);
1696 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1697 }
1698
1699 #[test]
1700 fn test_parse_binary_expr() {
1701 let bin_str = "x'616263'";
1702
1703 let schema = Arc::new(Schema::new(vec![Field::new(
1704 "binary",
1705 DataType::Binary,
1706 true,
1707 )]));
1708 let planner = Planner::new(schema);
1709 let expr = planner.parse_expr(bin_str).unwrap();
1710 assert_eq!(
1711 expr,
1712 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1713 );
1714 }
1715
1716 #[test]
1717 fn test_lance_context_provider_expr_planners() {
1718 let ctx_provider = LanceContextProvider::default();
1719 assert!(!ctx_provider.get_expr_planners().is_empty());
1720 }
1721}