1use std::sync::Arc;
7
8use arrow_schema::DataType;
9
10use crate::expr::safe_coerce_scalar;
11use datafusion::logical_expr::{Between, ScalarUDF, ScalarUDFImpl};
12use datafusion::logical_expr::{BinaryExpr, Operator, expr::ScalarFunction};
13use datafusion::prelude::*;
14use datafusion::scalar::ScalarValue;
15use datafusion_functions::core::getfield::GetFieldFunc;
16use lance_arrow::DataTypeExt;
17
18use lance_core::datatypes::Schema;
19use lance_core::{Error, Result};
20fn resolve_value(expr: &Expr, data_type: &DataType) -> Result<Expr> {
22 match expr {
23 Expr::Literal(scalar_value, metadata) => {
24 Ok(Expr::Literal(safe_coerce_scalar(scalar_value, data_type).ok_or_else(|| Error::invalid_input(format!("Received literal {expr} and could not convert to literal of type '{data_type:?}'")))?, metadata.clone()))
25 }
26 _ => Err(Error::invalid_input(format!("Expected a literal of type '{data_type:?}' but received: {expr}"))),
27 }
28}
29
30pub fn get_as_string_scalar_opt(expr: &Expr) -> Option<&str> {
33 match expr {
34 Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Some(s),
35 _ => None,
36 }
37}
38
39pub fn resolve_column_type(expr: &Expr, schema: &Schema) -> Option<DataType> {
45 let mut field_path = Vec::new();
46 let mut current_expr = expr;
47 loop {
49 match current_expr {
50 Expr::Column(c) => {
51 field_path.push(c.name.as_str());
52 break;
53 }
54 Expr::ScalarFunction(udf) => {
55 if udf.name() == GetFieldFunc::default().name() {
56 let name = get_as_string_scalar_opt(&udf.args[1])?;
57 field_path.push(name);
58 current_expr = &udf.args[0];
59 } else {
60 return None;
61 }
62 }
63 _ => return None,
64 }
65 }
66
67 let mut path_iter = field_path.iter().rev();
68 let mut field = schema.field(path_iter.next()?)?;
69 for name in path_iter {
70 if field.data_type().is_struct() {
71 field = field.children.iter().find(|f| &f.name == name)?;
72 } else {
73 return None;
74 }
75 }
76 Some(field.data_type())
77}
78
79pub fn resolve_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
86 match expr {
87 Expr::Between(Between {
88 expr: inner_expr,
89 low,
90 high,
91 negated,
92 }) => {
93 if let Some(inner_expr_type) = resolve_column_type(inner_expr.as_ref(), schema) {
94 Ok(Expr::Between(Between {
95 expr: inner_expr.clone(),
96 low: Box::new(coerce_expr(low.as_ref(), &inner_expr_type)?),
97 high: Box::new(coerce_expr(high.as_ref(), &inner_expr_type)?),
98 negated: *negated,
99 }))
100 } else {
101 Ok(expr.clone())
102 }
103 }
104 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
105 if matches!(op, Operator::And | Operator::Or) {
106 Ok(Expr::BinaryExpr(BinaryExpr {
107 left: Box::new(resolve_expr(left.as_ref(), schema)?),
108 op: *op,
109 right: Box::new(resolve_expr(right.as_ref(), schema)?),
110 }))
111 } else if let Some(left_type) = resolve_column_type(left.as_ref(), schema) {
112 match right.as_ref() {
113 Expr::Literal(..) => Ok(Expr::BinaryExpr(BinaryExpr {
114 left: left.clone(),
115 op: *op,
116 right: Box::new(resolve_value(right.as_ref(), &left_type)?),
117 })),
118 Expr::BinaryExpr(r) => Ok(Expr::BinaryExpr(BinaryExpr {
120 left: left.clone(),
121 op: *op,
122 right: Box::new(Expr::BinaryExpr(BinaryExpr {
123 left: coerce_expr(&r.left, &left_type).map(Box::new)?,
124 op: r.op,
125 right: coerce_expr(&r.right, &left_type).map(Box::new)?,
126 })),
127 })),
128 _ => Ok(expr.clone()),
129 }
130 } else if let Some(right_type) = resolve_column_type(right.as_ref(), schema) {
131 match left.as_ref() {
132 Expr::Literal(..) => Ok(Expr::BinaryExpr(BinaryExpr {
133 left: Box::new(resolve_value(left.as_ref(), &right_type)?),
134 op: *op,
135 right: right.clone(),
136 })),
137 _ => Ok(expr.clone()),
138 }
139 } else {
140 Ok(expr.clone())
141 }
142 }
143 Expr::InList(in_list) => {
144 if matches!(in_list.expr.as_ref(), Expr::Column(_)) {
145 if let Some(resolved_type) = resolve_column_type(in_list.expr.as_ref(), schema) {
146 let resolved_values = in_list
147 .list
148 .iter()
149 .map(|val| coerce_expr(val, &resolved_type))
150 .collect::<Result<Vec<_>>>()?;
151 Ok(Expr::in_list(
152 in_list.expr.as_ref().clone(),
153 resolved_values,
154 in_list.negated,
155 ))
156 } else {
157 Ok(expr.clone())
158 }
159 } else {
160 Ok(expr.clone())
161 }
162 }
163 _ => {
164 Ok(expr.clone())
166 }
167 }
168}
169
170pub fn coerce_expr(expr: &Expr, dtype: &DataType) -> Result<Expr> {
177 match expr {
178 Expr::BinaryExpr(BinaryExpr { left, op, right }) => Ok(Expr::BinaryExpr(BinaryExpr {
179 left: Box::new(coerce_expr(left, dtype)?),
180 op: *op,
181 right: Box::new(coerce_expr(right, dtype)?),
182 })),
183 literal_expr @ Expr::Literal(..) => Ok(resolve_value(literal_expr, dtype)?),
184 _ => Ok(expr.clone()),
185 }
186}
187
188pub fn coerce_filter_type_to_boolean(expr: Expr) -> Expr {
194 match expr {
195 Expr::ScalarFunction(sf) if sf.func.name() == "regexp_match" => {
197 log::warn!(
198 "regexp_match now is coerced to boolean, this may be changed in the future, please use `regexp_like` instead"
199 );
200 Expr::IsNotNull(Box::new(Expr::ScalarFunction(sf)))
201 }
202
203 Expr::BinaryExpr(BinaryExpr { left, op, right }) => Expr::BinaryExpr(BinaryExpr {
205 left: Box::new(coerce_filter_type_to_boolean(*left)),
206 op,
207 right: Box::new(coerce_filter_type_to_boolean(*right)),
208 }),
209 Expr::Not(inner) => Expr::Not(Box::new(coerce_filter_type_to_boolean(*inner))),
210 Expr::IsNull(inner) => Expr::IsNull(Box::new(coerce_filter_type_to_boolean(*inner))),
211 Expr::IsNotNull(inner) => Expr::IsNotNull(Box::new(coerce_filter_type_to_boolean(*inner))),
212
213 other => other,
215 }
216}
217
218pub trait ExprExt {
229 fn field_newstyle(&self, name: &str) -> Expr;
232}
233
234impl ExprExt for Expr {
235 fn field_newstyle(&self, name: &str) -> Expr {
236 Self::ScalarFunction(ScalarFunction {
237 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
238 args: vec![
239 self.clone(),
240 Self::Literal(ScalarValue::Utf8(Some(name.to_string())), None),
241 ],
242 })
243 }
244}
245
246pub fn field_path_to_expr(field_path: &str) -> Result<Expr> {
278 let parts = lance_core::datatypes::parse_field_path(field_path)?;
280
281 if parts.is_empty() {
282 return Err(Error::invalid_input(format!(
283 "Invalid empty field path: {}",
284 field_path
285 )));
286 }
287
288 let mut expr = col(&parts[0]);
290 for part in &parts[1..] {
291 expr = expr.field_newstyle(part);
292 }
293
294 Ok(expr)
295}
296
297#[cfg(test)]
298pub mod tests {
299 use std::sync::Arc;
300
301 use super::*;
302
303 use arrow_schema::{Field, Schema as ArrowSchema};
304 use datafusion_functions::core::expr_ext::FieldAccessor;
305
306 #[test]
307 fn test_resolve_large_utf8() {
308 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::LargeUtf8, false)]);
309 let expr = Expr::BinaryExpr(BinaryExpr {
310 left: Box::new(Expr::Column("a".to_string().into())),
311 op: Operator::Eq,
312 right: Box::new(Expr::Literal(
313 ScalarValue::Utf8(Some("a".to_string())),
314 None,
315 )),
316 });
317
318 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
319 match resolved {
320 Expr::BinaryExpr(be) => {
321 assert_eq!(
322 be.right.as_ref(),
323 &Expr::Literal(ScalarValue::LargeUtf8(Some("a".to_string())), None)
324 )
325 }
326 _ => unreachable!("Expected BinaryExpr"),
327 };
328 }
329
330 #[test]
331 fn test_resolve_binary_expr_on_right() {
332 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float64, false)]);
333 let expr = Expr::BinaryExpr(BinaryExpr {
334 left: Box::new(Expr::Column("a".to_string().into())),
335 op: Operator::Eq,
336 right: Box::new(Expr::BinaryExpr(BinaryExpr {
337 left: Box::new(Expr::Literal(ScalarValue::Int64(Some(2)), None)),
338 op: Operator::Minus,
339 right: Box::new(Expr::Literal(ScalarValue::Int64(Some(-1)), None)),
340 })),
341 });
342 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
343
344 match resolved {
345 Expr::BinaryExpr(be) => match be.right.as_ref() {
346 Expr::BinaryExpr(r_be) => {
347 assert_eq!(
348 r_be.left.as_ref(),
349 &Expr::Literal(ScalarValue::Float64(Some(2.0)), None)
350 );
351 assert_eq!(
352 r_be.right.as_ref(),
353 &Expr::Literal(ScalarValue::Float64(Some(-1.0)), None)
354 );
355 }
356 _ => panic!("Expected BinaryExpr"),
357 },
358 _ => panic!("Expected BinaryExpr"),
359 }
360 }
361
362 #[test]
363 fn test_resolve_in_expr() {
364 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float32, false)]);
366 let expr = Expr::in_list(
367 Expr::Column("a".to_string().into()),
368 vec![Expr::Literal(ScalarValue::Float64(Some(0.0)), None)],
369 false,
370 );
371 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
372 let expected = Expr::in_list(
373 Expr::Column("a".to_string().into()),
374 vec![Expr::Literal(ScalarValue::Float32(Some(0.0)), None)],
375 false,
376 );
377 assert_eq!(resolved, expected);
378
379 let expr = Expr::in_list(
380 Expr::Column("a".to_string().into()),
381 vec![Expr::Literal(ScalarValue::Float64(Some(0.0)), None)],
382 true,
383 );
384 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
385 let expected = Expr::in_list(
386 Expr::Column("a".to_string().into()),
387 vec![Expr::Literal(ScalarValue::Float32(Some(0.0)), None)],
388 true,
389 );
390 assert_eq!(resolved, expected);
391 }
392
393 #[test]
394 fn test_resolve_column_type() {
395 let schema = Arc::new(ArrowSchema::new(vec![
396 Field::new("int", DataType::Int32, true),
397 Field::new(
398 "st",
399 DataType::Struct(
400 vec![
401 Field::new("str", DataType::Utf8, true),
402 Field::new(
403 "st",
404 DataType::Struct(
405 vec![Field::new("float", DataType::Float64, true)].into(),
406 ),
407 true,
408 ),
409 ]
410 .into(),
411 ),
412 true,
413 ),
414 ]));
415 let schema = Schema::try_from(schema.as_ref()).unwrap();
416
417 assert_eq!(
418 resolve_column_type(&col("int"), &schema),
419 Some(DataType::Int32)
420 );
421 assert_eq!(
422 resolve_column_type(&col("st").field("str"), &schema),
423 Some(DataType::Utf8)
424 );
425 assert_eq!(
426 resolve_column_type(&col("st").field("st").field("float"), &schema),
427 Some(DataType::Float64)
428 );
429
430 assert_eq!(resolve_column_type(&col("x"), &schema), None);
431 assert_eq!(resolve_column_type(&col("str"), &schema), None);
432 assert_eq!(resolve_column_type(&col("float"), &schema), None);
433 assert_eq!(
434 resolve_column_type(&col("st").field("str").eq(lit("x")), &schema),
435 None
436 );
437 }
438}