1use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::{SessionContext, SessionState};
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30 Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{
34 ContextProvider, NullOrdering, ParserOptions, PlannerContext, SqlToRel,
35};
36use datafusion::sql::sqlparser::ast::{
37 AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
38 Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
39 ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
40};
41use datafusion::{
42 common::Column,
43 logical_expr::{Between, BinaryExpr, Like, Operator},
44 physical_expr::execution_props::ExecutionProps,
45 physical_plan::PhysicalExpr,
46 prelude::Expr,
47 scalar::ScalarValue,
48};
49use datafusion_functions::core::getfield::GetFieldFunc;
50use lance_arrow::cast::cast_with_options;
51use lance_core::datatypes::Schema;
52use lance_core::error::LanceOptionExt;
53use snafu::location;
54
55use chrono::Utc;
56use lance_core::{Error, Result};
57
58#[derive(Debug, Clone, Eq, PartialEq, Hash)]
59struct CastListF16Udf {
60 signature: Signature,
61}
62
63impl CastListF16Udf {
64 pub fn new() -> Self {
65 Self {
66 signature: Signature::any(1, Volatility::Immutable),
67 }
68 }
69}
70
71impl ScalarUDFImpl for CastListF16Udf {
72 fn as_any(&self) -> &dyn std::any::Any {
73 self
74 }
75
76 fn name(&self) -> &str {
77 "_cast_list_f16"
78 }
79
80 fn signature(&self) -> &Signature {
81 &self.signature
82 }
83
84 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
85 let input = &arg_types[0];
86 match input {
87 ArrowDataType::FixedSizeList(field, size) => {
88 if field.data_type() != &ArrowDataType::Float32
89 && field.data_type() != &ArrowDataType::Float16
90 {
91 return Err(datafusion::error::DataFusionError::Execution(
92 "cast_list_f16 only supports list of float32 or float16".to_string(),
93 ));
94 }
95 Ok(ArrowDataType::FixedSizeList(
96 Arc::new(Field::new(
97 field.name(),
98 ArrowDataType::Float16,
99 field.is_nullable(),
100 )),
101 *size,
102 ))
103 }
104 ArrowDataType::List(field) => {
105 if field.data_type() != &ArrowDataType::Float32
106 && field.data_type() != &ArrowDataType::Float16
107 {
108 return Err(datafusion::error::DataFusionError::Execution(
109 "cast_list_f16 only supports list of float32 or float16".to_string(),
110 ));
111 }
112 Ok(ArrowDataType::List(Arc::new(Field::new(
113 field.name(),
114 ArrowDataType::Float16,
115 field.is_nullable(),
116 ))))
117 }
118 _ => Err(datafusion::error::DataFusionError::Execution(
119 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
120 )),
121 }
122 }
123
124 fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
125 let ColumnarValue::Array(arr) = &func_args.args[0] else {
126 return Err(datafusion::error::DataFusionError::Execution(
127 "cast_list_f16 only supports array arguments".to_string(),
128 ));
129 };
130
131 let to_type = match arr.data_type() {
132 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
133 Arc::new(Field::new(
134 field.name(),
135 ArrowDataType::Float16,
136 field.is_nullable(),
137 )),
138 *size,
139 ),
140 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
141 field.name(),
142 ArrowDataType::Float16,
143 field.is_nullable(),
144 ))),
145 _ => {
146 return Err(datafusion::error::DataFusionError::Execution(
147 "cast_list_f16 only supports array arguments".to_string(),
148 ));
149 }
150 };
151
152 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
153 Ok(ColumnarValue::Array(res))
154 }
155}
156
157struct LanceContextProvider {
159 options: datafusion::config::ConfigOptions,
160 state: SessionState,
161 expr_planners: Vec<Arc<dyn ExprPlanner>>,
162}
163
164impl Default for LanceContextProvider {
165 fn default() -> Self {
166 let config = SessionConfig::new();
167 let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
168
169 let ctx = SessionContext::new_with_config_rt(config.clone(), runtime.clone());
170 crate::udf::register_functions(&ctx);
171
172 let state = ctx.state();
173
174 let mut state_builder = SessionStateBuilder::new()
176 .with_config(config)
177 .with_runtime_env(runtime)
178 .with_default_features();
179
180 let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
182
183 Self {
184 options: ConfigOptions::default(),
185 state,
186 expr_planners,
187 }
188 }
189}
190
191impl ContextProvider for LanceContextProvider {
192 fn get_table_source(
193 &self,
194 name: datafusion::sql::TableReference,
195 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
196 Err(datafusion::error::DataFusionError::NotImplemented(format!(
197 "Attempt to reference inner table {} not supported",
198 name
199 )))
200 }
201
202 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
203 self.state.aggregate_functions().get(name).cloned()
204 }
205
206 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
207 self.state.window_functions().get(name).cloned()
208 }
209
210 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
211 match f {
212 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
215 _ => self.state.scalar_functions().get(f).cloned(),
216 }
217 }
218
219 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
220 None
222 }
223
224 fn options(&self) -> &datafusion::config::ConfigOptions {
225 &self.options
226 }
227
228 fn udf_names(&self) -> Vec<String> {
229 self.state.scalar_functions().keys().cloned().collect()
230 }
231
232 fn udaf_names(&self) -> Vec<String> {
233 self.state.aggregate_functions().keys().cloned().collect()
234 }
235
236 fn udwf_names(&self) -> Vec<String> {
237 self.state.window_functions().keys().cloned().collect()
238 }
239
240 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
241 &self.expr_planners
242 }
243}
244
245pub struct Planner {
246 schema: SchemaRef,
247 context_provider: LanceContextProvider,
248 enable_relations: bool,
249}
250
251impl Planner {
252 pub fn new(schema: SchemaRef) -> Self {
253 Self {
254 schema,
255 context_provider: LanceContextProvider::default(),
256 enable_relations: false,
257 }
258 }
259
260 pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
266 self.enable_relations = enable_relations;
267 self
268 }
269
270 fn resolve_column_name(&self, name: &str) -> String {
273 if self.schema.field_with_name(name).is_ok() {
275 return name.to_string();
276 }
277 for field in self.schema.fields() {
279 if field.name().eq_ignore_ascii_case(name) {
280 return field.name().clone();
281 }
282 }
283 name.to_string()
285 }
286
287 fn column(&self, idents: &[Ident]) -> Expr {
288 fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
289 for ident in idents {
290 *expr = Expr::ScalarFunction(ScalarFunction {
291 args: vec![
292 std::mem::take(expr),
293 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
294 ],
295 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
296 });
297 }
298 }
299
300 if self.enable_relations && idents.len() > 1 {
301 let relation = &idents[0].value;
303 let column_name = self.resolve_column_name(&idents[1].value);
304 let column = Expr::Column(Column::new(Some(relation.clone()), column_name));
305 let mut result = column;
306 handle_remaining_idents(&mut result, &idents[2..]);
307 result
308 } else {
309 let resolved_name = self.resolve_column_name(&idents[0].value);
312 let mut column = Expr::Column(Column::from_name(resolved_name));
313 handle_remaining_idents(&mut column, &idents[1..]);
314 column
315 }
316 }
317
318 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
319 Ok(match op {
320 BinaryOperator::Plus => Operator::Plus,
321 BinaryOperator::Minus => Operator::Minus,
322 BinaryOperator::Multiply => Operator::Multiply,
323 BinaryOperator::Divide => Operator::Divide,
324 BinaryOperator::Modulo => Operator::Modulo,
325 BinaryOperator::StringConcat => Operator::StringConcat,
326 BinaryOperator::Gt => Operator::Gt,
327 BinaryOperator::Lt => Operator::Lt,
328 BinaryOperator::GtEq => Operator::GtEq,
329 BinaryOperator::LtEq => Operator::LtEq,
330 BinaryOperator::Eq => Operator::Eq,
331 BinaryOperator::NotEq => Operator::NotEq,
332 BinaryOperator::And => Operator::And,
333 BinaryOperator::Or => Operator::Or,
334 _ => {
335 return Err(Error::invalid_input(
336 format!("Operator {op} is not supported"),
337 location!(),
338 ));
339 }
340 })
341 }
342
343 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
344 Ok(Expr::BinaryExpr(BinaryExpr::new(
345 Box::new(self.parse_sql_expr(left)?),
346 self.binary_op(op)?,
347 Box::new(self.parse_sql_expr(right)?),
348 )))
349 }
350
351 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
352 Ok(match op {
353 UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
354 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
355 }
356
357 UnaryOperator::Minus => {
358 use datafusion::logical_expr::lit;
359 match expr {
360 SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
361 Ok(n) => lit(-n),
362 Err(_) => lit(-n
363 .parse::<f64>()
364 .map_err(|_e| {
365 Error::invalid_input(
366 format!("negative operator can be only applied to integer and float operands, got: {n}"),
367 location!(),
368 )
369 })?),
370 },
371 _ => {
372 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
373 }
374 }
375 }
376
377 _ => {
378 return Err(Error::invalid_input(
379 format!("Unary operator '{:?}' is not supported", op),
380 location!(),
381 ));
382 }
383 })
384 }
385
386 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
388 use datafusion::logical_expr::lit;
389 let value: Cow<str> = if negative {
390 Cow::Owned(format!("-{}", value))
391 } else {
392 Cow::Borrowed(value)
393 };
394 if let Ok(n) = value.parse::<i64>() {
395 Ok(lit(n))
396 } else {
397 value.parse::<f64>().map(lit).map_err(|_| {
398 Error::invalid_input(
399 format!("'{value}' is not supported number value."),
400 location!(),
401 )
402 })
403 }
404 }
405
406 fn value(&self, value: &Value) -> Result<Expr> {
407 Ok(match value {
408 Value::Number(v, _) => self.number(v.as_str(), false)?,
409 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
410 Value::HexStringLiteral(hsl) => {
411 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
412 }
413 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
414 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
415 Value::Null => Expr::Literal(ScalarValue::Null, None),
416 _ => todo!(),
417 })
418 }
419
420 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
421 match func_args {
422 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
423 _ => Err(Error::invalid_input(
424 format!("Unsupported function args: {:?}", func_args),
425 location!(),
426 )),
427 }
428 }
429
430 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
437 match &func.args {
438 FunctionArguments::List(args) => {
439 if func.name.0.len() != 1 {
440 return Err(Error::invalid_input(
441 format!("Function name must have 1 part, got: {:?}", func.name.0),
442 location!(),
443 ));
444 }
445 Ok(Expr::IsNotNull(Box::new(
446 self.parse_function_args(&args.args[0])?,
447 )))
448 }
449 _ => Err(Error::invalid_input(
450 format!("Unsupported function args: {:?}", &func.args),
451 location!(),
452 )),
453 }
454 }
455
456 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
457 if let SQLExpr::Function(function) = &function {
458 if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
459 if &name.value == "is_valid" {
460 return self.legacy_parse_function(function);
461 }
462 }
463 }
464 let sql_to_rel = SqlToRel::new_with_options(
465 &self.context_provider,
466 ParserOptions {
467 parse_float_as_decimal: false,
468 enable_ident_normalization: false,
469 support_varchar_with_length: false,
470 enable_options_value_normalization: false,
471 collect_spans: false,
472 map_string_types_to_utf8view: false,
473 default_null_ordering: NullOrdering::NullsMax,
474 },
475 );
476
477 let mut planner_context = PlannerContext::default();
478 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
479 sql_to_rel
480 .sql_to_expr(function, &schema, &mut planner_context)
481 .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
482 }
483
484 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
485 const SUPPORTED_TYPES: [&str; 13] = [
486 "int [unsigned]",
487 "tinyint [unsigned]",
488 "smallint [unsigned]",
489 "bigint [unsigned]",
490 "float",
491 "double",
492 "string",
493 "binary",
494 "date",
495 "timestamp(precision)",
496 "datetime(precision)",
497 "decimal(precision,scale)",
498 "boolean",
499 ];
500 match data_type {
501 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
502 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
503 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
504 SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
505 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
506 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
507 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
508 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
509 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
510 SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
511 SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
512 SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
513 Ok(ArrowDataType::UInt32)
514 }
515 SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
516 SQLDataType::Date => Ok(ArrowDataType::Date32),
517 SQLDataType::Timestamp(resolution, tz) => {
518 match tz {
519 TimezoneInfo::None => {}
520 _ => {
521 return Err(Error::invalid_input(
522 "Timezone not supported in timestamp".to_string(),
523 location!(),
524 ));
525 }
526 };
527 let time_unit = match resolution {
528 None => TimeUnit::Microsecond,
530 Some(0) => TimeUnit::Second,
531 Some(3) => TimeUnit::Millisecond,
532 Some(6) => TimeUnit::Microsecond,
533 Some(9) => TimeUnit::Nanosecond,
534 _ => {
535 return Err(Error::invalid_input(
536 format!("Unsupported datetime resolution: {:?}", resolution),
537 location!(),
538 ));
539 }
540 };
541 Ok(ArrowDataType::Timestamp(time_unit, None))
542 }
543 SQLDataType::Datetime(resolution) => {
544 let time_unit = match resolution {
545 None => TimeUnit::Microsecond,
546 Some(0) => TimeUnit::Second,
547 Some(3) => TimeUnit::Millisecond,
548 Some(6) => TimeUnit::Microsecond,
549 Some(9) => TimeUnit::Nanosecond,
550 _ => {
551 return Err(Error::invalid_input(
552 format!("Unsupported datetime resolution: {:?}", resolution),
553 location!(),
554 ));
555 }
556 };
557 Ok(ArrowDataType::Timestamp(time_unit, None))
558 }
559 SQLDataType::Decimal(number_info) => match number_info {
560 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
561 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
562 }
563 _ => Err(Error::invalid_input(
564 format!(
565 "Must provide precision and scale for decimal: {:?}",
566 number_info
567 ),
568 location!(),
569 )),
570 },
571 _ => Err(Error::invalid_input(
572 format!(
573 "Unsupported data type: {:?}. Supported types: {:?}",
574 data_type, SUPPORTED_TYPES
575 ),
576 location!(),
577 )),
578 }
579 }
580
581 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
582 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
583 for planner in self.context_provider.get_expr_planners() {
584 match planner.plan_field_access(field_access_expr, &df_schema)? {
585 PlannerResult::Planned(expr) => return Ok(expr),
586 PlannerResult::Original(expr) => {
587 field_access_expr = expr;
588 }
589 }
590 }
591 Err(Error::invalid_input(
592 "Field access could not be planned",
593 location!(),
594 ))
595 }
596
597 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
598 match expr {
599 SQLExpr::Identifier(id) => {
600 if id.quote_style == Some('"') {
603 Ok(Expr::Literal(
604 ScalarValue::Utf8(Some(id.value.clone())),
605 None,
606 ))
607 } else if id.quote_style == Some('`') {
610 Ok(Expr::Column(Column::from_name(id.value.clone())))
611 } else {
612 Ok(self.column(vec![id.clone()].as_slice()))
613 }
614 }
615 SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
616 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
617 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
618 SQLExpr::Value(value) => self.value(&value.value),
619 SQLExpr::Array(SQLArray { elem, .. }) => {
620 let mut values = vec![];
621
622 let array_literal_error = |pos: usize, value: &_| {
623 Err(Error::invalid_input(
624 format!(
625 "Expected a literal value in array, instead got {} at position {}",
626 value, pos
627 ),
628 location!(),
629 ))
630 };
631
632 for (pos, expr) in elem.iter().enumerate() {
633 match expr {
634 SQLExpr::Value(value) => {
635 if let Expr::Literal(value, _) = self.value(&value.value)? {
636 values.push(value);
637 } else {
638 return array_literal_error(pos, expr);
639 }
640 }
641 SQLExpr::UnaryOp {
642 op: UnaryOperator::Minus,
643 expr,
644 } => {
645 if let SQLExpr::Value(ValueWithSpan {
646 value: Value::Number(number, _),
647 ..
648 }) = expr.as_ref()
649 {
650 if let Expr::Literal(value, _) = self.number(number, true)? {
651 values.push(value);
652 } else {
653 return array_literal_error(pos, expr);
654 }
655 } else {
656 return array_literal_error(pos, expr);
657 }
658 }
659 _ => {
660 return array_literal_error(pos, expr);
661 }
662 }
663 }
664
665 let field = if !values.is_empty() {
666 let data_type = values[0].data_type();
667
668 for value in &mut values {
669 if value.data_type() != data_type {
670 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
671 format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
672 location!()
673 ))?;
674 }
675 }
676 Field::new("item", data_type, true)
677 } else {
678 Field::new("item", ArrowDataType::Null, true)
679 };
680
681 let values = values
682 .into_iter()
683 .map(|v| v.to_array().map_err(Error::from))
684 .collect::<Result<Vec<_>>>()?;
685 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
686 let values = concat(&array_refs)?;
687 let values = ListArray::try_new(
688 field.into(),
689 OffsetBuffer::from_lengths([values.len()]),
690 values,
691 None,
692 )?;
693
694 Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
695 }
696 SQLExpr::TypedString { data_type, value } => {
698 let value = value.clone().into_string().expect_ok()?;
699 Ok(Expr::Cast(datafusion::logical_expr::Cast {
700 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
701 data_type: self.parse_type(data_type)?,
702 }))
703 }
704 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
705 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
706 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
707 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
708 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
709 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
710 SQLExpr::InList {
711 expr,
712 list,
713 negated,
714 } => {
715 let value_expr = self.parse_sql_expr(expr)?;
716 let list_exprs = list
717 .iter()
718 .map(|e| self.parse_sql_expr(e))
719 .collect::<Result<Vec<_>>>()?;
720 Ok(value_expr.in_list(list_exprs, *negated))
721 }
722 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
723 SQLExpr::Function(_) => self.parse_function(expr.clone()),
724 SQLExpr::ILike {
725 negated,
726 expr,
727 pattern,
728 escape_char,
729 any: _,
730 } => Ok(Expr::Like(Like::new(
731 *negated,
732 Box::new(self.parse_sql_expr(expr)?),
733 Box::new(self.parse_sql_expr(pattern)?),
734 match escape_char {
735 Some(Value::SingleQuotedString(char)) => char.chars().next(),
736 Some(value) => return Err(Error::invalid_input(
737 format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
738 location!()
739 )),
740 None => None,
741 },
742 true,
743 ))),
744 SQLExpr::Like {
745 negated,
746 expr,
747 pattern,
748 escape_char,
749 any: _,
750 } => Ok(Expr::Like(Like::new(
751 *negated,
752 Box::new(self.parse_sql_expr(expr)?),
753 Box::new(self.parse_sql_expr(pattern)?),
754 match escape_char {
755 Some(Value::SingleQuotedString(char)) => char.chars().next(),
756 Some(value) => return Err(Error::invalid_input(
757 format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
758 location!()
759 )),
760 None => None,
761 },
762 false,
763 ))),
764 SQLExpr::Cast {
765 expr,
766 data_type,
767 kind,
768 ..
769 } => match kind {
770 datafusion::sql::sqlparser::ast::CastKind::TryCast
771 | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
772 Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
773 expr: Box::new(self.parse_sql_expr(expr)?),
774 data_type: self.parse_type(data_type)?,
775 }))
776 }
777 _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
778 expr: Box::new(self.parse_sql_expr(expr)?),
779 data_type: self.parse_type(data_type)?,
780 })),
781 },
782 SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
783 "JSON access is not supported",
784 location!(),
785 )),
786 SQLExpr::CompoundFieldAccess { root, access_chain } => {
787 let mut expr = self.parse_sql_expr(root)?;
788
789 for access in access_chain {
790 let field_access = match access {
791 AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
793 | AccessExpr::Subscript(Subscript::Index {
794 index:
795 SQLExpr::Value(ValueWithSpan {
796 value:
797 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
798 ..
799 }),
800 }) => GetFieldAccess::NamedStructField {
801 name: ScalarValue::from(s.as_str()),
802 },
803 AccessExpr::Subscript(Subscript::Index { index }) => {
804 let key = Box::new(self.parse_sql_expr(index)?);
805 GetFieldAccess::ListIndex { key }
806 }
807 AccessExpr::Subscript(Subscript::Slice { .. }) => {
808 return Err(Error::invalid_input(
809 "Slice subscript is not supported",
810 location!(),
811 ));
812 }
813 _ => {
814 return Err(Error::invalid_input(
817 "Only dot notation or index access is supported for field access",
818 location!(),
819 ));
820 }
821 };
822
823 let field_access_expr = RawFieldAccessExpr { expr, field_access };
824 expr = self.plan_field_access(field_access_expr)?;
825 }
826
827 Ok(expr)
828 }
829 SQLExpr::Between {
830 expr,
831 negated,
832 low,
833 high,
834 } => {
835 let expr = self.parse_sql_expr(expr)?;
837 let low = self.parse_sql_expr(low)?;
838 let high = self.parse_sql_expr(high)?;
839
840 let between = Expr::Between(Between::new(
841 Box::new(expr),
842 *negated,
843 Box::new(low),
844 Box::new(high),
845 ));
846 Ok(between)
847 }
848 _ => Err(Error::invalid_input(
849 format!("Expression '{expr}' is not supported SQL in lance"),
850 location!(),
851 )),
852 }
853 }
854
855 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
860 let ast_expr = parse_sql_filter(filter)?;
862 let expr = self.parse_sql_expr(&ast_expr)?;
863 let schema = Schema::try_from(self.schema.as_ref())?;
864 let resolved = resolve_expr(&expr, &schema).map_err(|e| {
865 Error::invalid_input(
866 format!("Error resolving filter expression {filter}: {e}"),
867 location!(),
868 )
869 })?;
870
871 Ok(coerce_filter_type_to_boolean(resolved))
872 }
873
874 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
879 let resolved_name = self.resolve_column_name(expr);
882 if self.schema.field_with_name(&resolved_name).is_ok() {
883 return Ok(Expr::Column(Column::from_name(resolved_name)));
884 }
885
886 let ast_expr = parse_sql_expr(expr)?;
888 let expr = self.parse_sql_expr(&ast_expr)?;
889 let schema = Schema::try_from(self.schema.as_ref())?;
890 let resolved = resolve_expr(&expr, &schema)?;
891 Ok(resolved)
892 }
893
894 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
900 let hex_bytes = s.as_bytes();
901 let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
902
903 let start_idx = hex_bytes.len() % 2;
904 if start_idx > 0 {
905 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
907 }
908
909 for i in (start_idx..hex_bytes.len()).step_by(2) {
910 let high = Self::try_decode_hex_char(hex_bytes[i])?;
911 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
912 decoded_bytes.push((high << 4) | low);
913 }
914
915 Some(decoded_bytes)
916 }
917
918 const fn try_decode_hex_char(c: u8) -> Option<u8> {
922 match c {
923 b'A'..=b'F' => Some(c - b'A' + 10),
924 b'a'..=b'f' => Some(c - b'a' + 10),
925 b'0'..=b'9' => Some(c - b'0'),
926 _ => None,
927 }
928 }
929
930 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
932 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
933
934 let props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
937 let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
938 let simplifier =
939 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
940
941 let expr = simplifier.simplify(expr)?;
942 let expr = simplifier.coerce(expr, &df_schema)?;
943
944 Ok(expr)
945 }
946
947 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
949 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
950 Ok(datafusion::physical_expr::create_physical_expr(
951 expr,
952 df_schema.as_ref(),
953 &Default::default(),
954 )?)
955 }
956
957 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
964 let mut visitor = ColumnCapturingVisitor {
965 current_path: VecDeque::new(),
966 columns: BTreeSet::new(),
967 };
968 expr.visit(&mut visitor).unwrap();
969 visitor.columns.into_iter().collect()
970 }
971}
972
973struct ColumnCapturingVisitor {
974 current_path: VecDeque<String>,
976 columns: BTreeSet<String>,
977}
978
979impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
980 type Node = Expr;
981
982 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
983 match node {
984 Expr::Column(Column { name, .. }) => {
985 let mut path = name.clone();
989 for part in self.current_path.drain(..) {
990 path.push('.');
991 if part.contains('.') || part.contains('`') {
993 let escaped = part.replace('`', "``");
995 path.push('`');
996 path.push_str(&escaped);
997 path.push('`');
998 } else {
999 path.push_str(&part);
1000 }
1001 }
1002 self.columns.insert(path);
1003 self.current_path.clear();
1004 }
1005 Expr::ScalarFunction(udf) => {
1006 if udf.name() == GetFieldFunc::default().name() {
1007 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
1008 self.current_path.push_front(name.to_string())
1009 } else {
1010 self.current_path.clear();
1011 }
1012 } else {
1013 self.current_path.clear();
1014 }
1015 }
1016 _ => {
1017 self.current_path.clear();
1018 }
1019 }
1020
1021 Ok(TreeNodeRecursion::Continue)
1022 }
1023}
1024
1025#[cfg(test)]
1026mod tests {
1027
1028 use crate::logical_expr::ExprExt;
1029
1030 use super::*;
1031
1032 use arrow::datatypes::Float64Type;
1033 use arrow_array::{
1034 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
1035 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1036 TimestampNanosecondArray, TimestampSecondArray,
1037 };
1038 use arrow_schema::{DataType, Fields, Schema};
1039 use datafusion::{
1040 logical_expr::{col, lit, Cast},
1041 prelude::{array_element, get_field},
1042 };
1043 use datafusion_functions::core::expr_ext::FieldAccessor;
1044
1045 #[test]
1046 fn test_parse_filter_simple() {
1047 let schema = Arc::new(Schema::new(vec![
1048 Field::new("i", DataType::Int32, false),
1049 Field::new("s", DataType::Utf8, true),
1050 Field::new(
1051 "st",
1052 DataType::Struct(Fields::from(vec![
1053 Field::new("x", DataType::Float32, false),
1054 Field::new("y", DataType::Float32, false),
1055 ])),
1056 true,
1057 ),
1058 ]));
1059
1060 let planner = Planner::new(schema.clone());
1061
1062 let expected = col("i")
1063 .gt(lit(3_i32))
1064 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1065 .and(
1066 col("s")
1067 .eq(lit("str-4"))
1068 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1069 );
1070
1071 let expr = planner
1073 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1074 .unwrap();
1075 assert_eq!(expr, expected);
1076
1077 let expr = planner
1079 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1080 .unwrap();
1081
1082 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1083
1084 let batch = RecordBatch::try_new(
1085 schema,
1086 vec![
1087 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1088 Arc::new(StringArray::from_iter_values(
1089 (0..10).map(|v| format!("str-{}", v)),
1090 )),
1091 Arc::new(StructArray::from(vec![
1092 (
1093 Arc::new(Field::new("x", DataType::Float32, false)),
1094 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1095 as ArrayRef,
1096 ),
1097 (
1098 Arc::new(Field::new("y", DataType::Float32, false)),
1099 Arc::new(Float32Array::from_iter_values(
1100 (0..10).map(|v| (v * 10) as f32),
1101 )),
1102 ),
1103 ])),
1104 ],
1105 )
1106 .unwrap();
1107 let predicates = physical_expr.evaluate(&batch).unwrap();
1108 assert_eq!(
1109 predicates.into_array(0).unwrap().as_ref(),
1110 &BooleanArray::from(vec![
1111 false, false, false, false, true, true, false, false, false, false
1112 ])
1113 );
1114 }
1115
1116 #[test]
1117 fn test_nested_col_refs() {
1118 let schema = Arc::new(Schema::new(vec![
1119 Field::new("s0", DataType::Utf8, true),
1120 Field::new(
1121 "st",
1122 DataType::Struct(Fields::from(vec![
1123 Field::new("s1", DataType::Utf8, true),
1124 Field::new(
1125 "st",
1126 DataType::Struct(Fields::from(vec![Field::new(
1127 "s2",
1128 DataType::Utf8,
1129 true,
1130 )])),
1131 true,
1132 ),
1133 ])),
1134 true,
1135 ),
1136 ]));
1137
1138 let planner = Planner::new(schema);
1139
1140 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1141 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1142 assert!(matches!(
1143 expr,
1144 Expr::BinaryExpr(BinaryExpr {
1145 left: _,
1146 op: Operator::Eq,
1147 right: _
1148 })
1149 ));
1150 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1151 assert_eq!(left.as_ref(), expected);
1152 }
1153 }
1154
1155 let expected = Expr::Column(Column::new_unqualified("s0"));
1156 assert_column_eq(&planner, "s0", &expected);
1157 assert_column_eq(&planner, "`s0`", &expected);
1158
1159 let expected = Expr::ScalarFunction(ScalarFunction {
1160 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1161 args: vec![
1162 Expr::Column(Column::new_unqualified("st")),
1163 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1164 ],
1165 });
1166 assert_column_eq(&planner, "st.s1", &expected);
1167 assert_column_eq(&planner, "`st`.`s1`", &expected);
1168 assert_column_eq(&planner, "st.`s1`", &expected);
1169
1170 let expected = Expr::ScalarFunction(ScalarFunction {
1171 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1172 args: vec![
1173 Expr::ScalarFunction(ScalarFunction {
1174 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1175 args: vec![
1176 Expr::Column(Column::new_unqualified("st")),
1177 Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1178 ],
1179 }),
1180 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1181 ],
1182 });
1183
1184 assert_column_eq(&planner, "st.st.s2", &expected);
1185 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1186 assert_column_eq(&planner, "st.st.`s2`", &expected);
1187 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1188 }
1189
1190 #[test]
1191 fn test_nested_list_refs() {
1192 let schema = Arc::new(Schema::new(vec![Field::new(
1193 "l",
1194 DataType::List(Arc::new(Field::new(
1195 "item",
1196 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1197 true,
1198 ))),
1199 true,
1200 )]));
1201
1202 let planner = Planner::new(schema);
1203
1204 let expected = array_element(col("l"), lit(0_i64));
1205 let expr = planner.parse_expr("l[0]").unwrap();
1206 assert_eq!(expr, expected);
1207
1208 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1209 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1210 assert_eq!(expr, expected);
1211
1212 }
1217
1218 #[test]
1219 fn test_negative_expressions() {
1220 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1221
1222 let planner = Planner::new(schema.clone());
1223
1224 let expected = col("x")
1225 .gt(lit(-3_i64))
1226 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1227
1228 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1229
1230 assert_eq!(expr, expected);
1231
1232 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1233
1234 let batch = RecordBatch::try_new(
1235 schema,
1236 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1237 )
1238 .unwrap();
1239 let predicates = physical_expr.evaluate(&batch).unwrap();
1240 assert_eq!(
1241 predicates.into_array(0).unwrap().as_ref(),
1242 &BooleanArray::from(vec![
1243 false, false, false, true, true, true, true, false, false, false
1244 ])
1245 );
1246 }
1247
1248 #[test]
1249 fn test_negative_array_expressions() {
1250 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1251
1252 let planner = Planner::new(schema);
1253
1254 let expected = Expr::Literal(
1255 ScalarValue::List(Arc::new(
1256 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1257 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1258 )]),
1259 )),
1260 None,
1261 );
1262
1263 let expr = planner
1264 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1265 .unwrap();
1266
1267 assert_eq!(expr, expected);
1268 }
1269
1270 #[test]
1271 fn test_sql_like() {
1272 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1273
1274 let planner = Planner::new(schema.clone());
1275
1276 let expected = col("s").like(lit("str-4"));
1277 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1279 assert_eq!(expr, expected);
1280 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1281
1282 let batch = RecordBatch::try_new(
1283 schema,
1284 vec![Arc::new(StringArray::from_iter_values(
1285 (0..10).map(|v| format!("str-{}", v)),
1286 ))],
1287 )
1288 .unwrap();
1289 let predicates = physical_expr.evaluate(&batch).unwrap();
1290 assert_eq!(
1291 predicates.into_array(0).unwrap().as_ref(),
1292 &BooleanArray::from(vec![
1293 false, false, false, false, true, false, false, false, false, false
1294 ])
1295 );
1296 }
1297
1298 #[test]
1299 fn test_not_like() {
1300 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1301
1302 let planner = Planner::new(schema.clone());
1303
1304 let expected = col("s").not_like(lit("str-4"));
1305 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1307 assert_eq!(expr, expected);
1308 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1309
1310 let batch = RecordBatch::try_new(
1311 schema,
1312 vec![Arc::new(StringArray::from_iter_values(
1313 (0..10).map(|v| format!("str-{}", v)),
1314 ))],
1315 )
1316 .unwrap();
1317 let predicates = physical_expr.evaluate(&batch).unwrap();
1318 assert_eq!(
1319 predicates.into_array(0).unwrap().as_ref(),
1320 &BooleanArray::from(vec![
1321 true, true, true, true, false, true, true, true, true, true
1322 ])
1323 );
1324 }
1325
1326 #[test]
1327 fn test_sql_is_in() {
1328 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1329
1330 let planner = Planner::new(schema.clone());
1331
1332 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1333 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1335 assert_eq!(expr, expected);
1336 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1337
1338 let batch = RecordBatch::try_new(
1339 schema,
1340 vec![Arc::new(StringArray::from_iter_values(
1341 (0..10).map(|v| format!("str-{}", v)),
1342 ))],
1343 )
1344 .unwrap();
1345 let predicates = physical_expr.evaluate(&batch).unwrap();
1346 assert_eq!(
1347 predicates.into_array(0).unwrap().as_ref(),
1348 &BooleanArray::from(vec![
1349 false, false, false, false, true, true, false, false, false, false
1350 ])
1351 );
1352 }
1353
1354 #[test]
1355 fn test_sql_is_null() {
1356 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1357
1358 let planner = Planner::new(schema.clone());
1359
1360 let expected = col("s").is_null();
1361 let expr = planner.parse_filter("s IS NULL").unwrap();
1362 assert_eq!(expr, expected);
1363 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1364
1365 let batch = RecordBatch::try_new(
1366 schema,
1367 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1368 if v % 3 == 0 {
1369 Some(format!("str-{}", v))
1370 } else {
1371 None
1372 }
1373 })))],
1374 )
1375 .unwrap();
1376 let predicates = physical_expr.evaluate(&batch).unwrap();
1377 assert_eq!(
1378 predicates.into_array(0).unwrap().as_ref(),
1379 &BooleanArray::from(vec![
1380 false, true, true, false, true, true, false, true, true, false
1381 ])
1382 );
1383
1384 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1385 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1386 let predicates = physical_expr.evaluate(&batch).unwrap();
1387 assert_eq!(
1388 predicates.into_array(0).unwrap().as_ref(),
1389 &BooleanArray::from(vec![
1390 true, false, false, true, false, false, true, false, false, true,
1391 ])
1392 );
1393 }
1394
1395 #[test]
1396 fn test_sql_invert() {
1397 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1398
1399 let planner = Planner::new(schema.clone());
1400
1401 let expr = planner.parse_filter("NOT s").unwrap();
1402 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1403
1404 let batch = RecordBatch::try_new(
1405 schema,
1406 vec![Arc::new(BooleanArray::from_iter(
1407 (0..10).map(|v| Some(v % 3 == 0)),
1408 ))],
1409 )
1410 .unwrap();
1411 let predicates = physical_expr.evaluate(&batch).unwrap();
1412 assert_eq!(
1413 predicates.into_array(0).unwrap().as_ref(),
1414 &BooleanArray::from(vec![
1415 false, true, true, false, true, true, false, true, true, false
1416 ])
1417 );
1418 }
1419
1420 #[test]
1421 fn test_sql_cast() {
1422 let cases = &[
1423 (
1424 "x = cast('2021-01-01 00:00:00' as timestamp)",
1425 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1426 ),
1427 (
1428 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1429 ArrowDataType::Timestamp(TimeUnit::Second, None),
1430 ),
1431 (
1432 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1433 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1434 ),
1435 (
1436 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1437 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1438 ),
1439 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1440 (
1441 "x = cast('1.238' as decimal(9,3))",
1442 ArrowDataType::Decimal128(9, 3),
1443 ),
1444 ("x = cast(1 as float)", ArrowDataType::Float32),
1445 ("x = cast(1 as double)", ArrowDataType::Float64),
1446 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1447 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1448 ("x = cast(1 as int)", ArrowDataType::Int32),
1449 ("x = cast(1 as integer)", ArrowDataType::Int32),
1450 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1451 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1452 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1453 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1454 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1455 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1456 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1457 ("x = cast(1 as string)", ArrowDataType::Utf8),
1458 ];
1459
1460 for (sql, expected_data_type) in cases {
1461 let schema = Arc::new(Schema::new(vec![Field::new(
1462 "x",
1463 expected_data_type.clone(),
1464 true,
1465 )]));
1466 let planner = Planner::new(schema.clone());
1467 let expr = planner.parse_filter(sql).unwrap();
1468
1469 let expected_value_str = sql
1471 .split("cast(")
1472 .nth(1)
1473 .unwrap()
1474 .split(" as")
1475 .next()
1476 .unwrap();
1477 let expected_value_str = expected_value_str.trim_matches('\'');
1479
1480 match expr {
1481 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1482 Expr::Cast(Cast { expr, data_type }) => {
1483 match expr.as_ref() {
1484 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1485 assert_eq!(value_str, expected_value_str);
1486 }
1487 Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1488 assert_eq!(*value, 1);
1489 }
1490 _ => panic!("Expected cast to be applied to literal"),
1491 }
1492 assert_eq!(data_type, expected_data_type);
1493 }
1494 _ => panic!("Expected right to be a cast"),
1495 },
1496 _ => panic!("Expected binary expression"),
1497 }
1498 }
1499 }
1500
1501 #[test]
1502 fn test_sql_literals() {
1503 let cases = &[
1504 (
1505 "x = timestamp '2021-01-01 00:00:00'",
1506 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1507 ),
1508 (
1509 "x = timestamp(0) '2021-01-01 00:00:00'",
1510 ArrowDataType::Timestamp(TimeUnit::Second, None),
1511 ),
1512 (
1513 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1514 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1515 ),
1516 ("x = date '2021-01-01'", ArrowDataType::Date32),
1517 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1518 ];
1519
1520 for (sql, expected_data_type) in cases {
1521 let schema = Arc::new(Schema::new(vec![Field::new(
1522 "x",
1523 expected_data_type.clone(),
1524 true,
1525 )]));
1526 let planner = Planner::new(schema.clone());
1527 let expr = planner.parse_filter(sql).unwrap();
1528
1529 let expected_value_str = sql.split('\'').nth(1).unwrap();
1530
1531 match expr {
1532 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1533 Expr::Cast(Cast { expr, data_type }) => {
1534 match expr.as_ref() {
1535 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1536 assert_eq!(value_str, expected_value_str);
1537 }
1538 _ => panic!("Expected cast to be applied to literal"),
1539 }
1540 assert_eq!(data_type, expected_data_type);
1541 }
1542 _ => panic!("Expected right to be a cast"),
1543 },
1544 _ => panic!("Expected binary expression"),
1545 }
1546 }
1547 }
1548
1549 #[test]
1550 fn test_sql_array_literals() {
1551 let cases = [
1552 (
1553 "x = [1, 2, 3]",
1554 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1555 ),
1556 (
1557 "x = [1, 2, 3]",
1558 ArrowDataType::FixedSizeList(
1559 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1560 3,
1561 ),
1562 ),
1563 ];
1564
1565 for (sql, expected_data_type) in cases {
1566 let schema = Arc::new(Schema::new(vec![Field::new(
1567 "x",
1568 expected_data_type.clone(),
1569 true,
1570 )]));
1571 let planner = Planner::new(schema.clone());
1572 let expr = planner.parse_filter(sql).unwrap();
1573 let expr = planner.optimize_expr(expr).unwrap();
1574
1575 match expr {
1576 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1577 Expr::Literal(value, _) => {
1578 assert_eq!(&value.data_type(), &expected_data_type);
1579 }
1580 _ => panic!("Expected right to be a literal"),
1581 },
1582 _ => panic!("Expected binary expression"),
1583 }
1584 }
1585 }
1586
1587 #[test]
1588 fn test_sql_between() {
1589 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1590 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1591 use std::sync::Arc;
1592
1593 let schema = Arc::new(Schema::new(vec![
1594 Field::new("x", DataType::Int32, false),
1595 Field::new("y", DataType::Float64, false),
1596 Field::new(
1597 "ts",
1598 DataType::Timestamp(TimeUnit::Microsecond, None),
1599 false,
1600 ),
1601 ]));
1602
1603 let planner = Planner::new(schema.clone());
1604
1605 let expr = planner
1607 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1608 .unwrap();
1609 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1610
1611 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1615 (0..10).map(|i| base_ts + i * 1_000_000), );
1617
1618 let batch = RecordBatch::try_new(
1619 schema,
1620 vec![
1621 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1622 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1623 Arc::new(ts_array),
1624 ],
1625 )
1626 .unwrap();
1627
1628 let predicates = physical_expr.evaluate(&batch).unwrap();
1629 assert_eq!(
1630 predicates.into_array(0).unwrap().as_ref(),
1631 &BooleanArray::from(vec![
1632 false, false, false, true, true, true, true, true, false, false
1633 ])
1634 );
1635
1636 let expr = planner
1638 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1639 .unwrap();
1640 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1641
1642 let predicates = physical_expr.evaluate(&batch).unwrap();
1643 assert_eq!(
1644 predicates.into_array(0).unwrap().as_ref(),
1645 &BooleanArray::from(vec![
1646 true, true, true, false, false, false, false, false, true, true
1647 ])
1648 );
1649
1650 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1652 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1653
1654 let predicates = physical_expr.evaluate(&batch).unwrap();
1655 assert_eq!(
1656 predicates.into_array(0).unwrap().as_ref(),
1657 &BooleanArray::from(vec![
1658 false, false, false, true, true, true, true, false, false, false
1659 ])
1660 );
1661
1662 let expr = planner
1664 .parse_filter(
1665 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1666 )
1667 .unwrap();
1668 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1669
1670 let predicates = physical_expr.evaluate(&batch).unwrap();
1671 assert_eq!(
1672 predicates.into_array(0).unwrap().as_ref(),
1673 &BooleanArray::from(vec![
1674 false, false, false, true, true, true, true, true, false, false
1675 ])
1676 );
1677 }
1678
1679 #[test]
1680 fn test_sql_comparison() {
1681 let batch: Vec<(&str, ArrayRef)> = vec![
1683 (
1684 "timestamp_s",
1685 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1686 ),
1687 (
1688 "timestamp_ms",
1689 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1690 ),
1691 (
1692 "timestamp_us",
1693 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1694 ),
1695 (
1696 "timestamp_ns",
1697 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1698 ),
1699 ];
1700 let batch = RecordBatch::try_from_iter(batch).unwrap();
1701
1702 let planner = Planner::new(batch.schema());
1703
1704 let expressions = &[
1706 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1707 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1708 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1709 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1710 ];
1711
1712 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1713 std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1714 ));
1715 for expression in expressions {
1716 let logical_expr = planner.parse_filter(expression).unwrap();
1718 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1719 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1720
1721 let result = physical_expr.evaluate(&batch).unwrap();
1723 let result = result.into_array(batch.num_rows()).unwrap();
1724 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1725 }
1726 }
1727
1728 #[test]
1729 fn test_columns_in_expr() {
1730 let expr = col("s0").gt(lit("value")).and(
1731 col("st")
1732 .field("st")
1733 .field("s2")
1734 .eq(lit("value"))
1735 .or(col("st")
1736 .field("s1")
1737 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1738 );
1739
1740 let columns = Planner::column_names_in_expr(&expr);
1741 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1742 }
1743
1744 #[test]
1745 fn test_parse_binary_expr() {
1746 let bin_str = "x'616263'";
1747
1748 let schema = Arc::new(Schema::new(vec![Field::new(
1749 "binary",
1750 DataType::Binary,
1751 true,
1752 )]));
1753 let planner = Planner::new(schema);
1754 let expr = planner.parse_expr(bin_str).unwrap();
1755 assert_eq!(
1756 expr,
1757 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1758 );
1759 }
1760
1761 #[test]
1762 fn test_lance_context_provider_expr_planners() {
1763 let ctx_provider = LanceContextProvider::default();
1764 assert!(!ctx_provider.get_expr_planners().is_empty());
1765 }
1766
1767 #[test]
1768 fn test_regexp_match_and_non_empty_captions() {
1769 let schema = Arc::new(Schema::new(vec![
1772 Field::new("keywords", DataType::Utf8, true),
1773 Field::new("natural_caption", DataType::Utf8, true),
1774 Field::new("poetic_caption", DataType::Utf8, true),
1775 ]));
1776
1777 let planner = Planner::new(schema.clone());
1778
1779 let expr = planner
1780 .parse_filter(
1781 "regexp_match(keywords, 'Liberty|revolution') AND \
1782 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1783 poetic_caption IS NOT NULL AND poetic_caption <> '')",
1784 )
1785 .unwrap();
1786
1787 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1788
1789 let batch = RecordBatch::try_new(
1790 schema,
1791 vec![
1792 Arc::new(StringArray::from(vec![
1793 Some("Liberty for all"),
1794 Some("peace"),
1795 Some("revolution now"),
1796 Some("Liberty"),
1797 Some("revolutionary"),
1798 Some("none"),
1799 ])) as ArrayRef,
1800 Arc::new(StringArray::from(vec![
1801 Some("a"),
1802 Some("b"),
1803 None,
1804 Some(""),
1805 Some("c"),
1806 Some("d"),
1807 ])) as ArrayRef,
1808 Arc::new(StringArray::from(vec![
1809 Some("x"),
1810 Some(""),
1811 Some("y"),
1812 Some("z"),
1813 None,
1814 Some("w"),
1815 ])) as ArrayRef,
1816 ],
1817 )
1818 .unwrap();
1819
1820 let result = physical_expr.evaluate(&batch).unwrap();
1821 assert_eq!(
1822 result.into_array(0).unwrap().as_ref(),
1823 &BooleanArray::from(vec![true, false, false, false, false, false])
1824 );
1825 }
1826
1827 #[test]
1828 fn test_regexp_match_infer_error_without_boolean_coercion() {
1829 let schema = Arc::new(Schema::new(vec![
1832 Field::new("keywords", DataType::Utf8, true),
1833 Field::new("natural_caption", DataType::Utf8, true),
1834 Field::new("poetic_caption", DataType::Utf8, true),
1835 ]));
1836
1837 let planner = Planner::new(schema);
1838
1839 let expr = planner
1840 .parse_filter(
1841 "regexp_match(keywords, 'Liberty|revolution') AND \
1842 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1843 poetic_caption IS NOT NULL AND poetic_caption <> '')",
1844 )
1845 .unwrap();
1846
1847 let _physical = planner.create_physical_expr(&expr).unwrap();
1849 }
1850}