use crate::extensions::Extensions;
use crate::logical_plan::producer::{
from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr,
from_case, from_cast, from_column, from_distinct, from_empty_relation, from_exists,
from_filter, from_in_list, from_in_subquery, from_join, from_like, from_limit,
from_literal, from_projection, from_repartition, from_scalar_function,
from_scalar_subquery, from_set_comparison, from_sort, from_subquery_alias,
from_table_scan, from_try_cast, from_unary_expr, from_union, from_values,
from_window, from_window_function, to_substrait_rel, to_substrait_rex,
};
use datafusion::common::{Column, DFSchemaRef, ScalarValue, substrait_err};
use datafusion::execution::SessionState;
use datafusion::execution::registry::SerializerRegistry;
use datafusion::logical_expr::Subquery;
use datafusion::logical_expr::expr::{
Alias, Exists, InList, InSubquery, SetComparison, WindowFunction,
};
use datafusion::logical_expr::{
Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, Extension,
Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort, SubqueryAlias,
TableScan, TryCast, Union, Values, Window, expr,
};
use pbjson_types::Any as ProtoAny;
use substrait::proto::aggregate_rel::Measure;
use substrait::proto::rel::RelType;
use substrait::proto::{
Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, Rel,
};
pub trait SubstraitProducer: Send + Sync + Sized {
fn register_function(&mut self, signature: String) -> u32;
fn register_type(&mut self, name: String) -> u32;
fn get_extensions(self) -> Extensions;
fn handle_plan(
&mut self,
plan: &LogicalPlan,
) -> datafusion::common::Result<Box<Rel>> {
to_substrait_rel(self, plan)
}
fn handle_projection(
&mut self,
plan: &Projection,
) -> datafusion::common::Result<Box<Rel>> {
from_projection(self, plan)
}
fn handle_filter(&mut self, plan: &Filter) -> datafusion::common::Result<Box<Rel>> {
from_filter(self, plan)
}
fn handle_window(&mut self, plan: &Window) -> datafusion::common::Result<Box<Rel>> {
from_window(self, plan)
}
fn handle_aggregate(
&mut self,
plan: &Aggregate,
) -> datafusion::common::Result<Box<Rel>> {
from_aggregate(self, plan)
}
fn handle_sort(&mut self, plan: &Sort) -> datafusion::common::Result<Box<Rel>> {
from_sort(self, plan)
}
fn handle_join(&mut self, plan: &Join) -> datafusion::common::Result<Box<Rel>> {
from_join(self, plan)
}
fn handle_repartition(
&mut self,
plan: &Repartition,
) -> datafusion::common::Result<Box<Rel>> {
from_repartition(self, plan)
}
fn handle_union(&mut self, plan: &Union) -> datafusion::common::Result<Box<Rel>> {
from_union(self, plan)
}
fn handle_table_scan(
&mut self,
plan: &TableScan,
) -> datafusion::common::Result<Box<Rel>> {
from_table_scan(self, plan)
}
fn handle_empty_relation(
&mut self,
plan: &EmptyRelation,
) -> datafusion::common::Result<Box<Rel>> {
from_empty_relation(self, plan)
}
fn handle_subquery_alias(
&mut self,
plan: &SubqueryAlias,
) -> datafusion::common::Result<Box<Rel>> {
from_subquery_alias(self, plan)
}
fn handle_limit(&mut self, plan: &Limit) -> datafusion::common::Result<Box<Rel>> {
from_limit(self, plan)
}
fn handle_values(&mut self, plan: &Values) -> datafusion::common::Result<Box<Rel>> {
from_values(self, plan)
}
fn handle_distinct(
&mut self,
plan: &Distinct,
) -> datafusion::common::Result<Box<Rel>> {
from_distinct(self, plan)
}
fn handle_extension(
&mut self,
_plan: &Extension,
) -> datafusion::common::Result<Box<Rel>> {
substrait_err!(
"Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait"
)
}
fn handle_expr(
&mut self,
expr: &Expr,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
to_substrait_rex(self, expr, schema)
}
fn handle_alias(
&mut self,
alias: &Alias,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_alias(self, alias, schema)
}
fn handle_column(
&mut self,
column: &Column,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_column(column, schema)
}
fn handle_literal(
&mut self,
value: &ScalarValue,
) -> datafusion::common::Result<Expression> {
from_literal(self, value)
}
fn handle_binary_expr(
&mut self,
expr: &BinaryExpr,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_binary_expr(self, expr, schema)
}
fn handle_like(
&mut self,
like: &Like,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_like(self, like, schema)
}
fn handle_unary_expr(
&mut self,
expr: &Expr,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_unary_expr(self, expr, schema)
}
fn handle_between(
&mut self,
between: &Between,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_between(self, between, schema)
}
fn handle_case(
&mut self,
case: &Case,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_case(self, case, schema)
}
fn handle_cast(
&mut self,
cast: &Cast,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_cast(self, cast, schema)
}
fn handle_try_cast(
&mut self,
cast: &TryCast,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_try_cast(self, cast, schema)
}
fn handle_scalar_function(
&mut self,
scalar_fn: &expr::ScalarFunction,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_scalar_function(self, scalar_fn, schema)
}
fn handle_aggregate_function(
&mut self,
agg_fn: &expr::AggregateFunction,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Measure> {
from_aggregate_function(self, agg_fn, schema)
}
fn handle_window_function(
&mut self,
window_fn: &WindowFunction,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_window_function(self, window_fn, schema)
}
fn handle_in_list(
&mut self,
in_list: &InList,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_in_list(self, in_list, schema)
}
fn handle_in_subquery(
&mut self,
in_subquery: &InSubquery,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_in_subquery(self, in_subquery, schema)
}
fn handle_set_comparison(
&mut self,
set_comparison: &SetComparison,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_set_comparison(self, set_comparison, schema)
}
fn handle_scalar_subquery(
&mut self,
subquery: &Subquery,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_scalar_subquery(self, subquery, schema)
}
fn handle_exists(
&mut self,
exists: &Exists,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_exists(self, exists, schema)
}
}
pub struct DefaultSubstraitProducer<'a> {
extensions: Extensions,
serializer_registry: &'a dyn SerializerRegistry,
}
impl<'a> DefaultSubstraitProducer<'a> {
pub fn new(state: &'a SessionState) -> Self {
DefaultSubstraitProducer {
extensions: Extensions::default(),
serializer_registry: state.serializer_registry().as_ref(),
}
}
}
impl SubstraitProducer for DefaultSubstraitProducer<'_> {
fn register_function(&mut self, fn_name: String) -> u32 {
self.extensions.register_function(&fn_name)
}
fn register_type(&mut self, type_name: String) -> u32 {
self.extensions.register_type(&type_name)
}
fn get_extensions(self) -> Extensions {
self.extensions
}
fn handle_extension(
&mut self,
plan: &Extension,
) -> datafusion::common::Result<Box<Rel>> {
let extension_bytes = self
.serializer_registry
.serialize_logical_plan(plan.node.as_ref())?;
let detail = ProtoAny {
type_url: plan.node.name().to_string(),
value: extension_bytes.into(),
};
let mut inputs_rel = plan
.node
.inputs()
.into_iter()
.map(|plan| self.handle_plan(plan))
.collect::<datafusion::common::Result<Vec<_>>>()?;
let rel_type = match inputs_rel.len() {
0 => RelType::ExtensionLeaf(ExtensionLeafRel {
common: None,
detail: Some(detail),
}),
1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
common: None,
detail: Some(detail),
input: Some(inputs_rel.pop().unwrap()),
})),
_ => RelType::ExtensionMulti(ExtensionMultiRel {
common: None,
detail: Some(detail),
inputs: inputs_rel.into_iter().map(|r| *r).collect(),
}),
};
Ok(Box::new(Rel {
rel_type: Some(rel_type),
}))
}
}