1use std::sync::Arc;
7
8use arrow_schema::DataType;
9
10use crate::expr::safe_coerce_scalar;
11use datafusion::logical_expr::{expr::ScalarFunction, BinaryExpr, Operator};
12use datafusion::logical_expr::{Between, ScalarUDF, ScalarUDFImpl};
13use datafusion::prelude::*;
14use datafusion::scalar::ScalarValue;
15use datafusion_functions::core::getfield::GetFieldFunc;
16use lance_arrow::DataTypeExt;
17
18use lance_core::datatypes::Schema;
19use lance_core::{Error, Result};
20use snafu::location;
21fn resolve_value(expr: &Expr, data_type: &DataType) -> Result<Expr> {
23 match expr {
24 Expr::Literal(scalar_value, metadata) => {
25 Ok(Expr::Literal(safe_coerce_scalar(scalar_value, data_type).ok_or_else(|| Error::invalid_input(
26 format!("Received literal {expr} and could not convert to literal of type '{data_type:?}'"),
27 location!(),
28 ))?, metadata.clone()))
29 }
30 _ => Err(Error::invalid_input(
31 format!("Expected a literal of type '{data_type:?}' but received: {expr}"),
32 location!(),
33 )),
34 }
35}
36
37pub fn get_as_string_scalar_opt(expr: &Expr) -> Option<&str> {
40 match expr {
41 Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Some(s),
42 _ => None,
43 }
44}
45
46pub fn resolve_column_type(expr: &Expr, schema: &Schema) -> Option<DataType> {
52 let mut field_path = Vec::new();
53 let mut current_expr = expr;
54 loop {
56 match current_expr {
57 Expr::Column(c) => {
58 field_path.push(c.name.as_str());
59 break;
60 }
61 Expr::ScalarFunction(udf) => {
62 if udf.name() == GetFieldFunc::default().name() {
63 let name = get_as_string_scalar_opt(&udf.args[1])?;
64 field_path.push(name);
65 current_expr = &udf.args[0];
66 } else {
67 return None;
68 }
69 }
70 _ => return None,
71 }
72 }
73
74 let mut path_iter = field_path.iter().rev();
75 let mut field = schema.field(path_iter.next()?)?;
76 for name in path_iter {
77 if field.data_type().is_struct() {
78 field = field.children.iter().find(|f| &f.name == name)?;
79 } else {
80 return None;
81 }
82 }
83 Some(field.data_type())
84}
85
86pub fn resolve_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
93 match expr {
94 Expr::Between(Between {
95 expr: inner_expr,
96 low,
97 high,
98 negated,
99 }) => {
100 if let Some(inner_expr_type) = resolve_column_type(inner_expr.as_ref(), schema) {
101 Ok(Expr::Between(Between {
102 expr: inner_expr.clone(),
103 low: Box::new(coerce_expr(low.as_ref(), &inner_expr_type)?),
104 high: Box::new(coerce_expr(high.as_ref(), &inner_expr_type)?),
105 negated: *negated,
106 }))
107 } else {
108 Ok(expr.clone())
109 }
110 }
111 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
112 if matches!(op, Operator::And | Operator::Or) {
113 Ok(Expr::BinaryExpr(BinaryExpr {
114 left: Box::new(resolve_expr(left.as_ref(), schema)?),
115 op: *op,
116 right: Box::new(resolve_expr(right.as_ref(), schema)?),
117 }))
118 } else if let Some(left_type) = resolve_column_type(left.as_ref(), schema) {
119 match right.as_ref() {
120 Expr::Literal(..) => Ok(Expr::BinaryExpr(BinaryExpr {
121 left: left.clone(),
122 op: *op,
123 right: Box::new(resolve_value(right.as_ref(), &left_type)?),
124 })),
125 Expr::BinaryExpr(r) => Ok(Expr::BinaryExpr(BinaryExpr {
127 left: left.clone(),
128 op: *op,
129 right: Box::new(Expr::BinaryExpr(BinaryExpr {
130 left: coerce_expr(&r.left, &left_type).map(Box::new)?,
131 op: r.op,
132 right: coerce_expr(&r.right, &left_type).map(Box::new)?,
133 })),
134 })),
135 _ => Ok(expr.clone()),
136 }
137 } else if let Some(right_type) = resolve_column_type(right.as_ref(), schema) {
138 match left.as_ref() {
139 Expr::Literal(..) => Ok(Expr::BinaryExpr(BinaryExpr {
140 left: Box::new(resolve_value(left.as_ref(), &right_type)?),
141 op: *op,
142 right: right.clone(),
143 })),
144 _ => Ok(expr.clone()),
145 }
146 } else {
147 Ok(expr.clone())
148 }
149 }
150 Expr::InList(in_list) => {
151 if matches!(in_list.expr.as_ref(), Expr::Column(_)) {
152 if let Some(resolved_type) = resolve_column_type(in_list.expr.as_ref(), schema) {
153 let resolved_values = in_list
154 .list
155 .iter()
156 .map(|val| coerce_expr(val, &resolved_type))
157 .collect::<Result<Vec<_>>>()?;
158 Ok(Expr::in_list(
159 in_list.expr.as_ref().clone(),
160 resolved_values,
161 in_list.negated,
162 ))
163 } else {
164 Ok(expr.clone())
165 }
166 } else {
167 Ok(expr.clone())
168 }
169 }
170 _ => {
171 Ok(expr.clone())
173 }
174 }
175}
176
177pub fn coerce_expr(expr: &Expr, dtype: &DataType) -> Result<Expr> {
184 match expr {
185 Expr::BinaryExpr(BinaryExpr { left, op, right }) => Ok(Expr::BinaryExpr(BinaryExpr {
186 left: Box::new(coerce_expr(left, dtype)?),
187 op: *op,
188 right: Box::new(coerce_expr(right, dtype)?),
189 })),
190 literal_expr @ Expr::Literal(..) => Ok(resolve_value(literal_expr, dtype)?),
191 _ => Ok(expr.clone()),
192 }
193}
194
195pub fn coerce_filter_type_to_boolean(expr: Expr) -> Expr {
201 match &expr {
202 Expr::ScalarFunction(ScalarFunction { func, .. }) => {
205 if func.name() == "regexp_match" {
206 Expr::IsNotNull(Box::new(expr))
207 } else {
208 expr
209 }
210 }
211 _ => expr,
212 }
213}
214
215pub trait ExprExt {
226 fn field_newstyle(&self, name: &str) -> Expr;
229}
230
231impl ExprExt for Expr {
232 fn field_newstyle(&self, name: &str) -> Expr {
233 Self::ScalarFunction(ScalarFunction {
234 func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())),
235 args: vec![
236 self.clone(),
237 Self::Literal(ScalarValue::Utf8(Some(name.to_string())), None),
238 ],
239 })
240 }
241}
242
243#[cfg(test)]
244pub mod tests {
245 use std::sync::Arc;
246
247 use super::*;
248
249 use arrow_schema::{Field, Schema as ArrowSchema};
250 use datafusion_functions::core::expr_ext::FieldAccessor;
251
252 #[test]
253 fn test_resolve_large_utf8() {
254 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::LargeUtf8, false)]);
255 let expr = Expr::BinaryExpr(BinaryExpr {
256 left: Box::new(Expr::Column("a".to_string().into())),
257 op: Operator::Eq,
258 right: Box::new(Expr::Literal(
259 ScalarValue::Utf8(Some("a".to_string())),
260 None,
261 )),
262 });
263
264 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
265 match resolved {
266 Expr::BinaryExpr(be) => {
267 assert_eq!(
268 be.right.as_ref(),
269 &Expr::Literal(ScalarValue::LargeUtf8(Some("a".to_string())), None)
270 )
271 }
272 _ => unreachable!("Expected BinaryExpr"),
273 };
274 }
275
276 #[test]
277 fn test_resolve_binary_expr_on_right() {
278 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float64, false)]);
279 let expr = Expr::BinaryExpr(BinaryExpr {
280 left: Box::new(Expr::Column("a".to_string().into())),
281 op: Operator::Eq,
282 right: Box::new(Expr::BinaryExpr(BinaryExpr {
283 left: Box::new(Expr::Literal(ScalarValue::Int64(Some(2)), None)),
284 op: Operator::Minus,
285 right: Box::new(Expr::Literal(ScalarValue::Int64(Some(-1)), None)),
286 })),
287 });
288 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
289
290 match resolved {
291 Expr::BinaryExpr(be) => match be.right.as_ref() {
292 Expr::BinaryExpr(r_be) => {
293 assert_eq!(
294 r_be.left.as_ref(),
295 &Expr::Literal(ScalarValue::Float64(Some(2.0)), None)
296 );
297 assert_eq!(
298 r_be.right.as_ref(),
299 &Expr::Literal(ScalarValue::Float64(Some(-1.0)), None)
300 );
301 }
302 _ => panic!("Expected BinaryExpr"),
303 },
304 _ => panic!("Expected BinaryExpr"),
305 }
306 }
307
308 #[test]
309 fn test_resolve_in_expr() {
310 let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float32, false)]);
312 let expr = Expr::in_list(
313 Expr::Column("a".to_string().into()),
314 vec![Expr::Literal(ScalarValue::Float64(Some(0.0)), None)],
315 false,
316 );
317 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
318 let expected = Expr::in_list(
319 Expr::Column("a".to_string().into()),
320 vec![Expr::Literal(ScalarValue::Float32(Some(0.0)), None)],
321 false,
322 );
323 assert_eq!(resolved, expected);
324
325 let expr = Expr::in_list(
326 Expr::Column("a".to_string().into()),
327 vec![Expr::Literal(ScalarValue::Float64(Some(0.0)), None)],
328 true,
329 );
330 let resolved = resolve_expr(&expr, &Schema::try_from(&arrow_schema).unwrap()).unwrap();
331 let expected = Expr::in_list(
332 Expr::Column("a".to_string().into()),
333 vec![Expr::Literal(ScalarValue::Float32(Some(0.0)), None)],
334 true,
335 );
336 assert_eq!(resolved, expected);
337 }
338
339 #[test]
340 fn test_resolve_column_type() {
341 let schema = Arc::new(ArrowSchema::new(vec![
342 Field::new("int", DataType::Int32, true),
343 Field::new(
344 "st",
345 DataType::Struct(
346 vec![
347 Field::new("str", DataType::Utf8, true),
348 Field::new(
349 "st",
350 DataType::Struct(
351 vec![Field::new("float", DataType::Float64, true)].into(),
352 ),
353 true,
354 ),
355 ]
356 .into(),
357 ),
358 true,
359 ),
360 ]));
361 let schema = Schema::try_from(schema.as_ref()).unwrap();
362
363 assert_eq!(
364 resolve_column_type(&col("int"), &schema),
365 Some(DataType::Int32)
366 );
367 assert_eq!(
368 resolve_column_type(&col("st").field("str"), &schema),
369 Some(DataType::Utf8)
370 );
371 assert_eq!(
372 resolve_column_type(&col("st").field("st").field("float"), &schema),
373 Some(DataType::Float64)
374 );
375
376 assert_eq!(resolve_column_type(&col("x"), &schema), None);
377 assert_eq!(resolve_column_type(&col("str"), &schema), None);
378 assert_eq!(resolve_column_type(&col("float"), &schema), None);
379 assert_eq!(
380 resolve_column_type(&col("st").field("str").eq(lit("x")), &schema),
381 None
382 );
383 }
384}