Skip to main content

datafusion_substrait/logical_plan/consumer/expr/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod aggregate_function;
19mod cast;
20mod field_reference;
21mod function_arguments;
22mod if_then;
23mod literal;
24mod scalar_function;
25mod singular_or_list;
26mod subquery;
27mod window_function;
28
29pub use aggregate_function::*;
30pub use cast::*;
31pub use field_reference::*;
32pub use function_arguments::*;
33pub use if_then::*;
34pub use literal::*;
35pub use scalar_function::*;
36pub use singular_or_list::*;
37pub use subquery::*;
38pub use window_function::*;
39
40use crate::extensions::Extensions;
41use crate::logical_plan::consumer::{
42    DefaultSubstraitConsumer, SubstraitConsumer, from_substrait_named_struct,
43    rename_field,
44};
45use datafusion::arrow::datatypes::Field;
46use datafusion::common::{DFSchema, DFSchemaRef, not_impl_err, plan_err, substrait_err};
47use datafusion::execution::SessionState;
48use datafusion::logical_expr::{Expr, ExprSchemable};
49use substrait::proto::expression::RexType;
50use substrait::proto::expression_reference::ExprType;
51use substrait::proto::{Expression, ExtendedExpression};
52
53/// Convert Substrait Rex to DataFusion Expr
54pub async fn from_substrait_rex(
55    consumer: &impl SubstraitConsumer,
56    expression: &Expression,
57    input_schema: &DFSchema,
58) -> datafusion::common::Result<Expr> {
59    match &expression.rex_type {
60        Some(t) => match t {
61            RexType::Literal(expr) => consumer.consume_literal(expr).await,
62            RexType::Selection(expr) => {
63                consumer.consume_field_reference(expr, input_schema).await
64            }
65            RexType::ScalarFunction(expr) => {
66                consumer.consume_scalar_function(expr, input_schema).await
67            }
68            RexType::WindowFunction(expr) => {
69                consumer.consume_window_function(expr, input_schema).await
70            }
71            RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await,
72            RexType::SwitchExpression(expr) => {
73                consumer.consume_switch(expr, input_schema).await
74            }
75            RexType::SingularOrList(expr) => {
76                consumer.consume_singular_or_list(expr, input_schema).await
77            }
78
79            RexType::MultiOrList(expr) => {
80                consumer.consume_multi_or_list(expr, input_schema).await
81            }
82
83            RexType::Cast(expr) => {
84                consumer.consume_cast(expr.as_ref(), input_schema).await
85            }
86
87            RexType::Subquery(expr) => {
88                consumer.consume_subquery(expr.as_ref(), input_schema).await
89            }
90            RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await,
91            RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await,
92            RexType::DynamicParameter(expr) => {
93                consumer.consume_dynamic_parameter(expr, input_schema).await
94            }
95        },
96        None => substrait_err!("Expression must set rex_type: {expression:?}"),
97    }
98}
99
100/// Convert Substrait ExtendedExpression to ExprContainer
101///
102/// A Substrait ExtendedExpression message contains one or more expressions,
103/// with names for the outputs, and an input schema.  These pieces are all included
104/// in the ExprContainer.
105///
106/// This is a top-level message and can be used to send expressions (not plans)
107/// between systems.  This is often useful for scenarios like pushdown where filter
108/// expressions need to be sent to remote systems.
109pub async fn from_substrait_extended_expr(
110    state: &SessionState,
111    extended_expr: &ExtendedExpression,
112) -> datafusion::common::Result<ExprContainer> {
113    // Register function extension
114    let extensions = Extensions::try_from(&extended_expr.extensions)?;
115    if !extensions.type_variations.is_empty() {
116        return not_impl_err!("Type variation extensions are not supported");
117    }
118
119    let consumer = DefaultSubstraitConsumer {
120        extensions: &extensions,
121        state,
122    };
123
124    let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
125        Some(base_schema) => from_substrait_named_struct(&consumer, base_schema),
126        None => {
127            plan_err!(
128                "required property `base_schema` missing from Substrait ExtendedExpression message"
129            )
130        }
131    }?);
132
133    // Parse expressions
134    let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len());
135    for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() {
136        let scalar_expr = match &substrait_expr.expr_type {
137            Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr),
138            Some(ExprType::Measure(_)) => {
139                not_impl_err!("Measure expressions are not yet supported")
140            }
141            None => {
142                plan_err!(
143                    "required property `expr_type` missing from Substrait ExpressionReference message"
144                )
145            }
146        }?;
147        let expr = consumer
148            .consume_expression(scalar_expr, &input_schema)
149            .await?;
150        let output_field = expr.to_field(&input_schema)?.1;
151        let mut names_idx = 0;
152        let output_field = rename_field(
153            &output_field,
154            &substrait_expr.output_names,
155            expr_idx,
156            &mut names_idx,
157        )?;
158        exprs.push((expr, output_field));
159    }
160
161    Ok(ExprContainer {
162        input_schema,
163        exprs,
164    })
165}
166
167/// An ExprContainer is a container for a collection of expressions with a common input schema
168///
169/// In addition, each expression is associated with a field, which defines the
170/// expression's output.  The data type and nullability of the field are calculated from the
171/// expression and the input schema.  However the names of the field (and its nested fields) are
172/// derived from the Substrait message.
173pub struct ExprContainer {
174    /// The input schema for the expressions
175    pub input_schema: DFSchemaRef,
176    /// The expressions
177    ///
178    /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output
179    pub exprs: Vec<(Expr, Field)>,
180}
181
182/// Convert Substrait Expressions to DataFusion Exprs
183pub async fn from_substrait_rex_vec(
184    consumer: &impl SubstraitConsumer,
185    exprs: &Vec<Expression>,
186    input_schema: &DFSchema,
187) -> datafusion::common::Result<Vec<Expr>> {
188    let mut expressions: Vec<Expr> = vec![];
189    for expr in exprs {
190        let expression = consumer.consume_expression(expr, input_schema).await?;
191        expressions.push(expression);
192    }
193    Ok(expressions)
194}
195
196#[cfg(test)]
197mod tests {
198    use crate::extensions::Extensions;
199    use crate::logical_plan::consumer::utils::tests::test_consumer;
200    use crate::logical_plan::consumer::*;
201    use datafusion::common::DFSchema;
202    use datafusion::logical_expr::Expr;
203    use substrait::proto::Expression;
204    use substrait::proto::expression::RexType;
205    use substrait::proto::expression::window_function::BoundsType;
206
207    #[tokio::test]
208    async fn window_function_with_range_unit_and_no_order_by()
209    -> datafusion::common::Result<()> {
210        let substrait = Expression {
211            rex_type: Some(RexType::WindowFunction(
212                substrait::proto::expression::WindowFunction {
213                    function_reference: 0,
214                    bounds_type: BoundsType::Range as i32,
215                    sorts: vec![],
216                    ..Default::default()
217                },
218            )),
219        };
220
221        let mut consumer = test_consumer();
222
223        // Just registering a single function (index 0) so that the plan
224        // does not throw a "function not found" error.
225        let mut extensions = Extensions::default();
226        extensions.register_function("count");
227        consumer.extensions = &extensions;
228
229        match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
230            Expr::WindowFunction(window_function) => {
231                assert_eq!(window_function.params.order_by.len(), 1)
232            }
233            _ => panic!("expr was not a WindowFunction"),
234        };
235
236        Ok(())
237    }
238
239    #[tokio::test]
240    async fn window_function_with_count() -> datafusion::common::Result<()> {
241        let substrait = Expression {
242            rex_type: Some(RexType::WindowFunction(
243                substrait::proto::expression::WindowFunction {
244                    function_reference: 0,
245                    ..Default::default()
246                },
247            )),
248        };
249
250        let mut consumer = test_consumer();
251
252        let mut extensions = Extensions::default();
253        extensions.register_function("count");
254        consumer.extensions = &extensions;
255
256        match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
257            Expr::WindowFunction(window_function) => {
258                assert_eq!(window_function.params.args.len(), 1)
259            }
260            _ => panic!("expr was not a WindowFunction"),
261        };
262
263        Ok(())
264    }
265}