#[cfg(test)]
pub mod test {
use datafusion::common::{TableReference, substrait_datafusion_err, substrait_err};
use datafusion::datasource::TableProvider;
use datafusion::datasource::empty::EmptyTable;
use datafusion::error::Result;
use datafusion::prelude::SessionContext;
use datafusion_substrait::extensions::Extensions;
use datafusion_substrait::logical_plan::consumer::{
DefaultSubstraitConsumer, SubstraitConsumer, from_substrait_named_struct,
};
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
use substrait::proto::exchange_rel::ExchangeKind;
use substrait::proto::expand_rel::expand_field::FieldType;
use substrait::proto::expression::RexType;
use substrait::proto::expression::nested::NestedType;
use substrait::proto::expression::subquery::SubqueryType;
use substrait::proto::function_argument::ArgType;
use substrait::proto::read_rel::{NamedTable, ReadType};
use substrait::proto::rel::RelType;
use substrait::proto::{Expression, FunctionArgument, Plan, ReadRel, Rel};
pub fn read_json(path: &str) -> Plan {
serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json")
}
pub fn add_plan_schemas_to_ctx(
ctx: SessionContext,
plan: &Plan,
) -> Result<SessionContext> {
let extensions = Extensions::default();
let state = ctx.state();
let consumer = DefaultSubstraitConsumer::new(&extensions, &state);
add_plan_schemas_to_ctx_with_consumer(&consumer, ctx, plan)
}
fn add_plan_schemas_to_ctx_with_consumer(
consumer: &impl SubstraitConsumer,
ctx: SessionContext,
plan: &Plan,
) -> Result<SessionContext> {
let schemas = TestSchemaCollector::collect_schemas(consumer, plan)?;
let mut schema_map: HashMap<TableReference, Arc<dyn TableProvider>> =
HashMap::new();
for (table_reference, table) in schemas.into_iter() {
let schema = table.schema();
if let Some(existing_table) =
schema_map.insert(table_reference.clone(), table)
&& existing_table.schema() != schema
{
return substrait_err!(
"Substrait plan contained the same table {} with different schemas.\nSchema 1: {}\nSchema 2: {}",
table_reference,
existing_table.schema(),
schema
);
}
}
for (table_reference, table) in schema_map.into_iter() {
ctx.register_table(table_reference, table)?;
}
Ok(ctx)
}
pub struct TestSchemaCollector<'a, T: SubstraitConsumer> {
consumer: &'a T,
schemas: Vec<(TableReference, Arc<dyn TableProvider>)>,
}
impl<'a, T: SubstraitConsumer> TestSchemaCollector<'a, T> {
fn new(consumer: &'a T) -> Self {
TestSchemaCollector {
schemas: Vec::new(),
consumer,
}
}
fn collect_schemas(
consumer: &'a T,
plan: &Plan,
) -> Result<Vec<(TableReference, Arc<dyn TableProvider>)>> {
let mut schema_collector = Self::new(consumer);
for plan_rel in plan.relations.iter() {
let rel_type = plan_rel
.rel_type
.as_ref()
.ok_or(substrait_datafusion_err!("PlanRel must set rel_type"))?;
match rel_type {
substrait::proto::plan_rel::RelType::Rel(r) => {
schema_collector.collect_schemas_from_rel(r)?
}
substrait::proto::plan_rel::RelType::Root(r) => {
let input = r
.input
.as_ref()
.ok_or(substrait_datafusion_err!("RelRoot must set input"))?;
schema_collector.collect_schemas_from_rel(input)?
}
}
}
Ok(schema_collector.schemas)
}
fn collect_named_table(&mut self, read: &ReadRel, nt: &NamedTable) -> Result<()> {
let table_reference = match nt.names.len() {
0 => {
panic!("No table name found in NamedTable");
}
1 => TableReference::Bare {
table: nt.names[0].clone().into(),
},
2 => TableReference::Partial {
schema: nt.names[0].clone().into(),
table: nt.names[1].clone().into(),
},
_ => TableReference::Full {
catalog: nt.names[0].clone().into(),
schema: nt.names[1].clone().into(),
table: nt.names[2].clone().into(),
},
};
let substrait_schema =
read.base_schema.as_ref().ok_or(substrait_datafusion_err!(
"No base schema found for NamedTable: {}",
table_reference
))?;
let df_schema = from_substrait_named_struct(self.consumer, substrait_schema)?
.replace_qualifier(table_reference.clone());
let table = EmptyTable::new(Arc::clone(df_schema.inner()));
self.schemas.push((table_reference, Arc::new(table)));
Ok(())
}
#[expect(deprecated)]
fn collect_schemas_from_rel(&mut self, rel: &Rel) -> Result<()> {
let rel_type = rel
.rel_type
.as_ref()
.ok_or(substrait_datafusion_err!("Rel must set rel_type"))?;
match rel_type {
RelType::Read(r) => {
let read_type = r
.read_type
.as_ref()
.ok_or(substrait_datafusion_err!("Read must set read_type"))?;
match read_type {
ReadType::VirtualTable(_) => (),
ReadType::LocalFiles(_) => todo!(),
ReadType::NamedTable(nt) => self.collect_named_table(r, nt)?,
ReadType::ExtensionTable(_) => todo!(),
ReadType::IcebergTable(_) => todo!(),
}
if let Some(expr) = r.filter.as_ref() {
self.collect_schemas_from_expr(expr)?
};
if let Some(expr) = r.best_effort_filter.as_ref() {
self.collect_schemas_from_expr(expr)?
};
}
RelType::Filter(f) => {
self.apply(f.input.as_ref().map(|b| b.as_ref()))?;
for expr in f.condition.iter() {
self.collect_schemas_from_expr(expr)?;
}
}
RelType::Fetch(f) => {
self.apply(f.input.as_ref().map(|b| b.as_ref()))?;
}
RelType::Aggregate(a) => {
self.apply(a.input.as_ref().map(|b| b.as_ref()))?;
for grouping in a.groupings.iter() {
for expr in grouping.grouping_expressions.iter() {
self.collect_schemas_from_expr(expr)?
}
}
for measure in a.measures.iter() {
if let Some(agg_fn) = measure.measure.as_ref() {
for arg in agg_fn.arguments.iter() {
self.collect_schemas_from_arg(arg)?
}
for sort in agg_fn.sorts.iter() {
if let Some(expr) = sort.expr.as_ref() {
self.collect_schemas_from_expr(expr)?
}
}
}
if let Some(expr) = measure.filter.as_ref() {
self.collect_schemas_from_expr(expr)?
}
}
}
RelType::Sort(s) => {
self.apply(s.input.as_ref().map(|b| b.as_ref()))?;
for sort_field in s.sorts.iter() {
if let Some(expr) = sort_field.expr.as_ref() {
self.collect_schemas_from_expr(expr)?
}
}
}
RelType::Join(j) => {
self.apply(j.left.as_ref().map(|b| b.as_ref()))?;
self.apply(j.right.as_ref().map(|b| b.as_ref()))?;
if let Some(expr) = j.expression.as_ref() {
self.collect_schemas_from_expr(expr)?;
}
if let Some(expr) = j.post_join_filter.as_ref() {
self.collect_schemas_from_expr(expr)?;
}
}
RelType::Project(p) => {
self.apply(p.input.as_ref().map(|b| b.as_ref()))?
}
RelType::Set(s) => {
for input in s.inputs.iter() {
self.collect_schemas_from_rel(input)?;
}
}
RelType::ExtensionSingle(s) => {
self.apply(s.input.as_ref().map(|b| b.as_ref()))?
}
RelType::ExtensionMulti(m) => {
for input in m.inputs.iter() {
self.collect_schemas_from_rel(input)?
}
}
RelType::ExtensionLeaf(_) => {}
RelType::Cross(c) => {
self.apply(c.left.as_ref().map(|b| b.as_ref()))?;
self.apply(c.right.as_ref().map(|b| b.as_ref()))?;
}
RelType::HashJoin(j) => {
self.apply(j.left.as_ref().map(|b| b.as_ref()))?;
self.apply(j.right.as_ref().map(|b| b.as_ref()))?;
if let Some(expr) = j.post_join_filter.as_ref() {
self.collect_schemas_from_expr(expr)?;
}
}
RelType::MergeJoin(j) => {
self.apply(j.left.as_ref().map(|b| b.as_ref()))?;
self.apply(j.right.as_ref().map(|b| b.as_ref()))?;
if let Some(expr) = j.post_join_filter.as_ref() {
self.collect_schemas_from_expr(expr)?;
}
}
RelType::NestedLoopJoin(j) => {
self.apply(j.left.as_ref().map(|b| b.as_ref()))?;
self.apply(j.right.as_ref().map(|b| b.as_ref()))?;
if let Some(expr) = j.expression.as_ref() {
self.collect_schemas_from_expr(expr)?;
}
}
RelType::Window(w) => {
self.apply(w.input.as_ref().map(|b| b.as_ref()))?;
for wf in w.window_functions.iter() {
for arg in wf.arguments.iter() {
self.collect_schemas_from_arg(arg)?;
}
}
for expr in w.partition_expressions.iter() {
self.collect_schemas_from_expr(expr)?;
}
for sort_field in w.sorts.iter() {
if let Some(expr) = sort_field.expr.as_ref() {
self.collect_schemas_from_expr(expr)?;
}
}
}
RelType::Exchange(e) => {
self.apply(e.input.as_ref().map(|b| b.as_ref()))?;
let exchange_kind = e.exchange_kind.as_ref().ok_or(
substrait_datafusion_err!("Exchange must set exchange_kind"),
)?;
match exchange_kind {
ExchangeKind::ScatterByFields(_) => {}
ExchangeKind::SingleTarget(st) => {
if let Some(expr) = st.expression.as_ref() {
self.collect_schemas_from_expr(expr)?
}
}
ExchangeKind::MultiTarget(mt) => {
if let Some(expr) = mt.expression.as_ref() {
self.collect_schemas_from_expr(expr)?
}
}
ExchangeKind::RoundRobin(_) => {}
ExchangeKind::Broadcast(_) => {}
}
}
RelType::Expand(e) => {
self.apply(e.input.as_ref().map(|b| b.as_ref()))?;
for expand_field in e.fields.iter() {
let expand_type = expand_field.field_type.as_ref().ok_or(
substrait_datafusion_err!("ExpandField must set field_type"),
)?;
match expand_type {
FieldType::SwitchingField(sf) => {
for expr in sf.duplicates.iter() {
self.collect_schemas_from_expr(expr)?;
}
}
FieldType::ConsistentField(expr) => {
self.collect_schemas_from_expr(expr)?
}
}
}
}
_ => todo!(),
}
Ok(())
}
fn apply(&mut self, input: Option<&Rel>) -> Result<()> {
match input {
None => Ok(()),
Some(rel) => self.collect_schemas_from_rel(rel),
}
}
fn collect_schemas_from_expr(&mut self, e: &Expression) -> Result<()> {
let rex_type = e.rex_type.as_ref().ok_or(substrait_datafusion_err!(
"rex_type must be set on Expression"
))?;
match rex_type {
RexType::Literal(_) => {}
RexType::Selection(_) => {}
RexType::ScalarFunction(sf) => {
for arg in sf.arguments.iter() {
self.collect_schemas_from_arg(arg)?
}
}
RexType::WindowFunction(wf) => {
for arg in wf.arguments.iter() {
self.collect_schemas_from_arg(arg)?
}
for sort_field in wf.sorts.iter() {
if let Some(expr) = sort_field.expr.as_ref() {
self.collect_schemas_from_expr(expr)?
}
}
for expr in wf.partitions.iter() {
self.collect_schemas_from_expr(expr)?
}
}
RexType::IfThen(it) => {
for if_clause in it.ifs.iter() {
if let Some(expr) = if_clause.r#if.as_ref() {
self.collect_schemas_from_expr(expr)?;
};
if let Some(expr) = if_clause.then.as_ref() {
self.collect_schemas_from_expr(expr)?;
};
}
if let Some(expr) = it.r#else.as_ref() {
self.collect_schemas_from_expr(expr)?;
};
}
RexType::SwitchExpression(se) => {
if let Some(expr) = se.r#match.as_ref() {
self.collect_schemas_from_expr(expr)?
}
for if_value in se.ifs.iter() {
if let Some(expr) = if_value.then.as_ref() {
self.collect_schemas_from_expr(expr)?
}
}
if let Some(expr) = se.r#else.as_ref() {
self.collect_schemas_from_expr(expr)?
}
}
RexType::SingularOrList(sol) => {
if let Some(expr) = sol.value.as_ref() {
self.collect_schemas_from_expr(expr)?
}
for expr in sol.options.iter() {
self.collect_schemas_from_expr(expr)?
}
}
RexType::MultiOrList(mol) => {
for expr in mol.value.iter() {
self.collect_schemas_from_expr(expr)?
}
for record in mol.options.iter() {
for expr in record.fields.iter() {
self.collect_schemas_from_expr(expr)?
}
}
}
RexType::Cast(c) => {
if let Some(expr) = c.input.as_ref() {
self.collect_schemas_from_expr(expr)?
}
}
RexType::Subquery(subquery) => {
let subquery_type = subquery
.subquery_type
.as_ref()
.ok_or(substrait_datafusion_err!("subquery_type must be set"))?;
match subquery_type {
SubqueryType::Scalar(s) => {
if let Some(rel) = s.input.as_ref() {
self.collect_schemas_from_rel(rel)?;
}
}
SubqueryType::InPredicate(ip) => {
for expr in ip.needles.iter() {
self.collect_schemas_from_expr(expr)?;
}
if let Some(rel) = ip.haystack.as_ref() {
self.collect_schemas_from_rel(rel)?;
}
}
SubqueryType::SetPredicate(sp) => {
if let Some(rel) = sp.tuples.as_ref() {
self.collect_schemas_from_rel(rel)?;
}
}
SubqueryType::SetComparison(sc) => {
if let Some(expr) = sc.left.as_ref() {
self.collect_schemas_from_expr(expr)?;
}
if let Some(rel) = sc.right.as_ref() {
self.collect_schemas_from_rel(rel)?;
}
}
}
}
RexType::Nested(n) => {
let nested_type = n.nested_type.as_ref().ok_or(
substrait_datafusion_err!("Nested must set nested_type"),
)?;
match nested_type {
NestedType::Struct(s) => {
for expr in s.fields.iter() {
self.collect_schemas_from_expr(expr)?;
}
}
NestedType::List(l) => {
for expr in l.values.iter() {
self.collect_schemas_from_expr(expr)?;
}
}
NestedType::Map(m) => {
for key_value in m.key_values.iter() {
if let Some(expr) = key_value.key.as_ref() {
self.collect_schemas_from_expr(expr)?;
}
if let Some(expr) = key_value.value.as_ref() {
self.collect_schemas_from_expr(expr)?;
}
}
}
}
}
RexType::DynamicParameter(_) => {}
#[expect(deprecated)]
RexType::Enum(_) => {}
}
Ok(())
}
fn collect_schemas_from_arg(&mut self, fa: &FunctionArgument) -> Result<()> {
let arg_type = fa.arg_type.as_ref().ok_or(substrait_datafusion_err!(
"FunctionArgument must set arg_type"
))?;
match arg_type {
ArgType::Enum(_) => {}
ArgType::Type(_) => {}
ArgType::Value(expr) => self.collect_schemas_from_expr(expr)?,
}
Ok(())
}
}
}