use std::{collections::HashMap, sync::Arc};
use datafusion::{
error::{DataFusionError, Result},
prelude::JoinType,
scalar::ScalarValue,
};
use datafusion::common::DFSchemaRef;
#[allow(unused_imports)]
use datafusion::logical_expr::aggregate_function;
use datafusion::logical_expr::expr::{BinaryExpr, Case, Sort};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
use datafusion::prelude::{binary_expr, Expr};
use substrait::proto::{
aggregate_function::AggregationInvocation,
aggregate_rel::{Grouping, Measure},
expression::{
field_reference::ReferenceType,
if_then::IfClause,
literal::{Decimal, LiteralType},
mask_expression::{StructItem, StructSelect},
reference_segment, FieldReference, IfThen, Literal, MaskExpression,
ReferenceSegment, RexType, ScalarFunction,
},
extensions::{
self,
simple_extension_declaration::{ExtensionFunction, MappingType},
},
function_argument::ArgType,
join_rel, plan_rel, r#type,
read_rel::{NamedTable, ReadType},
rel::RelType,
sort_field::{SortDirection, SortKind},
AggregateFunction, AggregateRel, AggregationPhase, Expression, FetchRel, FilterRel,
FunctionArgument, JoinRel, NamedStruct, Plan, PlanRel, ProjectRel, ReadRel, Rel,
RelRoot, SortField, SortRel,
};
pub fn to_substrait_plan(plan: &LogicalPlan) -> Result<Box<Plan>> {
let mut extension_info: (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
) = (vec![], HashMap::new());
let plan_rels = vec![PlanRel {
rel_type: Some(plan_rel::RelType::Root(RelRoot {
input: Some(*to_substrait_rel(plan, &mut extension_info)?),
names: plan.schema().field_names(),
})),
}];
let (function_extensions, _) = extension_info;
Ok(Box::new(Plan {
version: None, extension_uris: vec![],
extensions: function_extensions,
relations: plan_rels,
advanced_extensions: None,
expected_type_urls: vec![],
}))
}
pub fn to_substrait_rel(
plan: &LogicalPlan,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<Box<Rel>> {
match plan {
LogicalPlan::TableScan(scan) => {
let projection = scan.projection.as_ref().map(|p| {
p.iter()
.map(|i| StructItem {
field: *i as i32,
child: None,
})
.collect()
});
if let Some(struct_items) = projection {
Ok(Box::new(Rel {
rel_type: Some(RelType::Read(Box::new(ReadRel {
common: None,
base_schema: Some(NamedStruct {
names: scan
.source
.schema()
.fields()
.iter()
.map(|f| f.name().to_owned())
.collect(),
r#struct: None,
}),
filter: None,
best_effort_filter: None,
projection: Some(MaskExpression {
select: Some(StructSelect { struct_items }),
maintain_singular_struct: false,
}),
advanced_extension: None,
read_type: Some(ReadType::NamedTable(NamedTable {
names: vec![scan.table_name.clone()],
advanced_extension: None,
})),
}))),
}))
} else {
Err(DataFusionError::NotImplemented(
"TableScan without projection is not supported".to_string(),
))
}
}
LogicalPlan::Projection(p) => {
let expressions = p
.expr
.iter()
.map(|e| to_substrait_rex(e, p.input.schema(), extension_info))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Project(Box::new(ProjectRel {
common: None,
input: Some(to_substrait_rel(p.input.as_ref(), extension_info)?),
expressions,
advanced_extension: None,
}))),
}))
}
LogicalPlan::Filter(filter) => {
let input = to_substrait_rel(filter.input.as_ref(), extension_info)?;
let filter_expr = to_substrait_rex(
&filter.predicate,
filter.input.schema(),
extension_info,
)?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Filter(Box::new(FilterRel {
common: None,
input: Some(input),
condition: Some(Box::new(filter_expr)),
advanced_extension: None,
}))),
}))
}
LogicalPlan::Limit(limit) => {
let input = to_substrait_rel(limit.input.as_ref(), extension_info)?;
let limit_fetch = limit.fetch.unwrap_or(0);
Ok(Box::new(Rel {
rel_type: Some(RelType::Fetch(Box::new(FetchRel {
common: None,
input: Some(input),
offset: limit.skip as i64,
count: limit_fetch as i64,
advanced_extension: None,
}))),
}))
}
LogicalPlan::Sort(sort) => {
let input = to_substrait_rel(sort.input.as_ref(), extension_info)?;
let sort_fields = sort
.expr
.iter()
.map(|e| substrait_sort_field(e, sort.input.schema(), extension_info))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Sort(Box::new(SortRel {
common: None,
input: Some(input),
sorts: sort_fields,
advanced_extension: None,
}))),
}))
}
LogicalPlan::Aggregate(agg) => {
let input = to_substrait_rel(agg.input.as_ref(), extension_info)?;
let grouping = agg
.group_expr
.iter()
.map(|e| to_substrait_rex(e, agg.input.schema(), extension_info))
.collect::<Result<Vec<_>>>()?;
let measures = agg
.aggr_expr
.iter()
.map(|e| to_substrait_agg_measure(e, agg.input.schema(), extension_info))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
groupings: vec![Grouping {
grouping_expressions: grouping,
}], measures,
advanced_extension: None,
}))),
}))
}
LogicalPlan::Distinct(distinct) => {
let input = to_substrait_rel(distinct.input.as_ref(), extension_info)?;
let grouping = (0..distinct.input.schema().fields().len())
.map(substrait_field_ref)
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
groupings: vec![Grouping {
grouping_expressions: grouping,
}],
measures: vec![],
advanced_extension: None,
}))),
}))
}
LogicalPlan::Join(join) => {
let left = to_substrait_rel(join.left.as_ref(), extension_info)?;
let right = to_substrait_rel(join.right.as_ref(), extension_info)?;
let join_type = to_substrait_jointype(join.join_type);
if join.filter.is_some() {
return Err(DataFusionError::NotImplemented("join filter".to_string()));
}
match join.join_constraint {
JoinConstraint::On => {}
_ => {
return Err(DataFusionError::NotImplemented(
"join constraint".to_string(),
))
}
}
let eq_op = if join.null_equals_null {
Operator::IsNotDistinctFrom
} else {
Operator::Eq
};
let join_expression = join
.on
.iter()
.map(|(l, r)| binary_expr(l.clone(), eq_op, r.clone()))
.reduce(|acc: Expr, expr: Expr| acc.and(expr));
let join_schema = join.left.schema().join(join.right.schema());
if let Some(e) = join_expression {
Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
common: None,
left: Some(left),
right: Some(right),
r#type: join_type as i32,
expression: Some(Box::new(to_substrait_rex(
&e,
&Arc::new(join_schema?),
extension_info,
)?)),
post_join_filter: None,
advanced_extension: None,
}))),
}))
} else {
Err(DataFusionError::NotImplemented(
"Empty join condition".to_string(),
))
}
}
LogicalPlan::SubqueryAlias(alias) => {
to_substrait_rel(alias.input.as_ref(), extension_info)
}
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported operator: {plan:?}"
))),
}
}
fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType {
match join_type {
JoinType::Inner => join_rel::JoinType::Inner,
JoinType::Left => join_rel::JoinType::Left,
JoinType::Right => join_rel::JoinType::Right,
JoinType::Full => join_rel::JoinType::Outer,
JoinType::LeftAnti => join_rel::JoinType::Anti,
JoinType::LeftSemi => join_rel::JoinType::Semi,
JoinType::RightAnti | JoinType::RightSemi => unimplemented!(),
}
}
pub fn operator_to_name(op: Operator) -> &'static str {
match op {
Operator::Eq => "equal",
Operator::NotEq => "not_equal",
Operator::Lt => "lt",
Operator::LtEq => "lte",
Operator::Gt => "gt",
Operator::GtEq => "gte",
Operator::Plus => "add",
Operator::Minus => "substract",
Operator::Multiply => "multiply",
Operator::Divide => "divide",
Operator::Modulo => "mod",
Operator::And => "and",
Operator::Or => "or",
Operator::IsDistinctFrom => "is_distinct_from",
Operator::IsNotDistinctFrom => "is_not_distinct_from",
Operator::RegexMatch => "regex_match",
Operator::RegexIMatch => "regex_imatch",
Operator::RegexNotMatch => "regex_not_match",
Operator::RegexNotIMatch => "regex_not_imatch",
Operator::BitwiseAnd => "bitwise_and",
Operator::BitwiseOr => "bitwise_or",
Operator::StringConcat => "str_concat",
Operator::BitwiseXor => "bitwise_xor",
Operator::BitwiseShiftRight => "bitwise_shift_right",
Operator::BitwiseShiftLeft => "bitwise_shift_left",
}
}
#[allow(deprecated)]
pub fn to_substrait_agg_measure(
expr: &Expr,
schema: &DFSchemaRef,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<Measure> {
match expr {
Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter }) => {
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) });
}
let function_name = fun.to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
Ok(Measure {
measure: Some(AggregateFunction {
function_reference: function_anchor,
arguments,
sorts: vec![],
output_type: None,
invocation: match distinct {
true => AggregationInvocation::Distinct as i32,
false => AggregationInvocation::All as i32,
},
phase: AggregationPhase::Unspecified as i32,
args: vec![],
options: vec![],
}),
filter: match filter {
Some(f) => Some(to_substrait_rex(f, schema, extension_info)?),
None => None
}
})
}
Expr::Alias(expr, _name) => {
to_substrait_agg_measure(expr, schema, extension_info)
}
_ => Err(DataFusionError::Internal(format!(
"Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}",
expr,
expr.variant_name()
))),
}
}
fn _register_function(
function_name: String,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> u32 {
let (function_extensions, function_set) = extension_info;
let function_name = function_name.to_lowercase();
let function_anchor = match function_set.get(&function_name) {
Some(function_anchor) => {
*function_anchor
}
None => {
let function_anchor = function_set.len() as u32;
function_set.insert(function_name.clone(), function_anchor);
let function_extension = ExtensionFunction {
extension_uri_reference: u32::MAX,
function_anchor,
name: function_name,
};
let simple_extension = extensions::SimpleExtensionDeclaration {
mapping_type: Some(MappingType::ExtensionFunction(function_extension)),
};
function_extensions.push(simple_extension);
function_anchor
}
};
function_anchor
}
#[allow(deprecated)]
pub fn make_binary_op_scalar_func(
lhs: &Expression,
rhs: &Expression,
op: Operator,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Expression {
let function_name = operator_to_name(op).to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments: vec![
FunctionArgument {
arg_type: Some(ArgType::Value(lhs.clone())),
},
FunctionArgument {
arg_type: Some(ArgType::Value(rhs.clone())),
},
],
output_type: None,
args: vec![],
options: vec![],
})),
}
}
pub fn to_substrait_rex(
expr: &Expr,
schema: &DFSchemaRef,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<Expression> {
match expr {
Expr::Between(Between {
expr,
negated,
low,
high,
}) => {
if *negated {
let substrait_expr = to_substrait_rex(expr, schema, extension_info)?;
let substrait_low = to_substrait_rex(low, schema, extension_info)?;
let substrait_high = to_substrait_rex(high, schema, extension_info)?;
let l_expr = make_binary_op_scalar_func(
&substrait_expr,
&substrait_low,
Operator::Lt,
extension_info,
);
let r_expr = make_binary_op_scalar_func(
&substrait_high,
&substrait_expr,
Operator::Lt,
extension_info,
);
Ok(make_binary_op_scalar_func(
&l_expr,
&r_expr,
Operator::Or,
extension_info,
))
} else {
let substrait_expr = to_substrait_rex(expr, schema, extension_info)?;
let substrait_low = to_substrait_rex(low, schema, extension_info)?;
let substrait_high = to_substrait_rex(high, schema, extension_info)?;
let l_expr = make_binary_op_scalar_func(
&substrait_low,
&substrait_expr,
Operator::LtEq,
extension_info,
);
let r_expr = make_binary_op_scalar_func(
&substrait_expr,
&substrait_high,
Operator::LtEq,
extension_info,
);
Ok(make_binary_op_scalar_func(
&l_expr,
&r_expr,
Operator::And,
extension_info,
))
}
}
Expr::Column(col) => {
let index = schema.index_of_column(col)?;
substrait_field_ref(index)
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let l = to_substrait_rex(left, schema, extension_info)?;
let r = to_substrait_rex(right, schema, extension_info)?;
Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info))
}
Expr::Case(Case {
expr,
when_then_expr,
else_expr,
}) => {
let mut ifs: Vec<IfClause> = vec![];
if let Some(e) = expr {
ifs.push(IfClause {
r#if: Some(to_substrait_rex(e, schema, extension_info)?),
then: None,
});
}
for (r#if, then) in when_then_expr {
ifs.push(IfClause {
r#if: Some(to_substrait_rex(r#if, schema, extension_info)?),
then: Some(to_substrait_rex(then, schema, extension_info)?),
});
}
let r#else: Option<Box<Expression>> = match else_expr {
Some(e) => Some(Box::new(to_substrait_rex(e, schema, extension_info)?)),
None => None,
};
Ok(Expression {
rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))),
})
}
Expr::Literal(value) => {
let literal_type = match value {
ScalarValue::Int8(Some(n)) => Some(LiteralType::I8(*n as i32)),
ScalarValue::Int16(Some(n)) => Some(LiteralType::I16(*n as i32)),
ScalarValue::Int32(Some(n)) => Some(LiteralType::I32(*n)),
ScalarValue::Int64(Some(n)) => Some(LiteralType::I64(*n)),
ScalarValue::UInt8(Some(n)) => Some(LiteralType::I16(*n as i32)), ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)),
ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
ScalarValue::Decimal128(v, p, s) if v.is_some() => {
Some(LiteralType::Decimal(Decimal {
value: v.unwrap().to_le_bytes().to_vec(),
precision: *p as i32,
scale: *s as i32,
}))
}
ScalarValue::Utf8(Some(s)) => Some(LiteralType::String(s.clone())),
ScalarValue::LargeUtf8(Some(s)) => Some(LiteralType::String(s.clone())),
ScalarValue::Binary(Some(b)) => Some(LiteralType::Binary(b.clone())),
ScalarValue::LargeBinary(Some(b)) => Some(LiteralType::Binary(b.clone())),
ScalarValue::Date32(Some(d)) => Some(LiteralType::Date(*d)),
_ => Some(try_to_substrait_null(value)?),
};
Ok(Expression {
rex_type: Some(RexType::Literal(Literal {
nullable: true,
type_variation_reference: 0,
literal_type,
})),
})
}
Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, extension_info),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported expression: {expr:?}"
))),
}
}
fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
let default_type_ref = 0;
let default_nullability = r#type::Nullability::Nullable as i32;
match v {
ScalarValue::Int8(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I8(r#type::I8 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
ScalarValue::Int16(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I16(r#type::I16 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
ScalarValue::Int32(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
ScalarValue::Int64(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
})),
ScalarValue::Decimal128(None, p, s) => {
Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::Decimal(r#type::Decimal {
scale: *s as i32,
precision: *p as i32,
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
}))
}
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported literal: {v:?}"
))),
}
}
fn substrait_sort_field(
expr: &Expr,
schema: &DFSchemaRef,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<SortField> {
match expr {
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => {
let e = to_substrait_rex(expr, schema, extension_info)?;
let d = match (asc, nulls_first) {
(true, true) => SortDirection::AscNullsFirst,
(true, false) => SortDirection::AscNullsLast,
(false, true) => SortDirection::DescNullsFirst,
(false, false) => SortDirection::DescNullsLast,
};
Ok(SortField {
expr: Some(e),
sort_kind: Some(SortKind::Direction(d as i32)),
})
}
_ => Err(DataFusionError::NotImplemented(format!(
"Expecting sort expression but got {expr:?}"
))),
}
}
fn substrait_field_ref(index: usize) -> Result<Expression> {
Ok(Expression {
rex_type: Some(RexType::Selection(Box::new(FieldReference {
reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
reference_type: Some(reference_segment::ReferenceType::StructField(
Box::new(reference_segment::StructField {
field: index as i32,
child: None,
}),
)),
})),
root_type: None,
}))),
})
}