mod aggregate_function;
mod cast;
mod field_reference;
mod if_then;
mod literal;
mod scalar_function;
mod singular_or_list;
mod subquery;
mod window_function;
pub use aggregate_function::*;
pub use cast::*;
pub use field_reference::*;
pub use if_then::*;
pub use literal::*;
pub use scalar_function::*;
pub use singular_or_list::*;
pub use subquery::*;
pub use window_function::*;
use crate::logical_plan::producer::utils::flatten_names;
use crate::logical_plan::producer::{
DefaultSubstraitProducer, SubstraitProducer, to_substrait_named_struct,
};
use datafusion::arrow::datatypes::Field;
use datafusion::common::{DFSchemaRef, internal_err, not_impl_err};
use datafusion::execution::SessionState;
use datafusion::logical_expr::Expr;
use datafusion::logical_expr::expr::Alias;
use substrait::proto::expression_reference::ExprType;
use substrait::proto::{Expression, ExpressionReference, ExtendedExpression};
use substrait::version;
#[expect(deprecated)]
pub fn to_substrait_extended_expr(
exprs: &[(&Expr, &Field)],
schema: &DFSchemaRef,
state: &SessionState,
) -> datafusion::common::Result<Box<ExtendedExpression>> {
let mut producer = DefaultSubstraitProducer::new(state);
let substrait_exprs = exprs
.iter()
.map(|(expr, field)| {
let substrait_expr = producer.handle_expr(expr, schema)?;
let mut output_names = Vec::new();
flatten_names(field, false, &mut output_names)?;
Ok(ExpressionReference {
output_names,
expr_type: Some(ExprType::Expression(substrait_expr)),
})
})
.collect::<datafusion::common::Result<Vec<_>>>()?;
let substrait_schema = to_substrait_named_struct(&mut producer, schema)?;
let extensions = producer.get_extensions();
Ok(Box::new(ExtendedExpression {
advanced_extensions: None,
expected_type_urls: vec![],
extension_uris: vec![],
extension_urns: vec![],
extensions: extensions.into(),
version: Some(version::version_with_producer("datafusion")),
referred_expr: substrait_exprs,
base_schema: Some(substrait_schema),
}))
}
pub fn to_substrait_rex(
producer: &mut impl SubstraitProducer,
expr: &Expr,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
match expr {
Expr::Alias(expr) => producer.handle_alias(expr, schema),
Expr::Column(expr) => producer.handle_column(expr, schema),
Expr::ScalarVariable(_, _) => {
not_impl_err!("Cannot convert {expr:?} to Substrait")
}
Expr::Literal(expr, _) => producer.handle_literal(expr),
Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema),
Expr::Like(expr) => producer.handle_like(expr, schema),
Expr::SimilarTo(_) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
Expr::Not(_) => producer.handle_unary_expr(expr, schema),
Expr::IsNotNull(_) => producer.handle_unary_expr(expr, schema),
Expr::IsNull(_) => producer.handle_unary_expr(expr, schema),
Expr::IsTrue(_) => producer.handle_unary_expr(expr, schema),
Expr::IsFalse(_) => producer.handle_unary_expr(expr, schema),
Expr::IsUnknown(_) => producer.handle_unary_expr(expr, schema),
Expr::IsNotTrue(_) => producer.handle_unary_expr(expr, schema),
Expr::IsNotFalse(_) => producer.handle_unary_expr(expr, schema),
Expr::IsNotUnknown(_) => producer.handle_unary_expr(expr, schema),
Expr::Negative(_) => producer.handle_unary_expr(expr, schema),
Expr::Between(expr) => producer.handle_between(expr, schema),
Expr::Case(expr) => producer.handle_case(expr, schema),
Expr::Cast(expr) => producer.handle_cast(expr, schema),
Expr::TryCast(expr) => producer.handle_try_cast(expr, schema),
Expr::ScalarFunction(expr) => producer.handle_scalar_function(expr, schema),
Expr::AggregateFunction(_) => {
internal_err!(
"AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate"
)
}
Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema),
Expr::InList(expr) => producer.handle_in_list(expr, schema),
Expr::Exists(expr) => producer.handle_exists(expr, schema),
Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema),
Expr::SetComparison(expr) => producer.handle_set_comparison(expr, schema),
Expr::ScalarSubquery(expr) => producer.handle_scalar_subquery(expr, schema),
#[expect(deprecated)]
Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"),
Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
Expr::OuterReferenceColumn(_, _) => {
not_impl_err!("Cannot convert {expr:?} to Substrait")
}
Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
}
}
pub fn from_alias(
producer: &mut impl SubstraitProducer,
alias: &Alias,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
producer.handle_expr(alias.expr.as_ref(), schema)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::consumer::from_substrait_extended_expr;
use datafusion::arrow::datatypes::{DataType, Schema};
use datafusion::common::{DFSchema, DataFusionError, ScalarValue};
use datafusion::execution::SessionStateBuilder;
#[tokio::test]
async fn extended_expressions() -> datafusion::common::Result<()> {
let state = SessionStateBuilder::default().build();
let expr = Expr::Literal(ScalarValue::Int32(Some(42)), None);
let field = Field::new("out", DataType::Int32, false);
let empty_schema = DFSchemaRef::new(DFSchema::empty());
let substrait =
to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)?;
let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?;
assert_eq!(roundtrip_expr.input_schema, empty_schema);
assert_eq!(roundtrip_expr.exprs.len(), 1);
let (rt_expr, rt_field) = roundtrip_expr.exprs.first().unwrap();
assert_eq!(rt_field, &field);
assert_eq!(rt_expr, &expr);
let expr1 = Expr::Column("c0".into());
let expr2 = Expr::Column("c1".into());
let out1 = Field::new("out1", DataType::Int32, true);
let out2 = Field::new("out2", DataType::Utf8, true);
let input_schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![
Field::new("c0", DataType::Int32, true),
Field::new("c1", DataType::Utf8, true),
]))?);
let substrait = to_substrait_extended_expr(
&[(&expr1, &out1), (&expr2, &out2)],
&input_schema,
&state,
)?;
let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?;
assert_eq!(roundtrip_expr.input_schema, input_schema);
assert_eq!(roundtrip_expr.exprs.len(), 2);
let mut exprs = roundtrip_expr.exprs.into_iter();
let (rt_expr, rt_field) = exprs.next().unwrap();
assert_eq!(rt_field, out1);
assert_eq!(rt_expr, expr1);
let (rt_expr, rt_field) = exprs.next().unwrap();
assert_eq!(rt_field, out2);
assert_eq!(rt_expr, expr2);
Ok(())
}
#[tokio::test]
async fn invalid_extended_expression() {
let state = SessionStateBuilder::default().build();
let expr = Expr::Column("missing".into());
let field = Field::new("out", DataType::Int32, false);
let empty_schema = DFSchemaRef::new(DFSchema::empty());
let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state);
assert!(matches!(err, Err(DataFusionError::SchemaError(_, _))));
}
}