1use std::sync::Arc;
7
8use arrow_schema::DataType;
9
10use crate::expr::safe_coerce_scalar;
11use datafusion::logical_expr::{expr::ScalarFunction, BinaryExpr, Operator};
12use datafusion::logical_expr::{Between, ScalarUDF, ScalarUDFImpl};
13use datafusion::prelude::*;
14use datafusion::scalar::ScalarValue;
15use datafusion_functions::core::getfield::GetFieldFunc;
16use lance_arrow::DataTypeExt;
17
18use lance_core::datatypes::Schema;
19use lance_core::{Error, Result};
20use snafu::location;
21fn resolve_value(expr: &Expr, data_type: &DataType) -> Result<Expr> {
23 match expr {
24 Expr::Literal(scalar_value, metadata) => {
25 Ok(Expr::Literal(safe_coerce_scalar(scalar_value, data_type).ok_or_else(|| Error::invalid_input(
26 format!("Received literal {expr} and could not convert to literal of type '{data_type:?}'"),
27 location!(),
28 ))?, metadata.clone()))
29 }
30 _ => Err(Error::invalid_input(
31 format!("Expected a literal of type '{data_type:?}' but received: {expr}"),
32 location!(),
33 )),
34 }
35}
36
37pub fn get_as_string_scalar_opt(expr: &Expr) -> Option<&str> {
40 match expr {
41 Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Some(s),
42 _ => None,
43 }
44}
45
46pub fn resolve_column_type(expr: &Expr, schema: &Schema) -> Option<DataType> {
52 let mut field_path = Vec::new();
53 let mut current_expr = expr;
54 loop {
56 match current_expr {
57 Expr::Column(c) => {
58 field_path.push(c.name.as_str());
59 break;
60 }
61 Expr::ScalarFunction(udf) => {
62 if udf.name() == GetFieldFunc::default().name() {
63 let name = get_as_string_scalar_opt(&udf.args[1])?;
64 field_path.push(name);
65 current_expr = &udf.args[0];
66 } else {
67 return None;
68 }
69 }
70 _ => return None,
71 }
72 }
73
74 let mut path_iter = field_path.iter().rev();
75 let mut field = schema.field(path_iter.next()?)?;
76 for name in path_iter {
77 if field.data_type().is_struct() {
78 field = field.children.iter().find(|f| &f.name == name)?;
79 } else {
80 return None;
81 }
82 }
83 Some(field.data_type())
84}
85
86pub fn resolve_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
93 match expr {
94 Expr::Between(Between {
95 expr: inner_expr,
96 low,
97 high,
98 negated,
99 }) => {
100 if let Some(inner_expr_type) = resolve_column_type(inner_expr.as_ref(), schema) {
101 Ok(Expr::Between(Between {
102 expr: inner_expr.clone(),
103 low: Box::new(coerce_expr(low.as_ref(), &inner_expr_type)?),
104 high: Box::new(coerce_expr(high.as_ref(), &inner_expr_type)?),
105 negated: *negated,
106 }))
107 } else {
108 Ok(expr.clone())
109 }
110 }
111 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
112 if matches!(op, Operator::And | Operator::Or) {
113 Ok(Expr::BinaryExpr(BinaryExpr {
114 left: Box::new(resolve_expr(left.as_ref(), schema)?),
115 op: *op,
116 right: Box::new(resolve_expr(right.as_ref(), schema)?),
117 }))
118 } else if let Some(left_type) = resolve_column_type(left.as_ref(), schema) {
119 match right.as_ref() {
120 Expr::Literal(..) => Ok(Expr::BinaryExpr(BinaryExpr {
121 left: left.clone(),
122 op: *op,
123 right: Box::new(resolve_value(right.as_ref(), &left_type)?),
124 })),
125 Expr::BinaryExpr(r) => Ok(Expr::BinaryExpr(BinaryExpr {
127 left: left.clone(),
128 op: *op,
129 right: Box::new(Expr::BinaryExpr(BinaryExpr {
130 left: coerce_expr(&r.left, &left_type).map(Box::new)?,
131 op: r.op,
132 right: coerce_expr(&r.right, &left_type).map(Box::new)?,
133 })),
134 })),
135 _ => Ok(expr.clone()),
136 }
137 } else if let Some(right_type) = resolve_column_type(right.as_ref(), schema) {
138 match left.as_ref() {
139 Expr::Literal(..) => Ok(Expr::BinaryExpr(BinaryExpr {
140 left: Box::new(resolve_value(left.as_ref(), &right_type)?),
141 op: *op,
142 right: right.clone(),
143 })),
144 _ => Ok(expr.clone()),
145 }
146 } else {
147 Ok(expr.clone())
148 }
149 }
150 Expr::InList(in_list) => {
151 if matches!(in_list.expr.as_ref(), Expr::Column(_)) {
152 if let Some(resolved_type) = resolve_column_type(in_list.expr.as_ref(), schema) {
153 let resolved_values = in_list
154 .list
155 .iter()
156 .map(|val| coerce_expr(val, &resolved_type))
157 .collect::<Result<Vec<_>>>()?;
158 Ok(Expr::in_list(
159 in_list.expr.as_ref().clone(),
160 resolved_values,
161 in_list.negated,
162 ))
163 } else {
164 Ok(expr.clone())
165 }
166 } else {
167 Ok(expr.clone())
168 }
169 }
170 _ => {
171 Ok(expr.clone())
173 }
174 }
175}
176
177pub fn coerce_expr(expr: &Expr, dtype: &DataType) -> Result<Expr> {
184 match expr {
185 Expr::BinaryExpr(BinaryExpr { left, op, right }) => Ok(Expr::BinaryExpr(BinaryExpr {
186 left: Box::new(coerce_expr(left, dtype)?),
187 op: *op,
188 right: Box::new(coerce_expr(right, dtype)?),
189 })),
190 literal_expr @ Expr::Literal(..) => Ok(resolve_value(literal_expr, dtype)?),
191 _ => Ok(expr.clone()),
192 }
193}
194
195pub fn coerce_filter_type_to_boolean(expr: Expr) -> Expr {
201 match expr {
202 Expr::ScalarFunction(sf) if sf.func.name() == "regexp_match" => {
204 log::warn!("regexp_match now is coerced to boolean, this may be changed in the future, please use `regexp_like` instead");
205 Expr::IsNotNull(Box::new(Expr::ScalarFunction(sf)))
206 }
207
208 Expr::BinaryExpr(BinaryExpr { left, op, right }) => Expr::BinaryExpr(BinaryExpr {
210 left: Box::new(coerce_filter_type_to_boolean(*left)),
211 op,
212 right: Box::new(coerce_filter_type_to_boolean(*right)),
213 }),
214 Expr::Not(inner) => Expr::Not(Box::new(coerce_filter_type_to_boolean(*inner))),
215 Expr::IsNull(inner) => Expr::IsNull(Box::new(coerce_filter_type_to_boolean(*inner))),
216 Expr::IsNotNull(inner) => Expr::IsNotNull(Box::new(coerce_filter_type_to_boolean(*inner))),
217
218 other => other,
220 }
221}
222
223pub trait ExprExt {
234 fn field_newstyle(&self, name: &str) -> Expr;
237}
238
239impl ExprExt for Expr {
240 fn field_newstyle(&self, name: &str) -> Expr {
241 Self::ScalarFunction(ScalarFunction {
242 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
243 args: vec![
244 self.clone(),
245 Self::Literal(ScalarValue::Utf8(Some(name.to_string())), None),
246 ],
247 })
248 }
249}
250
251pub fn field_path_to_expr(field_path: &str) -> Result<Expr> {
283 let parts = lance_core::datatypes::parse_field_path(field_path)?;
285
286 if parts.is_empty() {
287 return Err(Error::invalid_input(
288 format!("Invalid empty field path: {}", field_path),
289 location!(),
290 ));
291 }
292
293 let mut expr = col(&parts[0]);
295 for part in &parts[1..] {
296 expr = expr.field_newstyle(part);
297 }
298
299 Ok(expr)
300}
301
302#[cfg(test)]
303pub mod tests {
304 use std::sync::Arc;
305
306 use super::*;
307
308 use arrow_schema::{Field, Schema as ArrowSchema};
309 use datafusion_functions::core::expr_ext::FieldAccessor;
310
311 #[test]
312 fn test_resolve_large_utf8() {
313 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::LargeUtf8, false)]);
314 let expr = Expr::BinaryExpr(BinaryExpr {
315 left: Box::new(Expr::Column("a".to_string().into())),
316 op: Operator::Eq,
317 right: Box::new(Expr::Literal(
318 ScalarValue::Utf8(Some("a".to_string())),
319 None,
320 )),
321 });
322
323 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
324 match resolved {
325 Expr::BinaryExpr(be) => {
326 assert_eq!(
327 be.right.as_ref(),
328 &Expr::Literal(ScalarValue::LargeUtf8(Some("a".to_string())), None)
329 )
330 }
331 _ => unreachable!("Expected BinaryExpr"),
332 };
333 }
334
335 #[test]
336 fn test_resolve_binary_expr_on_right() {
337 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float64, false)]);
338 let expr = Expr::BinaryExpr(BinaryExpr {
339 left: Box::new(Expr::Column("a".to_string().into())),
340 op: Operator::Eq,
341 right: Box::new(Expr::BinaryExpr(BinaryExpr {
342 left: Box::new(Expr::Literal(ScalarValue::Int64(Some(2)), None)),
343 op: Operator::Minus,
344 right: Box::new(Expr::Literal(ScalarValue::Int64(Some(-1)), None)),
345 })),
346 });
347 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
348
349 match resolved {
350 Expr::BinaryExpr(be) => match be.right.as_ref() {
351 Expr::BinaryExpr(r_be) => {
352 assert_eq!(
353 r_be.left.as_ref(),
354 &Expr::Literal(ScalarValue::Float64(Some(2.0)), None)
355 );
356 assert_eq!(
357 r_be.right.as_ref(),
358 &Expr::Literal(ScalarValue::Float64(Some(-1.0)), None)
359 );
360 }
361 _ => panic!("Expected BinaryExpr"),
362 },
363 _ => panic!("Expected BinaryExpr"),
364 }
365 }
366
367 #[test]
368 fn test_resolve_in_expr() {
369 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float32, false)]);
371 let expr = Expr::in_list(
372 Expr::Column("a".to_string().into()),
373 vec![Expr::Literal(ScalarValue::Float64(Some(0.0)), None)],
374 false,
375 );
376 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
377 let expected = Expr::in_list(
378 Expr::Column("a".to_string().into()),
379 vec![Expr::Literal(ScalarValue::Float32(Some(0.0)), None)],
380 false,
381 );
382 assert_eq!(resolved, expected);
383
384 let expr = Expr::in_list(
385 Expr::Column("a".to_string().into()),
386 vec![Expr::Literal(ScalarValue::Float64(Some(0.0)), None)],
387 true,
388 );
389 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
390 let expected = Expr::in_list(
391 Expr::Column("a".to_string().into()),
392 vec![Expr::Literal(ScalarValue::Float32(Some(0.0)), None)],
393 true,
394 );
395 assert_eq!(resolved, expected);
396 }
397
398 #[test]
399 fn test_resolve_column_type() {
400 let schema = Arc::new(ArrowSchema::new(vec![
401 Field::new("int", DataType::Int32, true),
402 Field::new(
403 "st",
404 DataType::Struct(
405 vec![
406 Field::new("str", DataType::Utf8, true),
407 Field::new(
408 "st",
409 DataType::Struct(
410 vec![Field::new("float", DataType::Float64, true)].into(),
411 ),
412 true,
413 ),
414 ]
415 .into(),
416 ),
417 true,
418 ),
419 ]));
420 let schema = Schema::try_from(schema.as_ref()).unwrap();
421
422 assert_eq!(
423 resolve_column_type(&col("int"), &schema),
424 Some(DataType::Int32)
425 );
426 assert_eq!(
427 resolve_column_type(&col("st").field("str"), &schema),
428 Some(DataType::Utf8)
429 );
430 assert_eq!(
431 resolve_column_type(&col("st").field("st").field("float"), &schema),
432 Some(DataType::Float64)
433 );
434
435 assert_eq!(resolve_column_type(&col("x"), &schema), None);
436 assert_eq!(resolve_column_type(&col("str"), &schema), None);
437 assert_eq!(resolve_column_type(&col("float"), &schema), None);
438 assert_eq!(
439 resolve_column_type(&col("st").field("str").eq(lit("x")), &schema),
440 None
441 );
442 }
443}