use super::{
from_aggregate_rel, from_cast, from_cross_rel, from_exchange_rel, from_fetch_rel,
from_field_reference, from_filter_rel, from_if_then, from_join_rel, from_literal,
from_project_rel, from_read_rel, from_scalar_function, from_set_rel,
from_singular_or_list, from_sort_rel, from_subquery, from_substrait_rel,
from_substrait_rex, from_window_function,
};
use crate::extensions::Extensions;
use async_trait::async_trait;
use datafusion::arrow::datatypes::DataType;
use datafusion::catalog::TableProvider;
use datafusion::common::{
DFSchema, ScalarValue, TableReference, not_impl_err, substrait_err,
};
use datafusion::execution::{FunctionRegistry, SessionState};
use datafusion::logical_expr::{Expr, Extension, LogicalPlan};
use std::sync::{Arc, RwLock};
use substrait::proto;
use substrait::proto::expression as substrait_expression;
use substrait::proto::expression::{
Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction,
SingularOrList, SwitchExpression, WindowFunction,
};
use substrait::proto::{
AggregateRel, ConsistentPartitionWindowRel, CrossRel, DynamicParameter, ExchangeRel,
Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, FetchRel,
FilterRel, JoinRel, ProjectRel, ReadRel, Rel, SetRel, SortRel, r#type,
};
#[async_trait]
pub trait SubstraitConsumer: Send + Sync + Sized {
async fn resolve_table_ref(
&self,
table_ref: &TableReference,
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>>;
fn get_extensions(&self) -> &Extensions;
fn get_function_registry(&self) -> &impl FunctionRegistry;
async fn consume_rel(&self, rel: &Rel) -> datafusion::common::Result<LogicalPlan> {
from_substrait_rel(self, rel).await
}
async fn consume_read(
&self,
rel: &ReadRel,
) -> datafusion::common::Result<LogicalPlan> {
from_read_rel(self, rel).await
}
async fn consume_filter(
&self,
rel: &FilterRel,
) -> datafusion::common::Result<LogicalPlan> {
from_filter_rel(self, rel).await
}
async fn consume_fetch(
&self,
rel: &FetchRel,
) -> datafusion::common::Result<LogicalPlan> {
from_fetch_rel(self, rel).await
}
async fn consume_aggregate(
&self,
rel: &AggregateRel,
) -> datafusion::common::Result<LogicalPlan> {
from_aggregate_rel(self, rel).await
}
async fn consume_sort(
&self,
rel: &SortRel,
) -> datafusion::common::Result<LogicalPlan> {
from_sort_rel(self, rel).await
}
async fn consume_join(
&self,
rel: &JoinRel,
) -> datafusion::common::Result<LogicalPlan> {
from_join_rel(self, rel).await
}
async fn consume_project(
&self,
rel: &ProjectRel,
) -> datafusion::common::Result<LogicalPlan> {
from_project_rel(self, rel).await
}
async fn consume_set(&self, rel: &SetRel) -> datafusion::common::Result<LogicalPlan> {
from_set_rel(self, rel).await
}
async fn consume_cross(
&self,
rel: &CrossRel,
) -> datafusion::common::Result<LogicalPlan> {
from_cross_rel(self, rel).await
}
async fn consume_consistent_partition_window(
&self,
_rel: &ConsistentPartitionWindowRel,
) -> datafusion::common::Result<LogicalPlan> {
not_impl_err!("Consistent Partition Window Rel not supported")
}
async fn consume_exchange(
&self,
rel: &ExchangeRel,
) -> datafusion::common::Result<LogicalPlan> {
from_exchange_rel(self, rel).await
}
async fn consume_expression(
&self,
expr: &Expression,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
from_substrait_rex(self, expr, input_schema).await
}
async fn consume_literal(&self, expr: &Literal) -> datafusion::common::Result<Expr> {
from_literal(self, expr).await
}
async fn consume_field_reference(
&self,
expr: &FieldReference,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
from_field_reference(self, expr, input_schema).await
}
async fn consume_scalar_function(
&self,
expr: &ScalarFunction,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
from_scalar_function(self, expr, input_schema).await
}
async fn consume_window_function(
&self,
expr: &WindowFunction,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
from_window_function(self, expr, input_schema).await
}
async fn consume_if_then(
&self,
expr: &IfThen,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
from_if_then(self, expr, input_schema).await
}
async fn consume_switch(
&self,
_expr: &SwitchExpression,
_input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
not_impl_err!("Switch expression not supported")
}
async fn consume_singular_or_list(
&self,
expr: &SingularOrList,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
from_singular_or_list(self, expr, input_schema).await
}
async fn consume_multi_or_list(
&self,
_expr: &MultiOrList,
_input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
not_impl_err!("Multi Or List expression not supported")
}
async fn consume_cast(
&self,
expr: &substrait_expression::Cast,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
from_cast(self, expr, input_schema).await
}
async fn consume_subquery(
&self,
expr: &substrait_expression::Subquery,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
from_subquery(self, expr, input_schema).await
}
async fn consume_nested(
&self,
_expr: &Nested,
_input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
not_impl_err!("Nested expression not supported")
}
async fn consume_enum(
&self,
_expr: &Enum,
_input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
not_impl_err!("Enum expression not supported")
}
async fn consume_dynamic_parameter(
&self,
_expr: &DynamicParameter,
_input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
not_impl_err!("Dynamic Parameter expression not supported")
}
fn push_outer_schema(&self, _schema: Arc<DFSchema>) {}
fn pop_outer_schema(&self) {}
fn get_outer_schema(&self, _steps_out: usize) -> Option<Arc<DFSchema>> {
None
}
async fn consume_extension_leaf(
&self,
rel: &ExtensionLeafRel,
) -> datafusion::common::Result<LogicalPlan> {
if let Some(detail) = rel.detail.as_ref() {
return substrait_err!(
"Missing handler for ExtensionLeafRel: {}",
detail.type_url
);
}
substrait_err!("Missing handler for ExtensionLeafRel")
}
async fn consume_extension_single(
&self,
rel: &ExtensionSingleRel,
) -> datafusion::common::Result<LogicalPlan> {
if let Some(detail) = rel.detail.as_ref() {
return substrait_err!(
"Missing handler for ExtensionSingleRel: {}",
detail.type_url
);
}
substrait_err!("Missing handler for ExtensionSingleRel")
}
async fn consume_extension_multi(
&self,
rel: &ExtensionMultiRel,
) -> datafusion::common::Result<LogicalPlan> {
if let Some(detail) = rel.detail.as_ref() {
return substrait_err!(
"Missing handler for ExtensionMultiRel: {}",
detail.type_url
);
}
substrait_err!("Missing handler for ExtensionMultiRel")
}
fn consume_user_defined_type(
&self,
user_defined_type: &r#type::UserDefined,
) -> datafusion::common::Result<DataType> {
substrait_err!(
"Missing handler for user-defined type: {}",
user_defined_type.type_reference
)
}
fn consume_user_defined_literal(
&self,
user_defined_literal: &proto::expression::literal::UserDefined,
) -> datafusion::common::Result<ScalarValue> {
substrait_err!(
"Missing handler for user-defined literals {}",
user_defined_literal.type_reference
)
}
}
pub struct DefaultSubstraitConsumer<'a> {
pub(super) extensions: &'a Extensions,
pub(super) state: &'a SessionState,
outer_schemas: RwLock<Vec<Arc<DFSchema>>>,
}
impl<'a> DefaultSubstraitConsumer<'a> {
pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self {
DefaultSubstraitConsumer {
extensions,
state,
outer_schemas: RwLock::new(Vec::new()),
}
}
}
#[async_trait]
impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
async fn resolve_table_ref(
&self,
table_ref: &TableReference,
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
let table = table_ref.table().to_string();
let schema = self.state.schema_for_ref(table_ref.clone())?;
let table_provider = schema.table(&table).await?;
Ok(table_provider)
}
fn get_extensions(&self) -> &Extensions {
self.extensions
}
fn get_function_registry(&self) -> &impl FunctionRegistry {
self.state
}
fn push_outer_schema(&self, schema: Arc<DFSchema>) {
self.outer_schemas.write().unwrap().push(schema);
}
fn pop_outer_schema(&self) {
self.outer_schemas.write().unwrap().pop();
}
fn get_outer_schema(&self, steps_out: usize) -> Option<Arc<DFSchema>> {
let schemas = self.outer_schemas.read().unwrap();
schemas
.len()
.checked_sub(steps_out)
.and_then(|idx| schemas.get(idx).cloned())
}
async fn consume_extension_leaf(
&self,
rel: &ExtensionLeafRel,
) -> datafusion::common::Result<LogicalPlan> {
let Some(ext_detail) = &rel.detail else {
return substrait_err!("Unexpected empty detail in ExtensionLeafRel");
};
let plan = self
.state
.serializer_registry()
.deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
Ok(LogicalPlan::Extension(Extension { node: plan }))
}
async fn consume_extension_single(
&self,
rel: &ExtensionSingleRel,
) -> datafusion::common::Result<LogicalPlan> {
let Some(ext_detail) = &rel.detail else {
return substrait_err!("Unexpected empty detail in ExtensionSingleRel");
};
let plan = self
.state
.serializer_registry()
.deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
let Some(input_rel) = &rel.input else {
return substrait_err!(
"ExtensionSingleRel missing input rel, try using ExtensionLeafRel instead"
);
};
let input_plan = self.consume_rel(input_rel).await?;
let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?;
Ok(LogicalPlan::Extension(Extension { node: plan }))
}
async fn consume_extension_multi(
&self,
rel: &ExtensionMultiRel,
) -> datafusion::common::Result<LogicalPlan> {
let Some(ext_detail) = &rel.detail else {
return substrait_err!("Unexpected empty detail in ExtensionMultiRel");
};
let plan = self
.state
.serializer_registry()
.deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
let mut inputs = Vec::with_capacity(rel.inputs.len());
for input in &rel.inputs {
let input_plan = self.consume_rel(input).await?;
inputs.push(input_plan);
}
let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
Ok(LogicalPlan::Extension(Extension { node: plan }))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::consumer::utils::tests::test_consumer;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
fn make_schema(fields: &[(&str, DataType)]) -> Arc<DFSchema> {
let arrow_fields: Vec<Field> = fields
.iter()
.map(|(name, dt)| Field::new(*name, dt.clone(), true))
.collect();
Arc::new(
DFSchema::try_from(Schema::new(arrow_fields))
.expect("failed to create schema"),
)
}
#[test]
fn test_get_outer_schema_empty_stack() {
let consumer = test_consumer();
assert!(consumer.get_outer_schema(0).is_none());
assert!(consumer.get_outer_schema(1).is_none());
assert!(consumer.get_outer_schema(2).is_none());
}
#[test]
fn test_get_outer_schema_single_level() {
let consumer = test_consumer();
let schema_a = make_schema(&[("a", DataType::Int64)]);
consumer.push_outer_schema(Arc::clone(&schema_a));
let result = consumer.get_outer_schema(1).unwrap();
assert_eq!(result.fields().len(), 1);
assert_eq!(result.fields()[0].name(), "a");
assert!(consumer.get_outer_schema(0).is_none());
assert!(consumer.get_outer_schema(2).is_none());
consumer.pop_outer_schema();
assert!(consumer.get_outer_schema(1).is_none());
}
#[test]
fn test_get_outer_schema_nested() {
let consumer = test_consumer();
let schema_a = make_schema(&[("a", DataType::Int64)]);
let schema_b = make_schema(&[("b", DataType::Utf8)]);
consumer.push_outer_schema(Arc::clone(&schema_a));
consumer.push_outer_schema(Arc::clone(&schema_b));
let result = consumer.get_outer_schema(1).unwrap();
assert_eq!(result.fields()[0].name(), "b");
let result = consumer.get_outer_schema(2).unwrap();
assert_eq!(result.fields()[0].name(), "a");
assert!(consumer.get_outer_schema(3).is_none());
consumer.pop_outer_schema();
let result = consumer.get_outer_schema(1).unwrap();
assert_eq!(result.fields()[0].name(), "a");
assert!(consumer.get_outer_schema(2).is_none());
}
}