datafusion_materialized_views/rewrite/
exploitation.rs1use std::collections::HashMap;
47use std::{collections::HashSet, sync::Arc};
48
49use async_trait::async_trait;
50use datafusion::catalog::TableProvider;
51use datafusion::datasource::provider_as_source;
52use datafusion::execution::context::SessionState;
53use datafusion::execution::{SendableRecordBatchStream, TaskContext};
54use datafusion::physical_expr::{PhysicalSortExpr, PhysicalSortRequirement};
55use datafusion::physical_expr_common::sort_expr::format_physical_sort_requirement_list;
56use datafusion::physical_optimizer::PhysicalOptimizerRule;
57use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
58use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
59use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter};
60use datafusion_common::{DataFusionError, Result, TableReference};
61use datafusion_expr::{Extension, LogicalPlan, UserDefinedLogicalNode, UserDefinedLogicalNodeCore};
62use datafusion_optimizer::OptimizerRule;
63use datafusion_physical_expr::OrderingRequirements;
64use itertools::Itertools;
65use ordered_float::OrderedFloat;
66
67use crate::materialized::cast_to_materialized;
68
69use super::normal_form::SpjNormalForm;
70use super::QueryRewriteOptions;
71
72pub type CostFn = Arc<
74 dyn for<'a> Fn(Box<dyn Iterator<Item = &'a dyn ExecutionPlan> + 'a>) -> Vec<f64> + Send + Sync,
75>;
76
77#[derive(Debug)]
79pub struct ViewMatcher {
80 mv_plans: HashMap<TableReference, (Arc<dyn TableProvider>, SpjNormalForm)>,
81}
82
83impl ViewMatcher {
84 pub async fn try_new_from_state(session_state: &SessionState) -> Result<Self> {
86 let mut mv_plans: HashMap<TableReference, _> = HashMap::default();
87 for (resolved_table_ref, table) in
88 super::util::list_tables(session_state.catalog_list().as_ref()).await?
89 {
90 let Some(mv) = cast_to_materialized(table.as_ref())? else {
91 continue;
92 };
93
94 if !mv.config().use_in_query_rewrite {
95 continue;
96 }
97
98 let analyzed_plan = session_state.analyzer().execute_and_check(
100 mv.query(),
101 session_state.config_options(),
102 |_, _| {},
103 )?;
104
105 match SpjNormalForm::new(&analyzed_plan) {
106 Err(e) => {
107 log::trace!("can't support view matching for {resolved_table_ref}: {e}")
108 }
109 Ok(normal_form) => {
110 mv_plans.insert(resolved_table_ref.clone().into(), (table, normal_form));
111 }
112 }
113 }
114
115 Ok(ViewMatcher { mv_plans })
116 }
117}
118
119impl OptimizerRule for ViewMatcher {
120 fn rewrite(
121 &self,
122 plan: LogicalPlan,
123 config: &dyn datafusion_optimizer::OptimizerConfig,
124 ) -> Result<Transformed<LogicalPlan>> {
125 if !config
126 .options()
127 .extensions
128 .get::<QueryRewriteOptions>()
129 .cloned()
130 .unwrap_or_default()
131 .enabled
132 {
133 return Ok(Transformed::no(plan));
134 }
135
136 plan.rewrite(&mut ViewMatchingRewriter { parent: self })
137 }
138
139 fn supports_rewrite(&self) -> bool {
140 true
141 }
142
143 fn name(&self) -> &str {
144 "view_matcher"
145 }
146}
147
148struct ViewMatchingRewriter<'a> {
150 parent: &'a ViewMatcher,
151}
152
153impl TreeNodeRewriter for ViewMatchingRewriter<'_> {
154 type Node = LogicalPlan;
155
156 fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
157 if matches!(&node, LogicalPlan::Extension(ext) if ext.node.as_any().is::<OneOf>()) {
158 return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump));
159 }
160
161 let table_reference = match locate_single_table_scan(&node)? {
162 None => return Ok(Transformed::no(node)),
163 Some(table_reference) => table_reference,
164 };
165
166 log::trace!("rewriting node: {}", node.display());
167
168 let form = match SpjNormalForm::new(&node) {
169 Err(e) => {
170 log::trace!(
171 "couldn't generate rewrites: for {table_reference}, recursing deeper: {e}"
172 );
173 return Ok(Transformed::no(node));
174 }
175 Ok(form) => form,
176 };
177
178 let candidates = self
180 .parent
181 .mv_plans
182 .iter()
183 .filter_map(|(table_ref, (table, plan))| {
184 plan.referenced_tables()
186 .contains(&table_reference)
187 .then(|| {
188 form.rewrite_from(
189 plan,
190 table_ref.clone(),
191 provider_as_source(Arc::clone(table)),
192 )
193 .transpose()
194 })
195 .flatten()
196 })
197 .flat_map(|res| match res {
198 Err(e) => {
199 log::trace!("error rewriting: {e}");
200 None
201 }
202 Ok(plan) => Some(plan),
203 })
204 .collect::<Vec<_>>();
205
206 if candidates.is_empty() {
207 log::trace!("no candidates");
208 Ok(Transformed::no(node))
209 } else {
210 Ok(Transformed::new(
211 LogicalPlan::Extension(Extension {
212 node: Arc::new(OneOf {
213 branches: Some(node).into_iter().chain(candidates).collect_vec(),
214 }),
215 }),
216 true,
217 TreeNodeRecursion::Jump,
218 ))
219 }
220 }
221}
222
223fn locate_single_table_scan(node: &LogicalPlan) -> Result<Option<TableReference>> {
224 let mut table_reference = None;
225 node.apply(|plan| {
226 if let LogicalPlan::TableScan(scan) = plan {
227 match table_reference {
228 Some(_) => {
229 table_reference = None;
232 return Ok(TreeNodeRecursion::Stop);
233 }
234 None => table_reference = Some(scan.table_name.clone()),
235 }
236 }
237 Ok(TreeNodeRecursion::Continue)
238 })?;
239
240 Ok(table_reference)
242}
243
244pub struct ViewExploitationPlanner {
246 cost: CostFn,
247}
248
249impl ViewExploitationPlanner {
250 pub fn new(cost: CostFn) -> Self {
252 Self { cost }
253 }
254}
255
256#[async_trait]
257impl ExtensionPlanner for ViewExploitationPlanner {
258 async fn plan_extension(
260 &self,
261 _planner: &dyn PhysicalPlanner,
262 node: &dyn UserDefinedLogicalNode,
263 logical_inputs: &[&LogicalPlan],
264 physical_inputs: &[Arc<dyn ExecutionPlan>],
265 _session_state: &SessionState,
266 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
267 if node.as_any().downcast_ref::<OneOf>().is_none() {
268 return Ok(None);
269 }
270
271 if logical_inputs
272 .iter()
273 .map(|plan| plan.schema())
274 .any(|schema| schema != logical_inputs[0].schema())
275 {
276 return Err(DataFusionError::Plan(
277 "candidate logical plans should have the same schema".to_string(),
278 ));
279 }
280
281 if physical_inputs
282 .iter()
283 .map(|plan| plan.schema())
284 .any(|schema| schema != physical_inputs[0].schema())
285 {
286 return Err(DataFusionError::Plan(
287 "candidate physical plans should have the same schema".to_string(),
288 ));
289 }
290
291 Ok(Some(Arc::new(OneOfExec::try_new(
292 physical_inputs.to_vec(),
293 None,
294 Arc::clone(&self.cost),
295 )?)))
296 }
297}
298
299#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)]
302pub struct OneOf {
303 branches: Vec<LogicalPlan>,
304}
305
306impl UserDefinedLogicalNodeCore for OneOf {
307 fn name(&self) -> &str {
308 "OneOf"
309 }
310
311 fn inputs(&self) -> Vec<&LogicalPlan> {
312 self.branches
313 .iter()
314 .sorted_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
315 .collect_vec()
316 }
317
318 fn schema(&self) -> &datafusion_common::DFSchemaRef {
319 self.branches[0].schema()
320 }
321
322 fn expressions(&self) -> Vec<datafusion::prelude::Expr> {
323 vec![]
324 }
325
326 fn prevent_predicate_push_down_columns(&self) -> std::collections::HashSet<String> {
327 HashSet::new() }
329
330 fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
331 write!(f, "OneOf")
332 }
333
334 fn with_exprs_and_inputs(
335 &self,
336 _exprs: Vec<datafusion::prelude::Expr>,
337 inputs: Vec<LogicalPlan>,
338 ) -> Result<Self> {
339 Ok(Self { branches: inputs })
340 }
341}
342
343#[derive(Clone)]
345pub struct OneOfExec {
346 candidates: Vec<Arc<dyn ExecutionPlan>>,
347 required_input_ordering: Option<OrderingRequirements>,
351 best: usize,
353 cost: CostFn,
355}
356
357impl std::fmt::Debug for OneOfExec {
358 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359 f.debug_struct("OneOfExec")
360 .field("candidates", &self.candidates)
361 .field("required_input_ordering", &self.required_input_ordering)
362 .field("best", &self.best)
363 .finish_non_exhaustive()
364 }
365}
366
367impl OneOfExec {
368 pub fn try_new(
370 candidates: Vec<Arc<dyn ExecutionPlan>>,
371 required_input_ordering: Option<OrderingRequirements>,
372 cost: CostFn,
373 ) -> Result<Self> {
374 if candidates.is_empty() {
375 return Err(DataFusionError::Plan(
376 "can't create OneOfExec with empty children".to_string(),
377 ));
378 }
379
380 let best = cost(Box::new(candidates.iter().map(|c| c.as_ref())))
381 .iter()
382 .position_min_by_key(|&cost| OrderedFloat(*cost))
383 .unwrap();
384
385 Ok(Self {
386 candidates,
387 required_input_ordering,
388 best,
389 cost,
390 })
391 }
392
393 pub fn best(&self) -> Arc<dyn ExecutionPlan> {
396 Arc::clone(&self.candidates[self.best])
397 }
398
399 pub fn with_required_input_ordering(self, requirement: Option<OrderingRequirements>) -> Self {
402 Self {
403 required_input_ordering: requirement,
404 ..self
405 }
406 }
407}
408
409impl ExecutionPlan for OneOfExec {
410 fn name(&self) -> &str {
411 "OneOfExec"
412 }
413
414 fn as_any(&self) -> &dyn std::any::Any {
415 self
416 }
417
418 fn properties(&self) -> &PlanProperties {
419 self.candidates[self.best].properties()
420 }
421
422 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
423 vec![self.required_input_ordering.clone(); self.children().len()]
424 }
425
426 fn maintains_input_order(&self) -> Vec<bool> {
427 vec![true; self.candidates.len()]
428 }
429
430 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
431 vec![false; self.candidates.len()]
432 }
433
434 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
435 self.candidates.iter().collect_vec()
436 }
437
438 fn with_new_children(
439 self: Arc<Self>,
440 children: Vec<Arc<dyn ExecutionPlan>>,
441 ) -> Result<Arc<dyn ExecutionPlan>> {
442 if children.len() == 1 {
443 return Ok(Arc::clone(&children[0]));
444 }
445
446 Ok(Arc::new(Self::try_new(
447 children,
448 self.required_input_ordering.clone(),
449 Arc::clone(&self.cost),
450 )?))
451 }
452
453 fn supports_limit_pushdown(&self) -> bool {
454 true
455 }
456
457 fn execute(
458 &self,
459 partition: usize,
460 context: Arc<TaskContext>,
461 ) -> Result<SendableRecordBatchStream> {
462 self.candidates[self.best].execute(partition, context)
463 }
464
465 fn statistics(&self) -> Result<datafusion_common::Statistics> {
466 self.candidates[self.best].partition_statistics(None)
467 }
468
469 fn partition_statistics(
470 &self,
471 partition: Option<usize>,
472 ) -> Result<datafusion_common::Statistics> {
473 self.candidates[self.best].partition_statistics(partition)
474 }
475}
476
477impl DisplayAs for OneOfExec {
478 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
479 let costs = (self.cost)(Box::new(self.children().iter().map(|arc| arc.as_ref())));
480 match t {
481 DisplayFormatType::Default | DisplayFormatType::Verbose => {
482 write!(
483 f,
484 "OneOfExec(best={}), costs=[{}], required_input_ordering=[{}]",
485 self.best,
486 costs.into_iter().join(","),
487 format_physical_sort_requirement_list(
488 &self
489 .required_input_ordering
490 .as_ref()
491 .map(|req| {
492 req.clone()
493 .into_single()
494 .into_iter()
495 .map(PhysicalSortExpr::from)
496 .map(PhysicalSortRequirement::from)
497 .collect_vec()
498 })
499 .unwrap_or_default(),
500 )
501 )
502 }
503 DisplayFormatType::TreeRender => Ok(()),
504 }
505 }
506}
507
508#[derive(Debug, Clone, Default)]
510pub struct PruneCandidates;
511
512impl PhysicalOptimizerRule for PruneCandidates {
513 fn optimize(
514 &self,
515 plan: Arc<dyn ExecutionPlan>,
516 _config: &datafusion::config::ConfigOptions,
517 ) -> Result<Arc<dyn ExecutionPlan>> {
518 plan.transform(&|plan: Arc<dyn ExecutionPlan>| {
520 if let Some(one_of_exec) = plan.as_any().downcast_ref::<OneOfExec>() {
521 Ok(Transformed::new(
522 one_of_exec.best(),
523 true,
524 TreeNodeRecursion::Jump,
525 ))
526 } else {
527 Ok(Transformed::no(plan))
528 }
529 })
530 .map(|t| t.data)
531 }
532
533 fn name(&self) -> &str {
534 "PruneCandidates"
535 }
536
537 fn schema_check(&self) -> bool {
538 true
539 }
540}