Skip to main content

datafusion_substrait/logical_plan/consumer/expr/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod aggregate_function;
19mod cast;
20mod field_reference;
21mod function_arguments;
22mod if_then;
23mod literal;
24mod scalar_function;
25mod singular_or_list;
26mod subquery;
27mod window_function;
28
29pub use aggregate_function::*;
30pub use cast::*;
31pub use field_reference::*;
32pub use function_arguments::*;
33pub use if_then::*;
34pub use literal::*;
35pub use scalar_function::*;
36pub use singular_or_list::*;
37pub use subquery::*;
38pub use window_function::*;
39
40use crate::extensions::Extensions;
41use crate::logical_plan::consumer::{
42    DefaultSubstraitConsumer, SubstraitConsumer, from_substrait_named_struct,
43    rename_field,
44};
45use datafusion::arrow::datatypes::Field;
46use datafusion::common::{DFSchema, DFSchemaRef, not_impl_err, plan_err, substrait_err};
47use datafusion::execution::SessionState;
48use datafusion::logical_expr::{Expr, ExprSchemable};
49use substrait::proto::expression::RexType;
50use substrait::proto::expression_reference::ExprType;
51use substrait::proto::{Expression, ExtendedExpression};
52
53/// Convert Substrait Rex to DataFusion Expr
54pub async fn from_substrait_rex(
55    consumer: &impl SubstraitConsumer,
56    expression: &Expression,
57    input_schema: &DFSchema,
58) -> datafusion::common::Result<Expr> {
59    match &expression.rex_type {
60        Some(t) => match t {
61            RexType::Literal(expr) => consumer.consume_literal(expr).await,
62            RexType::Selection(expr) => {
63                consumer.consume_field_reference(expr, input_schema).await
64            }
65            RexType::ScalarFunction(expr) => {
66                consumer.consume_scalar_function(expr, input_schema).await
67            }
68            RexType::WindowFunction(expr) => {
69                consumer.consume_window_function(expr, input_schema).await
70            }
71            RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await,
72            RexType::SwitchExpression(expr) => {
73                consumer.consume_switch(expr, input_schema).await
74            }
75            RexType::SingularOrList(expr) => {
76                consumer.consume_singular_or_list(expr, input_schema).await
77            }
78
79            RexType::MultiOrList(expr) => {
80                consumer.consume_multi_or_list(expr, input_schema).await
81            }
82
83            RexType::Cast(expr) => {
84                consumer.consume_cast(expr.as_ref(), input_schema).await
85            }
86
87            RexType::Subquery(expr) => {
88                consumer.consume_subquery(expr.as_ref(), input_schema).await
89            }
90            RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await,
91            #[expect(deprecated)]
92            RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await,
93            RexType::DynamicParameter(expr) => {
94                consumer.consume_dynamic_parameter(expr, input_schema).await
95            }
96        },
97        None => substrait_err!("Expression must set rex_type: {expression:?}"),
98    }
99}
100
101/// Convert Substrait ExtendedExpression to ExprContainer
102///
103/// A Substrait ExtendedExpression message contains one or more expressions,
104/// with names for the outputs, and an input schema.  These pieces are all included
105/// in the ExprContainer.
106///
107/// This is a top-level message and can be used to send expressions (not plans)
108/// between systems.  This is often useful for scenarios like pushdown where filter
109/// expressions need to be sent to remote systems.
110pub async fn from_substrait_extended_expr(
111    state: &SessionState,
112    extended_expr: &ExtendedExpression,
113) -> datafusion::common::Result<ExprContainer> {
114    // Register function extension
115    let extensions = Extensions::try_from(&extended_expr.extensions)?;
116    if !extensions.type_variations.is_empty() {
117        return not_impl_err!("Type variation extensions are not supported");
118    }
119
120    let consumer = DefaultSubstraitConsumer::new(&extensions, state);
121
122    let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
123        Some(base_schema) => from_substrait_named_struct(&consumer, base_schema),
124        None => {
125            plan_err!(
126                "required property `base_schema` missing from Substrait ExtendedExpression message"
127            )
128        }
129    }?);
130
131    // Parse expressions
132    let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len());
133    for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() {
134        let scalar_expr = match &substrait_expr.expr_type {
135            Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr),
136            Some(ExprType::Measure(_)) => {
137                not_impl_err!("Measure expressions are not yet supported")
138            }
139            None => {
140                plan_err!(
141                    "required property `expr_type` missing from Substrait ExpressionReference message"
142                )
143            }
144        }?;
145        let expr = consumer
146            .consume_expression(scalar_expr, &input_schema)
147            .await?;
148        let output_field = expr.to_field(&input_schema)?.1;
149        let mut names_idx = 0;
150        let output_field = rename_field(
151            &output_field,
152            &substrait_expr.output_names,
153            expr_idx,
154            &mut names_idx,
155        )?;
156        exprs.push((expr, output_field));
157    }
158
159    Ok(ExprContainer {
160        input_schema,
161        exprs,
162    })
163}
164
165/// An ExprContainer is a container for a collection of expressions with a common input schema
166///
167/// In addition, each expression is associated with a field, which defines the
168/// expression's output.  The data type and nullability of the field are calculated from the
169/// expression and the input schema.  However the names of the field (and its nested fields) are
170/// derived from the Substrait message.
171pub struct ExprContainer {
172    /// The input schema for the expressions
173    pub input_schema: DFSchemaRef,
174    /// The expressions
175    ///
176    /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output
177    pub exprs: Vec<(Expr, Field)>,
178}
179
180/// Convert Substrait Expressions to DataFusion Exprs
181pub async fn from_substrait_rex_vec(
182    consumer: &impl SubstraitConsumer,
183    exprs: &Vec<Expression>,
184    input_schema: &DFSchema,
185) -> datafusion::common::Result<Vec<Expr>> {
186    let mut expressions: Vec<Expr> = vec![];
187    for expr in exprs {
188        let expression = consumer.consume_expression(expr, input_schema).await?;
189        expressions.push(expression);
190    }
191    Ok(expressions)
192}
193
194#[cfg(test)]
195mod tests {
196    use crate::extensions::Extensions;
197    use crate::logical_plan::consumer::utils::tests::test_consumer;
198    use crate::logical_plan::consumer::*;
199    use datafusion::common::DFSchema;
200    use datafusion::logical_expr::Expr;
201    use substrait::proto::Expression;
202    use substrait::proto::expression::RexType;
203    use substrait::proto::expression::window_function::BoundsType;
204
205    #[tokio::test]
206    async fn window_function_with_range_unit_and_no_order_by()
207    -> datafusion::common::Result<()> {
208        let substrait = Expression {
209            rex_type: Some(RexType::WindowFunction(
210                substrait::proto::expression::WindowFunction {
211                    function_reference: 0,
212                    bounds_type: BoundsType::Range as i32,
213                    sorts: vec![],
214                    ..Default::default()
215                },
216            )),
217        };
218
219        let mut consumer = test_consumer();
220
221        // Just registering a single function (index 0) so that the plan
222        // does not throw a "function not found" error.
223        let mut extensions = Extensions::default();
224        extensions.register_function("count");
225        consumer.extensions = &extensions;
226
227        match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
228            Expr::WindowFunction(window_function) => {
229                assert_eq!(window_function.params.order_by.len(), 1)
230            }
231            _ => panic!("expr was not a WindowFunction"),
232        };
233
234        Ok(())
235    }
236
237    #[tokio::test]
238    async fn window_function_with_count() -> datafusion::common::Result<()> {
239        let substrait = Expression {
240            rex_type: Some(RexType::WindowFunction(
241                substrait::proto::expression::WindowFunction {
242                    function_reference: 0,
243                    ..Default::default()
244                },
245            )),
246        };
247
248        let mut consumer = test_consumer();
249
250        let mut extensions = Extensions::default();
251        extensions.register_function("count");
252        consumer.extensions = &extensions;
253
254        match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
255            Expr::WindowFunction(window_function) => {
256                assert_eq!(window_function.params.args.len(), 1)
257            }
258            _ => panic!("expr was not a WindowFunction"),
259        };
260
261        Ok(())
262    }
263}