Skip to main content

indexlake_datafusion/
expr.rs

1use arrow::datatypes::SchemaRef;
2use datafusion::prelude::SessionContext;
3use datafusion_common::DataFusionError;
4use datafusion_common::tree_node::TreeNode;
5use datafusion_common::{DFSchema, ScalarValue};
6use datafusion_expr::Expr;
7use datafusion_expr::{ExprSchemable, Operator};
8use datafusion_optimizer::analyzer::type_coercion::TypeCoercionRewriter;
9use indexlake::catalog::Scalar as ILScalar;
10use indexlake::expr::{BinaryOp as ILOperator, Expr as ILExpr};
11
12pub fn datafusion_expr_to_indexlake_expr(
13    expr: &Expr,
14    schema: &DFSchema,
15) -> Result<ILExpr, DataFusionError> {
16    match expr {
17        Expr::Alias(alias) => datafusion_expr_to_indexlake_expr(&alias.expr, schema),
18        Expr::Column(col) => Ok(ILExpr::Column(col.name.clone())),
19        Expr::Literal(lit, _) => Ok(datafusion_scalar_to_indexlake_scalar(lit)?.into()),
20        Expr::BinaryExpr(binary) => {
21            let left = Box::new(datafusion_expr_to_indexlake_expr(&binary.left, schema)?);
22            let op = datafusion_operator_to_indexlake_operator(&binary.op)?;
23            let right = Box::new(datafusion_expr_to_indexlake_expr(&binary.right, schema)?);
24            Ok(ILExpr::BinaryExpr(indexlake::expr::BinaryExpr {
25                left,
26                op,
27                right,
28            }))
29        }
30        Expr::Not(expr) => {
31            let expr = Box::new(datafusion_expr_to_indexlake_expr(expr, schema)?);
32            Ok(ILExpr::Not(expr))
33        }
34        Expr::IsNull(expr) => {
35            let expr = Box::new(datafusion_expr_to_indexlake_expr(expr, schema)?);
36            Ok(ILExpr::IsNull(expr))
37        }
38        Expr::IsNotNull(expr) => {
39            let expr = Box::new(datafusion_expr_to_indexlake_expr(expr, schema)?);
40            Ok(ILExpr::IsNotNull(expr))
41        }
42        Expr::IsTrue(expr) => {
43            let expr = Box::new(datafusion_expr_to_indexlake_expr(expr, schema)?);
44            Ok(ILExpr::BinaryExpr(indexlake::expr::BinaryExpr {
45                left: expr,
46                op: ILOperator::IsNotDistinctFrom,
47                right: Box::new(indexlake::expr::lit(true)),
48            }))
49        }
50        Expr::IsNotTrue(expr) => {
51            let expr = Box::new(datafusion_expr_to_indexlake_expr(expr, schema)?);
52            Ok(ILExpr::BinaryExpr(indexlake::expr::BinaryExpr {
53                left: expr,
54                op: ILOperator::IsDistinctFrom,
55                right: Box::new(indexlake::expr::lit(true)),
56            }))
57        }
58        Expr::IsFalse(expr) => {
59            let expr = Box::new(datafusion_expr_to_indexlake_expr(expr, schema)?);
60            Ok(ILExpr::BinaryExpr(indexlake::expr::BinaryExpr {
61                left: expr,
62                op: ILOperator::IsNotDistinctFrom,
63                right: Box::new(indexlake::expr::lit(false)),
64            }))
65        }
66        Expr::IsNotFalse(expr) => {
67            let expr = Box::new(datafusion_expr_to_indexlake_expr(expr, schema)?);
68            Ok(ILExpr::BinaryExpr(indexlake::expr::BinaryExpr {
69                left: expr,
70                op: ILOperator::IsDistinctFrom,
71                right: Box::new(indexlake::expr::lit(false)),
72            }))
73        }
74        Expr::IsUnknown(expr) => {
75            let expr = Box::new(datafusion_expr_to_indexlake_expr(expr, schema)?);
76            Ok(ILExpr::BinaryExpr(indexlake::expr::BinaryExpr {
77                left: expr,
78                op: ILOperator::IsNotDistinctFrom,
79                right: Box::new(ILScalar::Boolean(None).into()),
80            }))
81        }
82        Expr::IsNotUnknown(expr) => {
83            let expr = Box::new(datafusion_expr_to_indexlake_expr(expr, schema)?);
84            Ok(ILExpr::BinaryExpr(indexlake::expr::BinaryExpr {
85                left: expr,
86                op: ILOperator::IsDistinctFrom,
87                right: Box::new(ILScalar::Boolean(None).into()),
88            }))
89        }
90        Expr::Between(between) => {
91            let expr = datafusion_expr_to_indexlake_expr(&between.expr, schema)?;
92            let low = datafusion_expr_to_indexlake_expr(&between.low, schema)?;
93            let high = datafusion_expr_to_indexlake_expr(&between.high, schema)?;
94
95            let left_expr = expr.clone().gteq(low);
96            let right_expr = expr.lteq(high);
97            let and_expr = left_expr.and(right_expr);
98            if between.negated {
99                Ok(ILExpr::Not(Box::new(and_expr)))
100            } else {
101                Ok(and_expr)
102            }
103        }
104        Expr::InList(in_list) => {
105            let expr = Box::new(datafusion_expr_to_indexlake_expr(&in_list.expr, schema)?);
106            let list = in_list
107                .list
108                .iter()
109                .map(|expr| datafusion_expr_to_indexlake_expr(expr, schema))
110                .collect::<Result<Vec<_>, _>>()?;
111            Ok(ILExpr::InList(indexlake::expr::InList {
112                expr,
113                list,
114                negated: in_list.negated,
115            }))
116        }
117        Expr::Like(like) => {
118            let expr = Box::new(datafusion_expr_to_indexlake_expr(&like.expr, schema)?);
119            let pattern = Box::new(datafusion_expr_to_indexlake_expr(&like.pattern, schema)?);
120            Ok(ILExpr::Like(indexlake::expr::Like {
121                expr,
122                pattern,
123                negated: like.negated,
124                case_insensitive: like.case_insensitive,
125            }))
126        }
127        Expr::Cast(cast) => {
128            let expr = Box::new(datafusion_expr_to_indexlake_expr(&cast.expr, schema)?);
129            Ok(ILExpr::Cast(indexlake::expr::Cast {
130                expr,
131                cast_type: cast.data_type.clone(),
132            }))
133        }
134        Expr::TryCast(try_cast) => {
135            let expr = Box::new(datafusion_expr_to_indexlake_expr(&try_cast.expr, schema)?);
136            Ok(ILExpr::TryCast(indexlake::expr::TryCast {
137                expr,
138                cast_type: try_cast.data_type.clone(),
139            }))
140        }
141        Expr::Negative(expr) => {
142            let expr = Box::new(datafusion_expr_to_indexlake_expr(expr, schema)?);
143            Ok(ILExpr::Negative(expr))
144        }
145        Expr::Case(case) => {
146            let when_then = match &case.expr {
147                Some(expr) => {
148                    let expr = datafusion_expr_to_indexlake_expr(expr, schema)?;
149                    case.when_then_expr
150                        .iter()
151                        .map(|(when, then)| {
152                            let when = expr
153                                .clone()
154                                .eq(datafusion_expr_to_indexlake_expr(when, schema)?);
155                            let then = datafusion_expr_to_indexlake_expr(then, schema)?;
156                            Ok::<_, DataFusionError>((Box::new(when), Box::new(then)))
157                        })
158                        .collect::<Result<Vec<_>, _>>()?
159                }
160                None => case
161                    .when_then_expr
162                    .iter()
163                    .map(|(when, then)| {
164                        Ok::<_, DataFusionError>((
165                            Box::new(datafusion_expr_to_indexlake_expr(when, schema)?),
166                            Box::new(datafusion_expr_to_indexlake_expr(then, schema)?),
167                        ))
168                    })
169                    .collect::<Result<Vec<_>, _>>()?,
170            };
171            let else_expr = match &case.else_expr {
172                Some(expr) => Some(Box::new(datafusion_expr_to_indexlake_expr(expr, schema)?)),
173                None => None,
174            };
175            Ok(ILExpr::Case(indexlake::expr::Case {
176                when_then,
177                else_expr,
178            }))
179        }
180        Expr::ScalarFunction(func) => {
181            let mut arg_types = Vec::with_capacity(func.args.len());
182            for arg in func.args.iter() {
183                let (data_type, _) = arg.data_type_and_nullable(schema)?;
184                arg_types.push(data_type);
185            }
186            let args = func
187                .args
188                .iter()
189                .map(|arg| datafusion_expr_to_indexlake_expr(arg, schema))
190                .collect::<Result<Vec<_>, _>>()?;
191            let name = func.name().to_string();
192            let return_type = func.func.return_type(&arg_types)?;
193            Ok(ILExpr::Function(indexlake::expr::Function {
194                name,
195                args,
196                return_type,
197            }))
198        }
199        _ => Err(DataFusionError::NotImplemented(format!(
200            "Unsupported expr: {expr}"
201        ))),
202    }
203}
204
205pub fn datafusion_scalar_to_indexlake_scalar(
206    scalar: &ScalarValue,
207) -> Result<ILScalar, DataFusionError> {
208    match scalar {
209        ScalarValue::Boolean(v) => Ok(ILScalar::Boolean(*v)),
210        ScalarValue::Int8(v) => Ok(ILScalar::Int8(*v)),
211        ScalarValue::Int16(v) => Ok(ILScalar::Int16(*v)),
212        ScalarValue::Int32(v) => Ok(ILScalar::Int32(*v)),
213        ScalarValue::Int64(v) => Ok(ILScalar::Int64(*v)),
214        ScalarValue::UInt8(v) => Ok(ILScalar::UInt8(*v)),
215        ScalarValue::UInt16(v) => Ok(ILScalar::UInt16(*v)),
216        ScalarValue::UInt32(v) => Ok(ILScalar::UInt32(*v)),
217        ScalarValue::UInt64(v) => Ok(ILScalar::UInt64(*v)),
218        ScalarValue::Float32(v) => Ok(ILScalar::Float32(*v)),
219        ScalarValue::Float64(v) => Ok(ILScalar::Float64(*v)),
220        ScalarValue::Utf8(v) => Ok(ILScalar::Utf8(v.clone())),
221        ScalarValue::Utf8View(v) => Ok(ILScalar::Utf8View(v.clone())),
222        ScalarValue::LargeUtf8(v) => Ok(ILScalar::LargeUtf8(v.clone())),
223        ScalarValue::Binary(v) => Ok(ILScalar::Binary(v.clone())),
224        ScalarValue::BinaryView(v) => Ok(ILScalar::BinaryView(v.clone())),
225        ScalarValue::FixedSizeBinary(s, v) => Ok(ILScalar::FixedSizeBinary(*s, v.clone())),
226        ScalarValue::LargeBinary(v) => Ok(ILScalar::LargeBinary(v.clone())),
227        ScalarValue::TimestampSecond(v, tz) => Ok(ILScalar::TimestampSecond(*v, tz.clone())),
228        ScalarValue::TimestampMillisecond(v, tz) => {
229            Ok(ILScalar::TimestampMillisecond(*v, tz.clone()))
230        }
231        ScalarValue::TimestampMicrosecond(v, tz) => {
232            Ok(ILScalar::TimestampMicrosecond(*v, tz.clone()))
233        }
234        ScalarValue::TimestampNanosecond(v, tz) => {
235            Ok(ILScalar::TimestampNanosecond(*v, tz.clone()))
236        }
237        ScalarValue::Date32(v) => Ok(ILScalar::Date32(*v)),
238        ScalarValue::Date64(v) => Ok(ILScalar::Date64(*v)),
239        ScalarValue::Time32Second(v) => Ok(ILScalar::Time32Second(*v)),
240        ScalarValue::Time32Millisecond(v) => Ok(ILScalar::Time32Millisecond(*v)),
241        ScalarValue::Time64Microsecond(v) => Ok(ILScalar::Time64Microsecond(*v)),
242        ScalarValue::Time64Nanosecond(v) => Ok(ILScalar::Time64Nanosecond(*v)),
243        ScalarValue::List(v) => Ok(ILScalar::List(v.clone())),
244        ScalarValue::FixedSizeList(v) => Ok(ILScalar::FixedSizeList(v.clone())),
245        ScalarValue::LargeList(v) => Ok(ILScalar::LargeList(v.clone())),
246        ScalarValue::Decimal128(v, p, s) => Ok(ILScalar::Decimal128(*v, *p, *s)),
247        ScalarValue::Decimal256(v, p, s) => Ok(ILScalar::Decimal256(*v, *p, *s)),
248        _ => Err(DataFusionError::NotImplemented(format!(
249            "Unsupported scalar: {scalar}"
250        ))),
251    }
252}
253
254pub fn indexlake_scalar_to_datafusion_scalar(
255    scalar: &ILScalar,
256) -> Result<ScalarValue, DataFusionError> {
257    match scalar {
258        ILScalar::Boolean(v) => Ok(ScalarValue::Boolean(*v)),
259        ILScalar::Int8(v) => Ok(ScalarValue::Int8(*v)),
260        ILScalar::Int16(v) => Ok(ScalarValue::Int16(*v)),
261        ILScalar::Int32(v) => Ok(ScalarValue::Int32(*v)),
262        ILScalar::Int64(v) => Ok(ScalarValue::Int64(*v)),
263        ILScalar::UInt8(v) => Ok(ScalarValue::UInt8(*v)),
264        ILScalar::UInt16(v) => Ok(ScalarValue::UInt16(*v)),
265        ILScalar::UInt32(v) => Ok(ScalarValue::UInt32(*v)),
266        ILScalar::UInt64(v) => Ok(ScalarValue::UInt64(*v)),
267        ILScalar::Float32(v) => Ok(ScalarValue::Float32(*v)),
268        ILScalar::Float64(v) => Ok(ScalarValue::Float64(*v)),
269        ILScalar::Utf8(v) => Ok(ScalarValue::Utf8(v.clone())),
270        ILScalar::Utf8View(v) => Ok(ScalarValue::Utf8View(v.clone())),
271        ILScalar::LargeUtf8(v) => Ok(ScalarValue::LargeUtf8(v.clone())),
272        ILScalar::Binary(v) => Ok(ScalarValue::Binary(v.clone())),
273        ILScalar::BinaryView(v) => Ok(ScalarValue::BinaryView(v.clone())),
274        ILScalar::FixedSizeBinary(s, v) => Ok(ScalarValue::FixedSizeBinary(*s, v.clone())),
275        ILScalar::LargeBinary(v) => Ok(ScalarValue::LargeBinary(v.clone())),
276        ILScalar::TimestampSecond(v, tz) => Ok(ScalarValue::TimestampSecond(*v, tz.clone())),
277        ILScalar::TimestampMillisecond(v, tz) => {
278            Ok(ScalarValue::TimestampMillisecond(*v, tz.clone()))
279        }
280        ILScalar::TimestampMicrosecond(v, tz) => {
281            Ok(ScalarValue::TimestampMicrosecond(*v, tz.clone()))
282        }
283        ILScalar::TimestampNanosecond(v, tz) => {
284            Ok(ScalarValue::TimestampNanosecond(*v, tz.clone()))
285        }
286        ILScalar::Date32(v) => Ok(ScalarValue::Date32(*v)),
287        ILScalar::Date64(v) => Ok(ScalarValue::Date64(*v)),
288        ILScalar::Time32Second(v) => Ok(ScalarValue::Time32Second(*v)),
289        ILScalar::Time32Millisecond(v) => Ok(ScalarValue::Time32Millisecond(*v)),
290        ILScalar::Time64Microsecond(v) => Ok(ScalarValue::Time64Microsecond(*v)),
291        ILScalar::Time64Nanosecond(v) => Ok(ScalarValue::Time64Nanosecond(*v)),
292        ILScalar::List(v) => Ok(ScalarValue::List(v.clone())),
293        ILScalar::FixedSizeList(v) => Ok(ScalarValue::FixedSizeList(v.clone())),
294        ILScalar::LargeList(v) => Ok(ScalarValue::LargeList(v.clone())),
295        ILScalar::Decimal128(v, p, s) => Ok(ScalarValue::Decimal128(*v, *p, *s)),
296        ILScalar::Decimal256(v, p, s) => Ok(ScalarValue::Decimal256(*v, *p, *s)),
297    }
298}
299
300pub fn datafusion_operator_to_indexlake_operator(
301    operator: &Operator,
302) -> Result<ILOperator, DataFusionError> {
303    match operator {
304        Operator::Eq => Ok(ILOperator::Eq),
305        Operator::NotEq => Ok(ILOperator::NotEq),
306        Operator::Lt => Ok(ILOperator::Lt),
307        Operator::LtEq => Ok(ILOperator::LtEq),
308        Operator::Gt => Ok(ILOperator::Gt),
309        Operator::GtEq => Ok(ILOperator::GtEq),
310        Operator::Plus => Ok(ILOperator::Plus),
311        Operator::Minus => Ok(ILOperator::Minus),
312        Operator::Multiply => Ok(ILOperator::Multiply),
313        Operator::Divide => Ok(ILOperator::Divide),
314        Operator::Modulo => Ok(ILOperator::Modulo),
315        Operator::And => Ok(ILOperator::And),
316        Operator::Or => Ok(ILOperator::Or),
317        Operator::IsDistinctFrom => Ok(ILOperator::IsDistinctFrom),
318        Operator::IsNotDistinctFrom => Ok(ILOperator::IsNotDistinctFrom),
319        _ => Err(DataFusionError::NotImplemented(format!(
320            "Unsupported operator: {operator}",
321        ))),
322    }
323}
324
325pub fn parse_expr(expr: &str, schema: SchemaRef) -> Result<ILExpr, DataFusionError> {
326    let ctx = SessionContext::new();
327    let df_schema = DFSchema::try_from(schema)?;
328    let expr = ctx.parse_sql_expr(expr, &df_schema)?;
329
330    let mut rewriter = TypeCoercionRewriter::new(&df_schema);
331    let coerced_expr = expr.rewrite(&mut rewriter)?;
332
333    datafusion_expr_to_indexlake_expr(&coerced_expr.data, &df_schema)
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use datafusion::arrow::datatypes::{DataType, Field, Schema};
340    use std::sync::Arc;
341
342    #[test]
343    fn test_parse_expr() {
344        let schema = Arc::new(Schema::new(vec![
345            Field::new("id", DataType::Int32, false),
346            Field::new("name", DataType::Utf8, false),
347        ]));
348        let expr = parse_expr("id + 1", schema).unwrap();
349        assert_eq!(
350            expr,
351            ILExpr::BinaryExpr(indexlake::expr::BinaryExpr {
352                left: Box::new(ILExpr::Cast(indexlake::expr::Cast {
353                    expr: Box::new(indexlake::expr::col("id")),
354                    cast_type: DataType::Int64,
355                })),
356                right: Box::new(indexlake::expr::lit(1i64)),
357                op: ILOperator::Plus,
358            })
359        );
360    }
361}