use async_recursion::async_recursion;
use datafusion::common::{DFField, DFSchema, DFSchemaRef};
use datafusion::logical_expr::expr;
use datafusion::logical_expr::{
aggregate_function, BinaryExpr, Case, Expr, LogicalPlan, Operator,
};
use datafusion::logical_expr::{build_join_schema, LogicalPlanBuilder};
use datafusion::prelude::JoinType;
use datafusion::sql::TableReference;
use datafusion::{
error::{DataFusionError, Result},
optimizer::utils::split_conjunction,
prelude::{Column, SessionContext},
scalar::ScalarValue,
};
use substrait::proto::{
aggregate_function::AggregationInvocation,
expression::{
field_reference::ReferenceType::DirectReference, literal::LiteralType,
reference_segment::ReferenceType::StructField, MaskExpression, RexType,
},
extensions::simple_extension_declaration::MappingType,
function_argument::ArgType,
join_rel, plan_rel, r#type,
read_rel::ReadType,
rel::RelType,
sort_field::{SortDirection, SortKind::*},
AggregateFunction, Expression, Plan, Rel, Type,
};
use datafusion::logical_expr::expr::Sort;
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
pub fn name_to_op(name: &str) -> Result<Operator> {
match name {
"equal" => Ok(Operator::Eq),
"not_equal" => Ok(Operator::NotEq),
"lt" => Ok(Operator::Lt),
"lte" => Ok(Operator::LtEq),
"gt" => Ok(Operator::Gt),
"gte" => Ok(Operator::GtEq),
"add" => Ok(Operator::Plus),
"subtract" => Ok(Operator::Minus),
"multiply" => Ok(Operator::Multiply),
"divide" => Ok(Operator::Divide),
"mod" => Ok(Operator::Modulo),
"and" => Ok(Operator::And),
"or" => Ok(Operator::Or),
"is_distinct_from" => Ok(Operator::IsDistinctFrom),
"is_not_distinct_from" => Ok(Operator::IsNotDistinctFrom),
"regex_match" => Ok(Operator::RegexMatch),
"regex_imatch" => Ok(Operator::RegexIMatch),
"regex_not_match" => Ok(Operator::RegexNotMatch),
"regex_not_imatch" => Ok(Operator::RegexNotIMatch),
"bitwise_and" => Ok(Operator::BitwiseAnd),
"bitwise_or" => Ok(Operator::BitwiseOr),
"str_concat" => Ok(Operator::StringConcat),
"bitwise_xor" => Ok(Operator::BitwiseXor),
"bitwise_shift_right" => Ok(Operator::BitwiseShiftRight),
"bitwise_shift_left" => Ok(Operator::BitwiseShiftLeft),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported function name: {name:?}"
))),
}
}
pub async fn from_substrait_plan(
ctx: &mut SessionContext,
plan: &Plan,
) -> Result<LogicalPlan> {
let function_extension = plan
.extensions
.iter()
.map(|e| match &e.mapping_type {
Some(ext) => match ext {
MappingType::ExtensionFunction(ext_f) => {
Ok((ext_f.function_anchor, &ext_f.name))
}
_ => Err(DataFusionError::NotImplemented(format!(
"Extension type not supported: {ext:?}"
))),
},
None => Err(DataFusionError::NotImplemented(
"Cannot parse empty extension".to_string(),
)),
})
.collect::<Result<HashMap<_, _>>>()?;
match plan.relations.len() {
1 => {
match plan.relations[0].rel_type.as_ref() {
Some(rt) => match rt {
plan_rel::RelType::Rel(rel) => {
Ok(from_substrait_rel(ctx, rel, &function_extension).await?)
},
plan_rel::RelType::Root(root) => {
Ok(from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?)
}
},
None => Err(DataFusionError::Internal("Cannot parse plan relation: None".to_string()))
}
},
_ => Err(DataFusionError::NotImplemented(format!(
"Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}",
plan.relations.len()
)))
}
}
#[async_recursion]
pub async fn from_substrait_rel(
ctx: &mut SessionContext,
rel: &Rel,
extensions: &HashMap<u32, &String>,
) -> Result<LogicalPlan> {
match &rel.rel_type {
Some(RelType::Project(p)) => {
if let Some(input) = p.input.as_ref() {
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
let mut exprs: Vec<Expr> = vec![];
for e in &p.expressions {
let x = from_substrait_rex(e, input.schema(), extensions).await?;
exprs.push(x.as_ref().clone());
}
input.project(exprs)?.build()
} else {
Err(DataFusionError::NotImplemented(
"Projection without an input is not supported".to_string(),
))
}
}
Some(RelType::Filter(filter)) => {
if let Some(input) = filter.input.as_ref() {
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
if let Some(condition) = filter.condition.as_ref() {
let expr =
from_substrait_rex(condition, input.schema(), extensions).await?;
input.filter(expr.as_ref().clone())?.build()
} else {
Err(DataFusionError::NotImplemented(
"Filter without an condition is not valid".to_string(),
))
}
} else {
Err(DataFusionError::NotImplemented(
"Filter without an input is not valid".to_string(),
))
}
}
Some(RelType::Fetch(fetch)) => {
if let Some(input) = fetch.input.as_ref() {
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
let offset = fetch.offset as usize;
let count = fetch.count as usize;
input.limit(offset, Some(count))?.build()
} else {
Err(DataFusionError::NotImplemented(
"Fetch without an input is not valid".to_string(),
))
}
}
Some(RelType::Sort(sort)) => {
if let Some(input) = sort.input.as_ref() {
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
let mut sorts: Vec<Expr> = vec![];
for s in &sort.sorts {
let expr = from_substrait_rex(
s.expr.as_ref().unwrap(),
input.schema(),
extensions,
)
.await?;
let asc_nullfirst = match &s.sort_kind {
Some(k) => match k {
Direction(d) => {
let direction : SortDirection = unsafe {
::std::mem::transmute(*d)
};
match direction {
SortDirection::AscNullsFirst => Ok((true, true)),
SortDirection::AscNullsLast => Ok((true, false)),
SortDirection::DescNullsFirst => Ok((false, true)),
SortDirection::DescNullsLast => Ok((false, false)),
SortDirection::Clustered =>
Err(DataFusionError::NotImplemented("Sort with direction clustered is not yet supported".to_string()))
,
SortDirection::Unspecified =>
Err(DataFusionError::NotImplemented("Unspecified sort direction is invalid".to_string()))
}
}
ComparisonFunctionReference(_) => {
Err(DataFusionError::NotImplemented("Sort using comparison function reference is not supported".to_string()))
},
},
None => Err(DataFusionError::NotImplemented("Sort without sort kind is invalid".to_string()))
};
let (asc, nulls_first) = asc_nullfirst.unwrap();
sorts.push(Expr::Sort(Sort {
expr: Box::new(expr.as_ref().clone()),
asc,
nulls_first,
}));
}
input.sort(sorts)?.build()
} else {
Err(DataFusionError::NotImplemented(
"Sort without an input is not valid".to_string(),
))
}
}
Some(RelType::Aggregate(agg)) => {
if let Some(input) = agg.input.as_ref() {
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
let mut group_expr = vec![];
let mut aggr_expr = vec![];
let groupings = match agg.groupings.len() {
1 => Ok(&agg.groupings[0]),
_ => Err(DataFusionError::NotImplemented(
"Aggregate with multiple grouping sets is not supported"
.to_string(),
)),
};
for e in &groupings?.grouping_expressions {
let x = from_substrait_rex(e, input.schema(), extensions).await?;
group_expr.push(x.as_ref().clone());
}
for m in &agg.measures {
let filter = match &m.filter {
Some(fil) => Some(Box::new(
from_substrait_rex(fil, input.schema(), extensions)
.await?
.as_ref()
.clone(),
)),
None => None,
};
let agg_func = match &m.measure {
Some(f) => {
let distinct = match f.invocation {
_ if f.invocation
== AggregationInvocation::Distinct as i32 =>
{
true
}
_ if f.invocation
== AggregationInvocation::All as i32 =>
{
false
}
_ => false,
};
from_substrait_agg_func(
f,
input.schema(),
extensions,
filter,
distinct,
)
.await
}
None => Err(DataFusionError::NotImplemented(
"Aggregate without aggregate function is not supported"
.to_string(),
)),
};
aggr_expr.push(agg_func?.as_ref().clone());
}
input.aggregate(group_expr, aggr_expr)?.build()
} else {
Err(DataFusionError::NotImplemented(
"Aggregate without an input is not valid".to_string(),
))
}
}
Some(RelType::Join(join)) => {
let left = LogicalPlanBuilder::from(
from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?,
);
let right = LogicalPlanBuilder::from(
from_substrait_rel(ctx, join.right.as_ref().unwrap(), extensions).await?,
);
let join_type = from_substrait_jointype(join.r#type)?;
let schema =
build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?;
let on = from_substrait_rex(
join.expression.as_ref().unwrap(),
&schema,
extensions,
)
.await?;
let predicates = split_conjunction(&on);
let join_exprs: Vec<(Column, Column, bool)> = predicates
.iter()
.map(|p| match p {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => match op {
Operator::Eq => Ok((l.clone(), r.clone(), false)),
Operator::IsNotDistinctFrom => {
Ok((l.clone(), r.clone(), true))
}
_ => Err(DataFusionError::Internal(
"invalid join condition op".to_string(),
)),
},
_ => Err(DataFusionError::Internal(
"invalid join condition expresssion".to_string(),
)),
}
}
_ => Err(DataFusionError::Internal(
"Non-binary expression is not supported in join condition"
.to_string(),
)),
})
.collect::<Result<Vec<_>>>()?;
let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>, Vec<_>) =
itertools::multiunzip(join_exprs);
left.join_detailed(
right.build()?,
join_type,
(left_cols, right_cols),
None,
null_eq_nulls[0],
)?
.build()
}
Some(RelType::Read(read)) => match &read.as_ref().read_type {
Some(ReadType::NamedTable(nt)) => {
let table_reference = match nt.names.len() {
0 => {
return Err(DataFusionError::Internal(
"No table name found in NamedTable".to_string(),
));
}
1 => TableReference::Bare {
table: &nt.names[0],
},
2 => TableReference::Partial {
schema: &nt.names[0],
table: &nt.names[1],
},
_ => TableReference::Full {
catalog: &nt.names[0],
schema: &nt.names[1],
table: &nt.names[2],
},
};
let t = ctx.table(table_reference).await?;
let t = t.into_optimized_plan()?;
match &read.projection {
Some(MaskExpression { select, .. }) => match &select.as_ref() {
Some(projection) => {
let column_indices: Vec<usize> = projection
.struct_items
.iter()
.map(|item| item.field as usize)
.collect();
match &t {
LogicalPlan::TableScan(scan) => {
let fields: Vec<DFField> = column_indices
.iter()
.map(|i| scan.projected_schema.field(*i).clone())
.collect();
#[allow(clippy::redundant_clone)]
let mut scan = scan.clone();
scan.projection = Some(column_indices);
scan.projected_schema =
DFSchemaRef::new(DFSchema::new_with_metadata(
fields,
HashMap::new(),
)?);
Ok(LogicalPlan::TableScan(scan))
}
_ => Err(DataFusionError::Internal(
"unexpected plan for table".to_string(),
)),
}
}
_ => Ok(t),
},
_ => Ok(t),
}
}
_ => Err(DataFusionError::NotImplemented(
"Only NamedTable reads are supported".to_string(),
)),
},
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported RelType: {:?}",
rel.rel_type
))),
}
}
fn from_substrait_jointype(join_type: i32) -> Result<JoinType> {
if let Some(substrait_join_type) = join_rel::JoinType::from_i32(join_type) {
match substrait_join_type {
join_rel::JoinType::Inner => Ok(JoinType::Inner),
join_rel::JoinType::Left => Ok(JoinType::Left),
join_rel::JoinType::Right => Ok(JoinType::Right),
join_rel::JoinType::Outer => Ok(JoinType::Full),
join_rel::JoinType::Anti => Ok(JoinType::LeftAnti),
join_rel::JoinType::Semi => Ok(JoinType::LeftSemi),
_ => {
return Err(DataFusionError::Internal(format!(
"unsupported join type {substrait_join_type:?}"
)))
}
}
} else {
return Err(DataFusionError::Internal(format!(
"invalid join type variant {join_type:?}"
)));
}
}
pub async fn from_substrait_agg_func(
f: &AggregateFunction,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
filter: Option<Box<Expr>>,
distinct: bool,
) -> Result<Arc<Expr>> {
let mut args: Vec<Expr> = vec![];
for arg in &f.arguments {
let arg_expr = match &arg.arg_type {
Some(ArgType::Value(e)) => {
from_substrait_rex(e, input_schema, extensions).await
}
_ => Err(DataFusionError::NotImplemented(
"Aggregated function argument non-Value type not supported".to_string(),
)),
};
args.push(arg_expr?.as_ref().clone());
}
let fun = match extensions.get(&f.function_reference) {
Some(function_name) => {
aggregate_function::AggregateFunction::from_str(function_name)
}
None => Err(DataFusionError::NotImplemented(format!(
"Aggregated function not found: function anchor = {:?}",
f.function_reference
))),
};
Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction {
fun: fun.unwrap(),
args,
distinct,
filter,
})))
}
#[async_recursion]
pub async fn from_substrait_rex(
e: &Expression,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
match &e.rex_type {
Some(RexType::Selection(field_ref)) => match &field_ref.reference_type {
Some(DirectReference(direct)) => match &direct.reference_type.as_ref() {
Some(StructField(x)) => match &x.child.as_ref() {
Some(_) => Err(DataFusionError::NotImplemented(
"Direct reference StructField with child is not supported"
.to_string(),
)),
None => Ok(Arc::new(Expr::Column(Column {
relation: None,
name: input_schema.field(x.field as usize).name().to_string(),
}))),
},
_ => Err(DataFusionError::NotImplemented(
"Direct reference with types other than StructField is not supported"
.to_string(),
)),
},
_ => Err(DataFusionError::NotImplemented(
"unsupported field ref type".to_string(),
)),
},
Some(RexType::IfThen(if_then)) => {
let mut when_then_expr: Vec<(Box<Expr>, Box<Expr>)> = vec![];
let mut expr = None;
for (i, if_expr) in if_then.ifs.iter().enumerate() {
if i == 0 {
if if_expr.then.is_none() {
expr = Some(Box::new(
from_substrait_rex(
if_expr.r#if.as_ref().unwrap(),
input_schema,
extensions,
)
.await?
.as_ref()
.clone(),
));
continue;
}
}
when_then_expr.push((
Box::new(
from_substrait_rex(
if_expr.r#if.as_ref().unwrap(),
input_schema,
extensions,
)
.await?
.as_ref()
.clone(),
),
Box::new(
from_substrait_rex(
if_expr.then.as_ref().unwrap(),
input_schema,
extensions,
)
.await?
.as_ref()
.clone(),
),
));
}
let else_expr = match &if_then.r#else {
Some(e) => Some(Box::new(
from_substrait_rex(e, input_schema, extensions)
.await?
.as_ref()
.clone(),
)),
None => None,
};
Ok(Arc::new(Expr::Case(Case {
expr,
when_then_expr,
else_expr,
})))
}
Some(RexType::ScalarFunction(f)) => {
assert!(f.arguments.len() == 2);
let op = match extensions.get(&f.function_reference) {
Some(fname) => name_to_op(fname),
None => Err(DataFusionError::NotImplemented(format!(
"Aggregated function not found: function reference = {:?}",
f.function_reference
))),
};
match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) {
(Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
left: Box::new(
from_substrait_rex(l, input_schema, extensions)
.await?
.as_ref()
.clone(),
),
op: op?,
right: Box::new(
from_substrait_rex(r, input_schema, extensions)
.await?
.as_ref()
.clone(),
),
})))
}
(l, r) => Err(DataFusionError::NotImplemented(format!(
"Invalid arguments for binary expression: {l:?} and {r:?}"
))),
}
}
Some(RexType::Literal(lit)) => {
match &lit.literal_type {
Some(LiteralType::I8(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
}
Some(LiteralType::I16(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16)))))
}
Some(LiteralType::I32(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
}
Some(LiteralType::I64(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
}
Some(LiteralType::Boolean(b)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
}
Some(LiteralType::Date(d)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d)))))
}
Some(LiteralType::Fp32(f)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f)))))
}
Some(LiteralType::Fp64(f)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f)))))
}
Some(LiteralType::Decimal(d)) => {
let value: [u8; 16] = d.value.clone().try_into().or(Err(
DataFusionError::Substrait(
"Failed to parse decimal value".to_string(),
),
))?;
let p = d.precision.try_into().map_err(|e| {
DataFusionError::Substrait(format!(
"Failed to parse decimal precision: {e}"
))
})?;
let s = d.scale.try_into().map_err(|e| {
DataFusionError::Substrait(format!(
"Failed to parse decimal scale: {e}"
))
})?;
Ok(Arc::new(Expr::Literal(ScalarValue::Decimal128(
Some(std::primitive::i128::from_le_bytes(value)),
p,
s,
))))
}
Some(LiteralType::String(s)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone())))))
}
Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(
ScalarValue::Binary(Some(b.clone())),
))),
Some(LiteralType::Null(ntype)) => {
Ok(Arc::new(Expr::Literal(from_substrait_null(ntype)?)))
}
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Unsupported literal_type: {:?}",
lit.literal_type
)))
}
}
}
_ => Err(DataFusionError::NotImplemented(
"unsupported rex_type".to_string(),
)),
}
}
fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
if let Some(kind) = &null_type.kind {
match kind {
r#type::Kind::I8(_) => Ok(ScalarValue::Int8(None)),
r#type::Kind::I16(_) => Ok(ScalarValue::Int16(None)),
r#type::Kind::I32(_) => Ok(ScalarValue::Int32(None)),
r#type::Kind::I64(_) => Ok(ScalarValue::Int64(None)),
r#type::Kind::Decimal(d) => Ok(ScalarValue::Decimal128(
None,
d.precision as u8,
d.scale as i8,
)),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported null kind: {kind:?}"
))),
}
} else {
Err(DataFusionError::NotImplemented(
"Null type without kind is not supported".to_string(),
))
}
}