datafusion_substrait/logical_plan/consumer/expr/
mod.rs1mod aggregate_function;
19mod cast;
20mod field_reference;
21mod function_arguments;
22mod if_then;
23mod literal;
24mod nested;
25mod scalar_function;
26mod singular_or_list;
27mod subquery;
28mod window_function;
29
30pub use aggregate_function::*;
31pub use cast::*;
32pub use field_reference::*;
33pub use function_arguments::*;
34pub use if_then::*;
35pub use literal::*;
36pub use nested::*;
37pub use scalar_function::*;
38pub use singular_or_list::*;
39pub use subquery::*;
40pub use window_function::*;
41
42use crate::extensions::Extensions;
43use crate::logical_plan::consumer::{
44 DefaultSubstraitConsumer, SubstraitConsumer, from_substrait_named_struct,
45 rename_field,
46};
47use datafusion::arrow::datatypes::Field;
48use datafusion::common::{DFSchema, DFSchemaRef, not_impl_err, plan_err, substrait_err};
49use datafusion::execution::SessionState;
50use datafusion::logical_expr::{Expr, ExprSchemable};
51use substrait::proto::expression::RexType;
52use substrait::proto::expression_reference::ExprType;
53use substrait::proto::{Expression, ExtendedExpression};
54
55pub async fn from_substrait_rex(
57 consumer: &impl SubstraitConsumer,
58 expression: &Expression,
59 input_schema: &DFSchema,
60) -> datafusion::common::Result<Expr> {
61 match &expression.rex_type {
62 Some(t) => match t {
63 RexType::Literal(expr) => consumer.consume_literal(expr).await,
64 RexType::Selection(expr) => {
65 consumer.consume_field_reference(expr, input_schema).await
66 }
67 RexType::ScalarFunction(expr) => {
68 consumer.consume_scalar_function(expr, input_schema).await
69 }
70 RexType::WindowFunction(expr) => {
71 consumer.consume_window_function(expr, input_schema).await
72 }
73 RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await,
74 RexType::SwitchExpression(expr) => {
75 consumer.consume_switch(expr, input_schema).await
76 }
77 RexType::SingularOrList(expr) => {
78 consumer.consume_singular_or_list(expr, input_schema).await
79 }
80
81 RexType::MultiOrList(expr) => {
82 consumer.consume_multi_or_list(expr, input_schema).await
83 }
84
85 RexType::Cast(expr) => {
86 consumer.consume_cast(expr.as_ref(), input_schema).await
87 }
88
89 RexType::Subquery(expr) => {
90 consumer.consume_subquery(expr.as_ref(), input_schema).await
91 }
92 RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await,
93 #[expect(deprecated)]
94 RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await,
95 RexType::DynamicParameter(expr) => {
96 consumer.consume_dynamic_parameter(expr, input_schema).await
97 }
98 RexType::Lambda(_) | RexType::LambdaInvocation(_) => {
99 not_impl_err!("Lambda expressions are not yet supported")
100 }
101 },
102 None => substrait_err!("Expression must set rex_type: {expression:?}"),
103 }
104}
105
106pub async fn from_substrait_extended_expr(
116 state: &SessionState,
117 extended_expr: &ExtendedExpression,
118) -> datafusion::common::Result<ExprContainer> {
119 let extensions = Extensions::try_from(&extended_expr.extensions)?;
121 if !extensions.type_variations.is_empty() {
122 return not_impl_err!("Type variation extensions are not supported");
123 }
124
125 let consumer = DefaultSubstraitConsumer::new(&extensions, state);
126
127 let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
128 Some(base_schema) => from_substrait_named_struct(&consumer, base_schema),
129 None => {
130 plan_err!(
131 "required property `base_schema` missing from Substrait ExtendedExpression message"
132 )
133 }
134 }?);
135
136 let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len());
138 for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() {
139 let scalar_expr = match &substrait_expr.expr_type {
140 Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr),
141 Some(ExprType::Measure(_)) => {
142 not_impl_err!("Measure expressions are not yet supported")
143 }
144 None => {
145 plan_err!(
146 "required property `expr_type` missing from Substrait ExpressionReference message"
147 )
148 }
149 }?;
150 let expr = consumer
151 .consume_expression(scalar_expr, &input_schema)
152 .await?;
153 let output_field = expr.to_field(&input_schema)?.1;
154 let mut names_idx = 0;
155 let output_field = rename_field(
156 &output_field,
157 &substrait_expr.output_names,
158 expr_idx,
159 &mut names_idx,
160 )?;
161 exprs.push((expr, output_field));
162 }
163
164 Ok(ExprContainer {
165 input_schema,
166 exprs,
167 })
168}
169
170pub struct ExprContainer {
177 pub input_schema: DFSchemaRef,
179 pub exprs: Vec<(Expr, Field)>,
183}
184
185pub async fn from_substrait_rex_vec(
187 consumer: &impl SubstraitConsumer,
188 exprs: &Vec<Expression>,
189 input_schema: &DFSchema,
190) -> datafusion::common::Result<Vec<Expr>> {
191 let mut expressions: Vec<Expr> = vec![];
192 for expr in exprs {
193 let expression = consumer.consume_expression(expr, input_schema).await?;
194 expressions.push(expression);
195 }
196 Ok(expressions)
197}
198
199#[cfg(test)]
200mod tests {
201 use crate::extensions::Extensions;
202 use crate::logical_plan::consumer::utils::tests::test_consumer;
203 use crate::logical_plan::consumer::*;
204 use datafusion::common::DFSchema;
205 use datafusion::logical_expr::Expr;
206 use substrait::proto::Expression;
207 use substrait::proto::expression::RexType;
208 use substrait::proto::expression::window_function::BoundsType;
209
210 #[tokio::test]
211 async fn window_function_with_range_unit_and_no_order_by()
212 -> datafusion::common::Result<()> {
213 let substrait = Expression {
214 rex_type: Some(RexType::WindowFunction(
215 substrait::proto::expression::WindowFunction {
216 function_reference: 0,
217 bounds_type: BoundsType::Range as i32,
218 sorts: vec![],
219 ..Default::default()
220 },
221 )),
222 };
223
224 let mut consumer = test_consumer();
225
226 let mut extensions = Extensions::default();
229 extensions.register_function("count");
230 consumer.extensions = &extensions;
231
232 match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
233 Expr::WindowFunction(window_function) => {
234 assert_eq!(window_function.params.order_by.len(), 1)
235 }
236 _ => panic!("expr was not a WindowFunction"),
237 };
238
239 Ok(())
240 }
241
242 #[tokio::test]
243 async fn window_function_with_count() -> datafusion::common::Result<()> {
244 let substrait = Expression {
245 rex_type: Some(RexType::WindowFunction(
246 substrait::proto::expression::WindowFunction {
247 function_reference: 0,
248 ..Default::default()
249 },
250 )),
251 };
252
253 let mut consumer = test_consumer();
254
255 let mut extensions = Extensions::default();
256 extensions.register_function("count");
257 consumer.extensions = &extensions;
258
259 match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
260 Expr::WindowFunction(window_function) => {
261 assert_eq!(window_function.params.args.len(), 1)
262 }
263 _ => panic!("expr was not a WindowFunction"),
264 };
265
266 Ok(())
267 }
268}