1use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::{SessionContext, SessionState};
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30 Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{
34 ContextProvider, NullOrdering, ParserOptions, PlannerContext, SqlToRel,
35};
36use datafusion::sql::sqlparser::ast::{
37 AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
38 Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
39 ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
40};
41use datafusion::{
42 common::Column,
43 logical_expr::{col, Between, BinaryExpr, Like, Operator},
44 physical_expr::execution_props::ExecutionProps,
45 physical_plan::PhysicalExpr,
46 prelude::Expr,
47 scalar::ScalarValue,
48};
49use datafusion_functions::core::getfield::GetFieldFunc;
50use lance_arrow::cast::cast_with_options;
51use lance_core::datatypes::Schema;
52use lance_core::error::LanceOptionExt;
53use snafu::location;
54
55use chrono::Utc;
56use lance_core::{Error, Result};
57
58#[derive(Debug, Clone, Eq, PartialEq, Hash)]
59struct CastListF16Udf {
60 signature: Signature,
61}
62
63impl CastListF16Udf {
64 pub fn new() -> Self {
65 Self {
66 signature: Signature::any(1, Volatility::Immutable),
67 }
68 }
69}
70
71impl ScalarUDFImpl for CastListF16Udf {
72 fn as_any(&self) -> &dyn std::any::Any {
73 self
74 }
75
76 fn name(&self) -> &str {
77 "_cast_list_f16"
78 }
79
80 fn signature(&self) -> &Signature {
81 &self.signature
82 }
83
84 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
85 let input = &arg_types[0];
86 match input {
87 ArrowDataType::FixedSizeList(field, size) => {
88 if field.data_type() != &ArrowDataType::Float32
89 && field.data_type() != &ArrowDataType::Float16
90 {
91 return Err(datafusion::error::DataFusionError::Execution(
92 "cast_list_f16 only supports list of float32 or float16".to_string(),
93 ));
94 }
95 Ok(ArrowDataType::FixedSizeList(
96 Arc::new(Field::new(
97 field.name(),
98 ArrowDataType::Float16,
99 field.is_nullable(),
100 )),
101 *size,
102 ))
103 }
104 ArrowDataType::List(field) => {
105 if field.data_type() != &ArrowDataType::Float32
106 && field.data_type() != &ArrowDataType::Float16
107 {
108 return Err(datafusion::error::DataFusionError::Execution(
109 "cast_list_f16 only supports list of float32 or float16".to_string(),
110 ));
111 }
112 Ok(ArrowDataType::List(Arc::new(Field::new(
113 field.name(),
114 ArrowDataType::Float16,
115 field.is_nullable(),
116 ))))
117 }
118 _ => Err(datafusion::error::DataFusionError::Execution(
119 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
120 )),
121 }
122 }
123
124 fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
125 let ColumnarValue::Array(arr) = &func_args.args[0] else {
126 return Err(datafusion::error::DataFusionError::Execution(
127 "cast_list_f16 only supports array arguments".to_string(),
128 ));
129 };
130
131 let to_type = match arr.data_type() {
132 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
133 Arc::new(Field::new(
134 field.name(),
135 ArrowDataType::Float16,
136 field.is_nullable(),
137 )),
138 *size,
139 ),
140 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
141 field.name(),
142 ArrowDataType::Float16,
143 field.is_nullable(),
144 ))),
145 _ => {
146 return Err(datafusion::error::DataFusionError::Execution(
147 "cast_list_f16 only supports array arguments".to_string(),
148 ));
149 }
150 };
151
152 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
153 Ok(ColumnarValue::Array(res))
154 }
155}
156
157struct LanceContextProvider {
159 options: datafusion::config::ConfigOptions,
160 state: SessionState,
161 expr_planners: Vec<Arc<dyn ExprPlanner>>,
162}
163
164impl Default for LanceContextProvider {
165 fn default() -> Self {
166 let config = SessionConfig::new();
167 let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
168
169 let ctx = SessionContext::new_with_config_rt(config.clone(), runtime.clone());
170 crate::udf::register_functions(&ctx);
171
172 let state = ctx.state();
173
174 let mut state_builder = SessionStateBuilder::new()
176 .with_config(config)
177 .with_runtime_env(runtime)
178 .with_default_features();
179
180 let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
182
183 Self {
184 options: ConfigOptions::default(),
185 state,
186 expr_planners,
187 }
188 }
189}
190
191impl ContextProvider for LanceContextProvider {
192 fn get_table_source(
193 &self,
194 name: datafusion::sql::TableReference,
195 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
196 Err(datafusion::error::DataFusionError::NotImplemented(format!(
197 "Attempt to reference inner table {} not supported",
198 name
199 )))
200 }
201
202 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
203 self.state.aggregate_functions().get(name).cloned()
204 }
205
206 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
207 self.state.window_functions().get(name).cloned()
208 }
209
210 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
211 match f {
212 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
215 _ => self.state.scalar_functions().get(f).cloned(),
216 }
217 }
218
219 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
220 None
222 }
223
224 fn options(&self) -> &datafusion::config::ConfigOptions {
225 &self.options
226 }
227
228 fn udf_names(&self) -> Vec<String> {
229 self.state.scalar_functions().keys().cloned().collect()
230 }
231
232 fn udaf_names(&self) -> Vec<String> {
233 self.state.aggregate_functions().keys().cloned().collect()
234 }
235
236 fn udwf_names(&self) -> Vec<String> {
237 self.state.window_functions().keys().cloned().collect()
238 }
239
240 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
241 &self.expr_planners
242 }
243}
244
245pub struct Planner {
246 schema: SchemaRef,
247 context_provider: LanceContextProvider,
248 enable_relations: bool,
249}
250
251impl Planner {
252 pub fn new(schema: SchemaRef) -> Self {
253 Self {
254 schema,
255 context_provider: LanceContextProvider::default(),
256 enable_relations: false,
257 }
258 }
259
260 pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
266 self.enable_relations = enable_relations;
267 self
268 }
269
270 fn column(&self, idents: &[Ident]) -> Expr {
271 fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
272 for ident in idents {
273 *expr = Expr::ScalarFunction(ScalarFunction {
274 args: vec![
275 std::mem::take(expr),
276 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
277 ],
278 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
279 });
280 }
281 }
282
283 if self.enable_relations && idents.len() > 1 {
284 let relation = &idents[0].value;
286 let column_name = &idents[1].value;
287 let column = Expr::Column(Column::new(Some(relation.clone()), column_name.clone()));
288 let mut result = column;
289 handle_remaining_idents(&mut result, &idents[2..]);
290 result
291 } else {
292 let mut column = col(&idents[0].value);
294 handle_remaining_idents(&mut column, &idents[1..]);
295 column
296 }
297 }
298
299 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
300 Ok(match op {
301 BinaryOperator::Plus => Operator::Plus,
302 BinaryOperator::Minus => Operator::Minus,
303 BinaryOperator::Multiply => Operator::Multiply,
304 BinaryOperator::Divide => Operator::Divide,
305 BinaryOperator::Modulo => Operator::Modulo,
306 BinaryOperator::StringConcat => Operator::StringConcat,
307 BinaryOperator::Gt => Operator::Gt,
308 BinaryOperator::Lt => Operator::Lt,
309 BinaryOperator::GtEq => Operator::GtEq,
310 BinaryOperator::LtEq => Operator::LtEq,
311 BinaryOperator::Eq => Operator::Eq,
312 BinaryOperator::NotEq => Operator::NotEq,
313 BinaryOperator::And => Operator::And,
314 BinaryOperator::Or => Operator::Or,
315 _ => {
316 return Err(Error::invalid_input(
317 format!("Operator {op} is not supported"),
318 location!(),
319 ));
320 }
321 })
322 }
323
324 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
325 Ok(Expr::BinaryExpr(BinaryExpr::new(
326 Box::new(self.parse_sql_expr(left)?),
327 self.binary_op(op)?,
328 Box::new(self.parse_sql_expr(right)?),
329 )))
330 }
331
332 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
333 Ok(match op {
334 UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
335 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
336 }
337
338 UnaryOperator::Minus => {
339 use datafusion::logical_expr::lit;
340 match expr {
341 SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
342 Ok(n) => lit(-n),
343 Err(_) => lit(-n
344 .parse::<f64>()
345 .map_err(|_e| {
346 Error::invalid_input(
347 format!("negative operator can be only applied to integer and float operands, got: {n}"),
348 location!(),
349 )
350 })?),
351 },
352 _ => {
353 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
354 }
355 }
356 }
357
358 _ => {
359 return Err(Error::invalid_input(
360 format!("Unary operator '{:?}' is not supported", op),
361 location!(),
362 ));
363 }
364 })
365 }
366
367 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
369 use datafusion::logical_expr::lit;
370 let value: Cow<str> = if negative {
371 Cow::Owned(format!("-{}", value))
372 } else {
373 Cow::Borrowed(value)
374 };
375 if let Ok(n) = value.parse::<i64>() {
376 Ok(lit(n))
377 } else {
378 value.parse::<f64>().map(lit).map_err(|_| {
379 Error::invalid_input(
380 format!("'{value}' is not supported number value."),
381 location!(),
382 )
383 })
384 }
385 }
386
387 fn value(&self, value: &Value) -> Result<Expr> {
388 Ok(match value {
389 Value::Number(v, _) => self.number(v.as_str(), false)?,
390 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
391 Value::HexStringLiteral(hsl) => {
392 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
393 }
394 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
395 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
396 Value::Null => Expr::Literal(ScalarValue::Null, None),
397 _ => todo!(),
398 })
399 }
400
401 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
402 match func_args {
403 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
404 _ => Err(Error::invalid_input(
405 format!("Unsupported function args: {:?}", func_args),
406 location!(),
407 )),
408 }
409 }
410
411 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
418 match &func.args {
419 FunctionArguments::List(args) => {
420 if func.name.0.len() != 1 {
421 return Err(Error::invalid_input(
422 format!("Function name must have 1 part, got: {:?}", func.name.0),
423 location!(),
424 ));
425 }
426 Ok(Expr::IsNotNull(Box::new(
427 self.parse_function_args(&args.args[0])?,
428 )))
429 }
430 _ => Err(Error::invalid_input(
431 format!("Unsupported function args: {:?}", &func.args),
432 location!(),
433 )),
434 }
435 }
436
437 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
438 if let SQLExpr::Function(function) = &function {
439 if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
440 if &name.value == "is_valid" {
441 return self.legacy_parse_function(function);
442 }
443 }
444 }
445 let sql_to_rel = SqlToRel::new_with_options(
446 &self.context_provider,
447 ParserOptions {
448 parse_float_as_decimal: false,
449 enable_ident_normalization: false,
450 support_varchar_with_length: false,
451 enable_options_value_normalization: false,
452 collect_spans: false,
453 map_string_types_to_utf8view: false,
454 default_null_ordering: NullOrdering::NullsMax,
455 },
456 );
457
458 let mut planner_context = PlannerContext::default();
459 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
460 sql_to_rel
461 .sql_to_expr(function, &schema, &mut planner_context)
462 .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
463 }
464
465 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
466 const SUPPORTED_TYPES: [&str; 13] = [
467 "int [unsigned]",
468 "tinyint [unsigned]",
469 "smallint [unsigned]",
470 "bigint [unsigned]",
471 "float",
472 "double",
473 "string",
474 "binary",
475 "date",
476 "timestamp(precision)",
477 "datetime(precision)",
478 "decimal(precision,scale)",
479 "boolean",
480 ];
481 match data_type {
482 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
483 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
484 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
485 SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
486 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
487 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
488 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
489 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
490 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
491 SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
492 SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
493 SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
494 Ok(ArrowDataType::UInt32)
495 }
496 SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
497 SQLDataType::Date => Ok(ArrowDataType::Date32),
498 SQLDataType::Timestamp(resolution, tz) => {
499 match tz {
500 TimezoneInfo::None => {}
501 _ => {
502 return Err(Error::invalid_input(
503 "Timezone not supported in timestamp".to_string(),
504 location!(),
505 ));
506 }
507 };
508 let time_unit = match resolution {
509 None => TimeUnit::Microsecond,
511 Some(0) => TimeUnit::Second,
512 Some(3) => TimeUnit::Millisecond,
513 Some(6) => TimeUnit::Microsecond,
514 Some(9) => TimeUnit::Nanosecond,
515 _ => {
516 return Err(Error::invalid_input(
517 format!("Unsupported datetime resolution: {:?}", resolution),
518 location!(),
519 ));
520 }
521 };
522 Ok(ArrowDataType::Timestamp(time_unit, None))
523 }
524 SQLDataType::Datetime(resolution) => {
525 let time_unit = match resolution {
526 None => TimeUnit::Microsecond,
527 Some(0) => TimeUnit::Second,
528 Some(3) => TimeUnit::Millisecond,
529 Some(6) => TimeUnit::Microsecond,
530 Some(9) => TimeUnit::Nanosecond,
531 _ => {
532 return Err(Error::invalid_input(
533 format!("Unsupported datetime resolution: {:?}", resolution),
534 location!(),
535 ));
536 }
537 };
538 Ok(ArrowDataType::Timestamp(time_unit, None))
539 }
540 SQLDataType::Decimal(number_info) => match number_info {
541 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
542 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
543 }
544 _ => Err(Error::invalid_input(
545 format!(
546 "Must provide precision and scale for decimal: {:?}",
547 number_info
548 ),
549 location!(),
550 )),
551 },
552 _ => Err(Error::invalid_input(
553 format!(
554 "Unsupported data type: {:?}. Supported types: {:?}",
555 data_type, SUPPORTED_TYPES
556 ),
557 location!(),
558 )),
559 }
560 }
561
562 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
563 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
564 for planner in self.context_provider.get_expr_planners() {
565 match planner.plan_field_access(field_access_expr, &df_schema)? {
566 PlannerResult::Planned(expr) => return Ok(expr),
567 PlannerResult::Original(expr) => {
568 field_access_expr = expr;
569 }
570 }
571 }
572 Err(Error::invalid_input(
573 "Field access could not be planned",
574 location!(),
575 ))
576 }
577
578 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
579 match expr {
580 SQLExpr::Identifier(id) => {
581 if id.quote_style == Some('"') {
584 Ok(Expr::Literal(
585 ScalarValue::Utf8(Some(id.value.clone())),
586 None,
587 ))
588 } else if id.quote_style == Some('`') {
591 Ok(Expr::Column(Column::from_name(id.value.clone())))
592 } else {
593 Ok(self.column(vec![id.clone()].as_slice()))
594 }
595 }
596 SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
597 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
598 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
599 SQLExpr::Value(value) => self.value(&value.value),
600 SQLExpr::Array(SQLArray { elem, .. }) => {
601 let mut values = vec![];
602
603 let array_literal_error = |pos: usize, value: &_| {
604 Err(Error::invalid_input(
605 format!(
606 "Expected a literal value in array, instead got {} at position {}",
607 value, pos
608 ),
609 location!(),
610 ))
611 };
612
613 for (pos, expr) in elem.iter().enumerate() {
614 match expr {
615 SQLExpr::Value(value) => {
616 if let Expr::Literal(value, _) = self.value(&value.value)? {
617 values.push(value);
618 } else {
619 return array_literal_error(pos, expr);
620 }
621 }
622 SQLExpr::UnaryOp {
623 op: UnaryOperator::Minus,
624 expr,
625 } => {
626 if let SQLExpr::Value(ValueWithSpan {
627 value: Value::Number(number, _),
628 ..
629 }) = expr.as_ref()
630 {
631 if let Expr::Literal(value, _) = self.number(number, true)? {
632 values.push(value);
633 } else {
634 return array_literal_error(pos, expr);
635 }
636 } else {
637 return array_literal_error(pos, expr);
638 }
639 }
640 _ => {
641 return array_literal_error(pos, expr);
642 }
643 }
644 }
645
646 let field = if !values.is_empty() {
647 let data_type = values[0].data_type();
648
649 for value in &mut values {
650 if value.data_type() != data_type {
651 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
652 format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
653 location!()
654 ))?;
655 }
656 }
657 Field::new("item", data_type, true)
658 } else {
659 Field::new("item", ArrowDataType::Null, true)
660 };
661
662 let values = values
663 .into_iter()
664 .map(|v| v.to_array().map_err(Error::from))
665 .collect::<Result<Vec<_>>>()?;
666 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
667 let values = concat(&array_refs)?;
668 let values = ListArray::try_new(
669 field.into(),
670 OffsetBuffer::from_lengths([values.len()]),
671 values,
672 None,
673 )?;
674
675 Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
676 }
677 SQLExpr::TypedString { data_type, value } => {
679 let value = value.clone().into_string().expect_ok()?;
680 Ok(Expr::Cast(datafusion::logical_expr::Cast {
681 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
682 data_type: self.parse_type(data_type)?,
683 }))
684 }
685 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
686 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
687 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
688 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
689 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
690 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
691 SQLExpr::InList {
692 expr,
693 list,
694 negated,
695 } => {
696 let value_expr = self.parse_sql_expr(expr)?;
697 let list_exprs = list
698 .iter()
699 .map(|e| self.parse_sql_expr(e))
700 .collect::<Result<Vec<_>>>()?;
701 Ok(value_expr.in_list(list_exprs, *negated))
702 }
703 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
704 SQLExpr::Function(_) => self.parse_function(expr.clone()),
705 SQLExpr::ILike {
706 negated,
707 expr,
708 pattern,
709 escape_char,
710 any: _,
711 } => Ok(Expr::Like(Like::new(
712 *negated,
713 Box::new(self.parse_sql_expr(expr)?),
714 Box::new(self.parse_sql_expr(pattern)?),
715 match escape_char {
716 Some(Value::SingleQuotedString(char)) => char.chars().next(),
717 Some(value) => return Err(Error::invalid_input(
718 format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
719 location!()
720 )),
721 None => None,
722 },
723 true,
724 ))),
725 SQLExpr::Like {
726 negated,
727 expr,
728 pattern,
729 escape_char,
730 any: _,
731 } => Ok(Expr::Like(Like::new(
732 *negated,
733 Box::new(self.parse_sql_expr(expr)?),
734 Box::new(self.parse_sql_expr(pattern)?),
735 match escape_char {
736 Some(Value::SingleQuotedString(char)) => char.chars().next(),
737 Some(value) => return Err(Error::invalid_input(
738 format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
739 location!()
740 )),
741 None => None,
742 },
743 false,
744 ))),
745 SQLExpr::Cast {
746 expr,
747 data_type,
748 kind,
749 ..
750 } => match kind {
751 datafusion::sql::sqlparser::ast::CastKind::TryCast
752 | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
753 Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
754 expr: Box::new(self.parse_sql_expr(expr)?),
755 data_type: self.parse_type(data_type)?,
756 }))
757 }
758 _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
759 expr: Box::new(self.parse_sql_expr(expr)?),
760 data_type: self.parse_type(data_type)?,
761 })),
762 },
763 SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
764 "JSON access is not supported",
765 location!(),
766 )),
767 SQLExpr::CompoundFieldAccess { root, access_chain } => {
768 let mut expr = self.parse_sql_expr(root)?;
769
770 for access in access_chain {
771 let field_access = match access {
772 AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
774 | AccessExpr::Subscript(Subscript::Index {
775 index:
776 SQLExpr::Value(ValueWithSpan {
777 value:
778 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
779 ..
780 }),
781 }) => GetFieldAccess::NamedStructField {
782 name: ScalarValue::from(s.as_str()),
783 },
784 AccessExpr::Subscript(Subscript::Index { index }) => {
785 let key = Box::new(self.parse_sql_expr(index)?);
786 GetFieldAccess::ListIndex { key }
787 }
788 AccessExpr::Subscript(Subscript::Slice { .. }) => {
789 return Err(Error::invalid_input(
790 "Slice subscript is not supported",
791 location!(),
792 ));
793 }
794 _ => {
795 return Err(Error::invalid_input(
798 "Only dot notation or index access is supported for field access",
799 location!(),
800 ));
801 }
802 };
803
804 let field_access_expr = RawFieldAccessExpr { expr, field_access };
805 expr = self.plan_field_access(field_access_expr)?;
806 }
807
808 Ok(expr)
809 }
810 SQLExpr::Between {
811 expr,
812 negated,
813 low,
814 high,
815 } => {
816 let expr = self.parse_sql_expr(expr)?;
818 let low = self.parse_sql_expr(low)?;
819 let high = self.parse_sql_expr(high)?;
820
821 let between = Expr::Between(Between::new(
822 Box::new(expr),
823 *negated,
824 Box::new(low),
825 Box::new(high),
826 ));
827 Ok(between)
828 }
829 _ => Err(Error::invalid_input(
830 format!("Expression '{expr}' is not supported SQL in lance"),
831 location!(),
832 )),
833 }
834 }
835
836 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
841 let ast_expr = parse_sql_filter(filter)?;
843 let expr = self.parse_sql_expr(&ast_expr)?;
844 let schema = Schema::try_from(self.schema.as_ref())?;
845 let resolved = resolve_expr(&expr, &schema).map_err(|e| {
846 Error::invalid_input(
847 format!("Error resolving filter expression {filter}: {e}"),
848 location!(),
849 )
850 })?;
851
852 Ok(coerce_filter_type_to_boolean(resolved))
853 }
854
855 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
860 if self.schema.field_with_name(expr).is_ok() {
861 return Ok(col(expr));
862 }
863
864 let ast_expr = parse_sql_expr(expr)?;
865 let expr = self.parse_sql_expr(&ast_expr)?;
866 let schema = Schema::try_from(self.schema.as_ref())?;
867 let resolved = resolve_expr(&expr, &schema)?;
868 Ok(resolved)
869 }
870
871 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
877 let hex_bytes = s.as_bytes();
878 let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
879
880 let start_idx = hex_bytes.len() % 2;
881 if start_idx > 0 {
882 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
884 }
885
886 for i in (start_idx..hex_bytes.len()).step_by(2) {
887 let high = Self::try_decode_hex_char(hex_bytes[i])?;
888 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
889 decoded_bytes.push((high << 4) | low);
890 }
891
892 Some(decoded_bytes)
893 }
894
895 const fn try_decode_hex_char(c: u8) -> Option<u8> {
899 match c {
900 b'A'..=b'F' => Some(c - b'A' + 10),
901 b'a'..=b'f' => Some(c - b'a' + 10),
902 b'0'..=b'9' => Some(c - b'0'),
903 _ => None,
904 }
905 }
906
907 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
909 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
910
911 let props = ExecutionProps::new().with_query_execution_start_time(Utc::now());
914 let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
915 let simplifier =
916 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
917
918 let expr = simplifier.simplify(expr)?;
919 let expr = simplifier.coerce(expr, &df_schema)?;
920
921 Ok(expr)
922 }
923
924 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
926 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
927 Ok(datafusion::physical_expr::create_physical_expr(
928 expr,
929 df_schema.as_ref(),
930 &Default::default(),
931 )?)
932 }
933
934 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
941 let mut visitor = ColumnCapturingVisitor {
942 current_path: VecDeque::new(),
943 columns: BTreeSet::new(),
944 };
945 expr.visit(&mut visitor).unwrap();
946 visitor.columns.into_iter().collect()
947 }
948}
949
950struct ColumnCapturingVisitor {
951 current_path: VecDeque<String>,
953 columns: BTreeSet<String>,
954}
955
956impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
957 type Node = Expr;
958
959 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
960 match node {
961 Expr::Column(Column { name, .. }) => {
962 let mut path = name.clone();
966 for part in self.current_path.drain(..) {
967 path.push('.');
968 if part.contains('.') || part.contains('`') {
970 let escaped = part.replace('`', "``");
972 path.push('`');
973 path.push_str(&escaped);
974 path.push('`');
975 } else {
976 path.push_str(&part);
977 }
978 }
979 self.columns.insert(path);
980 self.current_path.clear();
981 }
982 Expr::ScalarFunction(udf) => {
983 if udf.name() == GetFieldFunc::default().name() {
984 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
985 self.current_path.push_front(name.to_string())
986 } else {
987 self.current_path.clear();
988 }
989 } else {
990 self.current_path.clear();
991 }
992 }
993 _ => {
994 self.current_path.clear();
995 }
996 }
997
998 Ok(TreeNodeRecursion::Continue)
999 }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004
1005 use crate::logical_expr::ExprExt;
1006
1007 use super::*;
1008
1009 use arrow::datatypes::Float64Type;
1010 use arrow_array::{
1011 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
1012 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1013 TimestampNanosecondArray, TimestampSecondArray,
1014 };
1015 use arrow_schema::{DataType, Fields, Schema};
1016 use datafusion::{
1017 logical_expr::{lit, Cast},
1018 prelude::{array_element, get_field},
1019 };
1020 use datafusion_functions::core::expr_ext::FieldAccessor;
1021
1022 #[test]
1023 fn test_parse_filter_simple() {
1024 let schema = Arc::new(Schema::new(vec![
1025 Field::new("i", DataType::Int32, false),
1026 Field::new("s", DataType::Utf8, true),
1027 Field::new(
1028 "st",
1029 DataType::Struct(Fields::from(vec![
1030 Field::new("x", DataType::Float32, false),
1031 Field::new("y", DataType::Float32, false),
1032 ])),
1033 true,
1034 ),
1035 ]));
1036
1037 let planner = Planner::new(schema.clone());
1038
1039 let expected = col("i")
1040 .gt(lit(3_i32))
1041 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1042 .and(
1043 col("s")
1044 .eq(lit("str-4"))
1045 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1046 );
1047
1048 let expr = planner
1050 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1051 .unwrap();
1052 assert_eq!(expr, expected);
1053
1054 let expr = planner
1056 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1057 .unwrap();
1058
1059 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1060
1061 let batch = RecordBatch::try_new(
1062 schema,
1063 vec![
1064 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1065 Arc::new(StringArray::from_iter_values(
1066 (0..10).map(|v| format!("str-{}", v)),
1067 )),
1068 Arc::new(StructArray::from(vec![
1069 (
1070 Arc::new(Field::new("x", DataType::Float32, false)),
1071 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1072 as ArrayRef,
1073 ),
1074 (
1075 Arc::new(Field::new("y", DataType::Float32, false)),
1076 Arc::new(Float32Array::from_iter_values(
1077 (0..10).map(|v| (v * 10) as f32),
1078 )),
1079 ),
1080 ])),
1081 ],
1082 )
1083 .unwrap();
1084 let predicates = physical_expr.evaluate(&batch).unwrap();
1085 assert_eq!(
1086 predicates.into_array(0).unwrap().as_ref(),
1087 &BooleanArray::from(vec![
1088 false, false, false, false, true, true, false, false, false, false
1089 ])
1090 );
1091 }
1092
1093 #[test]
1094 fn test_nested_col_refs() {
1095 let schema = Arc::new(Schema::new(vec![
1096 Field::new("s0", DataType::Utf8, true),
1097 Field::new(
1098 "st",
1099 DataType::Struct(Fields::from(vec![
1100 Field::new("s1", DataType::Utf8, true),
1101 Field::new(
1102 "st",
1103 DataType::Struct(Fields::from(vec![Field::new(
1104 "s2",
1105 DataType::Utf8,
1106 true,
1107 )])),
1108 true,
1109 ),
1110 ])),
1111 true,
1112 ),
1113 ]));
1114
1115 let planner = Planner::new(schema);
1116
1117 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1118 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1119 assert!(matches!(
1120 expr,
1121 Expr::BinaryExpr(BinaryExpr {
1122 left: _,
1123 op: Operator::Eq,
1124 right: _
1125 })
1126 ));
1127 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1128 assert_eq!(left.as_ref(), expected);
1129 }
1130 }
1131
1132 let expected = Expr::Column(Column::new_unqualified("s0"));
1133 assert_column_eq(&planner, "s0", &expected);
1134 assert_column_eq(&planner, "`s0`", &expected);
1135
1136 let expected = Expr::ScalarFunction(ScalarFunction {
1137 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1138 args: vec![
1139 Expr::Column(Column::new_unqualified("st")),
1140 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1141 ],
1142 });
1143 assert_column_eq(&planner, "st.s1", &expected);
1144 assert_column_eq(&planner, "`st`.`s1`", &expected);
1145 assert_column_eq(&planner, "st.`s1`", &expected);
1146
1147 let expected = Expr::ScalarFunction(ScalarFunction {
1148 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1149 args: vec![
1150 Expr::ScalarFunction(ScalarFunction {
1151 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1152 args: vec![
1153 Expr::Column(Column::new_unqualified("st")),
1154 Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1155 ],
1156 }),
1157 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1158 ],
1159 });
1160
1161 assert_column_eq(&planner, "st.st.s2", &expected);
1162 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1163 assert_column_eq(&planner, "st.st.`s2`", &expected);
1164 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1165 }
1166
1167 #[test]
1168 fn test_nested_list_refs() {
1169 let schema = Arc::new(Schema::new(vec![Field::new(
1170 "l",
1171 DataType::List(Arc::new(Field::new(
1172 "item",
1173 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1174 true,
1175 ))),
1176 true,
1177 )]));
1178
1179 let planner = Planner::new(schema);
1180
1181 let expected = array_element(col("l"), lit(0_i64));
1182 let expr = planner.parse_expr("l[0]").unwrap();
1183 assert_eq!(expr, expected);
1184
1185 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1186 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1187 assert_eq!(expr, expected);
1188
1189 }
1194
1195 #[test]
1196 fn test_negative_expressions() {
1197 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1198
1199 let planner = Planner::new(schema.clone());
1200
1201 let expected = col("x")
1202 .gt(lit(-3_i64))
1203 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1204
1205 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1206
1207 assert_eq!(expr, expected);
1208
1209 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1210
1211 let batch = RecordBatch::try_new(
1212 schema,
1213 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1214 )
1215 .unwrap();
1216 let predicates = physical_expr.evaluate(&batch).unwrap();
1217 assert_eq!(
1218 predicates.into_array(0).unwrap().as_ref(),
1219 &BooleanArray::from(vec![
1220 false, false, false, true, true, true, true, false, false, false
1221 ])
1222 );
1223 }
1224
1225 #[test]
1226 fn test_negative_array_expressions() {
1227 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1228
1229 let planner = Planner::new(schema);
1230
1231 let expected = Expr::Literal(
1232 ScalarValue::List(Arc::new(
1233 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1234 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1235 )]),
1236 )),
1237 None,
1238 );
1239
1240 let expr = planner
1241 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1242 .unwrap();
1243
1244 assert_eq!(expr, expected);
1245 }
1246
1247 #[test]
1248 fn test_sql_like() {
1249 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1250
1251 let planner = Planner::new(schema.clone());
1252
1253 let expected = col("s").like(lit("str-4"));
1254 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1256 assert_eq!(expr, expected);
1257 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1258
1259 let batch = RecordBatch::try_new(
1260 schema,
1261 vec![Arc::new(StringArray::from_iter_values(
1262 (0..10).map(|v| format!("str-{}", v)),
1263 ))],
1264 )
1265 .unwrap();
1266 let predicates = physical_expr.evaluate(&batch).unwrap();
1267 assert_eq!(
1268 predicates.into_array(0).unwrap().as_ref(),
1269 &BooleanArray::from(vec![
1270 false, false, false, false, true, false, false, false, false, false
1271 ])
1272 );
1273 }
1274
1275 #[test]
1276 fn test_not_like() {
1277 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1278
1279 let planner = Planner::new(schema.clone());
1280
1281 let expected = col("s").not_like(lit("str-4"));
1282 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1284 assert_eq!(expr, expected);
1285 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1286
1287 let batch = RecordBatch::try_new(
1288 schema,
1289 vec![Arc::new(StringArray::from_iter_values(
1290 (0..10).map(|v| format!("str-{}", v)),
1291 ))],
1292 )
1293 .unwrap();
1294 let predicates = physical_expr.evaluate(&batch).unwrap();
1295 assert_eq!(
1296 predicates.into_array(0).unwrap().as_ref(),
1297 &BooleanArray::from(vec![
1298 true, true, true, true, false, true, true, true, true, true
1299 ])
1300 );
1301 }
1302
1303 #[test]
1304 fn test_sql_is_in() {
1305 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1306
1307 let planner = Planner::new(schema.clone());
1308
1309 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1310 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1312 assert_eq!(expr, expected);
1313 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1314
1315 let batch = RecordBatch::try_new(
1316 schema,
1317 vec![Arc::new(StringArray::from_iter_values(
1318 (0..10).map(|v| format!("str-{}", v)),
1319 ))],
1320 )
1321 .unwrap();
1322 let predicates = physical_expr.evaluate(&batch).unwrap();
1323 assert_eq!(
1324 predicates.into_array(0).unwrap().as_ref(),
1325 &BooleanArray::from(vec![
1326 false, false, false, false, true, true, false, false, false, false
1327 ])
1328 );
1329 }
1330
1331 #[test]
1332 fn test_sql_is_null() {
1333 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1334
1335 let planner = Planner::new(schema.clone());
1336
1337 let expected = col("s").is_null();
1338 let expr = planner.parse_filter("s IS NULL").unwrap();
1339 assert_eq!(expr, expected);
1340 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1341
1342 let batch = RecordBatch::try_new(
1343 schema,
1344 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1345 if v % 3 == 0 {
1346 Some(format!("str-{}", v))
1347 } else {
1348 None
1349 }
1350 })))],
1351 )
1352 .unwrap();
1353 let predicates = physical_expr.evaluate(&batch).unwrap();
1354 assert_eq!(
1355 predicates.into_array(0).unwrap().as_ref(),
1356 &BooleanArray::from(vec![
1357 false, true, true, false, true, true, false, true, true, false
1358 ])
1359 );
1360
1361 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1362 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1363 let predicates = physical_expr.evaluate(&batch).unwrap();
1364 assert_eq!(
1365 predicates.into_array(0).unwrap().as_ref(),
1366 &BooleanArray::from(vec![
1367 true, false, false, true, false, false, true, false, false, true,
1368 ])
1369 );
1370 }
1371
1372 #[test]
1373 fn test_sql_invert() {
1374 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1375
1376 let planner = Planner::new(schema.clone());
1377
1378 let expr = planner.parse_filter("NOT s").unwrap();
1379 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1380
1381 let batch = RecordBatch::try_new(
1382 schema,
1383 vec![Arc::new(BooleanArray::from_iter(
1384 (0..10).map(|v| Some(v % 3 == 0)),
1385 ))],
1386 )
1387 .unwrap();
1388 let predicates = physical_expr.evaluate(&batch).unwrap();
1389 assert_eq!(
1390 predicates.into_array(0).unwrap().as_ref(),
1391 &BooleanArray::from(vec![
1392 false, true, true, false, true, true, false, true, true, false
1393 ])
1394 );
1395 }
1396
1397 #[test]
1398 fn test_sql_cast() {
1399 let cases = &[
1400 (
1401 "x = cast('2021-01-01 00:00:00' as timestamp)",
1402 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1403 ),
1404 (
1405 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1406 ArrowDataType::Timestamp(TimeUnit::Second, None),
1407 ),
1408 (
1409 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1410 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1411 ),
1412 (
1413 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1414 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1415 ),
1416 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1417 (
1418 "x = cast('1.238' as decimal(9,3))",
1419 ArrowDataType::Decimal128(9, 3),
1420 ),
1421 ("x = cast(1 as float)", ArrowDataType::Float32),
1422 ("x = cast(1 as double)", ArrowDataType::Float64),
1423 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1424 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1425 ("x = cast(1 as int)", ArrowDataType::Int32),
1426 ("x = cast(1 as integer)", ArrowDataType::Int32),
1427 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1428 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1429 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1430 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1431 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1432 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1433 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1434 ("x = cast(1 as string)", ArrowDataType::Utf8),
1435 ];
1436
1437 for (sql, expected_data_type) in cases {
1438 let schema = Arc::new(Schema::new(vec![Field::new(
1439 "x",
1440 expected_data_type.clone(),
1441 true,
1442 )]));
1443 let planner = Planner::new(schema.clone());
1444 let expr = planner.parse_filter(sql).unwrap();
1445
1446 let expected_value_str = sql
1448 .split("cast(")
1449 .nth(1)
1450 .unwrap()
1451 .split(" as")
1452 .next()
1453 .unwrap();
1454 let expected_value_str = expected_value_str.trim_matches('\'');
1456
1457 match expr {
1458 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1459 Expr::Cast(Cast { expr, data_type }) => {
1460 match expr.as_ref() {
1461 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1462 assert_eq!(value_str, expected_value_str);
1463 }
1464 Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1465 assert_eq!(*value, 1);
1466 }
1467 _ => panic!("Expected cast to be applied to literal"),
1468 }
1469 assert_eq!(data_type, expected_data_type);
1470 }
1471 _ => panic!("Expected right to be a cast"),
1472 },
1473 _ => panic!("Expected binary expression"),
1474 }
1475 }
1476 }
1477
1478 #[test]
1479 fn test_sql_literals() {
1480 let cases = &[
1481 (
1482 "x = timestamp '2021-01-01 00:00:00'",
1483 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1484 ),
1485 (
1486 "x = timestamp(0) '2021-01-01 00:00:00'",
1487 ArrowDataType::Timestamp(TimeUnit::Second, None),
1488 ),
1489 (
1490 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1491 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1492 ),
1493 ("x = date '2021-01-01'", ArrowDataType::Date32),
1494 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1495 ];
1496
1497 for (sql, expected_data_type) in cases {
1498 let schema = Arc::new(Schema::new(vec![Field::new(
1499 "x",
1500 expected_data_type.clone(),
1501 true,
1502 )]));
1503 let planner = Planner::new(schema.clone());
1504 let expr = planner.parse_filter(sql).unwrap();
1505
1506 let expected_value_str = sql.split('\'').nth(1).unwrap();
1507
1508 match expr {
1509 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1510 Expr::Cast(Cast { expr, data_type }) => {
1511 match expr.as_ref() {
1512 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1513 assert_eq!(value_str, expected_value_str);
1514 }
1515 _ => panic!("Expected cast to be applied to literal"),
1516 }
1517 assert_eq!(data_type, expected_data_type);
1518 }
1519 _ => panic!("Expected right to be a cast"),
1520 },
1521 _ => panic!("Expected binary expression"),
1522 }
1523 }
1524 }
1525
1526 #[test]
1527 fn test_sql_array_literals() {
1528 let cases = [
1529 (
1530 "x = [1, 2, 3]",
1531 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1532 ),
1533 (
1534 "x = [1, 2, 3]",
1535 ArrowDataType::FixedSizeList(
1536 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1537 3,
1538 ),
1539 ),
1540 ];
1541
1542 for (sql, expected_data_type) in cases {
1543 let schema = Arc::new(Schema::new(vec![Field::new(
1544 "x",
1545 expected_data_type.clone(),
1546 true,
1547 )]));
1548 let planner = Planner::new(schema.clone());
1549 let expr = planner.parse_filter(sql).unwrap();
1550 let expr = planner.optimize_expr(expr).unwrap();
1551
1552 match expr {
1553 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1554 Expr::Literal(value, _) => {
1555 assert_eq!(&value.data_type(), &expected_data_type);
1556 }
1557 _ => panic!("Expected right to be a literal"),
1558 },
1559 _ => panic!("Expected binary expression"),
1560 }
1561 }
1562 }
1563
1564 #[test]
1565 fn test_sql_between() {
1566 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1567 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1568 use std::sync::Arc;
1569
1570 let schema = Arc::new(Schema::new(vec![
1571 Field::new("x", DataType::Int32, false),
1572 Field::new("y", DataType::Float64, false),
1573 Field::new(
1574 "ts",
1575 DataType::Timestamp(TimeUnit::Microsecond, None),
1576 false,
1577 ),
1578 ]));
1579
1580 let planner = Planner::new(schema.clone());
1581
1582 let expr = planner
1584 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1585 .unwrap();
1586 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1587
1588 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1592 (0..10).map(|i| base_ts + i * 1_000_000), );
1594
1595 let batch = RecordBatch::try_new(
1596 schema,
1597 vec![
1598 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1599 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1600 Arc::new(ts_array),
1601 ],
1602 )
1603 .unwrap();
1604
1605 let predicates = physical_expr.evaluate(&batch).unwrap();
1606 assert_eq!(
1607 predicates.into_array(0).unwrap().as_ref(),
1608 &BooleanArray::from(vec![
1609 false, false, false, true, true, true, true, true, false, false
1610 ])
1611 );
1612
1613 let expr = planner
1615 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1616 .unwrap();
1617 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1618
1619 let predicates = physical_expr.evaluate(&batch).unwrap();
1620 assert_eq!(
1621 predicates.into_array(0).unwrap().as_ref(),
1622 &BooleanArray::from(vec![
1623 true, true, true, false, false, false, false, false, true, true
1624 ])
1625 );
1626
1627 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1629 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1630
1631 let predicates = physical_expr.evaluate(&batch).unwrap();
1632 assert_eq!(
1633 predicates.into_array(0).unwrap().as_ref(),
1634 &BooleanArray::from(vec![
1635 false, false, false, true, true, true, true, false, false, false
1636 ])
1637 );
1638
1639 let expr = planner
1641 .parse_filter(
1642 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1643 )
1644 .unwrap();
1645 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1646
1647 let predicates = physical_expr.evaluate(&batch).unwrap();
1648 assert_eq!(
1649 predicates.into_array(0).unwrap().as_ref(),
1650 &BooleanArray::from(vec![
1651 false, false, false, true, true, true, true, true, false, false
1652 ])
1653 );
1654 }
1655
1656 #[test]
1657 fn test_sql_comparison() {
1658 let batch: Vec<(&str, ArrayRef)> = vec![
1660 (
1661 "timestamp_s",
1662 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1663 ),
1664 (
1665 "timestamp_ms",
1666 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1667 ),
1668 (
1669 "timestamp_us",
1670 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1671 ),
1672 (
1673 "timestamp_ns",
1674 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1675 ),
1676 ];
1677 let batch = RecordBatch::try_from_iter(batch).unwrap();
1678
1679 let planner = Planner::new(batch.schema());
1680
1681 let expressions = &[
1683 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1684 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1685 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1686 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1687 ];
1688
1689 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1690 std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1691 ));
1692 for expression in expressions {
1693 let logical_expr = planner.parse_filter(expression).unwrap();
1695 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1696 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1697
1698 let result = physical_expr.evaluate(&batch).unwrap();
1700 let result = result.into_array(batch.num_rows()).unwrap();
1701 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1702 }
1703 }
1704
1705 #[test]
1706 fn test_columns_in_expr() {
1707 let expr = col("s0").gt(lit("value")).and(
1708 col("st")
1709 .field("st")
1710 .field("s2")
1711 .eq(lit("value"))
1712 .or(col("st")
1713 .field("s1")
1714 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1715 );
1716
1717 let columns = Planner::column_names_in_expr(&expr);
1718 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1719 }
1720
1721 #[test]
1722 fn test_parse_binary_expr() {
1723 let bin_str = "x'616263'";
1724
1725 let schema = Arc::new(Schema::new(vec![Field::new(
1726 "binary",
1727 DataType::Binary,
1728 true,
1729 )]));
1730 let planner = Planner::new(schema);
1731 let expr = planner.parse_expr(bin_str).unwrap();
1732 assert_eq!(
1733 expr,
1734 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1735 );
1736 }
1737
1738 #[test]
1739 fn test_lance_context_provider_expr_planners() {
1740 let ctx_provider = LanceContextProvider::default();
1741 assert!(!ctx_provider.get_expr_planners().is_empty());
1742 }
1743
1744 #[test]
1745 fn test_regexp_match_and_non_empty_captions() {
1746 let schema = Arc::new(Schema::new(vec![
1749 Field::new("keywords", DataType::Utf8, true),
1750 Field::new("natural_caption", DataType::Utf8, true),
1751 Field::new("poetic_caption", DataType::Utf8, true),
1752 ]));
1753
1754 let planner = Planner::new(schema.clone());
1755
1756 let expr = planner
1757 .parse_filter(
1758 "regexp_match(keywords, 'Liberty|revolution') AND \
1759 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1760 poetic_caption IS NOT NULL AND poetic_caption <> '')",
1761 )
1762 .unwrap();
1763
1764 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1765
1766 let batch = RecordBatch::try_new(
1767 schema,
1768 vec![
1769 Arc::new(StringArray::from(vec![
1770 Some("Liberty for all"),
1771 Some("peace"),
1772 Some("revolution now"),
1773 Some("Liberty"),
1774 Some("revolutionary"),
1775 Some("none"),
1776 ])) as ArrayRef,
1777 Arc::new(StringArray::from(vec![
1778 Some("a"),
1779 Some("b"),
1780 None,
1781 Some(""),
1782 Some("c"),
1783 Some("d"),
1784 ])) as ArrayRef,
1785 Arc::new(StringArray::from(vec![
1786 Some("x"),
1787 Some(""),
1788 Some("y"),
1789 Some("z"),
1790 None,
1791 Some("w"),
1792 ])) as ArrayRef,
1793 ],
1794 )
1795 .unwrap();
1796
1797 let result = physical_expr.evaluate(&batch).unwrap();
1798 assert_eq!(
1799 result.into_array(0).unwrap().as_ref(),
1800 &BooleanArray::from(vec![true, false, false, false, false, false])
1801 );
1802 }
1803
1804 #[test]
1805 fn test_regexp_match_infer_error_without_boolean_coercion() {
1806 let schema = Arc::new(Schema::new(vec![
1809 Field::new("keywords", DataType::Utf8, true),
1810 Field::new("natural_caption", DataType::Utf8, true),
1811 Field::new("poetic_caption", DataType::Utf8, true),
1812 ]));
1813
1814 let planner = Planner::new(schema);
1815
1816 let expr = planner
1817 .parse_filter(
1818 "regexp_match(keywords, 'Liberty|revolution') AND \
1819 (natural_caption IS NOT NULL AND natural_caption <> '' AND \
1820 poetic_caption IS NOT NULL AND poetic_caption <> '')",
1821 )
1822 .unwrap();
1823
1824 let _physical = planner.create_physical_expr(&expr).unwrap();
1826 }
1827}