datafusion_substrait/logical_plan/consumer/expr/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod aggregate_function;
19mod cast;
20mod field_reference;
21mod function_arguments;
22mod if_then;
23mod literal;
24mod scalar_function;
25mod singular_or_list;
26mod subquery;
27mod window_function;
28
29pub use aggregate_function::*;
30pub use cast::*;
31pub use field_reference::*;
32pub use function_arguments::*;
33pub use if_then::*;
34pub use literal::*;
35pub use scalar_function::*;
36pub use singular_or_list::*;
37pub use subquery::*;
38pub use window_function::*;
39
40use crate::extensions::Extensions;
41use crate::logical_plan::consumer::{
42    from_substrait_named_struct, rename_field, DefaultSubstraitConsumer,
43    SubstraitConsumer,
44};
45use datafusion::arrow::datatypes::Field;
46use datafusion::common::{not_impl_err, plan_err, substrait_err, DFSchema, DFSchemaRef};
47use datafusion::execution::SessionState;
48use datafusion::logical_expr::{Expr, ExprSchemable};
49use substrait::proto::expression::RexType;
50use substrait::proto::expression_reference::ExprType;
51use substrait::proto::{Expression, ExtendedExpression};
52
53/// Convert Substrait Rex to DataFusion Expr
54pub async fn from_substrait_rex(
55    consumer: &impl SubstraitConsumer,
56    expression: &Expression,
57    input_schema: &DFSchema,
58) -> datafusion::common::Result<Expr> {
59    match &expression.rex_type {
60        Some(t) => match t {
61            RexType::Literal(expr) => consumer.consume_literal(expr).await,
62            RexType::Selection(expr) => {
63                consumer.consume_field_reference(expr, input_schema).await
64            }
65            RexType::ScalarFunction(expr) => {
66                consumer.consume_scalar_function(expr, input_schema).await
67            }
68            RexType::WindowFunction(expr) => {
69                consumer.consume_window_function(expr, input_schema).await
70            }
71            RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await,
72            RexType::SwitchExpression(expr) => {
73                consumer.consume_switch(expr, input_schema).await
74            }
75            RexType::SingularOrList(expr) => {
76                consumer.consume_singular_or_list(expr, input_schema).await
77            }
78
79            RexType::MultiOrList(expr) => {
80                consumer.consume_multi_or_list(expr, input_schema).await
81            }
82
83            RexType::Cast(expr) => {
84                consumer.consume_cast(expr.as_ref(), input_schema).await
85            }
86
87            RexType::Subquery(expr) => {
88                consumer.consume_subquery(expr.as_ref(), input_schema).await
89            }
90            RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await,
91            RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await,
92            RexType::DynamicParameter(expr) => {
93                consumer.consume_dynamic_parameter(expr, input_schema).await
94            }
95        },
96        None => substrait_err!("Expression must set rex_type: {expression:?}"),
97    }
98}
99
100/// Convert Substrait ExtendedExpression to ExprContainer
101///
102/// A Substrait ExtendedExpression message contains one or more expressions,
103/// with names for the outputs, and an input schema.  These pieces are all included
104/// in the ExprContainer.
105///
106/// This is a top-level message and can be used to send expressions (not plans)
107/// between systems.  This is often useful for scenarios like pushdown where filter
108/// expressions need to be sent to remote systems.
109pub async fn from_substrait_extended_expr(
110    state: &SessionState,
111    extended_expr: &ExtendedExpression,
112) -> datafusion::common::Result<ExprContainer> {
113    // Register function extension
114    let extensions = Extensions::try_from(&extended_expr.extensions)?;
115    if !extensions.type_variations.is_empty() {
116        return not_impl_err!("Type variation extensions are not supported");
117    }
118
119    let consumer = DefaultSubstraitConsumer {
120        extensions: &extensions,
121        state,
122    };
123
124    let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
125        Some(base_schema) => from_substrait_named_struct(&consumer, base_schema),
126        None => {
127            plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message")
128        }
129    }?);
130
131    // Parse expressions
132    let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len());
133    for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() {
134        let scalar_expr = match &substrait_expr.expr_type {
135            Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr),
136            Some(ExprType::Measure(_)) => {
137                not_impl_err!("Measure expressions are not yet supported")
138            }
139            None => {
140                plan_err!("required property `expr_type` missing from Substrait ExpressionReference message")
141            }
142        }?;
143        let expr = consumer
144            .consume_expression(scalar_expr, &input_schema)
145            .await?;
146        let (output_type, expected_nullability) =
147            expr.data_type_and_nullable(&input_schema)?;
148        let output_field = Field::new("", output_type, expected_nullability);
149        let mut names_idx = 0;
150        let output_field = rename_field(
151            &output_field,
152            &substrait_expr.output_names,
153            expr_idx,
154            &mut names_idx,
155        )?;
156        exprs.push((expr, output_field));
157    }
158
159    Ok(ExprContainer {
160        input_schema,
161        exprs,
162    })
163}
164
165/// An ExprContainer is a container for a collection of expressions with a common input schema
166///
167/// In addition, each expression is associated with a field, which defines the
168/// expression's output.  The data type and nullability of the field are calculated from the
169/// expression and the input schema.  However the names of the field (and its nested fields) are
170/// derived from the Substrait message.
171pub struct ExprContainer {
172    /// The input schema for the expressions
173    pub input_schema: DFSchemaRef,
174    /// The expressions
175    ///
176    /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output
177    pub exprs: Vec<(Expr, Field)>,
178}
179
180/// Convert Substrait Expressions to DataFusion Exprs
181pub async fn from_substrait_rex_vec(
182    consumer: &impl SubstraitConsumer,
183    exprs: &Vec<Expression>,
184    input_schema: &DFSchema,
185) -> datafusion::common::Result<Vec<Expr>> {
186    let mut expressions: Vec<Expr> = vec![];
187    for expr in exprs {
188        let expression = consumer.consume_expression(expr, input_schema).await?;
189        expressions.push(expression);
190    }
191    Ok(expressions)
192}
193
194#[cfg(test)]
195mod tests {
196    use crate::extensions::Extensions;
197    use crate::logical_plan::consumer::utils::tests::test_consumer;
198    use crate::logical_plan::consumer::*;
199    use datafusion::common::DFSchema;
200    use datafusion::logical_expr::Expr;
201    use substrait::proto::expression::window_function::BoundsType;
202    use substrait::proto::expression::RexType;
203    use substrait::proto::Expression;
204
205    #[tokio::test]
206    async fn window_function_with_range_unit_and_no_order_by(
207    ) -> datafusion::common::Result<()> {
208        let substrait = Expression {
209            rex_type: Some(RexType::WindowFunction(
210                substrait::proto::expression::WindowFunction {
211                    function_reference: 0,
212                    bounds_type: BoundsType::Range as i32,
213                    sorts: vec![],
214                    ..Default::default()
215                },
216            )),
217        };
218
219        let mut consumer = test_consumer();
220
221        // Just registering a single function (index 0) so that the plan
222        // does not throw a "function not found" error.
223        let mut extensions = Extensions::default();
224        extensions.register_function("count".to_string());
225        consumer.extensions = &extensions;
226
227        match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
228            Expr::WindowFunction(window_function) => {
229                assert_eq!(window_function.params.order_by.len(), 1)
230            }
231            _ => panic!("expr was not a WindowFunction"),
232        };
233
234        Ok(())
235    }
236
237    #[tokio::test]
238    async fn window_function_with_count() -> datafusion::common::Result<()> {
239        let substrait = Expression {
240            rex_type: Some(RexType::WindowFunction(
241                substrait::proto::expression::WindowFunction {
242                    function_reference: 0,
243                    ..Default::default()
244                },
245            )),
246        };
247
248        let mut consumer = test_consumer();
249
250        let mut extensions = Extensions::default();
251        extensions.register_function("count".to_string());
252        consumer.extensions = &extensions;
253
254        match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
255            Expr::WindowFunction(window_function) => {
256                assert_eq!(window_function.params.args.len(), 1)
257            }
258            _ => panic!("expr was not a WindowFunction"),
259        };
260
261        Ok(())
262    }
263}