Skip to main content

lance_datafusion/
logical_expr.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Extends logical expression.
5
6use std::sync::Arc;
7
8use arrow_schema::DataType;
9
10use crate::expr::safe_coerce_scalar;
11use datafusion::logical_expr::{Between, ScalarUDF, ScalarUDFImpl};
12use datafusion::logical_expr::{BinaryExpr, Operator, expr::ScalarFunction};
13use datafusion::prelude::*;
14use datafusion::scalar::ScalarValue;
15use datafusion_functions::core::getfield::GetFieldFunc;
16use lance_arrow::DataTypeExt;
17
18use lance_core::datatypes::Schema;
19use lance_core::{Error, Result};
20/// Resolve a Value
21fn resolve_value(expr: &Expr, data_type: &DataType) -> Result<Expr> {
22    match expr {
23        Expr::Literal(scalar_value, metadata) => {
24            Ok(Expr::Literal(safe_coerce_scalar(scalar_value, data_type).ok_or_else(|| Error::invalid_input(format!("Received literal {expr} and could not convert to literal of type '{data_type:?}'")))?, metadata.clone()))
25        }
26        _ => Err(Error::invalid_input(format!("Expected a literal of type '{data_type:?}' but received: {expr}"))),
27    }
28}
29
30/// A simple helper function that interprets an Expr as a string scalar
31/// or returns None if it is not.
32pub fn get_as_string_scalar_opt(expr: &Expr) -> Option<&str> {
33    match expr {
34        Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Some(s),
35        _ => None,
36    }
37}
38
39/// Given a Expr::Column or Expr::GetIndexedField, get the data type of referenced
40/// field in the schema.
41///
42/// If the column is not found in the schema, return None. If the expression is
43/// not a field reference, also returns None.
44pub fn resolve_column_type(expr: &Expr, schema: &Schema) -> Option<DataType> {
45    let mut field_path = Vec::new();
46    let mut current_expr = expr;
47    // We are looping from outer-most reference to inner-most.
48    loop {
49        match current_expr {
50            Expr::Column(c) => {
51                field_path.push(c.name.as_str());
52                break;
53            }
54            Expr::ScalarFunction(udf) => {
55                if udf.name() == GetFieldFunc::default().name() {
56                    let name = get_as_string_scalar_opt(&udf.args[1])?;
57                    field_path.push(name);
58                    current_expr = &udf.args[0];
59                } else {
60                    return None;
61                }
62            }
63            _ => return None,
64        }
65    }
66
67    let mut path_iter = field_path.iter().rev();
68    let mut field = schema.field(path_iter.next()?)?;
69    for name in path_iter {
70        if field.data_type().is_struct() {
71            field = field.children.iter().find(|f| &f.name == name)?;
72        } else {
73            return None;
74        }
75    }
76    Some(field.data_type())
77}
78
79/// Resolve logical expression `expr`.
80///
81/// Parameters
82///
83/// - *expr*: a datafusion logical expression
84/// - *schema*: lance schema.
85pub fn resolve_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
86    match expr {
87        Expr::Between(Between {
88            expr: inner_expr,
89            low,
90            high,
91            negated,
92        }) => {
93            if let Some(inner_expr_type) = resolve_column_type(inner_expr.as_ref(), schema) {
94                Ok(Expr::Between(Between {
95                    expr: inner_expr.clone(),
96                    low: Box::new(coerce_expr(low.as_ref(), &inner_expr_type)?),
97                    high: Box::new(coerce_expr(high.as_ref(), &inner_expr_type)?),
98                    negated: *negated,
99                }))
100            } else {
101                Ok(expr.clone())
102            }
103        }
104        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
105            if matches!(op, Operator::And | Operator::Or) {
106                Ok(Expr::BinaryExpr(BinaryExpr {
107                    left: Box::new(resolve_expr(left.as_ref(), schema)?),
108                    op: *op,
109                    right: Box::new(resolve_expr(right.as_ref(), schema)?),
110                }))
111            } else if let Some(left_type) = resolve_column_type(left.as_ref(), schema) {
112                match right.as_ref() {
113                    Expr::Literal(..) => Ok(Expr::BinaryExpr(BinaryExpr {
114                        left: left.clone(),
115                        op: *op,
116                        right: Box::new(resolve_value(right.as_ref(), &left_type)?),
117                    })),
118                    // For cases complex expressions (not just literals) on right hand side like x = 1 + 1 + -2*2
119                    Expr::BinaryExpr(r) => Ok(Expr::BinaryExpr(BinaryExpr {
120                        left: left.clone(),
121                        op: *op,
122                        right: Box::new(Expr::BinaryExpr(BinaryExpr {
123                            left: coerce_expr(&r.left, &left_type).map(Box::new)?,
124                            op: r.op,
125                            right: coerce_expr(&r.right, &left_type).map(Box::new)?,
126                        })),
127                    })),
128                    _ => Ok(expr.clone()),
129                }
130            } else if let Some(right_type) = resolve_column_type(right.as_ref(), schema) {
131                match left.as_ref() {
132                    Expr::Literal(..) => Ok(Expr::BinaryExpr(BinaryExpr {
133                        left: Box::new(resolve_value(left.as_ref(), &right_type)?),
134                        op: *op,
135                        right: right.clone(),
136                    })),
137                    _ => Ok(expr.clone()),
138                }
139            } else {
140                Ok(expr.clone())
141            }
142        }
143        Expr::InList(in_list) => {
144            if matches!(in_list.expr.as_ref(), Expr::Column(_)) {
145                if let Some(resolved_type) = resolve_column_type(in_list.expr.as_ref(), schema) {
146                    let resolved_values = in_list
147                        .list
148                        .iter()
149                        .map(|val| coerce_expr(val, &resolved_type))
150                        .collect::<Result<Vec<_>>>()?;
151                    Ok(Expr::in_list(
152                        in_list.expr.as_ref().clone(),
153                        resolved_values,
154                        in_list.negated,
155                    ))
156                } else {
157                    Ok(expr.clone())
158                }
159            } else {
160                Ok(expr.clone())
161            }
162        }
163        _ => {
164            // Passthrough
165            Ok(expr.clone())
166        }
167    }
168}
169
170/// Coerce expression of literals to column type.
171///
172/// Parameters
173///
174/// - *expr*: a datafusion logical expression
175/// - *dtype*: a lance data type
176pub fn coerce_expr(expr: &Expr, dtype: &DataType) -> Result<Expr> {
177    match expr {
178        Expr::BinaryExpr(BinaryExpr { left, op, right }) => Ok(Expr::BinaryExpr(BinaryExpr {
179            left: Box::new(coerce_expr(left, dtype)?),
180            op: *op,
181            right: Box::new(coerce_expr(right, dtype)?),
182        })),
183        literal_expr @ Expr::Literal(..) => Ok(resolve_value(literal_expr, dtype)?),
184        _ => Ok(expr.clone()),
185    }
186}
187
188/// Coerce logical expression for filters to boolean.
189///
190/// Parameters
191///
192/// - *expr*: a datafusion logical expression
193pub fn coerce_filter_type_to_boolean(expr: Expr) -> Expr {
194    match expr {
195        // Coerce regexp_match to boolean by checking for non-null
196        Expr::ScalarFunction(sf) if sf.func.name() == "regexp_match" => {
197            log::warn!(
198                "regexp_match now is coerced to boolean, this may be changed in the future, please use `regexp_like` instead"
199            );
200            Expr::IsNotNull(Box::new(Expr::ScalarFunction(sf)))
201        }
202
203        // Recurse into boolean contexts so nested regexp_match terms are also coerced
204        Expr::BinaryExpr(BinaryExpr { left, op, right }) => Expr::BinaryExpr(BinaryExpr {
205            left: Box::new(coerce_filter_type_to_boolean(*left)),
206            op,
207            right: Box::new(coerce_filter_type_to_boolean(*right)),
208        }),
209        Expr::Not(inner) => Expr::Not(Box::new(coerce_filter_type_to_boolean(*inner))),
210        Expr::IsNull(inner) => Expr::IsNull(Box::new(coerce_filter_type_to_boolean(*inner))),
211        Expr::IsNotNull(inner) => Expr::IsNotNull(Box::new(coerce_filter_type_to_boolean(*inner))),
212
213        // Pass-through for all other nodes
214        other => other,
215    }
216}
217
218// As part of the DF 37 release there are now two different ways to
219// represent a nested field access in `Expr`.  The old way is to use
220// `Expr::field` which returns a `GetStructField` and the new way is
221// to use `Expr::ScalarFunction` with a `GetFieldFunc` UDF.
222//
223// Currently, the old path leads to bugs in DF.  This is probably a
224// bug and will probably be fixed in a future version.  In the meantime
225// we need to make sure we are always using the new way to avoid this
226// bug.  This trait adds field_newstyle which lets us easily create
227// logical `Expr` that use the new style.
228pub trait ExprExt {
229    // Helper function to replace Expr::field in DF 37 since DF
230    // confuses itself with the GetStructField returned by Expr::field
231    fn field_newstyle(&self, name: &str) -> Expr;
232}
233
234impl ExprExt for Expr {
235    fn field_newstyle(&self, name: &str) -> Expr {
236        Self::ScalarFunction(ScalarFunction {
237            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
238            args: vec![
239                self.clone(),
240                Self::Literal(ScalarValue::Utf8(Some(name.to_string())), None),
241            ],
242        })
243    }
244}
245
246/// Convert a field path string into a DataFusion expression.
247///
248/// This function handles:
249/// - Simple column names: "column"
250/// - Nested paths: "parent.child" or "parent.child.grandchild"
251/// - Backtick-escaped field names: "parent.`field.with.dots`"
252///
253/// # Arguments
254///
255/// * `field_path` - The field path to convert. Supports simple columns, nested paths,
256///   and backtick-escaped field names.
257///
258/// # Returns
259///
260/// Returns `Result<Expr>` - Ok with the DataFusion expression, or Err if the path
261/// could not be parsed.
262///
263/// # Example
264///
265/// ```
266/// use lance_datafusion::logical_expr::field_path_to_expr;
267///
268/// // Simple column
269/// let expr = field_path_to_expr("column_name").unwrap();
270///
271/// // Nested field
272/// let expr = field_path_to_expr("parent.child").unwrap();
273///
274/// // Backtick-escaped field with dots
275/// let expr = field_path_to_expr("parent.`field.with.dots`").unwrap();
276/// ```
277pub fn field_path_to_expr(field_path: &str) -> Result<Expr> {
278    // Parse the field path to handle nested fields and backtick-escaped names
279    let parts = lance_core::datatypes::parse_field_path(field_path)?;
280
281    if parts.is_empty() {
282        return Err(Error::invalid_input(format!(
283            "Invalid empty field path: {}",
284            field_path
285        )));
286    }
287
288    // Build the column expression, handling nested fields
289    let mut expr = col(&parts[0]);
290    for part in &parts[1..] {
291        expr = expr.field_newstyle(part);
292    }
293
294    Ok(expr)
295}
296
297#[cfg(test)]
298pub mod tests {
299    use std::sync::Arc;
300
301    use super::*;
302
303    use arrow_schema::{Field, Schema as ArrowSchema};
304    use datafusion_functions::core::expr_ext::FieldAccessor;
305
306    #[test]
307    fn test_resolve_large_utf8() {
308        let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::LargeUtf8, false)]);
309        let expr = Expr::BinaryExpr(BinaryExpr {
310            left: Box::new(Expr::Column("a".to_string().into())),
311            op: Operator::Eq,
312            right: Box::new(Expr::Literal(
313                ScalarValue::Utf8(Some("a".to_string())),
314                None,
315            )),
316        });
317
318        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
319        match resolved {
320            Expr::BinaryExpr(be) => {
321                assert_eq!(
322                    be.right.as_ref(),
323                    &Expr::Literal(ScalarValue::LargeUtf8(Some("a".to_string())), None)
324                )
325            }
326            _ => unreachable!("Expected BinaryExpr"),
327        };
328    }
329
330    #[test]
331    fn test_resolve_binary_expr_on_right() {
332        let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float64, false)]);
333        let expr = Expr::BinaryExpr(BinaryExpr {
334            left: Box::new(Expr::Column("a".to_string().into())),
335            op: Operator::Eq,
336            right: Box::new(Expr::BinaryExpr(BinaryExpr {
337                left: Box::new(Expr::Literal(ScalarValue::Int64(Some(2)), None)),
338                op: Operator::Minus,
339                right: Box::new(Expr::Literal(ScalarValue::Int64(Some(-1)), None)),
340            })),
341        });
342        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
343
344        match resolved {
345            Expr::BinaryExpr(be) => match be.right.as_ref() {
346                Expr::BinaryExpr(r_be) => {
347                    assert_eq!(
348                        r_be.left.as_ref(),
349                        &Expr::Literal(ScalarValue::Float64(Some(2.0)), None)
350                    );
351                    assert_eq!(
352                        r_be.right.as_ref(),
353                        &Expr::Literal(ScalarValue::Float64(Some(-1.0)), None)
354                    );
355                }
356                _ => panic!("Expected BinaryExpr"),
357            },
358            _ => panic!("Expected BinaryExpr"),
359        }
360    }
361
362    #[test]
363    fn test_resolve_in_expr() {
364        // Type coercion should apply for `A IN (0)` or `A NOT IN (0)`
365        let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float32, false)]);
366        let expr = Expr::in_list(
367            Expr::Column("a".to_string().into()),
368            vec![Expr::Literal(ScalarValue::Float64(Some(0.0)), None)],
369            false,
370        );
371        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
372        let expected = Expr::in_list(
373            Expr::Column("a".to_string().into()),
374            vec![Expr::Literal(ScalarValue::Float32(Some(0.0)), None)],
375            false,
376        );
377        assert_eq!(resolved, expected);
378
379        let expr = Expr::in_list(
380            Expr::Column("a".to_string().into()),
381            vec![Expr::Literal(ScalarValue::Float64(Some(0.0)), None)],
382            true,
383        );
384        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
385        let expected = Expr::in_list(
386            Expr::Column("a".to_string().into()),
387            vec![Expr::Literal(ScalarValue::Float32(Some(0.0)), None)],
388            true,
389        );
390        assert_eq!(resolved, expected);
391    }
392
393    #[test]
394    fn test_resolve_column_type() {
395        let schema = Arc::new(ArrowSchema::new(vec![
396            Field::new("int", DataType::Int32, true),
397            Field::new(
398                "st",
399                DataType::Struct(
400                    vec![
401                        Field::new("str", DataType::Utf8, true),
402                        Field::new(
403                            "st",
404                            DataType::Struct(
405                                vec![Field::new("float", DataType::Float64, true)].into(),
406                            ),
407                            true,
408                        ),
409                    ]
410                    .into(),
411                ),
412                true,
413            ),
414        ]));
415        let schema = Schema::try_from(schema.as_ref()).unwrap();
416
417        assert_eq!(
418            resolve_column_type(&col("int"), &schema),
419            Some(DataType::Int32)
420        );
421        assert_eq!(
422            resolve_column_type(&col("st").field("str"), &schema),
423            Some(DataType::Utf8)
424        );
425        assert_eq!(
426            resolve_column_type(&col("st").field("st").field("float"), &schema),
427            Some(DataType::Float64)
428        );
429
430        assert_eq!(resolve_column_type(&col("x"), &schema), None);
431        assert_eq!(resolve_column_type(&col("str"), &schema), None);
432        assert_eq!(resolve_column_type(&col("float"), &schema), None);
433        assert_eq!(
434            resolve_column_type(&col("st").field("str").eq(lit("x")), &schema),
435            None
436        );
437    }
438}