datafusion_federation/optimizer/
mod.rs1mod scan_result;
2
3use std::sync::Arc;
4
5use datafusion::{
6 common::not_impl_err,
7 common::tree_node::{Transformed, TreeNode, TreeNodeRecursion},
8 datasource::source_as_provider,
9 error::Result,
10 logical_expr::{Expr, Extension, LogicalPlan, Projection, TableScan, TableSource},
11 optimizer::optimizer::{Optimizer, OptimizerConfig, OptimizerRule},
12};
13
14use crate::{
15 FederatedTableProviderAdaptor, FederatedTableSource, FederationProvider, FederationProviderRef,
16};
17
18use scan_result::ScanResult;
19
20#[derive(Default, Debug)]
26pub struct FederationOptimizerRule {}
27
28impl OptimizerRule for FederationOptimizerRule {
29 fn rewrite(
35 &self,
36 plan: LogicalPlan,
37 config: &dyn OptimizerConfig,
38 ) -> Result<Transformed<LogicalPlan>> {
39 match self.optimize_plan_recursively(&plan, true, config)? {
40 (Some(optimized_plan), _) => Ok(Transformed::yes(optimized_plan)),
41 (None, _) => Ok(Transformed::no(plan)),
42 }
43 }
44
45 fn supports_rewrite(&self) -> bool {
47 true
48 }
49
50 fn name(&self) -> &str {
52 "federation_optimizer_rule"
53 }
54}
55
56impl FederationOptimizerRule {
57 pub fn new() -> Self {
59 Self::default()
60 }
61
62 fn scan_plan_recursively(&self, plan: &LogicalPlan) -> Result<ScanResult> {
64 let mut sole_provider: ScanResult = ScanResult::None;
65
66 plan.apply(&mut |p: &LogicalPlan| -> Result<TreeNodeRecursion> {
67 let exprs_provider = self.scan_plan_exprs(p)?;
68 sole_provider.merge(exprs_provider);
69
70 if sole_provider.is_ambiguous() {
71 return Ok(TreeNodeRecursion::Stop);
72 }
73
74 let sub_provider = get_leaf_provider(p)?;
75 sole_provider.add(sub_provider);
76
77 Ok(sole_provider.check_recursion())
78 })?;
79
80 Ok(sole_provider)
81 }
82
83 fn scan_plan_exprs(&self, plan: &LogicalPlan) -> Result<ScanResult> {
85 let mut sole_provider: ScanResult = ScanResult::None;
86
87 let exprs = plan.expressions();
88 for expr in &exprs {
89 let expr_result = self.scan_expr_recursively(expr)?;
90 sole_provider.merge(expr_result);
91
92 if sole_provider.is_ambiguous() {
93 return Ok(sole_provider);
94 }
95 }
96
97 Ok(sole_provider)
98 }
99
100 fn scan_expr_recursively(&self, expr: &Expr) -> Result<ScanResult> {
102 let mut sole_provider: ScanResult = ScanResult::None;
103
104 expr.apply(&mut |e: &Expr| -> Result<TreeNodeRecursion> {
105 match e {
107 Expr::ScalarSubquery(ref subquery) => {
108 let plan_result = self.scan_plan_recursively(&subquery.subquery)?;
109
110 sole_provider.merge(plan_result);
111 Ok(sole_provider.check_recursion())
112 }
113 Expr::InSubquery(_) => not_impl_err!("InSubquery"),
114 Expr::OuterReferenceColumn(..) => {
115 sole_provider = ScanResult::Ambiguous;
119 Ok(TreeNodeRecursion::Stop)
120 }
121 _ => Ok(TreeNodeRecursion::Continue),
122 }
123 })?;
124
125 Ok(sole_provider)
126 }
127
128 fn optimize_plan_recursively(
135 &self,
136 plan: &LogicalPlan,
137 is_root: bool,
138 _config: &dyn OptimizerConfig,
139 ) -> Result<(Option<LogicalPlan>, ScanResult)> {
140 let mut sole_provider: ScanResult = ScanResult::None;
141
142 if let LogicalPlan::Extension(Extension { ref node }) = plan {
143 if node.name() == "Federated" {
144 return Ok((None, ScanResult::Ambiguous));
146 }
147 }
148
149 let leaf_provider = get_leaf_provider(plan)?;
151
152 let exprs_result = self.scan_plan_exprs(plan)?;
154 let optimize_expressions = exprs_result.is_some();
155
156 if leaf_provider.is_some() && (exprs_result.is_none() || exprs_result == leaf_provider) {
158 return Ok((None, leaf_provider.into()));
159 }
160 sole_provider.add(leaf_provider);
162 sole_provider.merge(exprs_result);
163
164 let inputs = plan.inputs();
165 if inputs.is_empty() && sole_provider.is_none() {
167 return Ok((None, ScanResult::None));
168 }
169
170 let input_results = inputs
172 .iter()
173 .map(|i| self.optimize_plan_recursively(i, false, _config))
174 .collect::<Result<Vec<_>>>()?;
175
176 input_results.iter().for_each(|(_, scan_result)| {
178 sole_provider.merge(scan_result.clone());
179 });
180
181 if sole_provider.is_none() {
182 return Ok((None, ScanResult::None));
185 }
186
187 if let ScanResult::Distinct(provider) = sole_provider {
189 if !is_root {
190 return Ok((None, ScanResult::Distinct(provider)));
192 }
193
194 let Some(optimizer) = provider.optimizer() else {
195 return Ok((None, ScanResult::None));
197 };
198
199 let optimized = optimizer.optimize(plan.clone(), _config, |_, _| {})?;
201 return Ok((Some(optimized), ScanResult::None));
202 }
203
204 let new_inputs = input_results
210 .into_iter()
211 .enumerate()
212 .map(|(i, (input_plan, input_result))| {
213 if let Some(federated_plan) = input_plan {
214 return Ok(federated_plan);
216 }
217
218 let original_input = (*inputs.get(i).unwrap()).clone();
219 if input_result.is_ambiguous() {
220 return Ok(original_input);
223 }
224
225 let provider = input_result.unwrap();
226 let Some(provider) = provider else {
227 return Ok(original_input);
229 };
230
231 let Some(optimizer) = provider.optimizer() else {
232 return Ok(original_input);
234 };
235
236 let wrapped = wrap_projection(original_input)?;
238 let optimized = optimizer.optimize(wrapped, _config, |_, _| {})?;
239
240 Ok(optimized)
241 })
242 .collect::<Result<Vec<_>>>()?;
243
244 let new_expressions = if optimize_expressions {
246 self.optimize_plan_exprs(plan, _config)?
247 } else {
248 plan.expressions()
249 };
250
251 let new_plan = plan.with_new_exprs(new_expressions, new_inputs)?;
253
254 Ok((Some(new_plan), ScanResult::Ambiguous))
256 }
257
258 fn optimize_plan_exprs(
260 &self,
261 plan: &LogicalPlan,
262 _config: &dyn OptimizerConfig,
263 ) -> Result<Vec<Expr>> {
264 plan.expressions()
265 .iter()
266 .map(|expr| {
267 let transformed = expr
268 .clone()
269 .transform(&|e| self.optimize_expr_recursively(e, _config))?;
270 Ok(transformed.data)
271 })
272 .collect::<Result<Vec<_>>>()
273 }
274
275 fn optimize_expr_recursively(
278 &self,
279 expr: Expr,
280 _config: &dyn OptimizerConfig,
281 ) -> Result<Transformed<Expr>> {
282 match expr {
283 Expr::ScalarSubquery(ref subquery) => {
284 let (new_subquery, _) =
286 self.optimize_plan_recursively(&subquery.subquery, true, _config)?;
287 let Some(new_subquery) = new_subquery else {
288 return Ok(Transformed::no(expr));
289 };
290 Ok(Transformed::yes(Expr::ScalarSubquery(
291 subquery.with_plan(new_subquery.into()),
292 )))
293 }
294 Expr::InSubquery(_) => not_impl_err!("InSubquery"),
295 _ => Ok(Transformed::no(expr)),
296 }
297 }
298}
299
300struct NopFederationProvider {}
303
304impl FederationProvider for NopFederationProvider {
305 fn name(&self) -> &str {
306 "nop"
307 }
308
309 fn compute_context(&self) -> Option<String> {
310 None
311 }
312
313 fn optimizer(&self) -> Option<Arc<Optimizer>> {
314 None
315 }
316}
317
318fn get_leaf_provider(plan: &LogicalPlan) -> Result<Option<FederationProviderRef>> {
319 match plan {
320 LogicalPlan::TableScan(TableScan { ref source, .. }) => {
321 let Some(federated_source) = get_table_source(source)? else {
322 return Ok(Some(Arc::new(NopFederationProvider {})));
325 };
326 let provider = federated_source.federation_provider();
327 Ok(Some(provider))
328 }
329 _ => Ok(None),
330 }
331}
332
333fn wrap_projection(plan: LogicalPlan) -> Result<LogicalPlan> {
334 match plan {
336 LogicalPlan::Projection(_) => Ok(plan),
337 _ => {
338 let expr = plan
339 .schema()
340 .columns()
341 .iter()
342 .map(|c| Expr::Column(c.clone()))
343 .collect::<Vec<Expr>>();
344 Ok(LogicalPlan::Projection(Projection::try_new(
345 expr,
346 Arc::new(plan),
347 )?))
348 }
349 }
350}
351
352pub fn get_table_source(
353 source: &Arc<dyn TableSource>,
354) -> Result<Option<Arc<dyn FederatedTableSource>>> {
355 let source = source_as_provider(source)?;
357
358 let Some(wrapper) = source
360 .as_any()
361 .downcast_ref::<FederatedTableProviderAdaptor>()
362 else {
363 return Ok(None);
364 };
365
366 Ok(Some(Arc::clone(&wrapper.source)))
368}