mod aggregate_function;
mod cast;
mod field_reference;
mod function_arguments;
mod if_then;
mod literal;
mod scalar_function;
mod singular_or_list;
mod subquery;
mod window_function;
pub use aggregate_function::*;
pub use cast::*;
pub use field_reference::*;
pub use function_arguments::*;
pub use if_then::*;
pub use literal::*;
pub use scalar_function::*;
pub use singular_or_list::*;
pub use subquery::*;
pub use window_function::*;
use crate::extensions::Extensions;
use crate::logical_plan::consumer::{
DefaultSubstraitConsumer, SubstraitConsumer, from_substrait_named_struct,
rename_field,
};
use datafusion::arrow::datatypes::Field;
use datafusion::common::{DFSchema, DFSchemaRef, not_impl_err, plan_err, substrait_err};
use datafusion::execution::SessionState;
use datafusion::logical_expr::{Expr, ExprSchemable};
use substrait::proto::expression::RexType;
use substrait::proto::expression_reference::ExprType;
use substrait::proto::{Expression, ExtendedExpression};
pub async fn from_substrait_rex(
consumer: &impl SubstraitConsumer,
expression: &Expression,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
match &expression.rex_type {
Some(t) => match t {
RexType::Literal(expr) => consumer.consume_literal(expr).await,
RexType::Selection(expr) => {
consumer.consume_field_reference(expr, input_schema).await
}
RexType::ScalarFunction(expr) => {
consumer.consume_scalar_function(expr, input_schema).await
}
RexType::WindowFunction(expr) => {
consumer.consume_window_function(expr, input_schema).await
}
RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await,
RexType::SwitchExpression(expr) => {
consumer.consume_switch(expr, input_schema).await
}
RexType::SingularOrList(expr) => {
consumer.consume_singular_or_list(expr, input_schema).await
}
RexType::MultiOrList(expr) => {
consumer.consume_multi_or_list(expr, input_schema).await
}
RexType::Cast(expr) => {
consumer.consume_cast(expr.as_ref(), input_schema).await
}
RexType::Subquery(expr) => {
consumer.consume_subquery(expr.as_ref(), input_schema).await
}
RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await,
#[expect(deprecated)]
RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await,
RexType::DynamicParameter(expr) => {
consumer.consume_dynamic_parameter(expr, input_schema).await
}
},
None => substrait_err!("Expression must set rex_type: {expression:?}"),
}
}
pub async fn from_substrait_extended_expr(
state: &SessionState,
extended_expr: &ExtendedExpression,
) -> datafusion::common::Result<ExprContainer> {
let extensions = Extensions::try_from(&extended_expr.extensions)?;
if !extensions.type_variations.is_empty() {
return not_impl_err!("Type variation extensions are not supported");
}
let consumer = DefaultSubstraitConsumer::new(&extensions, state);
let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
Some(base_schema) => from_substrait_named_struct(&consumer, base_schema),
None => {
plan_err!(
"required property `base_schema` missing from Substrait ExtendedExpression message"
)
}
}?);
let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len());
for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() {
let scalar_expr = match &substrait_expr.expr_type {
Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr),
Some(ExprType::Measure(_)) => {
not_impl_err!("Measure expressions are not yet supported")
}
None => {
plan_err!(
"required property `expr_type` missing from Substrait ExpressionReference message"
)
}
}?;
let expr = consumer
.consume_expression(scalar_expr, &input_schema)
.await?;
let output_field = expr.to_field(&input_schema)?.1;
let mut names_idx = 0;
let output_field = rename_field(
&output_field,
&substrait_expr.output_names,
expr_idx,
&mut names_idx,
)?;
exprs.push((expr, output_field));
}
Ok(ExprContainer {
input_schema,
exprs,
})
}
pub struct ExprContainer {
pub input_schema: DFSchemaRef,
pub exprs: Vec<(Expr, Field)>,
}
pub async fn from_substrait_rex_vec(
consumer: &impl SubstraitConsumer,
exprs: &Vec<Expression>,
input_schema: &DFSchema,
) -> datafusion::common::Result<Vec<Expr>> {
let mut expressions: Vec<Expr> = vec![];
for expr in exprs {
let expression = consumer.consume_expression(expr, input_schema).await?;
expressions.push(expression);
}
Ok(expressions)
}
#[cfg(test)]
mod tests {
use crate::extensions::Extensions;
use crate::logical_plan::consumer::utils::tests::test_consumer;
use crate::logical_plan::consumer::*;
use datafusion::common::DFSchema;
use datafusion::logical_expr::Expr;
use substrait::proto::Expression;
use substrait::proto::expression::RexType;
use substrait::proto::expression::window_function::BoundsType;
#[tokio::test]
async fn window_function_with_range_unit_and_no_order_by()
-> datafusion::common::Result<()> {
let substrait = Expression {
rex_type: Some(RexType::WindowFunction(
substrait::proto::expression::WindowFunction {
function_reference: 0,
bounds_type: BoundsType::Range as i32,
sorts: vec![],
..Default::default()
},
)),
};
let mut consumer = test_consumer();
let mut extensions = Extensions::default();
extensions.register_function("count");
consumer.extensions = &extensions;
match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
Expr::WindowFunction(window_function) => {
assert_eq!(window_function.params.order_by.len(), 1)
}
_ => panic!("expr was not a WindowFunction"),
};
Ok(())
}
#[tokio::test]
async fn window_function_with_count() -> datafusion::common::Result<()> {
let substrait = Expression {
rex_type: Some(RexType::WindowFunction(
substrait::proto::expression::WindowFunction {
function_reference: 0,
..Default::default()
},
)),
};
let mut consumer = test_consumer();
let mut extensions = Extensions::default();
extensions.register_function("count");
consumer.extensions = &extensions;
match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
Expr::WindowFunction(window_function) => {
assert_eq!(window_function.params.args.len(), 1)
}
_ => panic!("expr was not a WindowFunction"),
};
Ok(())
}
}