datafusion_federation/optimizer/
mod.rs1mod scan_result;
2
3use std::sync::Arc;
4
5use datafusion::{
6 common::not_impl_err,
7 common::tree_node::{Transformed, TreeNode, TreeNodeRecursion},
8 datasource::source_as_provider,
9 error::Result,
10 logical_expr::{Expr, Extension, LogicalPlan, Projection, TableScan, TableSource},
11 optimizer::optimizer::{Optimizer, OptimizerConfig, OptimizerRule},
12};
13
14use crate::{
15 FederatedTableProviderAdaptor, FederatedTableSource, FederationProvider, FederationProviderRef,
16};
17
18use scan_result::ScanResult;
19
20#[derive(Default, Debug)]
26pub struct FederationOptimizerRule {}
27
28impl OptimizerRule for FederationOptimizerRule {
29 fn rewrite(
35 &self,
36 plan: LogicalPlan,
37 config: &dyn OptimizerConfig,
38 ) -> Result<Transformed<LogicalPlan>> {
39 match self.optimize_plan_recursively(&plan, true, config)? {
40 (Some(optimized_plan), _) => Ok(Transformed::yes(optimized_plan)),
41 (None, _) => Ok(Transformed::no(plan)),
42 }
43 }
44
45 fn supports_rewrite(&self) -> bool {
47 true
48 }
49
50 fn name(&self) -> &str {
52 "federation_optimizer_rule"
53 }
54}
55
56impl FederationOptimizerRule {
57 pub fn new() -> Self {
59 Self::default()
60 }
61
62 fn scan_plan_recursively(&self, plan: &LogicalPlan) -> Result<ScanResult> {
64 let mut sole_provider: ScanResult = ScanResult::None;
65
66 plan.apply(&mut |p: &LogicalPlan| -> Result<TreeNodeRecursion> {
67 let exprs_provider = self.scan_plan_exprs(p)?;
68 sole_provider.merge(exprs_provider);
69
70 if sole_provider.is_ambiguous() {
71 return Ok(TreeNodeRecursion::Stop);
72 }
73
74 let sub_provider = get_leaf_provider(p)?;
75 sole_provider.add(sub_provider);
76
77 Ok(sole_provider.check_recursion())
78 })?;
79
80 Ok(sole_provider)
81 }
82
83 fn scan_plan_exprs(&self, plan: &LogicalPlan) -> Result<ScanResult> {
85 let mut sole_provider: ScanResult = ScanResult::None;
86
87 let exprs = plan.expressions();
88 for expr in &exprs {
89 let expr_result = self.scan_expr_recursively(expr)?;
90 sole_provider.merge(expr_result);
91
92 if sole_provider.is_ambiguous() {
93 return Ok(sole_provider);
94 }
95 }
96
97 Ok(sole_provider)
98 }
99
100 fn scan_expr_recursively(&self, expr: &Expr) -> Result<ScanResult> {
102 let mut sole_provider: ScanResult = ScanResult::None;
103
104 expr.apply(&mut |e: &Expr| -> Result<TreeNodeRecursion> {
105 match e {
107 Expr::ScalarSubquery(ref subquery) => {
108 let plan_result = self.scan_plan_recursively(&subquery.subquery)?;
109
110 sole_provider.merge(plan_result);
111 Ok(sole_provider.check_recursion())
112 }
113 Expr::InSubquery(_) => not_impl_err!("InSubquery"),
114 Expr::OuterReferenceColumn(..) => {
115 sole_provider = ScanResult::Ambiguous;
119 Ok(TreeNodeRecursion::Stop)
120 }
121 _ => Ok(TreeNodeRecursion::Continue),
122 }
123 })?;
124
125 Ok(sole_provider)
126 }
127
128 fn optimize_plan_recursively(
135 &self,
136 plan: &LogicalPlan,
137 is_root: bool,
138 _config: &dyn OptimizerConfig,
139 ) -> Result<(Option<LogicalPlan>, ScanResult)> {
140 let mut sole_provider: ScanResult = ScanResult::None;
141
142 if let LogicalPlan::Extension(Extension { ref node }) = plan {
143 if node.name() == "Federated" {
144 return Ok((None, ScanResult::Ambiguous));
146 }
147 }
148
149 let leaf_provider = get_leaf_provider(plan)?;
151
152 let exprs_result = self.scan_plan_exprs(plan)?;
154 let optimize_expressions = exprs_result.is_some();
155
156 if leaf_provider.is_some() && (exprs_result.is_none() || exprs_result == leaf_provider) {
158 return Ok((None, leaf_provider.into()));
159 }
160 sole_provider.add(leaf_provider);
162 sole_provider.merge(exprs_result);
163
164 let inputs = plan.inputs();
165 if inputs.is_empty() && sole_provider.is_none() {
167 return Ok((None, ScanResult::None));
168 }
169
170 let input_results = inputs
172 .iter()
173 .map(|i| self.optimize_plan_recursively(i, false, _config))
174 .collect::<Result<Vec<_>>>()?;
175
176 input_results.iter().for_each(|(_, scan_result)| {
178 sole_provider.merge(scan_result.clone());
179 });
180
181 if sole_provider.is_none() {
182 return Ok((None, ScanResult::None));
185 }
186
187 if let ScanResult::Distinct(provider) = sole_provider {
189 if !is_root {
190 return Ok((None, ScanResult::Distinct(provider)));
192 }
193
194 if matches!(plan, LogicalPlan::Analyze(_)) {
199 } else {
201 let Some(optimizer) = provider.optimizer() else {
202 return Ok((None, ScanResult::None));
204 };
205
206 let optimized = optimizer.optimize(plan.clone(), _config, |_, _| {})?;
208 return Ok((Some(optimized), ScanResult::None));
209 }
210 }
211
212 let new_inputs = input_results
218 .into_iter()
219 .enumerate()
220 .map(|(i, (input_plan, input_result))| {
221 if let Some(federated_plan) = input_plan {
222 return Ok(federated_plan);
224 }
225
226 let original_input = (*inputs.get(i).unwrap()).clone();
227 if input_result.is_ambiguous() {
228 return Ok(original_input);
231 }
232
233 let provider = input_result.unwrap();
234 let Some(provider) = provider else {
235 return Ok(original_input);
237 };
238
239 let Some(optimizer) = provider.optimizer() else {
240 return Ok(original_input);
242 };
243
244 let wrapped = wrap_projection(original_input)?;
246 let optimized = optimizer.optimize(wrapped, _config, |_, _| {})?;
247
248 Ok(optimized)
249 })
250 .collect::<Result<Vec<_>>>()?;
251
252 let new_expressions = if optimize_expressions {
254 self.optimize_plan_exprs(plan, _config)?
255 } else {
256 plan.expressions()
257 };
258
259 let new_plan = plan.with_new_exprs(new_expressions, new_inputs)?;
261
262 Ok((Some(new_plan), ScanResult::Ambiguous))
264 }
265
266 fn optimize_plan_exprs(
268 &self,
269 plan: &LogicalPlan,
270 _config: &dyn OptimizerConfig,
271 ) -> Result<Vec<Expr>> {
272 plan.expressions()
273 .iter()
274 .map(|expr| {
275 let transformed = expr
276 .clone()
277 .transform(&|e| self.optimize_expr_recursively(e, _config))?;
278 Ok(transformed.data)
279 })
280 .collect::<Result<Vec<_>>>()
281 }
282
283 fn optimize_expr_recursively(
286 &self,
287 expr: Expr,
288 _config: &dyn OptimizerConfig,
289 ) -> Result<Transformed<Expr>> {
290 match expr {
291 Expr::ScalarSubquery(ref subquery) => {
292 let (new_subquery, _) =
294 self.optimize_plan_recursively(&subquery.subquery, true, _config)?;
295 let Some(new_subquery) = new_subquery else {
296 return Ok(Transformed::no(expr));
297 };
298 Ok(Transformed::yes(Expr::ScalarSubquery(
299 subquery.with_plan(new_subquery.into()),
300 )))
301 }
302 Expr::InSubquery(_) => not_impl_err!("InSubquery"),
303 _ => Ok(Transformed::no(expr)),
304 }
305 }
306}
307
308struct NopFederationProvider {}
311
312impl FederationProvider for NopFederationProvider {
313 fn name(&self) -> &str {
314 "nop"
315 }
316
317 fn compute_context(&self) -> Option<String> {
318 None
319 }
320
321 fn optimizer(&self) -> Option<Arc<Optimizer>> {
322 None
323 }
324}
325
326fn get_leaf_provider(plan: &LogicalPlan) -> Result<Option<FederationProviderRef>> {
327 match plan {
328 LogicalPlan::TableScan(TableScan { ref source, .. }) => {
329 let Some(federated_source) = get_table_source(source)? else {
330 return Ok(Some(Arc::new(NopFederationProvider {})));
333 };
334 let provider = federated_source.federation_provider();
335 Ok(Some(provider))
336 }
337 _ => Ok(None),
338 }
339}
340
341fn wrap_projection(plan: LogicalPlan) -> Result<LogicalPlan> {
342 match plan {
344 LogicalPlan::Projection(_) => Ok(plan),
345 _ => {
346 let expr = plan
347 .schema()
348 .columns()
349 .iter()
350 .map(|c| Expr::Column(c.clone()))
351 .collect::<Vec<Expr>>();
352 Ok(LogicalPlan::Projection(Projection::try_new(
353 expr,
354 Arc::new(plan),
355 )?))
356 }
357 }
358}
359
360pub fn get_table_source(
361 source: &Arc<dyn TableSource>,
362) -> Result<Option<Arc<dyn FederatedTableSource>>> {
363 let source = source_as_provider(source)?;
365
366 let Some(wrapper) = source
368 .as_any()
369 .downcast_ref::<FederatedTableProviderAdaptor>()
370 else {
371 return Ok(None);
372 };
373
374 Ok(Some(Arc::clone(&wrapper.source)))
376}