use std::sync::Arc;
use datafusion::common::not_impl_err;
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion::logical_expr::Extension;
use datafusion::optimizer::optimizer::Optimizer;
use datafusion::optimizer::{OptimizerConfig, OptimizerRule};
use datafusion::{
datasource::source_as_provider,
error::Result,
logical_expr::{Expr, LogicalPlan, Projection, TableScan, TableSource},
};
use crate::{
FederatedTableProviderAdaptor, FederatedTableSource, FederationProvider, FederationProviderRef,
};
#[derive(Default)]
pub struct FederationOptimizerRule {}
impl OptimizerRule for FederationOptimizerRule {
fn try_optimize(
&self,
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
let (optimized, _) = self.optimize_plan_recursively(plan, true, config)?;
Ok(optimized)
}
fn name(&self) -> &str {
"federation_optimizer_rule"
}
fn supports_rewrite(&self) -> bool {
false
}
}
enum ScanResult {
None,
Distinct(FederationProviderRef),
Ambiguous,
}
impl ScanResult {
fn merge(&mut self, other: Self) {
match (&self, &other) {
(_, ScanResult::None) => {}
(ScanResult::None, _) => *self = other,
(ScanResult::Ambiguous, _) | (_, ScanResult::Ambiguous) => {
*self = ScanResult::Ambiguous
}
(ScanResult::Distinct(provider), ScanResult::Distinct(other_provider)) => {
if provider != other_provider {
*self = ScanResult::Ambiguous
}
}
}
}
fn add(&mut self, provider: Option<FederationProviderRef>) {
self.merge(ScanResult::from(provider))
}
fn is_ambiguous(&self) -> bool {
matches!(self, ScanResult::Ambiguous)
}
fn is_none(&self) -> bool {
matches!(self, ScanResult::None)
}
fn is_some(&self) -> bool {
!self.is_none()
}
fn unwrap(self) -> Option<FederationProviderRef> {
match self {
ScanResult::None => None,
ScanResult::Distinct(provider) => Some(provider),
ScanResult::Ambiguous => panic!("called `ScanResult::unwrap()` on a `Ambiguous` value"),
}
}
fn check_recursion(&self) -> TreeNodeRecursion {
if self.is_ambiguous() {
TreeNodeRecursion::Stop
} else {
TreeNodeRecursion::Continue
}
}
}
impl From<Option<FederationProviderRef>> for ScanResult {
fn from(provider: Option<FederationProviderRef>) -> Self {
match provider {
Some(provider) => ScanResult::Distinct(provider),
None => ScanResult::None,
}
}
}
impl PartialEq<Option<FederationProviderRef>> for ScanResult {
fn eq(&self, other: &Option<FederationProviderRef>) -> bool {
match (self, other) {
(ScanResult::None, None) => true,
(ScanResult::Distinct(provider), Some(other_provider)) => provider == other_provider,
_ => false,
}
}
}
impl Clone for ScanResult {
fn clone(&self) -> Self {
match self {
ScanResult::None => ScanResult::None,
ScanResult::Distinct(provider) => ScanResult::Distinct(provider.clone()),
ScanResult::Ambiguous => ScanResult::Ambiguous,
}
}
}
impl FederationOptimizerRule {
pub fn new() -> Self {
Self::default()
}
fn scan_plan_recursively(&self, plan: &LogicalPlan) -> Result<ScanResult> {
let mut sole_provider: ScanResult = ScanResult::None;
plan.apply(&mut |p: &LogicalPlan| -> Result<TreeNodeRecursion> {
let exprs_provider = self.scan_plan_exprs(p)?;
sole_provider.merge(exprs_provider);
if sole_provider.is_ambiguous() {
return Ok(TreeNodeRecursion::Stop);
}
let sub_provider = get_leaf_provider(p)?;
sole_provider.add(sub_provider);
Ok(sole_provider.check_recursion())
})?;
Ok(sole_provider)
}
fn scan_plan_exprs(&self, plan: &LogicalPlan) -> Result<ScanResult> {
let mut sole_provider: ScanResult = ScanResult::None;
let exprs = plan.expressions();
for expr in &exprs {
let expr_result = self.scan_expr_recursively(expr)?;
sole_provider.merge(expr_result);
if sole_provider.is_ambiguous() {
return Ok(sole_provider);
}
}
Ok(sole_provider)
}
fn scan_expr_recursively(&self, expr: &Expr) -> Result<ScanResult> {
let mut sole_provider: ScanResult = ScanResult::None;
expr.apply(&mut |e: &Expr| -> Result<TreeNodeRecursion> {
match e {
Expr::ScalarSubquery(ref subquery) => {
let plan_result = self.scan_plan_recursively(&subquery.subquery)?;
sole_provider.merge(plan_result);
Ok(sole_provider.check_recursion())
}
Expr::InSubquery(_) => not_impl_err!("InSubquery"),
Expr::OuterReferenceColumn(..) => {
sole_provider = ScanResult::Ambiguous;
Ok(TreeNodeRecursion::Stop)
}
_ => Ok(TreeNodeRecursion::Continue),
}
})?;
Ok(sole_provider)
}
fn optimize_plan_recursively(
&self,
plan: &LogicalPlan,
is_root: bool,
_config: &dyn OptimizerConfig,
) -> Result<(Option<LogicalPlan>, ScanResult)> {
let mut sole_provider: ScanResult = ScanResult::None;
if let LogicalPlan::Extension(Extension { ref node }) = plan {
if node.name() == "Federated" {
return Ok((None, ScanResult::Ambiguous));
}
}
let leaf_provider = get_leaf_provider(plan)?;
let exprs_result = self.scan_plan_exprs(plan)?;
let optimize_expressions = exprs_result.is_some();
if leaf_provider.is_some() && (exprs_result.is_none() || exprs_result == leaf_provider) {
return Ok((None, leaf_provider.into()));
}
sole_provider.add(leaf_provider);
sole_provider.merge(exprs_result);
let inputs = plan.inputs();
if inputs.is_empty() && sole_provider.is_none() {
return Ok((None, ScanResult::None));
}
let input_results = inputs
.iter()
.map(|i| self.optimize_plan_recursively(i, false, _config))
.collect::<Result<Vec<_>>>()?;
input_results.iter().for_each(|(_, scan_result)| {
sole_provider.merge(scan_result.clone());
});
if sole_provider.is_none() {
return Ok((None, ScanResult::None));
}
if let ScanResult::Distinct(provider) = sole_provider {
if !is_root {
return Ok((None, ScanResult::Distinct(provider)));
}
let Some(optimizer) = provider.optimizer() else {
return Ok((None, ScanResult::None));
};
let optimized = optimizer.optimize(plan.clone(), _config, |_, _| {})?;
return Ok((Some(optimized), ScanResult::None));
}
let new_inputs = input_results
.into_iter()
.enumerate()
.map(|(i, (input_plan, input_result))| {
if let Some(federated_plan) = input_plan {
return Ok(federated_plan);
}
let original_input = (*inputs.get(i).unwrap()).clone();
if input_result.is_ambiguous() {
return Ok(original_input);
}
let provider = input_result.unwrap();
let Some(provider) = provider else {
return Ok(original_input);
};
let Some(optimizer) = provider.optimizer() else {
return Ok(original_input);
};
let wrapped = wrap_projection(original_input)?;
let optimized = optimizer.optimize(wrapped, _config, |_, _| {})?;
Ok(optimized)
})
.collect::<Result<Vec<_>>>()?;
let new_expressions = if optimize_expressions {
self.optimize_plan_exprs(plan, _config)?
} else {
plan.expressions()
};
let new_plan = plan.with_new_exprs(new_expressions, new_inputs)?;
Ok((Some(new_plan), ScanResult::Ambiguous))
}
fn optimize_plan_exprs(
&self,
plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Vec<Expr>> {
plan.expressions()
.iter()
.map(|expr| {
let transformed = expr
.clone()
.transform(&|e| self.optimize_expr_recursively(e, _config))?;
Ok(transformed.data)
})
.collect::<Result<Vec<_>>>()
}
fn optimize_expr_recursively(
&self,
expr: Expr,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<Expr>> {
match expr {
Expr::ScalarSubquery(ref subquery) => {
let (new_subquery, _) =
self.optimize_plan_recursively(&subquery.subquery, true, _config)?;
let Some(new_subquery) = new_subquery else {
return Ok(Transformed::no(expr));
};
Ok(Transformed::yes(Expr::ScalarSubquery(
subquery.with_plan(new_subquery.into()),
)))
}
Expr::InSubquery(_) => not_impl_err!("InSubquery"),
_ => Ok(Transformed::no(expr)),
}
}
}
struct NopFederationProvider {}
impl FederationProvider for NopFederationProvider {
fn name(&self) -> &str {
"nop"
}
fn compute_context(&self) -> Option<String> {
None
}
fn optimizer(&self) -> Option<Arc<Optimizer>> {
None
}
}
fn get_leaf_provider(plan: &LogicalPlan) -> Result<Option<FederationProviderRef>> {
match plan {
LogicalPlan::TableScan(TableScan { ref source, .. }) => {
let Some(federated_source) = get_table_source(source)? else {
return Ok(Some(Arc::new(NopFederationProvider {})));
};
let provider = federated_source.federation_provider();
Ok(Some(provider))
}
_ => Ok(None),
}
}
fn wrap_projection(plan: LogicalPlan) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Projection(_) => Ok(plan),
_ => {
let expr = plan
.schema()
.columns()
.iter()
.map(|c| Expr::Column(c.clone()))
.collect::<Vec<Expr>>();
Ok(LogicalPlan::Projection(Projection::try_new(
expr,
Arc::new(plan),
)?))
}
}
}
pub fn get_table_source(
source: &Arc<dyn TableSource>,
) -> Result<Option<Arc<dyn FederatedTableSource>>> {
let source = source_as_provider(source)?;
let Some(wrapper) = source
.as_any()
.downcast_ref::<FederatedTableProviderAdaptor>()
else {
return Ok(None);
};
Ok(Some(Arc::clone(&wrapper.source)))
}