datafusion_substrait/logical_plan/consumer/expr/
mod.rs1mod aggregate_function;
19mod cast;
20mod field_reference;
21mod function_arguments;
22mod if_then;
23mod literal;
24mod scalar_function;
25mod singular_or_list;
26mod subquery;
27mod window_function;
28
29pub use aggregate_function::*;
30pub use cast::*;
31pub use field_reference::*;
32pub use function_arguments::*;
33pub use if_then::*;
34pub use literal::*;
35pub use scalar_function::*;
36pub use singular_or_list::*;
37pub use subquery::*;
38pub use window_function::*;
39
40use crate::extensions::Extensions;
41use crate::logical_plan::consumer::{
42 from_substrait_named_struct, rename_field, DefaultSubstraitConsumer,
43 SubstraitConsumer,
44};
45use datafusion::arrow::datatypes::Field;
46use datafusion::common::{not_impl_err, plan_err, substrait_err, DFSchema, DFSchemaRef};
47use datafusion::execution::SessionState;
48use datafusion::logical_expr::{Expr, ExprSchemable};
49use substrait::proto::expression::RexType;
50use substrait::proto::expression_reference::ExprType;
51use substrait::proto::{Expression, ExtendedExpression};
52
53pub async fn from_substrait_rex(
55 consumer: &impl SubstraitConsumer,
56 expression: &Expression,
57 input_schema: &DFSchema,
58) -> datafusion::common::Result<Expr> {
59 match &expression.rex_type {
60 Some(t) => match t {
61 RexType::Literal(expr) => consumer.consume_literal(expr).await,
62 RexType::Selection(expr) => {
63 consumer.consume_field_reference(expr, input_schema).await
64 }
65 RexType::ScalarFunction(expr) => {
66 consumer.consume_scalar_function(expr, input_schema).await
67 }
68 RexType::WindowFunction(expr) => {
69 consumer.consume_window_function(expr, input_schema).await
70 }
71 RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await,
72 RexType::SwitchExpression(expr) => {
73 consumer.consume_switch(expr, input_schema).await
74 }
75 RexType::SingularOrList(expr) => {
76 consumer.consume_singular_or_list(expr, input_schema).await
77 }
78
79 RexType::MultiOrList(expr) => {
80 consumer.consume_multi_or_list(expr, input_schema).await
81 }
82
83 RexType::Cast(expr) => {
84 consumer.consume_cast(expr.as_ref(), input_schema).await
85 }
86
87 RexType::Subquery(expr) => {
88 consumer.consume_subquery(expr.as_ref(), input_schema).await
89 }
90 RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await,
91 RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await,
92 RexType::DynamicParameter(expr) => {
93 consumer.consume_dynamic_parameter(expr, input_schema).await
94 }
95 },
96 None => substrait_err!("Expression must set rex_type: {expression:?}"),
97 }
98}
99
100pub async fn from_substrait_extended_expr(
110 state: &SessionState,
111 extended_expr: &ExtendedExpression,
112) -> datafusion::common::Result<ExprContainer> {
113 let extensions = Extensions::try_from(&extended_expr.extensions)?;
115 if !extensions.type_variations.is_empty() {
116 return not_impl_err!("Type variation extensions are not supported");
117 }
118
119 let consumer = DefaultSubstraitConsumer {
120 extensions: &extensions,
121 state,
122 };
123
124 let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
125 Some(base_schema) => from_substrait_named_struct(&consumer, base_schema),
126 None => {
127 plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message")
128 }
129 }?);
130
131 let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len());
133 for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() {
134 let scalar_expr = match &substrait_expr.expr_type {
135 Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr),
136 Some(ExprType::Measure(_)) => {
137 not_impl_err!("Measure expressions are not yet supported")
138 }
139 None => {
140 plan_err!("required property `expr_type` missing from Substrait ExpressionReference message")
141 }
142 }?;
143 let expr = consumer
144 .consume_expression(scalar_expr, &input_schema)
145 .await?;
146 let (output_type, expected_nullability) =
147 expr.data_type_and_nullable(&input_schema)?;
148 let output_field = Field::new("", output_type, expected_nullability);
149 let mut names_idx = 0;
150 let output_field = rename_field(
151 &output_field,
152 &substrait_expr.output_names,
153 expr_idx,
154 &mut names_idx,
155 )?;
156 exprs.push((expr, output_field));
157 }
158
159 Ok(ExprContainer {
160 input_schema,
161 exprs,
162 })
163}
164
165pub struct ExprContainer {
172 pub input_schema: DFSchemaRef,
174 pub exprs: Vec<(Expr, Field)>,
178}
179
180pub async fn from_substrait_rex_vec(
182 consumer: &impl SubstraitConsumer,
183 exprs: &Vec<Expression>,
184 input_schema: &DFSchema,
185) -> datafusion::common::Result<Vec<Expr>> {
186 let mut expressions: Vec<Expr> = vec![];
187 for expr in exprs {
188 let expression = consumer.consume_expression(expr, input_schema).await?;
189 expressions.push(expression);
190 }
191 Ok(expressions)
192}
193
194#[cfg(test)]
195mod tests {
196 use crate::extensions::Extensions;
197 use crate::logical_plan::consumer::utils::tests::test_consumer;
198 use crate::logical_plan::consumer::*;
199 use datafusion::common::DFSchema;
200 use datafusion::logical_expr::Expr;
201 use substrait::proto::expression::window_function::BoundsType;
202 use substrait::proto::expression::RexType;
203 use substrait::proto::Expression;
204
205 #[tokio::test]
206 async fn window_function_with_range_unit_and_no_order_by(
207 ) -> datafusion::common::Result<()> {
208 let substrait = Expression {
209 rex_type: Some(RexType::WindowFunction(
210 substrait::proto::expression::WindowFunction {
211 function_reference: 0,
212 bounds_type: BoundsType::Range as i32,
213 sorts: vec![],
214 ..Default::default()
215 },
216 )),
217 };
218
219 let mut consumer = test_consumer();
220
221 let mut extensions = Extensions::default();
224 extensions.register_function("count".to_string());
225 consumer.extensions = &extensions;
226
227 match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
228 Expr::WindowFunction(window_function) => {
229 assert_eq!(window_function.params.order_by.len(), 1)
230 }
231 _ => panic!("expr was not a WindowFunction"),
232 };
233
234 Ok(())
235 }
236
237 #[tokio::test]
238 async fn window_function_with_count() -> datafusion::common::Result<()> {
239 let substrait = Expression {
240 rex_type: Some(RexType::WindowFunction(
241 substrait::proto::expression::WindowFunction {
242 function_reference: 0,
243 ..Default::default()
244 },
245 )),
246 };
247
248 let mut consumer = test_consumer();
249
250 let mut extensions = Extensions::default();
251 extensions.register_function("count".to_string());
252 consumer.extensions = &extensions;
253
254 match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
255 Expr::WindowFunction(window_function) => {
256 assert_eq!(window_function.params.args.len(), 1)
257 }
258 _ => panic!("expr was not a WindowFunction"),
259 };
260
261 Ok(())
262 }
263}