use std::sync::Arc;
use datafusion::{
config::ConfigOptions,
datasource::source_as_provider,
error::{DataFusionError, Result},
logical_expr::{Expr, LogicalPlan, Projection, TableScan, TableSource},
optimizer::analyzer::AnalyzerRule,
};
use crate::{FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef};
#[derive(Default)]
pub struct FederationAnalyzerRule {}
impl AnalyzerRule for FederationAnalyzerRule {
fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
let (optimized, _) = self.optimize_recursively(&plan, None, config)?;
if let Some(result) = optimized {
return Ok(result);
}
Ok(plan.clone())
}
fn name(&self) -> &str {
"federation_optimizer_rule"
}
}
impl FederationAnalyzerRule {
pub fn new() -> Self {
Self::default()
}
fn optimize_recursively(
&self,
plan: &LogicalPlan,
parent: Option<&LogicalPlan>,
_config: &ConfigOptions,
) -> Result<(Option<LogicalPlan>, Option<FederationProviderRef>)> {
let sole_provider = self.get_federation_provider(plan)?;
if sole_provider.is_some() {
return Ok((None, sole_provider));
}
let inputs = plan.inputs();
if inputs.is_empty() {
return Ok((None, None));
}
let (new_inputs, providers): (Vec<_>, Vec<_>) = inputs
.iter()
.map(|i| self.optimize_recursively(i, Some(plan), _config))
.collect::<Result<Vec<_>>>()?
.into_iter()
.unzip();
let first_provider = providers.first().unwrap();
let is_singular = providers.iter().all(|p| p.is_some() && p == first_provider);
if is_singular {
if parent.is_none() {
if let Some(provider) = first_provider {
if let Some(optimizer) = provider.analyzer() {
let optimized = optimizer.execute_and_check(plan, _config, |_, _| {})?;
return Ok((Some(optimized), None));
}
return Ok((None, None));
}
return Ok((None, None));
}
return Ok((None, first_provider.clone()));
}
let new_inputs = new_inputs
.into_iter()
.enumerate()
.map(|(i, new_sub_plan)| {
if let Some(sub_plan) = new_sub_plan {
return Ok(sub_plan);
}
let sub_plan = inputs.get(i).unwrap();
if let Some(provider) = providers.get(i).unwrap() {
if let Some(optimizer) = provider.analyzer() {
let wrapped = wrap_projection((*sub_plan).clone())?;
let optimized =
optimizer.execute_and_check(&wrapped, _config, |_, _| {})?;
return Ok(optimized);
}
return Ok((*sub_plan).clone());
}
Ok((*sub_plan).clone())
})
.collect::<Result<Vec<_>>>()?;
let new_plan = plan.with_new_inputs(&new_inputs)?;
Ok((Some(new_plan), None))
}
fn get_federation_provider(&self, plan: &LogicalPlan) -> Result<Option<FederationProviderRef>> {
match plan {
LogicalPlan::TableScan(TableScan { ref source, .. }) => {
let federated_source = get_table_source(source.clone())?;
let provider = federated_source.federation_provider();
Ok(Some(provider))
}
_ => Ok(None),
}
}
}
fn wrap_projection(plan: LogicalPlan) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Projection(_) => Ok(plan),
_ => {
let expr = plan
.schema()
.fields()
.iter()
.map(|f| Expr::Column(f.qualified_column()))
.collect::<Vec<Expr>>();
Ok(LogicalPlan::Projection(Projection::try_new(
expr,
Arc::new(plan),
)?))
}
}
}
pub fn get_table_source(source: Arc<dyn TableSource>) -> Result<Arc<dyn FederatedTableSource>> {
let source = source_as_provider(&source)?;
let wrapper = source
.as_any()
.downcast_ref::<FederatedTableProviderAdaptor>()
.ok_or(DataFusionError::Plan(
"expected a FederatedTableSourceWrapper".to_string(),
))?;
Ok(wrapper.source.clone())
}