Skip to main content

datafusion_substrait/logical_plan/consumer/expr/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod aggregate_function;
19mod cast;
20mod field_reference;
21mod function_arguments;
22mod if_then;
23mod literal;
24mod nested;
25mod scalar_function;
26mod singular_or_list;
27mod subquery;
28mod window_function;
29
30pub use aggregate_function::*;
31pub use cast::*;
32pub use field_reference::*;
33pub use function_arguments::*;
34pub use if_then::*;
35pub use literal::*;
36pub use nested::*;
37pub use scalar_function::*;
38pub use singular_or_list::*;
39pub use subquery::*;
40pub use window_function::*;
41
42use crate::extensions::Extensions;
43use crate::logical_plan::consumer::{
44    DefaultSubstraitConsumer, SubstraitConsumer, from_substrait_named_struct,
45    rename_field,
46};
47use datafusion::arrow::datatypes::Field;
48use datafusion::common::{DFSchema, DFSchemaRef, not_impl_err, plan_err, substrait_err};
49use datafusion::execution::SessionState;
50use datafusion::logical_expr::{Expr, ExprSchemable};
51use substrait::proto::expression::RexType;
52use substrait::proto::expression_reference::ExprType;
53use substrait::proto::{Expression, ExtendedExpression};
54
55/// Convert Substrait Rex to DataFusion Expr
56pub async fn from_substrait_rex(
57    consumer: &impl SubstraitConsumer,
58    expression: &Expression,
59    input_schema: &DFSchema,
60) -> datafusion::common::Result<Expr> {
61    match &expression.rex_type {
62        Some(t) => match t {
63            RexType::Literal(expr) => consumer.consume_literal(expr).await,
64            RexType::Selection(expr) => {
65                consumer.consume_field_reference(expr, input_schema).await
66            }
67            RexType::ScalarFunction(expr) => {
68                consumer.consume_scalar_function(expr, input_schema).await
69            }
70            RexType::WindowFunction(expr) => {
71                consumer.consume_window_function(expr, input_schema).await
72            }
73            RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await,
74            RexType::SwitchExpression(expr) => {
75                consumer.consume_switch(expr, input_schema).await
76            }
77            RexType::SingularOrList(expr) => {
78                consumer.consume_singular_or_list(expr, input_schema).await
79            }
80
81            RexType::MultiOrList(expr) => {
82                consumer.consume_multi_or_list(expr, input_schema).await
83            }
84
85            RexType::Cast(expr) => {
86                consumer.consume_cast(expr.as_ref(), input_schema).await
87            }
88
89            RexType::Subquery(expr) => {
90                consumer.consume_subquery(expr.as_ref(), input_schema).await
91            }
92            RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await,
93            #[expect(deprecated)]
94            RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await,
95            RexType::DynamicParameter(expr) => {
96                consumer.consume_dynamic_parameter(expr, input_schema).await
97            }
98            RexType::Lambda(_) | RexType::LambdaInvocation(_) => {
99                not_impl_err!("Lambda expressions are not yet supported")
100            }
101        },
102        None => substrait_err!("Expression must set rex_type: {expression:?}"),
103    }
104}
105
106/// Convert Substrait ExtendedExpression to ExprContainer
107///
108/// A Substrait ExtendedExpression message contains one or more expressions,
109/// with names for the outputs, and an input schema.  These pieces are all included
110/// in the ExprContainer.
111///
112/// This is a top-level message and can be used to send expressions (not plans)
113/// between systems.  This is often useful for scenarios like pushdown where filter
114/// expressions need to be sent to remote systems.
115pub async fn from_substrait_extended_expr(
116    state: &SessionState,
117    extended_expr: &ExtendedExpression,
118) -> datafusion::common::Result<ExprContainer> {
119    // Register function extension
120    let extensions = Extensions::try_from(&extended_expr.extensions)?;
121    if !extensions.type_variations.is_empty() {
122        return not_impl_err!("Type variation extensions are not supported");
123    }
124
125    let consumer = DefaultSubstraitConsumer::new(&extensions, state);
126
127    let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
128        Some(base_schema) => from_substrait_named_struct(&consumer, base_schema),
129        None => {
130            plan_err!(
131                "required property `base_schema` missing from Substrait ExtendedExpression message"
132            )
133        }
134    }?);
135
136    // Parse expressions
137    let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len());
138    for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() {
139        let scalar_expr = match &substrait_expr.expr_type {
140            Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr),
141            Some(ExprType::Measure(_)) => {
142                not_impl_err!("Measure expressions are not yet supported")
143            }
144            None => {
145                plan_err!(
146                    "required property `expr_type` missing from Substrait ExpressionReference message"
147                )
148            }
149        }?;
150        let expr = consumer
151            .consume_expression(scalar_expr, &input_schema)
152            .await?;
153        let output_field = expr.to_field(&input_schema)?.1;
154        let mut names_idx = 0;
155        let output_field = rename_field(
156            &output_field,
157            &substrait_expr.output_names,
158            expr_idx,
159            &mut names_idx,
160        )?;
161        exprs.push((expr, output_field));
162    }
163
164    Ok(ExprContainer {
165        input_schema,
166        exprs,
167    })
168}
169
170/// An ExprContainer is a container for a collection of expressions with a common input schema
171///
172/// In addition, each expression is associated with a field, which defines the
173/// expression's output.  The data type and nullability of the field are calculated from the
174/// expression and the input schema.  However the names of the field (and its nested fields) are
175/// derived from the Substrait message.
176pub struct ExprContainer {
177    /// The input schema for the expressions
178    pub input_schema: DFSchemaRef,
179    /// The expressions
180    ///
181    /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output
182    pub exprs: Vec<(Expr, Field)>,
183}
184
185/// Convert Substrait Expressions to DataFusion Exprs
186pub async fn from_substrait_rex_vec(
187    consumer: &impl SubstraitConsumer,
188    exprs: &Vec<Expression>,
189    input_schema: &DFSchema,
190) -> datafusion::common::Result<Vec<Expr>> {
191    let mut expressions: Vec<Expr> = vec![];
192    for expr in exprs {
193        let expression = consumer.consume_expression(expr, input_schema).await?;
194        expressions.push(expression);
195    }
196    Ok(expressions)
197}
198
199#[cfg(test)]
200mod tests {
201    use crate::extensions::Extensions;
202    use crate::logical_plan::consumer::utils::tests::test_consumer;
203    use crate::logical_plan::consumer::*;
204    use datafusion::common::DFSchema;
205    use datafusion::logical_expr::Expr;
206    use substrait::proto::Expression;
207    use substrait::proto::expression::RexType;
208    use substrait::proto::expression::window_function::BoundsType;
209
210    #[tokio::test]
211    async fn window_function_with_range_unit_and_no_order_by()
212    -> datafusion::common::Result<()> {
213        let substrait = Expression {
214            rex_type: Some(RexType::WindowFunction(
215                substrait::proto::expression::WindowFunction {
216                    function_reference: 0,
217                    bounds_type: BoundsType::Range as i32,
218                    sorts: vec![],
219                    ..Default::default()
220                },
221            )),
222        };
223
224        let mut consumer = test_consumer();
225
226        // Just registering a single function (index 0) so that the plan
227        // does not throw a "function not found" error.
228        let mut extensions = Extensions::default();
229        extensions.register_function("count");
230        consumer.extensions = &extensions;
231
232        match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
233            Expr::WindowFunction(window_function) => {
234                assert_eq!(window_function.params.order_by.len(), 1)
235            }
236            _ => panic!("expr was not a WindowFunction"),
237        };
238
239        Ok(())
240    }
241
242    #[tokio::test]
243    async fn window_function_with_count() -> datafusion::common::Result<()> {
244        let substrait = Expression {
245            rex_type: Some(RexType::WindowFunction(
246                substrait::proto::expression::WindowFunction {
247                    function_reference: 0,
248                    ..Default::default()
249                },
250            )),
251        };
252
253        let mut consumer = test_consumer();
254
255        let mut extensions = Extensions::default();
256        extensions.register_function("count");
257        consumer.extensions = &extensions;
258
259        match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
260            Expr::WindowFunction(window_function) => {
261                assert_eq!(window_function.params.args.len(), 1)
262            }
263            _ => panic!("expr was not a WindowFunction"),
264        };
265
266        Ok(())
267    }
268}