1use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::SessionState;
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30 Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel};
34use datafusion::sql::sqlparser::ast::{
35 AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
36 Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
37 ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
38};
39use datafusion::{
40 common::Column,
41 logical_expr::{col, Between, BinaryExpr, Like, Operator},
42 physical_expr::execution_props::ExecutionProps,
43 physical_plan::PhysicalExpr,
44 prelude::Expr,
45 scalar::ScalarValue,
46};
47use datafusion_functions::core::getfield::GetFieldFunc;
48use lance_arrow::cast::cast_with_options;
49use lance_core::datatypes::Schema;
50use lance_core::error::LanceOptionExt;
51use snafu::location;
52
53use lance_core::{Error, Result};
54
55#[derive(Debug, Clone)]
56struct CastListF16Udf {
57 signature: Signature,
58}
59
60impl CastListF16Udf {
61 pub fn new() -> Self {
62 Self {
63 signature: Signature::any(1, Volatility::Immutable),
64 }
65 }
66}
67
68impl ScalarUDFImpl for CastListF16Udf {
69 fn as_any(&self) -> &dyn std::any::Any {
70 self
71 }
72
73 fn name(&self) -> &str {
74 "_cast_list_f16"
75 }
76
77 fn signature(&self) -> &Signature {
78 &self.signature
79 }
80
81 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
82 let input = &arg_types[0];
83 match input {
84 ArrowDataType::FixedSizeList(field, size) => {
85 if field.data_type() != &ArrowDataType::Float32
86 && field.data_type() != &ArrowDataType::Float16
87 {
88 return Err(datafusion::error::DataFusionError::Execution(
89 "cast_list_f16 only supports list of float32 or float16".to_string(),
90 ));
91 }
92 Ok(ArrowDataType::FixedSizeList(
93 Arc::new(Field::new(
94 field.name(),
95 ArrowDataType::Float16,
96 field.is_nullable(),
97 )),
98 *size,
99 ))
100 }
101 ArrowDataType::List(field) => {
102 if field.data_type() != &ArrowDataType::Float32
103 && field.data_type() != &ArrowDataType::Float16
104 {
105 return Err(datafusion::error::DataFusionError::Execution(
106 "cast_list_f16 only supports list of float32 or float16".to_string(),
107 ));
108 }
109 Ok(ArrowDataType::List(Arc::new(Field::new(
110 field.name(),
111 ArrowDataType::Float16,
112 field.is_nullable(),
113 ))))
114 }
115 _ => Err(datafusion::error::DataFusionError::Execution(
116 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
117 )),
118 }
119 }
120
121 fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
122 let ColumnarValue::Array(arr) = &func_args.args[0] else {
123 return Err(datafusion::error::DataFusionError::Execution(
124 "cast_list_f16 only supports array arguments".to_string(),
125 ));
126 };
127
128 let to_type = match arr.data_type() {
129 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
130 Arc::new(Field::new(
131 field.name(),
132 ArrowDataType::Float16,
133 field.is_nullable(),
134 )),
135 *size,
136 ),
137 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
138 field.name(),
139 ArrowDataType::Float16,
140 field.is_nullable(),
141 ))),
142 _ => {
143 return Err(datafusion::error::DataFusionError::Execution(
144 "cast_list_f16 only supports array arguments".to_string(),
145 ));
146 }
147 };
148
149 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
150 Ok(ColumnarValue::Array(res))
151 }
152}
153
154struct LanceContextProvider {
156 options: datafusion::config::ConfigOptions,
157 state: SessionState,
158 expr_planners: Vec<Arc<dyn ExprPlanner>>,
159}
160
161impl Default for LanceContextProvider {
162 fn default() -> Self {
163 let config = SessionConfig::new();
164 let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
165 let mut state_builder = SessionStateBuilder::new()
166 .with_config(config)
167 .with_runtime_env(runtime)
168 .with_default_features();
169
170 let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
175
176 Self {
177 options: ConfigOptions::default(),
178 state: state_builder.build(),
179 expr_planners,
180 }
181 }
182}
183
184impl ContextProvider for LanceContextProvider {
185 fn get_table_source(
186 &self,
187 name: datafusion::sql::TableReference,
188 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
189 Err(datafusion::error::DataFusionError::NotImplemented(format!(
190 "Attempt to reference inner table {} not supported",
191 name
192 )))
193 }
194
195 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
196 self.state.aggregate_functions().get(name).cloned()
197 }
198
199 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
200 self.state.window_functions().get(name).cloned()
201 }
202
203 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
204 match f {
205 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
208 _ => self.state.scalar_functions().get(f).cloned(),
209 }
210 }
211
212 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
213 None
215 }
216
217 fn options(&self) -> &datafusion::config::ConfigOptions {
218 &self.options
219 }
220
221 fn udf_names(&self) -> Vec<String> {
222 self.state.scalar_functions().keys().cloned().collect()
223 }
224
225 fn udaf_names(&self) -> Vec<String> {
226 self.state.aggregate_functions().keys().cloned().collect()
227 }
228
229 fn udwf_names(&self) -> Vec<String> {
230 self.state.window_functions().keys().cloned().collect()
231 }
232
233 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
234 &self.expr_planners
235 }
236}
237
238pub struct Planner {
239 schema: SchemaRef,
240 context_provider: LanceContextProvider,
241}
242
243impl Planner {
244 pub fn new(schema: SchemaRef) -> Self {
245 Self {
246 schema,
247 context_provider: LanceContextProvider::default(),
248 }
249 }
250
251 fn column(idents: &[Ident]) -> Expr {
252 let mut column = col(&idents[0].value);
253 for ident in &idents[1..] {
254 column = Expr::ScalarFunction(ScalarFunction {
255 args: vec![
256 column,
257 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
258 ],
259 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
260 });
261 }
262 column
263 }
264
265 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
266 Ok(match op {
267 BinaryOperator::Plus => Operator::Plus,
268 BinaryOperator::Minus => Operator::Minus,
269 BinaryOperator::Multiply => Operator::Multiply,
270 BinaryOperator::Divide => Operator::Divide,
271 BinaryOperator::Modulo => Operator::Modulo,
272 BinaryOperator::StringConcat => Operator::StringConcat,
273 BinaryOperator::Gt => Operator::Gt,
274 BinaryOperator::Lt => Operator::Lt,
275 BinaryOperator::GtEq => Operator::GtEq,
276 BinaryOperator::LtEq => Operator::LtEq,
277 BinaryOperator::Eq => Operator::Eq,
278 BinaryOperator::NotEq => Operator::NotEq,
279 BinaryOperator::And => Operator::And,
280 BinaryOperator::Or => Operator::Or,
281 _ => {
282 return Err(Error::invalid_input(
283 format!("Operator {op} is not supported"),
284 location!(),
285 ));
286 }
287 })
288 }
289
290 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
291 Ok(Expr::BinaryExpr(BinaryExpr::new(
292 Box::new(self.parse_sql_expr(left)?),
293 self.binary_op(op)?,
294 Box::new(self.parse_sql_expr(right)?),
295 )))
296 }
297
298 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
299 Ok(match op {
300 UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
301 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
302 }
303
304 UnaryOperator::Minus => {
305 use datafusion::logical_expr::lit;
306 match expr {
307 SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
308 Ok(n) => lit(-n),
309 Err(_) => lit(-n
310 .parse::<f64>()
311 .map_err(|_e| {
312 Error::invalid_input(
313 format!("negative operator can be only applied to integer and float operands, got: {n}"),
314 location!(),
315 )
316 })?),
317 },
318 _ => {
319 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
320 }
321 }
322 }
323
324 _ => {
325 return Err(Error::invalid_input(
326 format!("Unary operator '{:?}' is not supported", op),
327 location!(),
328 ));
329 }
330 })
331 }
332
333 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
335 use datafusion::logical_expr::lit;
336 let value: Cow<str> = if negative {
337 Cow::Owned(format!("-{}", value))
338 } else {
339 Cow::Borrowed(value)
340 };
341 if let Ok(n) = value.parse::<i64>() {
342 Ok(lit(n))
343 } else {
344 value.parse::<f64>().map(lit).map_err(|_| {
345 Error::invalid_input(
346 format!("'{value}' is not supported number value."),
347 location!(),
348 )
349 })
350 }
351 }
352
353 fn value(&self, value: &Value) -> Result<Expr> {
354 Ok(match value {
355 Value::Number(v, _) => self.number(v.as_str(), false)?,
356 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
357 Value::HexStringLiteral(hsl) => {
358 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
359 }
360 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
361 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
362 Value::Null => Expr::Literal(ScalarValue::Null, None),
363 _ => todo!(),
364 })
365 }
366
367 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
368 match func_args {
369 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
370 _ => Err(Error::invalid_input(
371 format!("Unsupported function args: {:?}", func_args),
372 location!(),
373 )),
374 }
375 }
376
377 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
384 match &func.args {
385 FunctionArguments::List(args) => {
386 if func.name.0.len() != 1 {
387 return Err(Error::invalid_input(
388 format!("Function name must have 1 part, got: {:?}", func.name.0),
389 location!(),
390 ));
391 }
392 Ok(Expr::IsNotNull(Box::new(
393 self.parse_function_args(&args.args[0])?,
394 )))
395 }
396 _ => Err(Error::invalid_input(
397 format!("Unsupported function args: {:?}", &func.args),
398 location!(),
399 )),
400 }
401 }
402
403 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
404 if let SQLExpr::Function(function) = &function {
405 if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
406 if &name.value == "is_valid" {
407 return self.legacy_parse_function(function);
408 }
409 }
410 }
411 let sql_to_rel = SqlToRel::new_with_options(
412 &self.context_provider,
413 ParserOptions {
414 parse_float_as_decimal: false,
415 enable_ident_normalization: false,
416 support_varchar_with_length: false,
417 enable_options_value_normalization: false,
418 collect_spans: false,
419 map_varchar_to_utf8view: false,
420 },
421 );
422
423 let mut planner_context = PlannerContext::default();
424 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
425 sql_to_rel
426 .sql_to_expr(function, &schema, &mut planner_context)
427 .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
428 }
429
430 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
431 const SUPPORTED_TYPES: [&str; 13] = [
432 "int [unsigned]",
433 "tinyint [unsigned]",
434 "smallint [unsigned]",
435 "bigint [unsigned]",
436 "float",
437 "double",
438 "string",
439 "binary",
440 "date",
441 "timestamp(precision)",
442 "datetime(precision)",
443 "decimal(precision,scale)",
444 "boolean",
445 ];
446 match data_type {
447 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
448 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
449 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
450 SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
451 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
452 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
453 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
454 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
455 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
456 SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
457 SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
458 SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
459 Ok(ArrowDataType::UInt32)
460 }
461 SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
462 SQLDataType::Date => Ok(ArrowDataType::Date32),
463 SQLDataType::Timestamp(resolution, tz) => {
464 match tz {
465 TimezoneInfo::None => {}
466 _ => {
467 return Err(Error::invalid_input(
468 "Timezone not supported in timestamp".to_string(),
469 location!(),
470 ));
471 }
472 };
473 let time_unit = match resolution {
474 None => TimeUnit::Microsecond,
476 Some(0) => TimeUnit::Second,
477 Some(3) => TimeUnit::Millisecond,
478 Some(6) => TimeUnit::Microsecond,
479 Some(9) => TimeUnit::Nanosecond,
480 _ => {
481 return Err(Error::invalid_input(
482 format!("Unsupported datetime resolution: {:?}", resolution),
483 location!(),
484 ));
485 }
486 };
487 Ok(ArrowDataType::Timestamp(time_unit, None))
488 }
489 SQLDataType::Datetime(resolution) => {
490 let time_unit = match resolution {
491 None => TimeUnit::Microsecond,
492 Some(0) => TimeUnit::Second,
493 Some(3) => TimeUnit::Millisecond,
494 Some(6) => TimeUnit::Microsecond,
495 Some(9) => TimeUnit::Nanosecond,
496 _ => {
497 return Err(Error::invalid_input(
498 format!("Unsupported datetime resolution: {:?}", resolution),
499 location!(),
500 ));
501 }
502 };
503 Ok(ArrowDataType::Timestamp(time_unit, None))
504 }
505 SQLDataType::Decimal(number_info) => match number_info {
506 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
507 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
508 }
509 _ => Err(Error::invalid_input(
510 format!(
511 "Must provide precision and scale for decimal: {:?}",
512 number_info
513 ),
514 location!(),
515 )),
516 },
517 _ => Err(Error::invalid_input(
518 format!(
519 "Unsupported data type: {:?}. Supported types: {:?}",
520 data_type, SUPPORTED_TYPES
521 ),
522 location!(),
523 )),
524 }
525 }
526
527 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
528 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
529 for planner in self.context_provider.get_expr_planners() {
530 match planner.plan_field_access(field_access_expr, &df_schema)? {
531 PlannerResult::Planned(expr) => return Ok(expr),
532 PlannerResult::Original(expr) => {
533 field_access_expr = expr;
534 }
535 }
536 }
537 Err(Error::invalid_input(
538 "Field access could not be planned",
539 location!(),
540 ))
541 }
542
543 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
544 match expr {
545 SQLExpr::Identifier(id) => {
546 if id.quote_style == Some('"') {
549 Ok(Expr::Literal(
550 ScalarValue::Utf8(Some(id.value.clone())),
551 None,
552 ))
553 } else if id.quote_style == Some('`') {
556 Ok(Expr::Column(Column::from_name(id.value.clone())))
557 } else {
558 Ok(Self::column(vec![id.clone()].as_slice()))
559 }
560 }
561 SQLExpr::CompoundIdentifier(ids) => Ok(Self::column(ids.as_slice())),
562 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
563 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
564 SQLExpr::Value(value) => self.value(&value.value),
565 SQLExpr::Array(SQLArray { elem, .. }) => {
566 let mut values = vec![];
567
568 let array_literal_error = |pos: usize, value: &_| {
569 Err(Error::invalid_input(
570 format!(
571 "Expected a literal value in array, instead got {} at position {}",
572 value, pos
573 ),
574 location!(),
575 ))
576 };
577
578 for (pos, expr) in elem.iter().enumerate() {
579 match expr {
580 SQLExpr::Value(value) => {
581 if let Expr::Literal(value, _) = self.value(&value.value)? {
582 values.push(value);
583 } else {
584 return array_literal_error(pos, expr);
585 }
586 }
587 SQLExpr::UnaryOp {
588 op: UnaryOperator::Minus,
589 expr,
590 } => {
591 if let SQLExpr::Value(ValueWithSpan {
592 value: Value::Number(number, _),
593 ..
594 }) = expr.as_ref()
595 {
596 if let Expr::Literal(value, _) = self.number(number, true)? {
597 values.push(value);
598 } else {
599 return array_literal_error(pos, expr);
600 }
601 } else {
602 return array_literal_error(pos, expr);
603 }
604 }
605 _ => {
606 return array_literal_error(pos, expr);
607 }
608 }
609 }
610
611 let field = if !values.is_empty() {
612 let data_type = values[0].data_type();
613
614 for value in &mut values {
615 if value.data_type() != data_type {
616 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
617 format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
618 location!()
619 ))?;
620 }
621 }
622 Field::new("item", data_type, true)
623 } else {
624 Field::new("item", ArrowDataType::Null, true)
625 };
626
627 let values = values
628 .into_iter()
629 .map(|v| v.to_array().map_err(Error::from))
630 .collect::<Result<Vec<_>>>()?;
631 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
632 let values = concat(&array_refs)?;
633 let values = ListArray::try_new(
634 field.into(),
635 OffsetBuffer::from_lengths([values.len()]),
636 values,
637 None,
638 )?;
639
640 Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
641 }
642 SQLExpr::TypedString { data_type, value } => {
644 let value = value.clone().into_string().expect_ok()?;
645 Ok(Expr::Cast(datafusion::logical_expr::Cast {
646 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
647 data_type: self.parse_type(data_type)?,
648 }))
649 }
650 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
651 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
652 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
653 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
654 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
655 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
656 SQLExpr::InList {
657 expr,
658 list,
659 negated,
660 } => {
661 let value_expr = self.parse_sql_expr(expr)?;
662 let list_exprs = list
663 .iter()
664 .map(|e| self.parse_sql_expr(e))
665 .collect::<Result<Vec<_>>>()?;
666 Ok(value_expr.in_list(list_exprs, *negated))
667 }
668 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
669 SQLExpr::Function(_) => self.parse_function(expr.clone()),
670 SQLExpr::ILike {
671 negated,
672 expr,
673 pattern,
674 escape_char,
675 any: _,
676 } => Ok(Expr::Like(Like::new(
677 *negated,
678 Box::new(self.parse_sql_expr(expr)?),
679 Box::new(self.parse_sql_expr(pattern)?),
680 escape_char.as_ref().and_then(|c| c.chars().next()),
681 true,
682 ))),
683 SQLExpr::Like {
684 negated,
685 expr,
686 pattern,
687 escape_char,
688 any: _,
689 } => Ok(Expr::Like(Like::new(
690 *negated,
691 Box::new(self.parse_sql_expr(expr)?),
692 Box::new(self.parse_sql_expr(pattern)?),
693 escape_char.as_ref().and_then(|c| c.chars().next()),
694 false,
695 ))),
696 SQLExpr::Cast {
697 expr, data_type, ..
698 } => Ok(Expr::Cast(datafusion::logical_expr::Cast {
699 expr: Box::new(self.parse_sql_expr(expr)?),
700 data_type: self.parse_type(data_type)?,
701 })),
702 SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
703 "JSON access is not supported",
704 location!(),
705 )),
706 SQLExpr::CompoundFieldAccess { root, access_chain } => {
707 let mut expr = self.parse_sql_expr(root)?;
708
709 for access in access_chain {
710 let field_access = match access {
711 AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
713 | AccessExpr::Subscript(Subscript::Index {
714 index:
715 SQLExpr::Value(ValueWithSpan {
716 value:
717 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
718 ..
719 }),
720 }) => GetFieldAccess::NamedStructField {
721 name: ScalarValue::from(s.as_str()),
722 },
723 AccessExpr::Subscript(Subscript::Index { index }) => {
724 let key = Box::new(self.parse_sql_expr(index)?);
725 GetFieldAccess::ListIndex { key }
726 }
727 AccessExpr::Subscript(Subscript::Slice { .. }) => {
728 return Err(Error::invalid_input(
729 "Slice subscript is not supported",
730 location!(),
731 ));
732 }
733 _ => {
734 return Err(Error::invalid_input(
737 "Only dot notation or index access is supported for field access",
738 location!(),
739 ));
740 }
741 };
742
743 let field_access_expr = RawFieldAccessExpr { expr, field_access };
744 expr = self.plan_field_access(field_access_expr)?;
745 }
746
747 Ok(expr)
748 }
749 SQLExpr::Between {
750 expr,
751 negated,
752 low,
753 high,
754 } => {
755 let expr = self.parse_sql_expr(expr)?;
757 let low = self.parse_sql_expr(low)?;
758 let high = self.parse_sql_expr(high)?;
759
760 let between = Expr::Between(Between::new(
761 Box::new(expr),
762 *negated,
763 Box::new(low),
764 Box::new(high),
765 ));
766 Ok(between)
767 }
768 _ => Err(Error::invalid_input(
769 format!("Expression '{expr}' is not supported SQL in lance"),
770 location!(),
771 )),
772 }
773 }
774
775 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
780 let ast_expr = parse_sql_filter(filter)?;
782 let expr = self.parse_sql_expr(&ast_expr)?;
783 let schema = Schema::try_from(self.schema.as_ref())?;
784 let resolved = resolve_expr(&expr, &schema).map_err(|e| {
785 Error::invalid_input(
786 format!("Error resolving filter expression {filter}: {e}"),
787 location!(),
788 )
789 })?;
790 Ok(coerce_filter_type_to_boolean(resolved))
791 }
792
793 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
798 let ast_expr = parse_sql_expr(expr)?;
799 let expr = self.parse_sql_expr(&ast_expr)?;
800 let schema = Schema::try_from(self.schema.as_ref())?;
801 let resolved = resolve_expr(&expr, &schema)?;
802 Ok(resolved)
803 }
804
805 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
811 let hex_bytes = s.as_bytes();
812 let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
813
814 let start_idx = hex_bytes.len() % 2;
815 if start_idx > 0 {
816 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
818 }
819
820 for i in (start_idx..hex_bytes.len()).step_by(2) {
821 let high = Self::try_decode_hex_char(hex_bytes[i])?;
822 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
823 decoded_bytes.push((high << 4) | low);
824 }
825
826 Some(decoded_bytes)
827 }
828
829 const fn try_decode_hex_char(c: u8) -> Option<u8> {
833 match c {
834 b'A'..=b'F' => Some(c - b'A' + 10),
835 b'a'..=b'f' => Some(c - b'a' + 10),
836 b'0'..=b'9' => Some(c - b'0'),
837 _ => None,
838 }
839 }
840
841 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
843 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
844
845 let props = ExecutionProps::default();
848 let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
849 let simplifier =
850 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
851
852 let expr = simplifier.simplify(expr)?;
853 let expr = simplifier.coerce(expr, &df_schema)?;
854
855 Ok(expr)
856 }
857
858 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
860 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
861
862 Ok(datafusion::physical_expr::create_physical_expr(
863 expr,
864 df_schema.as_ref(),
865 &Default::default(),
866 )?)
867 }
868
869 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
873 let mut visitor = ColumnCapturingVisitor {
874 current_path: VecDeque::new(),
875 columns: BTreeSet::new(),
876 };
877 expr.visit(&mut visitor).unwrap();
878 visitor.columns.into_iter().collect()
879 }
880}
881
882struct ColumnCapturingVisitor {
883 current_path: VecDeque<String>,
885 columns: BTreeSet<String>,
886}
887
888impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
889 type Node = Expr;
890
891 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
892 match node {
893 Expr::Column(Column { name, .. }) => {
894 let mut path = name.clone();
895 for part in self.current_path.drain(..) {
896 path.push('.');
897 path.push_str(&part);
898 }
899 self.columns.insert(path);
900 self.current_path.clear();
901 }
902 Expr::ScalarFunction(udf) => {
903 if udf.name() == GetFieldFunc::default().name() {
904 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
905 self.current_path.push_front(name.to_string())
906 } else {
907 self.current_path.clear();
908 }
909 } else {
910 self.current_path.clear();
911 }
912 }
913 _ => {
914 self.current_path.clear();
915 }
916 }
917
918 Ok(TreeNodeRecursion::Continue)
919 }
920}
921
922#[cfg(test)]
923mod tests {
924
925 use crate::logical_expr::ExprExt;
926
927 use super::*;
928
929 use arrow::datatypes::Float64Type;
930 use arrow_array::{
931 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
932 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
933 TimestampNanosecondArray, TimestampSecondArray,
934 };
935 use arrow_schema::{DataType, Fields, Schema};
936 use datafusion::{
937 logical_expr::{lit, Cast},
938 prelude::{array_element, get_field},
939 };
940 use datafusion_functions::core::expr_ext::FieldAccessor;
941
942 #[test]
943 fn test_parse_filter_simple() {
944 let schema = Arc::new(Schema::new(vec![
945 Field::new("i", DataType::Int32, false),
946 Field::new("s", DataType::Utf8, true),
947 Field::new(
948 "st",
949 DataType::Struct(Fields::from(vec![
950 Field::new("x", DataType::Float32, false),
951 Field::new("y", DataType::Float32, false),
952 ])),
953 true,
954 ),
955 ]));
956
957 let planner = Planner::new(schema.clone());
958
959 let expected = col("i")
960 .gt(lit(3_i32))
961 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
962 .and(
963 col("s")
964 .eq(lit("str-4"))
965 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
966 );
967
968 let expr = planner
970 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
971 .unwrap();
972 assert_eq!(expr, expected);
973
974 let expr = planner
976 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
977 .unwrap();
978
979 let physical_expr = planner.create_physical_expr(&expr).unwrap();
980
981 let batch = RecordBatch::try_new(
982 schema,
983 vec![
984 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
985 Arc::new(StringArray::from_iter_values(
986 (0..10).map(|v| format!("str-{}", v)),
987 )),
988 Arc::new(StructArray::from(vec![
989 (
990 Arc::new(Field::new("x", DataType::Float32, false)),
991 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
992 as ArrayRef,
993 ),
994 (
995 Arc::new(Field::new("y", DataType::Float32, false)),
996 Arc::new(Float32Array::from_iter_values(
997 (0..10).map(|v| (v * 10) as f32),
998 )),
999 ),
1000 ])),
1001 ],
1002 )
1003 .unwrap();
1004 let predicates = physical_expr.evaluate(&batch).unwrap();
1005 assert_eq!(
1006 predicates.into_array(0).unwrap().as_ref(),
1007 &BooleanArray::from(vec![
1008 false, false, false, false, true, true, false, false, false, false
1009 ])
1010 );
1011 }
1012
1013 #[test]
1014 fn test_nested_col_refs() {
1015 let schema = Arc::new(Schema::new(vec![
1016 Field::new("s0", DataType::Utf8, true),
1017 Field::new(
1018 "st",
1019 DataType::Struct(Fields::from(vec![
1020 Field::new("s1", DataType::Utf8, true),
1021 Field::new(
1022 "st",
1023 DataType::Struct(Fields::from(vec![Field::new(
1024 "s2",
1025 DataType::Utf8,
1026 true,
1027 )])),
1028 true,
1029 ),
1030 ])),
1031 true,
1032 ),
1033 ]));
1034
1035 let planner = Planner::new(schema);
1036
1037 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1038 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1039 assert!(matches!(
1040 expr,
1041 Expr::BinaryExpr(BinaryExpr {
1042 left: _,
1043 op: Operator::Eq,
1044 right: _
1045 })
1046 ));
1047 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1048 assert_eq!(left.as_ref(), expected);
1049 }
1050 }
1051
1052 let expected = Expr::Column(Column::new_unqualified("s0"));
1053 assert_column_eq(&planner, "s0", &expected);
1054 assert_column_eq(&planner, "`s0`", &expected);
1055
1056 let expected = Expr::ScalarFunction(ScalarFunction {
1057 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1058 args: vec![
1059 Expr::Column(Column::new_unqualified("st")),
1060 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1061 ],
1062 });
1063 assert_column_eq(&planner, "st.s1", &expected);
1064 assert_column_eq(&planner, "`st`.`s1`", &expected);
1065 assert_column_eq(&planner, "st.`s1`", &expected);
1066
1067 let expected = Expr::ScalarFunction(ScalarFunction {
1068 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1069 args: vec![
1070 Expr::ScalarFunction(ScalarFunction {
1071 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1072 args: vec![
1073 Expr::Column(Column::new_unqualified("st")),
1074 Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1075 ],
1076 }),
1077 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1078 ],
1079 });
1080
1081 assert_column_eq(&planner, "st.st.s2", &expected);
1082 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1083 assert_column_eq(&planner, "st.st.`s2`", &expected);
1084 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1085 }
1086
1087 #[test]
1088 fn test_nested_list_refs() {
1089 let schema = Arc::new(Schema::new(vec![Field::new(
1090 "l",
1091 DataType::List(Arc::new(Field::new(
1092 "item",
1093 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1094 true,
1095 ))),
1096 true,
1097 )]));
1098
1099 let planner = Planner::new(schema);
1100
1101 let expected = array_element(col("l"), lit(0_i64));
1102 let expr = planner.parse_expr("l[0]").unwrap();
1103 assert_eq!(expr, expected);
1104
1105 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1106 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1107 assert_eq!(expr, expected);
1108
1109 }
1114
1115 #[test]
1116 fn test_negative_expressions() {
1117 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1118
1119 let planner = Planner::new(schema.clone());
1120
1121 let expected = col("x")
1122 .gt(lit(-3_i64))
1123 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1124
1125 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1126
1127 assert_eq!(expr, expected);
1128
1129 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1130
1131 let batch = RecordBatch::try_new(
1132 schema,
1133 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1134 )
1135 .unwrap();
1136 let predicates = physical_expr.evaluate(&batch).unwrap();
1137 assert_eq!(
1138 predicates.into_array(0).unwrap().as_ref(),
1139 &BooleanArray::from(vec![
1140 false, false, false, true, true, true, true, false, false, false
1141 ])
1142 );
1143 }
1144
1145 #[test]
1146 fn test_negative_array_expressions() {
1147 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1148
1149 let planner = Planner::new(schema);
1150
1151 let expected = Expr::Literal(
1152 ScalarValue::List(Arc::new(
1153 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1154 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1155 )]),
1156 )),
1157 None,
1158 );
1159
1160 let expr = planner
1161 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1162 .unwrap();
1163
1164 assert_eq!(expr, expected);
1165 }
1166
1167 #[test]
1168 fn test_sql_like() {
1169 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1170
1171 let planner = Planner::new(schema.clone());
1172
1173 let expected = col("s").like(lit("str-4"));
1174 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1176 assert_eq!(expr, expected);
1177 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1178
1179 let batch = RecordBatch::try_new(
1180 schema,
1181 vec![Arc::new(StringArray::from_iter_values(
1182 (0..10).map(|v| format!("str-{}", v)),
1183 ))],
1184 )
1185 .unwrap();
1186 let predicates = physical_expr.evaluate(&batch).unwrap();
1187 assert_eq!(
1188 predicates.into_array(0).unwrap().as_ref(),
1189 &BooleanArray::from(vec![
1190 false, false, false, false, true, false, false, false, false, false
1191 ])
1192 );
1193 }
1194
1195 #[test]
1196 fn test_not_like() {
1197 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1198
1199 let planner = Planner::new(schema.clone());
1200
1201 let expected = col("s").not_like(lit("str-4"));
1202 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1204 assert_eq!(expr, expected);
1205 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1206
1207 let batch = RecordBatch::try_new(
1208 schema,
1209 vec![Arc::new(StringArray::from_iter_values(
1210 (0..10).map(|v| format!("str-{}", v)),
1211 ))],
1212 )
1213 .unwrap();
1214 let predicates = physical_expr.evaluate(&batch).unwrap();
1215 assert_eq!(
1216 predicates.into_array(0).unwrap().as_ref(),
1217 &BooleanArray::from(vec![
1218 true, true, true, true, false, true, true, true, true, true
1219 ])
1220 );
1221 }
1222
1223 #[test]
1224 fn test_sql_is_in() {
1225 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1226
1227 let planner = Planner::new(schema.clone());
1228
1229 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1230 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1232 assert_eq!(expr, expected);
1233 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1234
1235 let batch = RecordBatch::try_new(
1236 schema,
1237 vec![Arc::new(StringArray::from_iter_values(
1238 (0..10).map(|v| format!("str-{}", v)),
1239 ))],
1240 )
1241 .unwrap();
1242 let predicates = physical_expr.evaluate(&batch).unwrap();
1243 assert_eq!(
1244 predicates.into_array(0).unwrap().as_ref(),
1245 &BooleanArray::from(vec![
1246 false, false, false, false, true, true, false, false, false, false
1247 ])
1248 );
1249 }
1250
1251 #[test]
1252 fn test_sql_is_null() {
1253 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1254
1255 let planner = Planner::new(schema.clone());
1256
1257 let expected = col("s").is_null();
1258 let expr = planner.parse_filter("s IS NULL").unwrap();
1259 assert_eq!(expr, expected);
1260 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1261
1262 let batch = RecordBatch::try_new(
1263 schema,
1264 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1265 if v % 3 == 0 {
1266 Some(format!("str-{}", v))
1267 } else {
1268 None
1269 }
1270 })))],
1271 )
1272 .unwrap();
1273 let predicates = physical_expr.evaluate(&batch).unwrap();
1274 assert_eq!(
1275 predicates.into_array(0).unwrap().as_ref(),
1276 &BooleanArray::from(vec![
1277 false, true, true, false, true, true, false, true, true, false
1278 ])
1279 );
1280
1281 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1282 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1283 let predicates = physical_expr.evaluate(&batch).unwrap();
1284 assert_eq!(
1285 predicates.into_array(0).unwrap().as_ref(),
1286 &BooleanArray::from(vec![
1287 true, false, false, true, false, false, true, false, false, true,
1288 ])
1289 );
1290 }
1291
1292 #[test]
1293 fn test_sql_invert() {
1294 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1295
1296 let planner = Planner::new(schema.clone());
1297
1298 let expr = planner.parse_filter("NOT s").unwrap();
1299 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1300
1301 let batch = RecordBatch::try_new(
1302 schema,
1303 vec![Arc::new(BooleanArray::from_iter(
1304 (0..10).map(|v| Some(v % 3 == 0)),
1305 ))],
1306 )
1307 .unwrap();
1308 let predicates = physical_expr.evaluate(&batch).unwrap();
1309 assert_eq!(
1310 predicates.into_array(0).unwrap().as_ref(),
1311 &BooleanArray::from(vec![
1312 false, true, true, false, true, true, false, true, true, false
1313 ])
1314 );
1315 }
1316
1317 #[test]
1318 fn test_sql_cast() {
1319 let cases = &[
1320 (
1321 "x = cast('2021-01-01 00:00:00' as timestamp)",
1322 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1323 ),
1324 (
1325 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1326 ArrowDataType::Timestamp(TimeUnit::Second, None),
1327 ),
1328 (
1329 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1330 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1331 ),
1332 (
1333 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1334 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1335 ),
1336 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1337 (
1338 "x = cast('1.238' as decimal(9,3))",
1339 ArrowDataType::Decimal128(9, 3),
1340 ),
1341 ("x = cast(1 as float)", ArrowDataType::Float32),
1342 ("x = cast(1 as double)", ArrowDataType::Float64),
1343 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1344 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1345 ("x = cast(1 as int)", ArrowDataType::Int32),
1346 ("x = cast(1 as integer)", ArrowDataType::Int32),
1347 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1348 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1349 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1350 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1351 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1352 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1353 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1354 ("x = cast(1 as string)", ArrowDataType::Utf8),
1355 ];
1356
1357 for (sql, expected_data_type) in cases {
1358 let schema = Arc::new(Schema::new(vec![Field::new(
1359 "x",
1360 expected_data_type.clone(),
1361 true,
1362 )]));
1363 let planner = Planner::new(schema.clone());
1364 let expr = planner.parse_filter(sql).unwrap();
1365
1366 let expected_value_str = sql
1368 .split("cast(")
1369 .nth(1)
1370 .unwrap()
1371 .split(" as")
1372 .next()
1373 .unwrap();
1374 let expected_value_str = expected_value_str.trim_matches('\'');
1376
1377 match expr {
1378 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1379 Expr::Cast(Cast { expr, data_type }) => {
1380 match expr.as_ref() {
1381 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1382 assert_eq!(value_str, expected_value_str);
1383 }
1384 Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1385 assert_eq!(*value, 1);
1386 }
1387 _ => panic!("Expected cast to be applied to literal"),
1388 }
1389 assert_eq!(data_type, expected_data_type);
1390 }
1391 _ => panic!("Expected right to be a cast"),
1392 },
1393 _ => panic!("Expected binary expression"),
1394 }
1395 }
1396 }
1397
1398 #[test]
1399 fn test_sql_literals() {
1400 let cases = &[
1401 (
1402 "x = timestamp '2021-01-01 00:00:00'",
1403 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1404 ),
1405 (
1406 "x = timestamp(0) '2021-01-01 00:00:00'",
1407 ArrowDataType::Timestamp(TimeUnit::Second, None),
1408 ),
1409 (
1410 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1411 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1412 ),
1413 ("x = date '2021-01-01'", ArrowDataType::Date32),
1414 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1415 ];
1416
1417 for (sql, expected_data_type) in cases {
1418 let schema = Arc::new(Schema::new(vec![Field::new(
1419 "x",
1420 expected_data_type.clone(),
1421 true,
1422 )]));
1423 let planner = Planner::new(schema.clone());
1424 let expr = planner.parse_filter(sql).unwrap();
1425
1426 let expected_value_str = sql.split('\'').nth(1).unwrap();
1427
1428 match expr {
1429 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1430 Expr::Cast(Cast { expr, data_type }) => {
1431 match expr.as_ref() {
1432 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1433 assert_eq!(value_str, expected_value_str);
1434 }
1435 _ => panic!("Expected cast to be applied to literal"),
1436 }
1437 assert_eq!(data_type, expected_data_type);
1438 }
1439 _ => panic!("Expected right to be a cast"),
1440 },
1441 _ => panic!("Expected binary expression"),
1442 }
1443 }
1444 }
1445
1446 #[test]
1447 fn test_sql_array_literals() {
1448 let cases = [
1449 (
1450 "x = [1, 2, 3]",
1451 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1452 ),
1453 (
1454 "x = [1, 2, 3]",
1455 ArrowDataType::FixedSizeList(
1456 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1457 3,
1458 ),
1459 ),
1460 ];
1461
1462 for (sql, expected_data_type) in cases {
1463 let schema = Arc::new(Schema::new(vec![Field::new(
1464 "x",
1465 expected_data_type.clone(),
1466 true,
1467 )]));
1468 let planner = Planner::new(schema.clone());
1469 let expr = planner.parse_filter(sql).unwrap();
1470 let expr = planner.optimize_expr(expr).unwrap();
1471
1472 match expr {
1473 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1474 Expr::Literal(value, _) => {
1475 assert_eq!(&value.data_type(), &expected_data_type);
1476 }
1477 _ => panic!("Expected right to be a literal"),
1478 },
1479 _ => panic!("Expected binary expression"),
1480 }
1481 }
1482 }
1483
1484 #[test]
1485 fn test_sql_between() {
1486 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1487 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1488 use std::sync::Arc;
1489
1490 let schema = Arc::new(Schema::new(vec![
1491 Field::new("x", DataType::Int32, false),
1492 Field::new("y", DataType::Float64, false),
1493 Field::new(
1494 "ts",
1495 DataType::Timestamp(TimeUnit::Microsecond, None),
1496 false,
1497 ),
1498 ]));
1499
1500 let planner = Planner::new(schema.clone());
1501
1502 let expr = planner
1504 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1505 .unwrap();
1506 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1507
1508 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1512 (0..10).map(|i| base_ts + i * 1_000_000), );
1514
1515 let batch = RecordBatch::try_new(
1516 schema,
1517 vec![
1518 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1519 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1520 Arc::new(ts_array),
1521 ],
1522 )
1523 .unwrap();
1524
1525 let predicates = physical_expr.evaluate(&batch).unwrap();
1526 assert_eq!(
1527 predicates.into_array(0).unwrap().as_ref(),
1528 &BooleanArray::from(vec![
1529 false, false, false, true, true, true, true, true, false, false
1530 ])
1531 );
1532
1533 let expr = planner
1535 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1536 .unwrap();
1537 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1538
1539 let predicates = physical_expr.evaluate(&batch).unwrap();
1540 assert_eq!(
1541 predicates.into_array(0).unwrap().as_ref(),
1542 &BooleanArray::from(vec![
1543 true, true, true, false, false, false, false, false, true, true
1544 ])
1545 );
1546
1547 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1549 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1550
1551 let predicates = physical_expr.evaluate(&batch).unwrap();
1552 assert_eq!(
1553 predicates.into_array(0).unwrap().as_ref(),
1554 &BooleanArray::from(vec![
1555 false, false, false, true, true, true, true, false, false, false
1556 ])
1557 );
1558
1559 let expr = planner
1561 .parse_filter(
1562 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1563 )
1564 .unwrap();
1565 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1566
1567 let predicates = physical_expr.evaluate(&batch).unwrap();
1568 assert_eq!(
1569 predicates.into_array(0).unwrap().as_ref(),
1570 &BooleanArray::from(vec![
1571 false, false, false, true, true, true, true, true, false, false
1572 ])
1573 );
1574 }
1575
1576 #[test]
1577 fn test_sql_comparison() {
1578 let batch: Vec<(&str, ArrayRef)> = vec![
1580 (
1581 "timestamp_s",
1582 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1583 ),
1584 (
1585 "timestamp_ms",
1586 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1587 ),
1588 (
1589 "timestamp_us",
1590 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1591 ),
1592 (
1593 "timestamp_ns",
1594 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1595 ),
1596 ];
1597 let batch = RecordBatch::try_from_iter(batch).unwrap();
1598
1599 let planner = Planner::new(batch.schema());
1600
1601 let expressions = &[
1603 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1604 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1605 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1606 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1607 ];
1608
1609 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1610 std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1611 ));
1612 for expression in expressions {
1613 let logical_expr = planner.parse_filter(expression).unwrap();
1615 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1616 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1617
1618 let result = physical_expr.evaluate(&batch).unwrap();
1620 let result = result.into_array(batch.num_rows()).unwrap();
1621 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1622 }
1623 }
1624
1625 #[test]
1626 fn test_columns_in_expr() {
1627 let expr = col("s0").gt(lit("value")).and(
1628 col("st")
1629 .field("st")
1630 .field("s2")
1631 .eq(lit("value"))
1632 .or(col("st")
1633 .field("s1")
1634 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1635 );
1636
1637 let columns = Planner::column_names_in_expr(&expr);
1638 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1639 }
1640
1641 #[test]
1642 fn test_parse_binary_expr() {
1643 let bin_str = "x'616263'";
1644
1645 let schema = Arc::new(Schema::new(vec![Field::new(
1646 "binary",
1647 DataType::Binary,
1648 true,
1649 )]));
1650 let planner = Planner::new(schema);
1651 let expr = planner.parse_expr(bin_str).unwrap();
1652 assert_eq!(
1653 expr,
1654 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1655 );
1656 }
1657
1658 #[test]
1659 fn test_lance_context_provider_expr_planners() {
1660 let ctx_provider = LanceContextProvider::default();
1661 assert!(!ctx_provider.get_expr_planners().is_empty());
1662 }
1663}