1use std::borrow::Cow;
7use std::collections::{BTreeSet, VecDeque};
8use std::sync::Arc;
9
10use crate::expr::safe_coerce_scalar;
11use crate::logical_expr::{coerce_filter_type_to_boolean, get_as_string_scalar_opt, resolve_expr};
12use crate::sql::{parse_sql_expr, parse_sql_filter};
13use arrow::compute::CastOptions;
14use arrow_array::ListArray;
15use arrow_buffer::OffsetBuffer;
16use arrow_schema::{DataType as ArrowDataType, Field, SchemaRef, TimeUnit};
17use arrow_select::concat::concat;
18use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
19use datafusion::common::DFSchema;
20use datafusion::config::ConfigOptions;
21use datafusion::error::Result as DFResult;
22use datafusion::execution::config::SessionConfig;
23use datafusion::execution::context::{SessionContext, SessionState};
24use datafusion::execution::runtime_env::RuntimeEnvBuilder;
25use datafusion::execution::session_state::SessionStateBuilder;
26use datafusion::logical_expr::expr::ScalarFunction;
27use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr};
28use datafusion::logical_expr::{
29 AggregateUDF, ColumnarValue, GetFieldAccess, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
30 Signature, Volatility, WindowUDF,
31};
32use datafusion::optimizer::simplify_expressions::SimplifyContext;
33use datafusion::sql::planner::{
34 ContextProvider, NullOrdering, ParserOptions, PlannerContext, SqlToRel,
35};
36use datafusion::sql::sqlparser::ast::{
37 AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo,
38 Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
39 ObjectNamePart, Subscript, TimezoneInfo, UnaryOperator, Value, ValueWithSpan,
40};
41use datafusion::{
42 common::Column,
43 logical_expr::{col, Between, BinaryExpr, Like, Operator},
44 physical_expr::execution_props::ExecutionProps,
45 physical_plan::PhysicalExpr,
46 prelude::Expr,
47 scalar::ScalarValue,
48};
49use datafusion_functions::core::getfield::GetFieldFunc;
50use lance_arrow::cast::cast_with_options;
51use lance_core::datatypes::Schema;
52use lance_core::error::LanceOptionExt;
53use snafu::location;
54
55use lance_core::{Error, Result};
56
57#[derive(Debug, Clone, Eq, PartialEq, Hash)]
58struct CastListF16Udf {
59 signature: Signature,
60}
61
62impl CastListF16Udf {
63 pub fn new() -> Self {
64 Self {
65 signature: Signature::any(1, Volatility::Immutable),
66 }
67 }
68}
69
70impl ScalarUDFImpl for CastListF16Udf {
71 fn as_any(&self) -> &dyn std::any::Any {
72 self
73 }
74
75 fn name(&self) -> &str {
76 "_cast_list_f16"
77 }
78
79 fn signature(&self) -> &Signature {
80 &self.signature
81 }
82
83 fn return_type(&self, arg_types: &[ArrowDataType]) -> DFResult<ArrowDataType> {
84 let input = &arg_types[0];
85 match input {
86 ArrowDataType::FixedSizeList(field, size) => {
87 if field.data_type() != &ArrowDataType::Float32
88 && field.data_type() != &ArrowDataType::Float16
89 {
90 return Err(datafusion::error::DataFusionError::Execution(
91 "cast_list_f16 only supports list of float32 or float16".to_string(),
92 ));
93 }
94 Ok(ArrowDataType::FixedSizeList(
95 Arc::new(Field::new(
96 field.name(),
97 ArrowDataType::Float16,
98 field.is_nullable(),
99 )),
100 *size,
101 ))
102 }
103 ArrowDataType::List(field) => {
104 if field.data_type() != &ArrowDataType::Float32
105 && field.data_type() != &ArrowDataType::Float16
106 {
107 return Err(datafusion::error::DataFusionError::Execution(
108 "cast_list_f16 only supports list of float32 or float16".to_string(),
109 ));
110 }
111 Ok(ArrowDataType::List(Arc::new(Field::new(
112 field.name(),
113 ArrowDataType::Float16,
114 field.is_nullable(),
115 ))))
116 }
117 _ => Err(datafusion::error::DataFusionError::Execution(
118 "cast_list_f16 only supports FixedSizeList/List arguments".to_string(),
119 )),
120 }
121 }
122
123 fn invoke_with_args(&self, func_args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
124 let ColumnarValue::Array(arr) = &func_args.args[0] else {
125 return Err(datafusion::error::DataFusionError::Execution(
126 "cast_list_f16 only supports array arguments".to_string(),
127 ));
128 };
129
130 let to_type = match arr.data_type() {
131 ArrowDataType::FixedSizeList(field, size) => ArrowDataType::FixedSizeList(
132 Arc::new(Field::new(
133 field.name(),
134 ArrowDataType::Float16,
135 field.is_nullable(),
136 )),
137 *size,
138 ),
139 ArrowDataType::List(field) => ArrowDataType::List(Arc::new(Field::new(
140 field.name(),
141 ArrowDataType::Float16,
142 field.is_nullable(),
143 ))),
144 _ => {
145 return Err(datafusion::error::DataFusionError::Execution(
146 "cast_list_f16 only supports array arguments".to_string(),
147 ));
148 }
149 };
150
151 let res = cast_with_options(arr.as_ref(), &to_type, &CastOptions::default())?;
152 Ok(ColumnarValue::Array(res))
153 }
154}
155
156struct LanceContextProvider {
158 options: datafusion::config::ConfigOptions,
159 state: SessionState,
160 expr_planners: Vec<Arc<dyn ExprPlanner>>,
161}
162
163impl Default for LanceContextProvider {
164 fn default() -> Self {
165 let config = SessionConfig::new();
166 let runtime = RuntimeEnvBuilder::new().build_arc().unwrap();
167
168 let ctx = SessionContext::new_with_config_rt(config.clone(), runtime.clone());
169 crate::udf::register_functions(&ctx);
170
171 let state = ctx.state();
172
173 let mut state_builder = SessionStateBuilder::new()
175 .with_config(config)
176 .with_runtime_env(runtime)
177 .with_default_features();
178
179 let expr_planners = state_builder.expr_planners().as_ref().unwrap().clone();
181
182 Self {
183 options: ConfigOptions::default(),
184 state,
185 expr_planners,
186 }
187 }
188}
189
190impl ContextProvider for LanceContextProvider {
191 fn get_table_source(
192 &self,
193 name: datafusion::sql::TableReference,
194 ) -> DFResult<Arc<dyn datafusion::logical_expr::TableSource>> {
195 Err(datafusion::error::DataFusionError::NotImplemented(format!(
196 "Attempt to reference inner table {} not supported",
197 name
198 )))
199 }
200
201 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
202 self.state.aggregate_functions().get(name).cloned()
203 }
204
205 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
206 self.state.window_functions().get(name).cloned()
207 }
208
209 fn get_function_meta(&self, f: &str) -> Option<Arc<ScalarUDF>> {
210 match f {
211 "_cast_list_f16" => Some(Arc::new(ScalarUDF::new_from_impl(CastListF16Udf::new()))),
214 _ => self.state.scalar_functions().get(f).cloned(),
215 }
216 }
217
218 fn get_variable_type(&self, _: &[String]) -> Option<ArrowDataType> {
219 None
221 }
222
223 fn options(&self) -> &datafusion::config::ConfigOptions {
224 &self.options
225 }
226
227 fn udf_names(&self) -> Vec<String> {
228 self.state.scalar_functions().keys().cloned().collect()
229 }
230
231 fn udaf_names(&self) -> Vec<String> {
232 self.state.aggregate_functions().keys().cloned().collect()
233 }
234
235 fn udwf_names(&self) -> Vec<String> {
236 self.state.window_functions().keys().cloned().collect()
237 }
238
239 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
240 &self.expr_planners
241 }
242}
243
244pub struct Planner {
245 schema: SchemaRef,
246 context_provider: LanceContextProvider,
247 enable_relations: bool,
248}
249
250impl Planner {
251 pub fn new(schema: SchemaRef) -> Self {
252 Self {
253 schema,
254 context_provider: LanceContextProvider::default(),
255 enable_relations: false,
256 }
257 }
258
259 pub fn with_enable_relations(mut self, enable_relations: bool) -> Self {
265 self.enable_relations = enable_relations;
266 self
267 }
268
269 fn column(&self, idents: &[Ident]) -> Expr {
270 fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) {
271 for ident in idents {
272 *expr = Expr::ScalarFunction(ScalarFunction {
273 args: vec![
274 std::mem::take(expr),
275 Expr::Literal(ScalarValue::Utf8(Some(ident.value.clone())), None),
276 ],
277 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
278 });
279 }
280 }
281
282 if self.enable_relations && idents.len() > 1 {
283 let relation = &idents[0].value;
285 let column_name = &idents[1].value;
286 let column = Expr::Column(Column::new(Some(relation.clone()), column_name.clone()));
287 let mut result = column;
288 handle_remaining_idents(&mut result, &idents[2..]);
289 result
290 } else {
291 let mut column = col(&idents[0].value);
293 handle_remaining_idents(&mut column, &idents[1..]);
294 column
295 }
296 }
297
298 fn binary_op(&self, op: &BinaryOperator) -> Result<Operator> {
299 Ok(match op {
300 BinaryOperator::Plus => Operator::Plus,
301 BinaryOperator::Minus => Operator::Minus,
302 BinaryOperator::Multiply => Operator::Multiply,
303 BinaryOperator::Divide => Operator::Divide,
304 BinaryOperator::Modulo => Operator::Modulo,
305 BinaryOperator::StringConcat => Operator::StringConcat,
306 BinaryOperator::Gt => Operator::Gt,
307 BinaryOperator::Lt => Operator::Lt,
308 BinaryOperator::GtEq => Operator::GtEq,
309 BinaryOperator::LtEq => Operator::LtEq,
310 BinaryOperator::Eq => Operator::Eq,
311 BinaryOperator::NotEq => Operator::NotEq,
312 BinaryOperator::And => Operator::And,
313 BinaryOperator::Or => Operator::Or,
314 _ => {
315 return Err(Error::invalid_input(
316 format!("Operator {op} is not supported"),
317 location!(),
318 ));
319 }
320 })
321 }
322
323 fn binary_expr(&self, left: &SQLExpr, op: &BinaryOperator, right: &SQLExpr) -> Result<Expr> {
324 Ok(Expr::BinaryExpr(BinaryExpr::new(
325 Box::new(self.parse_sql_expr(left)?),
326 self.binary_op(op)?,
327 Box::new(self.parse_sql_expr(right)?),
328 )))
329 }
330
331 fn unary_expr(&self, op: &UnaryOperator, expr: &SQLExpr) -> Result<Expr> {
332 Ok(match op {
333 UnaryOperator::Not | UnaryOperator::PGBitwiseNot => {
334 Expr::Not(Box::new(self.parse_sql_expr(expr)?))
335 }
336
337 UnaryOperator::Minus => {
338 use datafusion::logical_expr::lit;
339 match expr {
340 SQLExpr::Value(ValueWithSpan { value: Value::Number(n, _), ..}) => match n.parse::<i64>() {
341 Ok(n) => lit(-n),
342 Err(_) => lit(-n
343 .parse::<f64>()
344 .map_err(|_e| {
345 Error::invalid_input(
346 format!("negative operator can be only applied to integer and float operands, got: {n}"),
347 location!(),
348 )
349 })?),
350 },
351 _ => {
352 Expr::Negative(Box::new(self.parse_sql_expr(expr)?))
353 }
354 }
355 }
356
357 _ => {
358 return Err(Error::invalid_input(
359 format!("Unary operator '{:?}' is not supported", op),
360 location!(),
361 ));
362 }
363 })
364 }
365
366 fn number(&self, value: &str, negative: bool) -> Result<Expr> {
368 use datafusion::logical_expr::lit;
369 let value: Cow<str> = if negative {
370 Cow::Owned(format!("-{}", value))
371 } else {
372 Cow::Borrowed(value)
373 };
374 if let Ok(n) = value.parse::<i64>() {
375 Ok(lit(n))
376 } else {
377 value.parse::<f64>().map(lit).map_err(|_| {
378 Error::invalid_input(
379 format!("'{value}' is not supported number value."),
380 location!(),
381 )
382 })
383 }
384 }
385
386 fn value(&self, value: &Value) -> Result<Expr> {
387 Ok(match value {
388 Value::Number(v, _) => self.number(v.as_str(), false)?,
389 Value::SingleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
390 Value::HexStringLiteral(hsl) => {
391 Expr::Literal(ScalarValue::Binary(Self::try_decode_hex_literal(hsl)), None)
392 }
393 Value::DoubleQuotedString(s) => Expr::Literal(ScalarValue::Utf8(Some(s.clone())), None),
394 Value::Boolean(v) => Expr::Literal(ScalarValue::Boolean(Some(*v)), None),
395 Value::Null => Expr::Literal(ScalarValue::Null, None),
396 _ => todo!(),
397 })
398 }
399
400 fn parse_function_args(&self, func_args: &FunctionArg) -> Result<Expr> {
401 match func_args {
402 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => self.parse_sql_expr(expr),
403 _ => Err(Error::invalid_input(
404 format!("Unsupported function args: {:?}", func_args),
405 location!(),
406 )),
407 }
408 }
409
410 fn legacy_parse_function(&self, func: &Function) -> Result<Expr> {
417 match &func.args {
418 FunctionArguments::List(args) => {
419 if func.name.0.len() != 1 {
420 return Err(Error::invalid_input(
421 format!("Function name must have 1 part, got: {:?}", func.name.0),
422 location!(),
423 ));
424 }
425 Ok(Expr::IsNotNull(Box::new(
426 self.parse_function_args(&args.args[0])?,
427 )))
428 }
429 _ => Err(Error::invalid_input(
430 format!("Unsupported function args: {:?}", &func.args),
431 location!(),
432 )),
433 }
434 }
435
436 fn parse_function(&self, function: SQLExpr) -> Result<Expr> {
437 if let SQLExpr::Function(function) = &function {
438 if let Some(ObjectNamePart::Identifier(name)) = &function.name.0.first() {
439 if &name.value == "is_valid" {
440 return self.legacy_parse_function(function);
441 }
442 }
443 }
444 let sql_to_rel = SqlToRel::new_with_options(
445 &self.context_provider,
446 ParserOptions {
447 parse_float_as_decimal: false,
448 enable_ident_normalization: false,
449 support_varchar_with_length: false,
450 enable_options_value_normalization: false,
451 collect_spans: false,
452 map_string_types_to_utf8view: false,
453 default_null_ordering: NullOrdering::NullsMax,
454 },
455 );
456
457 let mut planner_context = PlannerContext::default();
458 let schema = DFSchema::try_from(self.schema.as_ref().clone())?;
459 sql_to_rel
460 .sql_to_expr(function, &schema, &mut planner_context)
461 .map_err(|e| Error::invalid_input(format!("Error parsing function: {e}"), location!()))
462 }
463
464 fn parse_type(&self, data_type: &SQLDataType) -> Result<ArrowDataType> {
465 const SUPPORTED_TYPES: [&str; 13] = [
466 "int [unsigned]",
467 "tinyint [unsigned]",
468 "smallint [unsigned]",
469 "bigint [unsigned]",
470 "float",
471 "double",
472 "string",
473 "binary",
474 "date",
475 "timestamp(precision)",
476 "datetime(precision)",
477 "decimal(precision,scale)",
478 "boolean",
479 ];
480 match data_type {
481 SQLDataType::String(_) => Ok(ArrowDataType::Utf8),
482 SQLDataType::Binary(_) => Ok(ArrowDataType::Binary),
483 SQLDataType::Float(_) => Ok(ArrowDataType::Float32),
484 SQLDataType::Double(_) => Ok(ArrowDataType::Float64),
485 SQLDataType::Boolean => Ok(ArrowDataType::Boolean),
486 SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8),
487 SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16),
488 SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(ArrowDataType::Int32),
489 SQLDataType::BigInt(_) => Ok(ArrowDataType::Int64),
490 SQLDataType::TinyIntUnsigned(_) => Ok(ArrowDataType::UInt8),
491 SQLDataType::SmallIntUnsigned(_) => Ok(ArrowDataType::UInt16),
492 SQLDataType::IntUnsigned(_) | SQLDataType::IntegerUnsigned(_) => {
493 Ok(ArrowDataType::UInt32)
494 }
495 SQLDataType::BigIntUnsigned(_) => Ok(ArrowDataType::UInt64),
496 SQLDataType::Date => Ok(ArrowDataType::Date32),
497 SQLDataType::Timestamp(resolution, tz) => {
498 match tz {
499 TimezoneInfo::None => {}
500 _ => {
501 return Err(Error::invalid_input(
502 "Timezone not supported in timestamp".to_string(),
503 location!(),
504 ));
505 }
506 };
507 let time_unit = match resolution {
508 None => TimeUnit::Microsecond,
510 Some(0) => TimeUnit::Second,
511 Some(3) => TimeUnit::Millisecond,
512 Some(6) => TimeUnit::Microsecond,
513 Some(9) => TimeUnit::Nanosecond,
514 _ => {
515 return Err(Error::invalid_input(
516 format!("Unsupported datetime resolution: {:?}", resolution),
517 location!(),
518 ));
519 }
520 };
521 Ok(ArrowDataType::Timestamp(time_unit, None))
522 }
523 SQLDataType::Datetime(resolution) => {
524 let time_unit = match resolution {
525 None => TimeUnit::Microsecond,
526 Some(0) => TimeUnit::Second,
527 Some(3) => TimeUnit::Millisecond,
528 Some(6) => TimeUnit::Microsecond,
529 Some(9) => TimeUnit::Nanosecond,
530 _ => {
531 return Err(Error::invalid_input(
532 format!("Unsupported datetime resolution: {:?}", resolution),
533 location!(),
534 ));
535 }
536 };
537 Ok(ArrowDataType::Timestamp(time_unit, None))
538 }
539 SQLDataType::Decimal(number_info) => match number_info {
540 ExactNumberInfo::PrecisionAndScale(precision, scale) => {
541 Ok(ArrowDataType::Decimal128(*precision as u8, *scale as i8))
542 }
543 _ => Err(Error::invalid_input(
544 format!(
545 "Must provide precision and scale for decimal: {:?}",
546 number_info
547 ),
548 location!(),
549 )),
550 },
551 _ => Err(Error::invalid_input(
552 format!(
553 "Unsupported data type: {:?}. Supported types: {:?}",
554 data_type, SUPPORTED_TYPES
555 ),
556 location!(),
557 )),
558 }
559 }
560
561 fn plan_field_access(&self, mut field_access_expr: RawFieldAccessExpr) -> Result<Expr> {
562 let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
563 for planner in self.context_provider.get_expr_planners() {
564 match planner.plan_field_access(field_access_expr, &df_schema)? {
565 PlannerResult::Planned(expr) => return Ok(expr),
566 PlannerResult::Original(expr) => {
567 field_access_expr = expr;
568 }
569 }
570 }
571 Err(Error::invalid_input(
572 "Field access could not be planned",
573 location!(),
574 ))
575 }
576
577 fn parse_sql_expr(&self, expr: &SQLExpr) -> Result<Expr> {
578 match expr {
579 SQLExpr::Identifier(id) => {
580 if id.quote_style == Some('"') {
583 Ok(Expr::Literal(
584 ScalarValue::Utf8(Some(id.value.clone())),
585 None,
586 ))
587 } else if id.quote_style == Some('`') {
590 Ok(Expr::Column(Column::from_name(id.value.clone())))
591 } else {
592 Ok(self.column(vec![id.clone()].as_slice()))
593 }
594 }
595 SQLExpr::CompoundIdentifier(ids) => Ok(self.column(ids.as_slice())),
596 SQLExpr::BinaryOp { left, op, right } => self.binary_expr(left, op, right),
597 SQLExpr::UnaryOp { op, expr } => self.unary_expr(op, expr),
598 SQLExpr::Value(value) => self.value(&value.value),
599 SQLExpr::Array(SQLArray { elem, .. }) => {
600 let mut values = vec![];
601
602 let array_literal_error = |pos: usize, value: &_| {
603 Err(Error::invalid_input(
604 format!(
605 "Expected a literal value in array, instead got {} at position {}",
606 value, pos
607 ),
608 location!(),
609 ))
610 };
611
612 for (pos, expr) in elem.iter().enumerate() {
613 match expr {
614 SQLExpr::Value(value) => {
615 if let Expr::Literal(value, _) = self.value(&value.value)? {
616 values.push(value);
617 } else {
618 return array_literal_error(pos, expr);
619 }
620 }
621 SQLExpr::UnaryOp {
622 op: UnaryOperator::Minus,
623 expr,
624 } => {
625 if let SQLExpr::Value(ValueWithSpan {
626 value: Value::Number(number, _),
627 ..
628 }) = expr.as_ref()
629 {
630 if let Expr::Literal(value, _) = self.number(number, true)? {
631 values.push(value);
632 } else {
633 return array_literal_error(pos, expr);
634 }
635 } else {
636 return array_literal_error(pos, expr);
637 }
638 }
639 _ => {
640 return array_literal_error(pos, expr);
641 }
642 }
643 }
644
645 let field = if !values.is_empty() {
646 let data_type = values[0].data_type();
647
648 for value in &mut values {
649 if value.data_type() != data_type {
650 *value = safe_coerce_scalar(value, &data_type).ok_or_else(|| Error::invalid_input(
651 format!("Array expressions must have a consistent datatype. Expected: {}, got: {}", data_type, value.data_type()),
652 location!()
653 ))?;
654 }
655 }
656 Field::new("item", data_type, true)
657 } else {
658 Field::new("item", ArrowDataType::Null, true)
659 };
660
661 let values = values
662 .into_iter()
663 .map(|v| v.to_array().map_err(Error::from))
664 .collect::<Result<Vec<_>>>()?;
665 let array_refs = values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
666 let values = concat(&array_refs)?;
667 let values = ListArray::try_new(
668 field.into(),
669 OffsetBuffer::from_lengths([values.len()]),
670 values,
671 None,
672 )?;
673
674 Ok(Expr::Literal(ScalarValue::List(Arc::new(values)), None))
675 }
676 SQLExpr::TypedString { data_type, value } => {
678 let value = value.clone().into_string().expect_ok()?;
679 Ok(Expr::Cast(datafusion::logical_expr::Cast {
680 expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(value)), None)),
681 data_type: self.parse_type(data_type)?,
682 }))
683 }
684 SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))),
685 SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))),
686 SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))),
687 SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))),
688 SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))),
689 SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(self.parse_sql_expr(expr)?))),
690 SQLExpr::InList {
691 expr,
692 list,
693 negated,
694 } => {
695 let value_expr = self.parse_sql_expr(expr)?;
696 let list_exprs = list
697 .iter()
698 .map(|e| self.parse_sql_expr(e))
699 .collect::<Result<Vec<_>>>()?;
700 Ok(value_expr.in_list(list_exprs, *negated))
701 }
702 SQLExpr::Nested(inner) => self.parse_sql_expr(inner.as_ref()),
703 SQLExpr::Function(_) => self.parse_function(expr.clone()),
704 SQLExpr::ILike {
705 negated,
706 expr,
707 pattern,
708 escape_char,
709 any: _,
710 } => Ok(Expr::Like(Like::new(
711 *negated,
712 Box::new(self.parse_sql_expr(expr)?),
713 Box::new(self.parse_sql_expr(pattern)?),
714 match escape_char {
715 Some(Value::SingleQuotedString(char)) => char.chars().next(),
716 Some(value) => return Err(Error::invalid_input(
717 format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
718 location!()
719 )),
720 None => None,
721 },
722 true,
723 ))),
724 SQLExpr::Like {
725 negated,
726 expr,
727 pattern,
728 escape_char,
729 any: _,
730 } => Ok(Expr::Like(Like::new(
731 *negated,
732 Box::new(self.parse_sql_expr(expr)?),
733 Box::new(self.parse_sql_expr(pattern)?),
734 match escape_char {
735 Some(Value::SingleQuotedString(char)) => char.chars().next(),
736 Some(value) => return Err(Error::invalid_input(
737 format!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {}", value),
738 location!()
739 )),
740 None => None,
741 },
742 false,
743 ))),
744 SQLExpr::Cast {
745 expr,
746 data_type,
747 kind,
748 ..
749 } => match kind {
750 datafusion::sql::sqlparser::ast::CastKind::TryCast
751 | datafusion::sql::sqlparser::ast::CastKind::SafeCast => {
752 Ok(Expr::TryCast(datafusion::logical_expr::TryCast {
753 expr: Box::new(self.parse_sql_expr(expr)?),
754 data_type: self.parse_type(data_type)?,
755 }))
756 }
757 _ => Ok(Expr::Cast(datafusion::logical_expr::Cast {
758 expr: Box::new(self.parse_sql_expr(expr)?),
759 data_type: self.parse_type(data_type)?,
760 })),
761 },
762 SQLExpr::JsonAccess { .. } => Err(Error::invalid_input(
763 "JSON access is not supported",
764 location!(),
765 )),
766 SQLExpr::CompoundFieldAccess { root, access_chain } => {
767 let mut expr = self.parse_sql_expr(root)?;
768
769 for access in access_chain {
770 let field_access = match access {
771 AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. }))
773 | AccessExpr::Subscript(Subscript::Index {
774 index:
775 SQLExpr::Value(ValueWithSpan {
776 value:
777 Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
778 ..
779 }),
780 }) => GetFieldAccess::NamedStructField {
781 name: ScalarValue::from(s.as_str()),
782 },
783 AccessExpr::Subscript(Subscript::Index { index }) => {
784 let key = Box::new(self.parse_sql_expr(index)?);
785 GetFieldAccess::ListIndex { key }
786 }
787 AccessExpr::Subscript(Subscript::Slice { .. }) => {
788 return Err(Error::invalid_input(
789 "Slice subscript is not supported",
790 location!(),
791 ));
792 }
793 _ => {
794 return Err(Error::invalid_input(
797 "Only dot notation or index access is supported for field access",
798 location!(),
799 ));
800 }
801 };
802
803 let field_access_expr = RawFieldAccessExpr { expr, field_access };
804 expr = self.plan_field_access(field_access_expr)?;
805 }
806
807 Ok(expr)
808 }
809 SQLExpr::Between {
810 expr,
811 negated,
812 low,
813 high,
814 } => {
815 let expr = self.parse_sql_expr(expr)?;
817 let low = self.parse_sql_expr(low)?;
818 let high = self.parse_sql_expr(high)?;
819
820 let between = Expr::Between(Between::new(
821 Box::new(expr),
822 *negated,
823 Box::new(low),
824 Box::new(high),
825 ));
826 Ok(between)
827 }
828 _ => Err(Error::invalid_input(
829 format!("Expression '{expr}' is not supported SQL in lance"),
830 location!(),
831 )),
832 }
833 }
834
835 pub fn parse_filter(&self, filter: &str) -> Result<Expr> {
840 let ast_expr = parse_sql_filter(filter)?;
842 let expr = self.parse_sql_expr(&ast_expr)?;
843 let schema = Schema::try_from(self.schema.as_ref())?;
844 let resolved = resolve_expr(&expr, &schema).map_err(|e| {
845 Error::invalid_input(
846 format!("Error resolving filter expression {filter}: {e}"),
847 location!(),
848 )
849 })?;
850
851 Ok(coerce_filter_type_to_boolean(resolved))
852 }
853
854 pub fn parse_expr(&self, expr: &str) -> Result<Expr> {
859 if self.schema.field_with_name(expr).is_ok() {
860 return Ok(col(expr));
861 }
862
863 let ast_expr = parse_sql_expr(expr)?;
864 let expr = self.parse_sql_expr(&ast_expr)?;
865 let schema = Schema::try_from(self.schema.as_ref())?;
866 let resolved = resolve_expr(&expr, &schema)?;
867 Ok(resolved)
868 }
869
870 fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
876 let hex_bytes = s.as_bytes();
877 let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
878
879 let start_idx = hex_bytes.len() % 2;
880 if start_idx > 0 {
881 decoded_bytes.push(Self::try_decode_hex_char(hex_bytes[0])?);
883 }
884
885 for i in (start_idx..hex_bytes.len()).step_by(2) {
886 let high = Self::try_decode_hex_char(hex_bytes[i])?;
887 let low = Self::try_decode_hex_char(hex_bytes[i + 1])?;
888 decoded_bytes.push((high << 4) | low);
889 }
890
891 Some(decoded_bytes)
892 }
893
894 const fn try_decode_hex_char(c: u8) -> Option<u8> {
898 match c {
899 b'A'..=b'F' => Some(c - b'A' + 10),
900 b'a'..=b'f' => Some(c - b'a' + 10),
901 b'0'..=b'9' => Some(c - b'0'),
902 _ => None,
903 }
904 }
905
906 pub fn optimize_expr(&self, expr: Expr) -> Result<Expr> {
908 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
909
910 let props = ExecutionProps::default();
913 let simplify_context = SimplifyContext::new(&props).with_schema(df_schema.clone());
914 let simplifier =
915 datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context);
916
917 let expr = simplifier.simplify(expr)?;
918 let expr = simplifier.coerce(expr, &df_schema)?;
919
920 Ok(expr)
921 }
922
923 pub fn create_physical_expr(&self, expr: &Expr) -> Result<Arc<dyn PhysicalExpr>> {
925 let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?);
926 Ok(datafusion::physical_expr::create_physical_expr(
927 expr,
928 df_schema.as_ref(),
929 &Default::default(),
930 )?)
931 }
932
933 pub fn column_names_in_expr(expr: &Expr) -> Vec<String> {
940 let mut visitor = ColumnCapturingVisitor {
941 current_path: VecDeque::new(),
942 columns: BTreeSet::new(),
943 };
944 expr.visit(&mut visitor).unwrap();
945 visitor.columns.into_iter().collect()
946 }
947}
948
949struct ColumnCapturingVisitor {
950 current_path: VecDeque<String>,
952 columns: BTreeSet<String>,
953}
954
955impl TreeNodeVisitor<'_> for ColumnCapturingVisitor {
956 type Node = Expr;
957
958 fn f_down(&mut self, node: &Self::Node) -> DFResult<TreeNodeRecursion> {
959 match node {
960 Expr::Column(Column { name, .. }) => {
961 let mut path = name.clone();
965 for part in self.current_path.drain(..) {
966 path.push('.');
967 if part.contains('.') || part.contains('`') {
969 let escaped = part.replace('`', "``");
971 path.push('`');
972 path.push_str(&escaped);
973 path.push('`');
974 } else {
975 path.push_str(&part);
976 }
977 }
978 self.columns.insert(path);
979 self.current_path.clear();
980 }
981 Expr::ScalarFunction(udf) => {
982 if udf.name() == GetFieldFunc::default().name() {
983 if let Some(name) = get_as_string_scalar_opt(&udf.args[1]) {
984 self.current_path.push_front(name.to_string())
985 } else {
986 self.current_path.clear();
987 }
988 } else {
989 self.current_path.clear();
990 }
991 }
992 _ => {
993 self.current_path.clear();
994 }
995 }
996
997 Ok(TreeNodeRecursion::Continue)
998 }
999}
1000
1001#[cfg(test)]
1002mod tests {
1003
1004 use crate::logical_expr::ExprExt;
1005
1006 use super::*;
1007
1008 use arrow::datatypes::Float64Type;
1009 use arrow_array::{
1010 ArrayRef, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, StringArray,
1011 StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1012 TimestampNanosecondArray, TimestampSecondArray,
1013 };
1014 use arrow_schema::{DataType, Fields, Schema};
1015 use datafusion::{
1016 logical_expr::{lit, Cast},
1017 prelude::{array_element, get_field},
1018 };
1019 use datafusion_functions::core::expr_ext::FieldAccessor;
1020
1021 #[test]
1022 fn test_parse_filter_simple() {
1023 let schema = Arc::new(Schema::new(vec![
1024 Field::new("i", DataType::Int32, false),
1025 Field::new("s", DataType::Utf8, true),
1026 Field::new(
1027 "st",
1028 DataType::Struct(Fields::from(vec![
1029 Field::new("x", DataType::Float32, false),
1030 Field::new("y", DataType::Float32, false),
1031 ])),
1032 true,
1033 ),
1034 ]));
1035
1036 let planner = Planner::new(schema.clone());
1037
1038 let expected = col("i")
1039 .gt(lit(3_i32))
1040 .and(col("st").field_newstyle("x").lt_eq(lit(5.0_f32)))
1041 .and(
1042 col("s")
1043 .eq(lit("str-4"))
1044 .or(col("s").in_list(vec![lit("str-4"), lit("str-5")], false)),
1045 );
1046
1047 let expr = planner
1049 .parse_filter("i > 3 AND st.x <= 5.0 AND (s == 'str-4' OR s in ('str-4', 'str-5'))")
1050 .unwrap();
1051 assert_eq!(expr, expected);
1052
1053 let expr = planner
1055 .parse_filter("i > 3 AND st.x <= 5.0 AND (s = 'str-4' OR s in ('str-4', 'str-5'))")
1056 .unwrap();
1057
1058 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1059
1060 let batch = RecordBatch::try_new(
1061 schema,
1062 vec![
1063 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1064 Arc::new(StringArray::from_iter_values(
1065 (0..10).map(|v| format!("str-{}", v)),
1066 )),
1067 Arc::new(StructArray::from(vec![
1068 (
1069 Arc::new(Field::new("x", DataType::Float32, false)),
1070 Arc::new(Float32Array::from_iter_values((0..10).map(|v| v as f32)))
1071 as ArrayRef,
1072 ),
1073 (
1074 Arc::new(Field::new("y", DataType::Float32, false)),
1075 Arc::new(Float32Array::from_iter_values(
1076 (0..10).map(|v| (v * 10) as f32),
1077 )),
1078 ),
1079 ])),
1080 ],
1081 )
1082 .unwrap();
1083 let predicates = physical_expr.evaluate(&batch).unwrap();
1084 assert_eq!(
1085 predicates.into_array(0).unwrap().as_ref(),
1086 &BooleanArray::from(vec![
1087 false, false, false, false, true, true, false, false, false, false
1088 ])
1089 );
1090 }
1091
1092 #[test]
1093 fn test_nested_col_refs() {
1094 let schema = Arc::new(Schema::new(vec![
1095 Field::new("s0", DataType::Utf8, true),
1096 Field::new(
1097 "st",
1098 DataType::Struct(Fields::from(vec![
1099 Field::new("s1", DataType::Utf8, true),
1100 Field::new(
1101 "st",
1102 DataType::Struct(Fields::from(vec![Field::new(
1103 "s2",
1104 DataType::Utf8,
1105 true,
1106 )])),
1107 true,
1108 ),
1109 ])),
1110 true,
1111 ),
1112 ]));
1113
1114 let planner = Planner::new(schema);
1115
1116 fn assert_column_eq(planner: &Planner, expr: &str, expected: &Expr) {
1117 let expr = planner.parse_filter(&format!("{expr} = 'val'")).unwrap();
1118 assert!(matches!(
1119 expr,
1120 Expr::BinaryExpr(BinaryExpr {
1121 left: _,
1122 op: Operator::Eq,
1123 right: _
1124 })
1125 ));
1126 if let Expr::BinaryExpr(BinaryExpr { left, .. }) = expr {
1127 assert_eq!(left.as_ref(), expected);
1128 }
1129 }
1130
1131 let expected = Expr::Column(Column::new_unqualified("s0"));
1132 assert_column_eq(&planner, "s0", &expected);
1133 assert_column_eq(&planner, "`s0`", &expected);
1134
1135 let expected = Expr::ScalarFunction(ScalarFunction {
1136 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1137 args: vec![
1138 Expr::Column(Column::new_unqualified("st")),
1139 Expr::Literal(ScalarValue::Utf8(Some("s1".to_string())), None),
1140 ],
1141 });
1142 assert_column_eq(&planner, "st.s1", &expected);
1143 assert_column_eq(&planner, "`st`.`s1`", &expected);
1144 assert_column_eq(&planner, "st.`s1`", &expected);
1145
1146 let expected = Expr::ScalarFunction(ScalarFunction {
1147 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1148 args: vec![
1149 Expr::ScalarFunction(ScalarFunction {
1150 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
1151 args: vec![
1152 Expr::Column(Column::new_unqualified("st")),
1153 Expr::Literal(ScalarValue::Utf8(Some("st".to_string())), None),
1154 ],
1155 }),
1156 Expr::Literal(ScalarValue::Utf8(Some("s2".to_string())), None),
1157 ],
1158 });
1159
1160 assert_column_eq(&planner, "st.st.s2", &expected);
1161 assert_column_eq(&planner, "`st`.`st`.`s2`", &expected);
1162 assert_column_eq(&planner, "st.st.`s2`", &expected);
1163 assert_column_eq(&planner, "st['st'][\"s2\"]", &expected);
1164 }
1165
1166 #[test]
1167 fn test_nested_list_refs() {
1168 let schema = Arc::new(Schema::new(vec![Field::new(
1169 "l",
1170 DataType::List(Arc::new(Field::new(
1171 "item",
1172 DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Utf8, true)])),
1173 true,
1174 ))),
1175 true,
1176 )]));
1177
1178 let planner = Planner::new(schema);
1179
1180 let expected = array_element(col("l"), lit(0_i64));
1181 let expr = planner.parse_expr("l[0]").unwrap();
1182 assert_eq!(expr, expected);
1183
1184 let expected = get_field(array_element(col("l"), lit(0_i64)), "f1");
1185 let expr = planner.parse_expr("l[0]['f1']").unwrap();
1186 assert_eq!(expr, expected);
1187
1188 }
1193
1194 #[test]
1195 fn test_negative_expressions() {
1196 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1197
1198 let planner = Planner::new(schema.clone());
1199
1200 let expected = col("x")
1201 .gt(lit(-3_i64))
1202 .and(col("x").lt(-(lit(-5_i64) + lit(3_i64))));
1203
1204 let expr = planner.parse_filter("x > -3 AND x < -(-5 + 3)").unwrap();
1205
1206 assert_eq!(expr, expected);
1207
1208 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1209
1210 let batch = RecordBatch::try_new(
1211 schema,
1212 vec![Arc::new(Int64Array::from_iter_values(-5..5)) as ArrayRef],
1213 )
1214 .unwrap();
1215 let predicates = physical_expr.evaluate(&batch).unwrap();
1216 assert_eq!(
1217 predicates.into_array(0).unwrap().as_ref(),
1218 &BooleanArray::from(vec![
1219 false, false, false, true, true, true, true, false, false, false
1220 ])
1221 );
1222 }
1223
1224 #[test]
1225 fn test_negative_array_expressions() {
1226 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1227
1228 let planner = Planner::new(schema);
1229
1230 let expected = Expr::Literal(
1231 ScalarValue::List(Arc::new(
1232 ListArray::from_iter_primitive::<Float64Type, _, _>(vec![Some(
1233 [-1_f64, -2.0, -3.0, -4.0, -5.0].map(Some),
1234 )]),
1235 )),
1236 None,
1237 );
1238
1239 let expr = planner
1240 .parse_expr("[-1.0, -2.0, -3.0, -4.0, -5.0]")
1241 .unwrap();
1242
1243 assert_eq!(expr, expected);
1244 }
1245
1246 #[test]
1247 fn test_sql_like() {
1248 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1249
1250 let planner = Planner::new(schema.clone());
1251
1252 let expected = col("s").like(lit("str-4"));
1253 let expr = planner.parse_filter("s LIKE 'str-4'").unwrap();
1255 assert_eq!(expr, expected);
1256 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1257
1258 let batch = RecordBatch::try_new(
1259 schema,
1260 vec![Arc::new(StringArray::from_iter_values(
1261 (0..10).map(|v| format!("str-{}", v)),
1262 ))],
1263 )
1264 .unwrap();
1265 let predicates = physical_expr.evaluate(&batch).unwrap();
1266 assert_eq!(
1267 predicates.into_array(0).unwrap().as_ref(),
1268 &BooleanArray::from(vec![
1269 false, false, false, false, true, false, false, false, false, false
1270 ])
1271 );
1272 }
1273
1274 #[test]
1275 fn test_not_like() {
1276 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1277
1278 let planner = Planner::new(schema.clone());
1279
1280 let expected = col("s").not_like(lit("str-4"));
1281 let expr = planner.parse_filter("s NOT LIKE 'str-4'").unwrap();
1283 assert_eq!(expr, expected);
1284 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1285
1286 let batch = RecordBatch::try_new(
1287 schema,
1288 vec![Arc::new(StringArray::from_iter_values(
1289 (0..10).map(|v| format!("str-{}", v)),
1290 ))],
1291 )
1292 .unwrap();
1293 let predicates = physical_expr.evaluate(&batch).unwrap();
1294 assert_eq!(
1295 predicates.into_array(0).unwrap().as_ref(),
1296 &BooleanArray::from(vec![
1297 true, true, true, true, false, true, true, true, true, true
1298 ])
1299 );
1300 }
1301
1302 #[test]
1303 fn test_sql_is_in() {
1304 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1305
1306 let planner = Planner::new(schema.clone());
1307
1308 let expected = col("s").in_list(vec![lit("str-4"), lit("str-5")], false);
1309 let expr = planner.parse_filter("s IN ('str-4', 'str-5')").unwrap();
1311 assert_eq!(expr, expected);
1312 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1313
1314 let batch = RecordBatch::try_new(
1315 schema,
1316 vec![Arc::new(StringArray::from_iter_values(
1317 (0..10).map(|v| format!("str-{}", v)),
1318 ))],
1319 )
1320 .unwrap();
1321 let predicates = physical_expr.evaluate(&batch).unwrap();
1322 assert_eq!(
1323 predicates.into_array(0).unwrap().as_ref(),
1324 &BooleanArray::from(vec![
1325 false, false, false, false, true, true, false, false, false, false
1326 ])
1327 );
1328 }
1329
1330 #[test]
1331 fn test_sql_is_null() {
1332 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
1333
1334 let planner = Planner::new(schema.clone());
1335
1336 let expected = col("s").is_null();
1337 let expr = planner.parse_filter("s IS NULL").unwrap();
1338 assert_eq!(expr, expected);
1339 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1340
1341 let batch = RecordBatch::try_new(
1342 schema,
1343 vec![Arc::new(StringArray::from_iter((0..10).map(|v| {
1344 if v % 3 == 0 {
1345 Some(format!("str-{}", v))
1346 } else {
1347 None
1348 }
1349 })))],
1350 )
1351 .unwrap();
1352 let predicates = physical_expr.evaluate(&batch).unwrap();
1353 assert_eq!(
1354 predicates.into_array(0).unwrap().as_ref(),
1355 &BooleanArray::from(vec![
1356 false, true, true, false, true, true, false, true, true, false
1357 ])
1358 );
1359
1360 let expr = planner.parse_filter("s IS NOT NULL").unwrap();
1361 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1362 let predicates = physical_expr.evaluate(&batch).unwrap();
1363 assert_eq!(
1364 predicates.into_array(0).unwrap().as_ref(),
1365 &BooleanArray::from(vec![
1366 true, false, false, true, false, false, true, false, false, true,
1367 ])
1368 );
1369 }
1370
1371 #[test]
1372 fn test_sql_invert() {
1373 let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Boolean, true)]));
1374
1375 let planner = Planner::new(schema.clone());
1376
1377 let expr = planner.parse_filter("NOT s").unwrap();
1378 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1379
1380 let batch = RecordBatch::try_new(
1381 schema,
1382 vec![Arc::new(BooleanArray::from_iter(
1383 (0..10).map(|v| Some(v % 3 == 0)),
1384 ))],
1385 )
1386 .unwrap();
1387 let predicates = physical_expr.evaluate(&batch).unwrap();
1388 assert_eq!(
1389 predicates.into_array(0).unwrap().as_ref(),
1390 &BooleanArray::from(vec![
1391 false, true, true, false, true, true, false, true, true, false
1392 ])
1393 );
1394 }
1395
1396 #[test]
1397 fn test_sql_cast() {
1398 let cases = &[
1399 (
1400 "x = cast('2021-01-01 00:00:00' as timestamp)",
1401 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1402 ),
1403 (
1404 "x = cast('2021-01-01 00:00:00' as timestamp(0))",
1405 ArrowDataType::Timestamp(TimeUnit::Second, None),
1406 ),
1407 (
1408 "x = cast('2021-01-01 00:00:00.123' as timestamp(9))",
1409 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1410 ),
1411 (
1412 "x = cast('2021-01-01 00:00:00.123' as datetime(9))",
1413 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1414 ),
1415 ("x = cast('2021-01-01' as date)", ArrowDataType::Date32),
1416 (
1417 "x = cast('1.238' as decimal(9,3))",
1418 ArrowDataType::Decimal128(9, 3),
1419 ),
1420 ("x = cast(1 as float)", ArrowDataType::Float32),
1421 ("x = cast(1 as double)", ArrowDataType::Float64),
1422 ("x = cast(1 as tinyint)", ArrowDataType::Int8),
1423 ("x = cast(1 as smallint)", ArrowDataType::Int16),
1424 ("x = cast(1 as int)", ArrowDataType::Int32),
1425 ("x = cast(1 as integer)", ArrowDataType::Int32),
1426 ("x = cast(1 as bigint)", ArrowDataType::Int64),
1427 ("x = cast(1 as tinyint unsigned)", ArrowDataType::UInt8),
1428 ("x = cast(1 as smallint unsigned)", ArrowDataType::UInt16),
1429 ("x = cast(1 as int unsigned)", ArrowDataType::UInt32),
1430 ("x = cast(1 as integer unsigned)", ArrowDataType::UInt32),
1431 ("x = cast(1 as bigint unsigned)", ArrowDataType::UInt64),
1432 ("x = cast(1 as boolean)", ArrowDataType::Boolean),
1433 ("x = cast(1 as string)", ArrowDataType::Utf8),
1434 ];
1435
1436 for (sql, expected_data_type) in cases {
1437 let schema = Arc::new(Schema::new(vec![Field::new(
1438 "x",
1439 expected_data_type.clone(),
1440 true,
1441 )]));
1442 let planner = Planner::new(schema.clone());
1443 let expr = planner.parse_filter(sql).unwrap();
1444
1445 let expected_value_str = sql
1447 .split("cast(")
1448 .nth(1)
1449 .unwrap()
1450 .split(" as")
1451 .next()
1452 .unwrap();
1453 let expected_value_str = expected_value_str.trim_matches('\'');
1455
1456 match expr {
1457 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1458 Expr::Cast(Cast { expr, data_type }) => {
1459 match expr.as_ref() {
1460 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1461 assert_eq!(value_str, expected_value_str);
1462 }
1463 Expr::Literal(ScalarValue::Int64(Some(value)), _) => {
1464 assert_eq!(*value, 1);
1465 }
1466 _ => panic!("Expected cast to be applied to literal"),
1467 }
1468 assert_eq!(data_type, expected_data_type);
1469 }
1470 _ => panic!("Expected right to be a cast"),
1471 },
1472 _ => panic!("Expected binary expression"),
1473 }
1474 }
1475 }
1476
1477 #[test]
1478 fn test_sql_literals() {
1479 let cases = &[
1480 (
1481 "x = timestamp '2021-01-01 00:00:00'",
1482 ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1483 ),
1484 (
1485 "x = timestamp(0) '2021-01-01 00:00:00'",
1486 ArrowDataType::Timestamp(TimeUnit::Second, None),
1487 ),
1488 (
1489 "x = timestamp(9) '2021-01-01 00:00:00.123'",
1490 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1491 ),
1492 ("x = date '2021-01-01'", ArrowDataType::Date32),
1493 ("x = decimal(9,3) '1.238'", ArrowDataType::Decimal128(9, 3)),
1494 ];
1495
1496 for (sql, expected_data_type) in cases {
1497 let schema = Arc::new(Schema::new(vec![Field::new(
1498 "x",
1499 expected_data_type.clone(),
1500 true,
1501 )]));
1502 let planner = Planner::new(schema.clone());
1503 let expr = planner.parse_filter(sql).unwrap();
1504
1505 let expected_value_str = sql.split('\'').nth(1).unwrap();
1506
1507 match expr {
1508 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1509 Expr::Cast(Cast { expr, data_type }) => {
1510 match expr.as_ref() {
1511 Expr::Literal(ScalarValue::Utf8(Some(value_str)), _) => {
1512 assert_eq!(value_str, expected_value_str);
1513 }
1514 _ => panic!("Expected cast to be applied to literal"),
1515 }
1516 assert_eq!(data_type, expected_data_type);
1517 }
1518 _ => panic!("Expected right to be a cast"),
1519 },
1520 _ => panic!("Expected binary expression"),
1521 }
1522 }
1523 }
1524
1525 #[test]
1526 fn test_sql_array_literals() {
1527 let cases = [
1528 (
1529 "x = [1, 2, 3]",
1530 ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))),
1531 ),
1532 (
1533 "x = [1, 2, 3]",
1534 ArrowDataType::FixedSizeList(
1535 Arc::new(Field::new("item", ArrowDataType::Int64, true)),
1536 3,
1537 ),
1538 ),
1539 ];
1540
1541 for (sql, expected_data_type) in cases {
1542 let schema = Arc::new(Schema::new(vec![Field::new(
1543 "x",
1544 expected_data_type.clone(),
1545 true,
1546 )]));
1547 let planner = Planner::new(schema.clone());
1548 let expr = planner.parse_filter(sql).unwrap();
1549 let expr = planner.optimize_expr(expr).unwrap();
1550
1551 match expr {
1552 Expr::BinaryExpr(BinaryExpr { right, .. }) => match right.as_ref() {
1553 Expr::Literal(value, _) => {
1554 assert_eq!(&value.data_type(), &expected_data_type);
1555 }
1556 _ => panic!("Expected right to be a literal"),
1557 },
1558 _ => panic!("Expected binary expression"),
1559 }
1560 }
1561 }
1562
1563 #[test]
1564 fn test_sql_between() {
1565 use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
1566 use arrow_schema::{DataType, Field, Schema, TimeUnit};
1567 use std::sync::Arc;
1568
1569 let schema = Arc::new(Schema::new(vec![
1570 Field::new("x", DataType::Int32, false),
1571 Field::new("y", DataType::Float64, false),
1572 Field::new(
1573 "ts",
1574 DataType::Timestamp(TimeUnit::Microsecond, None),
1575 false,
1576 ),
1577 ]));
1578
1579 let planner = Planner::new(schema.clone());
1580
1581 let expr = planner
1583 .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1584 .unwrap();
1585 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1586
1587 let base_ts = 1704067200000000_i64; let ts_array = TimestampMicrosecondArray::from_iter_values(
1591 (0..10).map(|i| base_ts + i * 1_000_000), );
1593
1594 let batch = RecordBatch::try_new(
1595 schema,
1596 vec![
1597 Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
1598 Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
1599 Arc::new(ts_array),
1600 ],
1601 )
1602 .unwrap();
1603
1604 let predicates = physical_expr.evaluate(&batch).unwrap();
1605 assert_eq!(
1606 predicates.into_array(0).unwrap().as_ref(),
1607 &BooleanArray::from(vec![
1608 false, false, false, true, true, true, true, true, false, false
1609 ])
1610 );
1611
1612 let expr = planner
1614 .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
1615 .unwrap();
1616 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1617
1618 let predicates = physical_expr.evaluate(&batch).unwrap();
1619 assert_eq!(
1620 predicates.into_array(0).unwrap().as_ref(),
1621 &BooleanArray::from(vec![
1622 true, true, true, false, false, false, false, false, true, true
1623 ])
1624 );
1625
1626 let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
1628 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1629
1630 let predicates = physical_expr.evaluate(&batch).unwrap();
1631 assert_eq!(
1632 predicates.into_array(0).unwrap().as_ref(),
1633 &BooleanArray::from(vec![
1634 false, false, false, true, true, true, true, false, false, false
1635 ])
1636 );
1637
1638 let expr = planner
1640 .parse_filter(
1641 "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
1642 )
1643 .unwrap();
1644 let physical_expr = planner.create_physical_expr(&expr).unwrap();
1645
1646 let predicates = physical_expr.evaluate(&batch).unwrap();
1647 assert_eq!(
1648 predicates.into_array(0).unwrap().as_ref(),
1649 &BooleanArray::from(vec![
1650 false, false, false, true, true, true, true, true, false, false
1651 ])
1652 );
1653 }
1654
1655 #[test]
1656 fn test_sql_comparison() {
1657 let batch: Vec<(&str, ArrayRef)> = vec![
1659 (
1660 "timestamp_s",
1661 Arc::new(TimestampSecondArray::from_iter_values(0..10)),
1662 ),
1663 (
1664 "timestamp_ms",
1665 Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
1666 ),
1667 (
1668 "timestamp_us",
1669 Arc::new(TimestampMicrosecondArray::from_iter_values(0..10)),
1670 ),
1671 (
1672 "timestamp_ns",
1673 Arc::new(TimestampNanosecondArray::from_iter_values(4995..5005)),
1674 ),
1675 ];
1676 let batch = RecordBatch::try_from_iter(batch).unwrap();
1677
1678 let planner = Planner::new(batch.schema());
1679
1680 let expressions = &[
1682 "timestamp_s >= TIMESTAMP '1970-01-01 00:00:05'",
1683 "timestamp_ms >= TIMESTAMP '1970-01-01 00:00:00.005'",
1684 "timestamp_us >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1685 "timestamp_ns >= TIMESTAMP '1970-01-01 00:00:00.000005'",
1686 ];
1687
1688 let expected: ArrayRef = Arc::new(BooleanArray::from_iter(
1689 std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)),
1690 ));
1691 for expression in expressions {
1692 let logical_expr = planner.parse_filter(expression).unwrap();
1694 let logical_expr = planner.optimize_expr(logical_expr).unwrap();
1695 let physical_expr = planner.create_physical_expr(&logical_expr).unwrap();
1696
1697 let result = physical_expr.evaluate(&batch).unwrap();
1699 let result = result.into_array(batch.num_rows()).unwrap();
1700 assert_eq!(&expected, &result, "unexpected result for {}", expression);
1701 }
1702 }
1703
1704 #[test]
1705 fn test_columns_in_expr() {
1706 let expr = col("s0").gt(lit("value")).and(
1707 col("st")
1708 .field("st")
1709 .field("s2")
1710 .eq(lit("value"))
1711 .or(col("st")
1712 .field("s1")
1713 .in_list(vec![lit("value 1"), lit("value 2")], false)),
1714 );
1715
1716 let columns = Planner::column_names_in_expr(&expr);
1717 assert_eq!(columns, vec!["s0", "st.s1", "st.st.s2"]);
1718 }
1719
1720 #[test]
1721 fn test_parse_binary_expr() {
1722 let bin_str = "x'616263'";
1723
1724 let schema = Arc::new(Schema::new(vec![Field::new(
1725 "binary",
1726 DataType::Binary,
1727 true,
1728 )]));
1729 let planner = Planner::new(schema);
1730 let expr = planner.parse_expr(bin_str).unwrap();
1731 assert_eq!(
1732 expr,
1733 Expr::Literal(ScalarValue::Binary(Some(vec![b'a', b'b', b'c'])), None)
1734 );
1735 }
1736
1737 #[test]
1738 fn test_lance_context_provider_expr_planners() {
1739 let ctx_provider = LanceContextProvider::default();
1740 assert!(!ctx_provider.get_expr_planners().is_empty());
1741 }
1742}