datafusion_substrait/logical_plan/consumer/expr/
mod.rs1mod aggregate_function;
19mod cast;
20mod field_reference;
21mod function_arguments;
22mod if_then;
23mod literal;
24mod scalar_function;
25mod singular_or_list;
26mod subquery;
27mod window_function;
28
29pub use aggregate_function::*;
30pub use cast::*;
31pub use field_reference::*;
32pub use function_arguments::*;
33pub use if_then::*;
34pub use literal::*;
35pub use scalar_function::*;
36pub use singular_or_list::*;
37pub use subquery::*;
38pub use window_function::*;
39
40use crate::extensions::Extensions;
41use crate::logical_plan::consumer::{
42 DefaultSubstraitConsumer, SubstraitConsumer, from_substrait_named_struct,
43 rename_field,
44};
45use datafusion::arrow::datatypes::Field;
46use datafusion::common::{DFSchema, DFSchemaRef, not_impl_err, plan_err, substrait_err};
47use datafusion::execution::SessionState;
48use datafusion::logical_expr::{Expr, ExprSchemable};
49use substrait::proto::expression::RexType;
50use substrait::proto::expression_reference::ExprType;
51use substrait::proto::{Expression, ExtendedExpression};
52
53pub async fn from_substrait_rex(
55 consumer: &impl SubstraitConsumer,
56 expression: &Expression,
57 input_schema: &DFSchema,
58) -> datafusion::common::Result<Expr> {
59 match &expression.rex_type {
60 Some(t) => match t {
61 RexType::Literal(expr) => consumer.consume_literal(expr).await,
62 RexType::Selection(expr) => {
63 consumer.consume_field_reference(expr, input_schema).await
64 }
65 RexType::ScalarFunction(expr) => {
66 consumer.consume_scalar_function(expr, input_schema).await
67 }
68 RexType::WindowFunction(expr) => {
69 consumer.consume_window_function(expr, input_schema).await
70 }
71 RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await,
72 RexType::SwitchExpression(expr) => {
73 consumer.consume_switch(expr, input_schema).await
74 }
75 RexType::SingularOrList(expr) => {
76 consumer.consume_singular_or_list(expr, input_schema).await
77 }
78
79 RexType::MultiOrList(expr) => {
80 consumer.consume_multi_or_list(expr, input_schema).await
81 }
82
83 RexType::Cast(expr) => {
84 consumer.consume_cast(expr.as_ref(), input_schema).await
85 }
86
87 RexType::Subquery(expr) => {
88 consumer.consume_subquery(expr.as_ref(), input_schema).await
89 }
90 RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await,
91 RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await,
92 RexType::DynamicParameter(expr) => {
93 consumer.consume_dynamic_parameter(expr, input_schema).await
94 }
95 },
96 None => substrait_err!("Expression must set rex_type: {expression:?}"),
97 }
98}
99
100pub async fn from_substrait_extended_expr(
110 state: &SessionState,
111 extended_expr: &ExtendedExpression,
112) -> datafusion::common::Result<ExprContainer> {
113 let extensions = Extensions::try_from(&extended_expr.extensions)?;
115 if !extensions.type_variations.is_empty() {
116 return not_impl_err!("Type variation extensions are not supported");
117 }
118
119 let consumer = DefaultSubstraitConsumer {
120 extensions: &extensions,
121 state,
122 };
123
124 let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
125 Some(base_schema) => from_substrait_named_struct(&consumer, base_schema),
126 None => {
127 plan_err!(
128 "required property `base_schema` missing from Substrait ExtendedExpression message"
129 )
130 }
131 }?);
132
133 let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len());
135 for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() {
136 let scalar_expr = match &substrait_expr.expr_type {
137 Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr),
138 Some(ExprType::Measure(_)) => {
139 not_impl_err!("Measure expressions are not yet supported")
140 }
141 None => {
142 plan_err!(
143 "required property `expr_type` missing from Substrait ExpressionReference message"
144 )
145 }
146 }?;
147 let expr = consumer
148 .consume_expression(scalar_expr, &input_schema)
149 .await?;
150 let output_field = expr.to_field(&input_schema)?.1;
151 let mut names_idx = 0;
152 let output_field = rename_field(
153 &output_field,
154 &substrait_expr.output_names,
155 expr_idx,
156 &mut names_idx,
157 )?;
158 exprs.push((expr, output_field));
159 }
160
161 Ok(ExprContainer {
162 input_schema,
163 exprs,
164 })
165}
166
167pub struct ExprContainer {
174 pub input_schema: DFSchemaRef,
176 pub exprs: Vec<(Expr, Field)>,
180}
181
182pub async fn from_substrait_rex_vec(
184 consumer: &impl SubstraitConsumer,
185 exprs: &Vec<Expression>,
186 input_schema: &DFSchema,
187) -> datafusion::common::Result<Vec<Expr>> {
188 let mut expressions: Vec<Expr> = vec![];
189 for expr in exprs {
190 let expression = consumer.consume_expression(expr, input_schema).await?;
191 expressions.push(expression);
192 }
193 Ok(expressions)
194}
195
196#[cfg(test)]
197mod tests {
198 use crate::extensions::Extensions;
199 use crate::logical_plan::consumer::utils::tests::test_consumer;
200 use crate::logical_plan::consumer::*;
201 use datafusion::common::DFSchema;
202 use datafusion::logical_expr::Expr;
203 use substrait::proto::Expression;
204 use substrait::proto::expression::RexType;
205 use substrait::proto::expression::window_function::BoundsType;
206
207 #[tokio::test]
208 async fn window_function_with_range_unit_and_no_order_by()
209 -> datafusion::common::Result<()> {
210 let substrait = Expression {
211 rex_type: Some(RexType::WindowFunction(
212 substrait::proto::expression::WindowFunction {
213 function_reference: 0,
214 bounds_type: BoundsType::Range as i32,
215 sorts: vec![],
216 ..Default::default()
217 },
218 )),
219 };
220
221 let mut consumer = test_consumer();
222
223 let mut extensions = Extensions::default();
226 extensions.register_function("count");
227 consumer.extensions = &extensions;
228
229 match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
230 Expr::WindowFunction(window_function) => {
231 assert_eq!(window_function.params.order_by.len(), 1)
232 }
233 _ => panic!("expr was not a WindowFunction"),
234 };
235
236 Ok(())
237 }
238
239 #[tokio::test]
240 async fn window_function_with_count() -> datafusion::common::Result<()> {
241 let substrait = Expression {
242 rex_type: Some(RexType::WindowFunction(
243 substrait::proto::expression::WindowFunction {
244 function_reference: 0,
245 ..Default::default()
246 },
247 )),
248 };
249
250 let mut consumer = test_consumer();
251
252 let mut extensions = Extensions::default();
253 extensions.register_function("count");
254 consumer.extensions = &extensions;
255
256 match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
257 Expr::WindowFunction(window_function) => {
258 assert_eq!(window_function.params.args.len(), 1)
259 }
260 _ => panic!("expr was not a WindowFunction"),
261 };
262
263 Ok(())
264 }
265}