use crate::logical_plan::consumer::SubstraitConsumer;
use datafusion::common::{DFSchema, Spans, substrait_datafusion_err, substrait_err};
use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier};
use datafusion::logical_expr::{Expr, LogicalPlan, Operator, Subquery};
use std::sync::Arc;
use substrait::proto::Rel;
use substrait::proto::expression as substrait_expression;
use substrait::proto::expression::subquery::SubqueryType;
use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp};
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
async fn consume_subquery_rel(
consumer: &impl SubstraitConsumer,
rel: &Rel,
outer_schema: &DFSchema,
) -> datafusion::common::Result<LogicalPlan> {
consumer.push_outer_schema(Arc::new(outer_schema.clone()));
let result = consumer.consume_rel(rel).await;
consumer.pop_outer_schema();
result
}
pub async fn from_subquery(
consumer: &impl SubstraitConsumer,
subquery: &substrait_expression::Subquery,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
match &subquery.subquery_type {
Some(subquery_type) => match subquery_type {
SubqueryType::InPredicate(in_predicate) => {
if in_predicate.needles.len() != 1 {
substrait_err!(
"InPredicate Subquery type must have exactly one Needle expression"
)
} else {
let needle_expr = &in_predicate.needles[0];
let haystack_expr = &in_predicate.haystack;
if let Some(haystack_expr) = haystack_expr {
let haystack_expr =
consume_subquery_rel(consumer, haystack_expr, input_schema)
.await?;
let outer_refs = haystack_expr.all_out_ref_exprs();
Ok(Expr::InSubquery(InSubquery {
expr: Box::new(
consumer
.consume_expression(needle_expr, input_schema)
.await?,
),
subquery: Subquery {
subquery: Arc::new(haystack_expr),
outer_ref_columns: outer_refs,
spans: Spans::new(),
},
negated: false,
}))
} else {
substrait_err!(
"InPredicate Subquery type must have a Haystack expression"
)
}
}
}
SubqueryType::Scalar(query) => {
let plan = consume_subquery_rel(
consumer,
&(query.input.clone()).unwrap_or_default(),
input_schema,
)
.await?;
let outer_ref_columns = plan.all_out_ref_exprs();
Ok(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(plan),
outer_ref_columns,
spans: Spans::new(),
}))
}
SubqueryType::SetPredicate(predicate) => {
match predicate.predicate_op() {
PredicateOp::Exists => {
let relation = &predicate.tuples;
let plan = consume_subquery_rel(
consumer,
&relation.clone().unwrap_or_default(),
input_schema,
)
.await?;
let outer_ref_columns = plan.all_out_ref_exprs();
Ok(Expr::Exists(Exists::new(
Subquery {
subquery: Arc::new(plan),
outer_ref_columns,
spans: Spans::new(),
},
false,
)))
}
other_type => substrait_err!(
"unimplemented type {other_type:?} for set predicate"
),
}
}
SubqueryType::SetComparison(comparison) => {
let left = comparison.left.as_ref().ok_or_else(|| {
substrait_datafusion_err!("SetComparison requires a left expression")
})?;
let right = comparison.right.as_ref().ok_or_else(|| {
substrait_datafusion_err!("SetComparison requires a right relation")
})?;
let reduction_op = match ReductionOp::try_from(comparison.reduction_op) {
Ok(ReductionOp::Any) => SetQuantifier::Any,
Ok(ReductionOp::All) => SetQuantifier::All,
_ => {
return substrait_err!(
"Unsupported reduction op for SetComparison: {}",
comparison.reduction_op
);
}
};
let comparison_op = match ComparisonOp::try_from(comparison.comparison_op)
{
Ok(ComparisonOp::Eq) => Operator::Eq,
Ok(ComparisonOp::Ne) => Operator::NotEq,
Ok(ComparisonOp::Lt) => Operator::Lt,
Ok(ComparisonOp::Gt) => Operator::Gt,
Ok(ComparisonOp::Le) => Operator::LtEq,
Ok(ComparisonOp::Ge) => Operator::GtEq,
_ => {
return substrait_err!(
"Unsupported comparison op for SetComparison: {}",
comparison.comparison_op
);
}
};
let left_expr = consumer.consume_expression(left, input_schema).await?;
let plan = consume_subquery_rel(consumer, right, input_schema).await?;
let outer_ref_columns = plan.all_out_ref_exprs();
Ok(Expr::SetComparison(SetComparison::new(
Box::new(left_expr),
Subquery {
subquery: Arc::new(plan),
outer_ref_columns,
spans: Spans::new(),
},
comparison_op,
reduction_op,
)))
}
},
None => {
substrait_err!("Subquery expression without SubqueryType is not allowed")
}
}
}