lance_datafusion/
logical_expr.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Extends logical expression.
5
6use std::sync::Arc;
7
8use arrow_schema::DataType;
9
10use crate::expr::safe_coerce_scalar;
11use datafusion::logical_expr::{expr::ScalarFunction, BinaryExpr, Operator};
12use datafusion::logical_expr::{Between, ScalarUDF, ScalarUDFImpl};
13use datafusion::prelude::*;
14use datafusion::scalar::ScalarValue;
15use datafusion_functions::core::getfield::GetFieldFunc;
16use lance_arrow::DataTypeExt;
17
18use lance_core::datatypes::Schema;
19use lance_core::{Error, Result};
20use snafu::location;
21/// Resolve a Value
22fn resolve_value(expr: &Expr, data_type: &DataType) -> Result<Expr> {
23    match expr {
24        Expr::Literal(scalar_value, metadata) => {
25            Ok(Expr::Literal(safe_coerce_scalar(scalar_value, data_type).ok_or_else(|| Error::invalid_input(
26                format!("Received literal {expr} and could not convert to literal of type '{data_type:?}'"),
27                location!(),
28            ))?, metadata.clone()))
29        }
30        _ => Err(Error::invalid_input(
31            format!("Expected a literal of type '{data_type:?}' but received: {expr}"),
32            location!(),
33        )),
34    }
35}
36
37/// A simple helper function that interprets an Expr as a string scalar
38/// or returns None if it is not.
39pub fn get_as_string_scalar_opt(expr: &Expr) -> Option<&str> {
40    match expr {
41        Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Some(s),
42        _ => None,
43    }
44}
45
46/// Given a Expr::Column or Expr::GetIndexedField, get the data type of referenced
47/// field in the schema.
48///
49/// If the column is not found in the schema, return None. If the expression is
50/// not a field reference, also returns None.
51pub fn resolve_column_type(expr: &Expr, schema: &Schema) -> Option<DataType> {
52    let mut field_path = Vec::new();
53    let mut current_expr = expr;
54    // We are looping from outer-most reference to inner-most.
55    loop {
56        match current_expr {
57            Expr::Column(c) => {
58                field_path.push(c.name.as_str());
59                break;
60            }
61            Expr::ScalarFunction(udf) => {
62                if udf.name() == GetFieldFunc::default().name() {
63                    let name = get_as_string_scalar_opt(&udf.args[1])?;
64                    field_path.push(name);
65                    current_expr = &udf.args[0];
66                } else {
67                    return None;
68                }
69            }
70            _ => return None,
71        }
72    }
73
74    let mut path_iter = field_path.iter().rev();
75    let mut field = schema.field(path_iter.next()?)?;
76    for name in path_iter {
77        if field.data_type().is_struct() {
78            field = field.children.iter().find(|f| &f.name == name)?;
79        } else {
80            return None;
81        }
82    }
83    Some(field.data_type())
84}
85
86/// Resolve logical expression `expr`.
87///
88/// Parameters
89///
90/// - *expr*: a datafusion logical expression
91/// - *schema*: lance schema.
92pub fn resolve_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
93    match expr {
94        Expr::Between(Between {
95            expr: inner_expr,
96            low,
97            high,
98            negated,
99        }) => {
100            if let Some(inner_expr_type) = resolve_column_type(inner_expr.as_ref(), schema) {
101                Ok(Expr::Between(Between {
102                    expr: inner_expr.clone(),
103                    low: Box::new(coerce_expr(low.as_ref(), &inner_expr_type)?),
104                    high: Box::new(coerce_expr(high.as_ref(), &inner_expr_type)?),
105                    negated: *negated,
106                }))
107            } else {
108                Ok(expr.clone())
109            }
110        }
111        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
112            if matches!(op, Operator::And | Operator::Or) {
113                Ok(Expr::BinaryExpr(BinaryExpr {
114                    left: Box::new(resolve_expr(left.as_ref(), schema)?),
115                    op: *op,
116                    right: Box::new(resolve_expr(right.as_ref(), schema)?),
117                }))
118            } else if let Some(left_type) = resolve_column_type(left.as_ref(), schema) {
119                match right.as_ref() {
120                    Expr::Literal(..) => Ok(Expr::BinaryExpr(BinaryExpr {
121                        left: left.clone(),
122                        op: *op,
123                        right: Box::new(resolve_value(right.as_ref(), &left_type)?),
124                    })),
125                    // For cases complex expressions (not just literals) on right hand side like x = 1 + 1 + -2*2
126                    Expr::BinaryExpr(r) => Ok(Expr::BinaryExpr(BinaryExpr {
127                        left: left.clone(),
128                        op: *op,
129                        right: Box::new(Expr::BinaryExpr(BinaryExpr {
130                            left: coerce_expr(&r.left, &left_type).map(Box::new)?,
131                            op: r.op,
132                            right: coerce_expr(&r.right, &left_type).map(Box::new)?,
133                        })),
134                    })),
135                    _ => Ok(expr.clone()),
136                }
137            } else if let Some(right_type) = resolve_column_type(right.as_ref(), schema) {
138                match left.as_ref() {
139                    Expr::Literal(..) => Ok(Expr::BinaryExpr(BinaryExpr {
140                        left: Box::new(resolve_value(left.as_ref(), &right_type)?),
141                        op: *op,
142                        right: right.clone(),
143                    })),
144                    _ => Ok(expr.clone()),
145                }
146            } else {
147                Ok(expr.clone())
148            }
149        }
150        Expr::InList(in_list) => {
151            if matches!(in_list.expr.as_ref(), Expr::Column(_)) {
152                if let Some(resolved_type) = resolve_column_type(in_list.expr.as_ref(), schema) {
153                    let resolved_values = in_list
154                        .list
155                        .iter()
156                        .map(|val| coerce_expr(val, &resolved_type))
157                        .collect::<Result<Vec<_>>>()?;
158                    Ok(Expr::in_list(
159                        in_list.expr.as_ref().clone(),
160                        resolved_values,
161                        in_list.negated,
162                    ))
163                } else {
164                    Ok(expr.clone())
165                }
166            } else {
167                Ok(expr.clone())
168            }
169        }
170        _ => {
171            // Passthrough
172            Ok(expr.clone())
173        }
174    }
175}
176
177/// Coerce expression of literals to column type.
178///
179/// Parameters
180///
181/// - *expr*: a datafusion logical expression
182/// - *dtype*: a lance data type
183pub fn coerce_expr(expr: &Expr, dtype: &DataType) -> Result<Expr> {
184    match expr {
185        Expr::BinaryExpr(BinaryExpr { left, op, right }) => Ok(Expr::BinaryExpr(BinaryExpr {
186            left: Box::new(coerce_expr(left, dtype)?),
187            op: *op,
188            right: Box::new(coerce_expr(right, dtype)?),
189        })),
190        literal_expr @ Expr::Literal(..) => Ok(resolve_value(literal_expr, dtype)?),
191        _ => Ok(expr.clone()),
192    }
193}
194
195/// Coerce logical expression for filters to boolean.
196///
197/// Parameters
198///
199/// - *expr*: a datafusion logical expression
200pub fn coerce_filter_type_to_boolean(expr: Expr) -> Expr {
201    match &expr {
202        // TODO: consider making this dispatch more generic, i.e. fun.output_type -> coerce
203        // instead of hardcoding coerce method for each function
204        Expr::ScalarFunction(ScalarFunction { func, .. }) => {
205            if func.name() == "regexp_match" {
206                Expr::IsNotNull(Box::new(expr))
207            } else {
208                expr
209            }
210        }
211        _ => expr,
212    }
213}
214
215// As part of the DF 37 release there are now two different ways to
216// represent a nested field access in `Expr`.  The old way is to use
217// `Expr::field` which returns a `GetStructField` and the new way is
218// to use `Expr::ScalarFunction` with a `GetFieldFunc` UDF.
219//
220// Currently, the old path leads to bugs in DF.  This is probably a
221// bug and will probably be fixed in a future version.  In the meantime
222// we need to make sure we are always using the new way to avoid this
223// bug.  This trait adds field_newstyle which lets us easily create
224// logical `Expr` that use the new style.
225pub trait ExprExt {
226    // Helper function to replace Expr::field in DF 37 since DF
227    // confuses itself with the GetStructField returned by Expr::field
228    fn field_newstyle(&self, name: &str) -> Expr;
229}
230
231impl ExprExt for Expr {
232    fn field_newstyle(&self, name: &str) -> Expr {
233        Self::ScalarFunction(ScalarFunction {
234            func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
235            args: vec![
236                self.clone(),
237                Self::Literal(ScalarValue::Utf8(Some(name.to_string())), None),
238            ],
239        })
240    }
241}
242
243#[cfg(test)]
244pub mod tests {
245    use std::sync::Arc;
246
247    use super::*;
248
249    use arrow_schema::{Field, Schema as ArrowSchema};
250    use datafusion_functions::core::expr_ext::FieldAccessor;
251
252    #[test]
253    fn test_resolve_large_utf8() {
254        let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::LargeUtf8, false)]);
255        let expr = Expr::BinaryExpr(BinaryExpr {
256            left: Box::new(Expr::Column("a".to_string().into())),
257            op: Operator::Eq,
258            right: Box::new(Expr::Literal(
259                ScalarValue::Utf8(Some("a".to_string())),
260                None,
261            )),
262        });
263
264        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
265        match resolved {
266            Expr::BinaryExpr(be) => {
267                assert_eq!(
268                    be.right.as_ref(),
269                    &Expr::Literal(ScalarValue::LargeUtf8(Some("a".to_string())), None)
270                )
271            }
272            _ => unreachable!("Expected BinaryExpr"),
273        };
274    }
275
276    #[test]
277    fn test_resolve_binary_expr_on_right() {
278        let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float64, false)]);
279        let expr = Expr::BinaryExpr(BinaryExpr {
280            left: Box::new(Expr::Column("a".to_string().into())),
281            op: Operator::Eq,
282            right: Box::new(Expr::BinaryExpr(BinaryExpr {
283                left: Box::new(Expr::Literal(ScalarValue::Int64(Some(2)), None)),
284                op: Operator::Minus,
285                right: Box::new(Expr::Literal(ScalarValue::Int64(Some(-1)), None)),
286            })),
287        });
288        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
289
290        match resolved {
291            Expr::BinaryExpr(be) => match be.right.as_ref() {
292                Expr::BinaryExpr(r_be) => {
293                    assert_eq!(
294                        r_be.left.as_ref(),
295                        &Expr::Literal(ScalarValue::Float64(Some(2.0)), None)
296                    );
297                    assert_eq!(
298                        r_be.right.as_ref(),
299                        &Expr::Literal(ScalarValue::Float64(Some(-1.0)), None)
300                    );
301                }
302                _ => panic!("Expected BinaryExpr"),
303            },
304            _ => panic!("Expected BinaryExpr"),
305        }
306    }
307
308    #[test]
309    fn test_resolve_in_expr() {
310        // Type coercion should apply for `A IN (0)` or `A NOT IN (0)`
311        let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float32, false)]);
312        let expr = Expr::in_list(
313            Expr::Column("a".to_string().into()),
314            vec![Expr::Literal(ScalarValue::Float64(Some(0.0)), None)],
315            false,
316        );
317        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
318        let expected = Expr::in_list(
319            Expr::Column("a".to_string().into()),
320            vec![Expr::Literal(ScalarValue::Float32(Some(0.0)), None)],
321            false,
322        );
323        assert_eq!(resolved, expected);
324
325        let expr = Expr::in_list(
326            Expr::Column("a".to_string().into()),
327            vec![Expr::Literal(ScalarValue::Float64(Some(0.0)), None)],
328            true,
329        );
330        let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
331        let expected = Expr::in_list(
332            Expr::Column("a".to_string().into()),
333            vec![Expr::Literal(ScalarValue::Float32(Some(0.0)), None)],
334            true,
335        );
336        assert_eq!(resolved, expected);
337    }
338
339    #[test]
340    fn test_resolve_column_type() {
341        let schema = Arc::new(ArrowSchema::new(vec![
342            Field::new("int", DataType::Int32, true),
343            Field::new(
344                "st",
345                DataType::Struct(
346                    vec![
347                        Field::new("str", DataType::Utf8, true),
348                        Field::new(
349                            "st",
350                            DataType::Struct(
351                                vec![Field::new("float", DataType::Float64, true)].into(),
352                            ),
353                            true,
354                        ),
355                    ]
356                    .into(),
357                ),
358                true,
359            ),
360        ]));
361        let schema = Schema::try_from(schema.as_ref()).unwrap();
362
363        assert_eq!(
364            resolve_column_type(&col("int"), &schema),
365            Some(DataType::Int32)
366        );
367        assert_eq!(
368            resolve_column_type(&col("st").field("str"), &schema),
369            Some(DataType::Utf8)
370        );
371        assert_eq!(
372            resolve_column_type(&col("st").field("st").field("float"), &schema),
373            Some(DataType::Float64)
374        );
375
376        assert_eq!(resolve_column_type(&col("x"), &schema), None);
377        assert_eq!(resolve_column_type(&col("str"), &schema), None);
378        assert_eq!(resolve_column_type(&col("float"), &schema), None);
379        assert_eq!(
380            resolve_column_type(&col("st").field("str").eq(lit("x")), &schema),
381            None
382        );
383    }
384}