lance_datafusion/
logical_expr.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Extends logical expression.
5
6use std::sync::Arc;
7
8use arrow_schema::DataType;
9
10use crate::expr::safe_coerce_scalar;
11use datafusion::logical_expr::{expr::ScalarFunction, BinaryExpr, Operator};
12use datafusion::logical_expr::{Between, ScalarUDF, ScalarUDFImpl};
13use datafusion::prelude::*;
14use datafusion::scalar::ScalarValue;
15use datafusion_functions::core::getfield::GetFieldFunc;
16use lance_arrow::DataTypeExt;
17
18use lance_core::datatypes::Schema;
19use lance_core::{Error, Result};
20use snafu::location;
21/// Resolve a Value
22fn resolve_value(expr: &Expr, data_type: &DataType) -> Result<Expr> {
23    match expr {
24        Expr::Literal(scalar_value, metadata) => {
25            Ok(Expr::Literal(safe_coerce_scalar(scalar_value, data_type).ok_or_else(|| Error::invalid_input(
26                format!("Received literal {expr} and could not convert to literal of type '{data_type:?}'"),
27                location!(),
28            ))?, metadata.clone()))
29        }
30        _ => Err(Error::invalid_input(
31            format!("Expected a literal of type '{data_type:?}' but received: {expr}"),
32            location!(),
33        )),
34    }
35}
36
37/// A simple helper function that interprets an Expr as a string scalar
38/// or returns None if it is not.
39pub fn get_as_string_scalar_opt(expr: &Expr) -> Option<&str> {
40    match expr {
41        Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Some(s),
42        _ => None,
43    }
44}
45
46/// Given a Expr::Column or Expr::GetIndexedField, get the data type of referenced
47/// field in the schema.
48///
49/// If the column is not found in the schema, return None. If the expression is
50/// not a field reference, also returns None.
51pub fn resolve_column_type(expr: &Expr, schema: &Schema) -> Option<DataType> {
52    let mut field_path = Vec::new();
53    let mut current_expr = expr;
54    // We are looping from outer-most reference to inner-most.
55    loop {
56        match current_expr {
57            Expr::Column(c) => {
58                field_path.push(c.name.as_str());
59                break;
60            }
61            Expr::ScalarFunction(udf) => {
62                if udf.name() == GetFieldFunc::default().name() {
63                    let name = get_as_string_scalar_opt(&udf.args[1])?;
64                    field_path.push(name);
65                    current_expr = &udf.args[0];
66                } else {
67                    return None;
68                }
69            }
70            _ => return None,
71        }
72    }
73
74    let mut path_iter = field_path.iter().rev();
75    let mut field = schema.field(path_iter.next()?)?;
76    for name in path_iter {
77        if field.data_type().is_struct() {
78            field = field.children.iter().find(|f| &f.name == name)?;
79        } else {
80            return None;
81        }
82    }
83    Some(field.data_type())
84}
85
86/// Resolve logical expression `expr`.
87///
88/// Parameters
89///
90/// - *expr*: a datafusion logical expression
91/// - *schema*: lance schema.
92pub fn resolve_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
93    match expr {
94        Expr::Between(Between {
95            expr: inner_expr,
96            low,
97            high,
98            negated,
99        }) => {
100            if let Some(inner_expr_type) = resolve_column_type(inner_expr.as_ref(), schema) {
101                Ok(Expr::Between(Between {
102                    expr: inner_expr.clone(),
103                    low: Box::new(coerce_expr(low.as_ref(), &inner_expr_type)?),
104                    high: Box::new(coerce_expr(high.as_ref(), &inner_expr_type)?),
105                    negated: *negated,
106                }))
107            } else {
108                Ok(expr.clone())
109            }
110        }
111        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
112            if matches!(op, Operator::And | Operator::Or) {
113                Ok(Expr::BinaryExpr(BinaryExpr {
114                    left: Box::new(resolve_expr(left.as_ref(), schema)?),
115                    op: *op,
116                    right: Box::new(resolve_expr(right.as_ref(), schema)?),
117                }))
118            } else if let Some(left_type) = resolve_column_type(left.as_ref(), schema) {
119                match right.as_ref() {
120                    Expr::Literal(..) => Ok(Expr::BinaryExpr(BinaryExpr {
121                        left: left.clone(),
122                        op: *op,
123                        right: Box::new(resolve_value(right.as_ref(), &left_type)?),
124                    })),
125                    // For cases complex expressions (not just literals) on right hand side like x = 1 + 1 + -2*2
126                    Expr::BinaryExpr(r) => Ok(Expr::BinaryExpr(BinaryExpr {
127                        left: left.clone(),
128                        op: *op,
129                        right: Box::new(Expr::BinaryExpr(BinaryExpr {
130                            left: coerce_expr(&r.left, &left_type).map(Box::new)?,
131                            op: r.op,
132                            right: coerce_expr(&r.right, &left_type).map(Box::new)?,
133                        })),
134                    })),
135                    _ => Ok(expr.clone()),
136                }
137            } else if let Some(right_type) = resolve_column_type(right.as_ref(), schema) {
138                match left.as_ref() {
139                    Expr::Literal(..) => Ok(Expr::BinaryExpr(BinaryExpr {
140                        left: Box::new(resolve_value(left.as_ref(), &right_type)?),
141                        op: *op,
142                        right: right.clone(),
143                    })),
144                    _ => Ok(expr.clone()),
145                }
146            } else {
147                Ok(expr.clone())
148            }
149        }
150        Expr::InList(in_list) => {
151            if matches!(in_list.expr.as_ref(), Expr::Column(_)) {
152                if let Some(resolved_type) = resolve_column_type(in_list.expr.as_ref(), schema) {
153                    let resolved_values = in_list
154                        .list
155                        .iter()
156                        .map(|val| coerce_expr(val, &resolved_type))
157                        .collect::<Result<Vec<_>>>()?;
158                    Ok(Expr::in_list(
159                        in_list.expr.as_ref().clone(),
160                        resolved_values,
161                        in_list.negated,
162                    ))
163                } else {
164                    Ok(expr.clone())
165                }
166            } else {
167                Ok(expr.clone())
168            }
169        }
170        _ => {
171            // Passthrough
172            Ok(expr.clone())
173        }
174    }
175}
176
177/// Coerce expression of literals to column type.
178///
179/// Parameters
180///
181/// - *expr*: a datafusion logical expression
182/// - *dtype*: a lance data type
183pub fn coerce_expr(expr: &Expr, dtype: &DataType) -> Result<Expr> {
184    match expr {
185        Expr::BinaryExpr(BinaryExpr { left, op, right }) => Ok(Expr::BinaryExpr(BinaryExpr {
186            left: Box::new(coerce_expr(left, dtype)?),
187            op: *op,
188            right: Box::new(coerce_expr(right, dtype)?),
189        })),
190        literal_expr @ Expr::Literal(..) => Ok(resolve_value(literal_expr, dtype)?),
191        _ => Ok(expr.clone()),
192    }
193}
194
195/// Coerce logical expression for filters to boolean.
196///
197/// Parameters
198///
199/// - *expr*: a datafusion logical expression
200pub fn coerce_filter_type_to_boolean(expr: Expr) -> Expr {
201    match expr {
202        // Coerce regexp_match to boolean by checking for non-null
203        Expr::ScalarFunction(sf) if sf.func.name() == "regexp_match" => {
204            log::warn!("regexp_match now is coerced to boolean, this may be changed in the future, please use `regexp_like` instead");
205            Expr::IsNotNull(Box::new(Expr::ScalarFunction(sf)))
206        }
207
208        // Recurse into boolean contexts so nested regexp_match terms are also coerced
209        Expr::BinaryExpr(BinaryExpr { left, op, right }) => Expr::BinaryExpr(BinaryExpr {
210            left: Box::new(coerce_filter_type_to_boolean(*left)),
211            op,
212            right: Box::new(coerce_filter_type_to_boolean(*right)),
213        }),
214        Expr::Not(inner) => Expr::Not(Box::new(coerce_filter_type_to_boolean(*inner))),
215        Expr::IsNull(inner) => Expr::IsNull(Box::new(coerce_filter_type_to_boolean(*inner))),
216        Expr::IsNotNull(inner) => Expr::IsNotNull(Box::new(coerce_filter_type_to_boolean(*inner))),
217
218        // Pass-through for all other nodes
219        other => other,
220    }
221}
222
223// As part of the DF 37 release there are now two different ways to
224// represent a nested field access in `Expr`.  The old way is to use
225// `Expr::field` which returns a `GetStructField` and the new way is
226// to use `Expr::ScalarFunction` with a `GetFieldFunc` UDF.
227//
228// Currently, the old path leads to bugs in DF.  This is probably a
229// bug and will probably be fixed in a future version.  In the meantime
230// we need to make sure we are always using the new way to avoid this
231// bug.  This trait adds field_newstyle which lets us easily create
232// logical `Expr` that use the new style.
233pub trait ExprExt {
234    // Helper function to replace Expr::field in DF 37 since DF
235    // confuses itself with the GetStructField returned by Expr::field
236    fn field_newstyle(&self, name: &str) -> Expr;
237}
238
239impl ExprExt for Expr {
240    fn field_newstyle(&self, name: &str) -> Expr {
241        Self::ScalarFunction(ScalarFunction {
242            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
243            args: vec![
244                self.clone(),
245                Self::Literal(ScalarValue::Utf8(Some(name.to_string())), None),
246            ],
247        })
248    }
249}
250
251/// Convert a field path string into a DataFusion expression.
252///
253/// This function handles:
254/// - Simple column names: "column"
255/// - Nested paths: "parent.child" or "parent.child.grandchild"
256/// - Backtick-escaped field names: "parent.`field.with.dots`"
257///
258/// # Arguments
259///
260/// * `field_path` - The field path to convert. Supports simple columns, nested paths,
261///   and backtick-escaped field names.
262///
263/// # Returns
264///
265/// Returns `Result<Expr>` - Ok with the DataFusion expression, or Err if the path
266/// could not be parsed.
267///
268/// # Example
269///
270/// ```
271/// use lance_datafusion::logical_expr::field_path_to_expr;
272///
273/// // Simple column
274/// let expr = field_path_to_expr("column_name").unwrap();
275///
276/// // Nested field
277/// let expr = field_path_to_expr("parent.child").unwrap();
278///
279/// // Backtick-escaped field with dots
280/// let expr = field_path_to_expr("parent.`field.with.dots`").unwrap();
281/// ```
282pub fn field_path_to_expr(field_path: &str) -> Result<Expr> {
283    // Parse the field path to handle nested fields and backtick-escaped names
284    let parts = lance_core::datatypes::parse_field_path(field_path)?;
285
286    if parts.is_empty() {
287        return Err(Error::invalid_input(
288            format!("Invalid empty field path: {}", field_path),
289            location!(),
290        ));
291    }
292
293    // Build the column expression, handling nested fields
294    let mut expr = col(&parts[0]);
295    for part in &parts[1..] {
296        expr = expr.field_newstyle(part);
297    }
298
299    Ok(expr)
300}
301
302#[cfg(test)]
303pub mod tests {
304    use std::sync::Arc;
305
306    use super::*;
307
308    use arrow_schema::{Field, Schema as ArrowSchema};
309    use datafusion_functions::core::expr_ext::FieldAccessor;
310
311    #[test]
312    fn test_resolve_large_utf8() {
313        let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::LargeUtf8, false)]);
314        let expr = Expr::BinaryExpr(BinaryExpr {
315            left: Box::new(Expr::Column("a".to_string().into())),
316            op: Operator::Eq,
317            right: Box::new(Expr::Literal(
318                ScalarValue::Utf8(Some("a".to_string())),
319                None,
320            )),
321        });
322
323        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
324        match resolved {
325            Expr::BinaryExpr(be) => {
326                assert_eq!(
327                    be.right.as_ref(),
328                    &Expr::Literal(ScalarValue::LargeUtf8(Some("a".to_string())), None)
329                )
330            }
331            _ => unreachable!("Expected BinaryExpr"),
332        };
333    }
334
335    #[test]
336    fn test_resolve_binary_expr_on_right() {
337        let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float64, false)]);
338        let expr = Expr::BinaryExpr(BinaryExpr {
339            left: Box::new(Expr::Column("a".to_string().into())),
340            op: Operator::Eq,
341            right: Box::new(Expr::BinaryExpr(BinaryExpr {
342                left: Box::new(Expr::Literal(ScalarValue::Int64(Some(2)), None)),
343                op: Operator::Minus,
344                right: Box::new(Expr::Literal(ScalarValue::Int64(Some(-1)), None)),
345            })),
346        });
347        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
348
349        match resolved {
350            Expr::BinaryExpr(be) => match be.right.as_ref() {
351                Expr::BinaryExpr(r_be) => {
352                    assert_eq!(
353                        r_be.left.as_ref(),
354                        &Expr::Literal(ScalarValue::Float64(Some(2.0)), None)
355                    );
356                    assert_eq!(
357                        r_be.right.as_ref(),
358                        &Expr::Literal(ScalarValue::Float64(Some(-1.0)), None)
359                    );
360                }
361                _ => panic!("Expected BinaryExpr"),
362            },
363            _ => panic!("Expected BinaryExpr"),
364        }
365    }
366
367    #[test]
368    fn test_resolve_in_expr() {
369        // Type coercion should apply for `A IN (0)` or `A NOT IN (0)`
370        let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float32, false)]);
371        let expr = Expr::in_list(
372            Expr::Column("a".to_string().into()),
373            vec![Expr::Literal(ScalarValue::Float64(Some(0.0)), None)],
374            false,
375        );
376        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
377        let expected = Expr::in_list(
378            Expr::Column("a".to_string().into()),
379            vec![Expr::Literal(ScalarValue::Float32(Some(0.0)), None)],
380            false,
381        );
382        assert_eq!(resolved, expected);
383
384        let expr = Expr::in_list(
385            Expr::Column("a".to_string().into()),
386            vec![Expr::Literal(ScalarValue::Float64(Some(0.0)), None)],
387            true,
388        );
389        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
390        let expected = Expr::in_list(
391            Expr::Column("a".to_string().into()),
392            vec![Expr::Literal(ScalarValue::Float32(Some(0.0)), None)],
393            true,
394        );
395        assert_eq!(resolved, expected);
396    }
397
398    #[test]
399    fn test_resolve_column_type() {
400        let schema = Arc::new(ArrowSchema::new(vec![
401            Field::new("int", DataType::Int32, true),
402            Field::new(
403                "st",
404                DataType::Struct(
405                    vec![
406                        Field::new("str", DataType::Utf8, true),
407                        Field::new(
408                            "st",
409                            DataType::Struct(
410                                vec![Field::new("float", DataType::Float64, true)].into(),
411                            ),
412                            true,
413                        ),
414                    ]
415                    .into(),
416                ),
417                true,
418            ),
419        ]));
420        let schema = Schema::try_from(schema.as_ref()).unwrap();
421
422        assert_eq!(
423            resolve_column_type(&col("int"), &schema),
424            Some(DataType::Int32)
425        );
426        assert_eq!(
427            resolve_column_type(&col("st").field("str"), &schema),
428            Some(DataType::Utf8)
429        );
430        assert_eq!(
431            resolve_column_type(&col("st").field("st").field("float"), &schema),
432            Some(DataType::Float64)
433        );
434
435        assert_eq!(resolve_column_type(&col("x"), &schema), None);
436        assert_eq!(resolve_column_type(&col("str"), &schema), None);
437        assert_eq!(resolve_column_type(&col("float"), &schema), None);
438        assert_eq!(
439            resolve_column_type(&col("st").field("str").eq(lit("x")), &schema),
440            None
441        );
442    }
443}