datafusion_materialized_views/rewrite/
exploitation.rs1use std::collections::HashMap;
19use std::{collections::HashSet, sync::Arc};
20
21use async_trait::async_trait;
22use datafusion::catalog::TableProvider;
23use datafusion::datasource::provider_as_source;
24use datafusion::execution::context::SessionState;
25use datafusion::execution::{SendableRecordBatchStream, TaskContext};
26use datafusion::physical_expr::{LexRequirement, PhysicalSortExpr, PhysicalSortRequirement};
27use datafusion::physical_expr_common::sort_expr::format_physical_sort_requirement_list;
28use datafusion::physical_optimizer::PhysicalOptimizerRule;
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
31use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter};
32use datafusion_common::{DataFusionError, Result, TableReference};
33use datafusion_expr::{Extension, LogicalPlan, UserDefinedLogicalNode, UserDefinedLogicalNodeCore};
34use datafusion_optimizer::OptimizerRule;
35use itertools::Itertools;
36use ordered_float::OrderedFloat;
37
38use crate::materialized::cast_to_materialized;
39
40use super::normal_form::SpjNormalForm;
41use super::QueryRewriteOptions;
42
43pub type CostFn = Arc<dyn Fn(&dyn ExecutionPlan) -> f64 + Send + Sync>;
45
46#[derive(Debug)]
48pub struct ViewMatcher {
49 mv_plans: HashMap<TableReference, (Arc<dyn TableProvider>, SpjNormalForm)>,
50}
51
52impl ViewMatcher {
53 pub async fn try_new_from_state(session_state: &SessionState) -> Result<Self> {
55 let mut mv_plans: HashMap<TableReference, _> = HashMap::default();
56 for (resolved_table_ref, table) in
57 super::util::list_tables(session_state.catalog_list().as_ref()).await?
58 {
59 let Some(mv) = cast_to_materialized(table.as_ref()) else {
60 continue;
61 };
62
63 let analyzed_plan = session_state.analyzer().execute_and_check(
65 mv.query(),
66 session_state.config_options(),
67 |_, _| {},
68 )?;
69
70 match SpjNormalForm::new(&analyzed_plan) {
71 Err(e) => {
72 log::trace!("can't support view matching for {resolved_table_ref}: {e}")
73 }
74 Ok(normal_form) => {
75 mv_plans.insert(resolved_table_ref.clone().into(), (table, normal_form));
76 }
77 }
78 }
79
80 Ok(ViewMatcher { mv_plans })
81 }
82}
83
84impl OptimizerRule for ViewMatcher {
85 fn rewrite(
86 &self,
87 plan: LogicalPlan,
88 config: &dyn datafusion_optimizer::OptimizerConfig,
89 ) -> Result<Transformed<LogicalPlan>> {
90 if !config
91 .options()
92 .extensions
93 .get::<QueryRewriteOptions>()
94 .cloned()
95 .unwrap_or_default()
96 .enabled
97 {
98 return Ok(Transformed::no(plan));
99 }
100
101 plan.rewrite(&mut ViewMatchingRewriter { parent: self })
102 }
103
104 fn supports_rewrite(&self) -> bool {
105 true
106 }
107
108 fn name(&self) -> &str {
109 "view_matcher"
110 }
111}
112
113struct ViewMatchingRewriter<'a> {
115 parent: &'a ViewMatcher,
116}
117
118impl TreeNodeRewriter for ViewMatchingRewriter<'_> {
119 type Node = LogicalPlan;
120
121 fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
122 if matches!(&node, LogicalPlan::Extension(ext) if ext.node.as_any().is::<OneOf>()) {
123 return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump));
124 }
125
126 let table_reference = match locate_single_table_scan(&node)? {
127 None => return Ok(Transformed::no(node)),
128 Some(table_reference) => table_reference,
129 };
130
131 log::trace!("rewriting node: {}", node.display());
132
133 let form = match SpjNormalForm::new(&node) {
134 Err(e) => {
135 log::trace!(
136 "couldn't generate rewrites: for {table_reference}, recursing deeper: {e}"
137 );
138 return Ok(Transformed::no(node));
139 }
140 Ok(form) => form,
141 };
142
143 let candidates = self
145 .parent
146 .mv_plans
147 .iter()
148 .filter_map(|(table_ref, (table, plan))| {
149 plan.referenced_tables()
151 .contains(&table_reference)
152 .then(|| {
153 form.rewrite_from(
154 plan,
155 table_ref.clone(),
156 provider_as_source(Arc::clone(table)),
157 )
158 .transpose()
159 })
160 .flatten()
161 })
162 .flat_map(|res| match res {
163 Err(e) => {
164 log::trace!("error rewriting: {e}");
165 None
166 }
167 Ok(plan) => Some(plan),
168 })
169 .collect::<Vec<_>>();
170
171 if candidates.is_empty() {
172 log::trace!("no candidates");
173 Ok(Transformed::no(node))
174 } else {
175 Ok(Transformed::new(
176 LogicalPlan::Extension(Extension {
177 node: Arc::new(OneOf {
178 branches: Some(node).into_iter().chain(candidates).collect_vec(),
179 }),
180 }),
181 true,
182 TreeNodeRecursion::Jump,
183 ))
184 }
185 }
186}
187
188fn locate_single_table_scan(node: &LogicalPlan) -> Result<Option<TableReference>> {
189 let mut table_reference = None;
190 node.apply(|plan| {
191 if let LogicalPlan::TableScan(scan) = plan {
192 match table_reference {
193 Some(_) => {
194 table_reference = None;
197 return Ok(TreeNodeRecursion::Stop);
198 }
199 None => table_reference = Some(scan.table_name.clone()),
200 }
201 }
202 Ok(TreeNodeRecursion::Continue)
203 })?;
204
205 Ok(table_reference)
207}
208
209pub struct ViewExploitationPlanner {
211 cost: CostFn,
212}
213
214impl ViewExploitationPlanner {
215 pub fn new(cost: CostFn) -> Self {
217 Self { cost }
218 }
219}
220
221#[async_trait]
222impl ExtensionPlanner for ViewExploitationPlanner {
223 async fn plan_extension(
225 &self,
226 _planner: &dyn PhysicalPlanner,
227 node: &dyn UserDefinedLogicalNode,
228 logical_inputs: &[&LogicalPlan],
229 physical_inputs: &[Arc<dyn ExecutionPlan>],
230 _session_state: &SessionState,
231 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
232 if node.as_any().downcast_ref::<OneOf>().is_none() {
233 return Ok(None);
234 }
235
236 if logical_inputs
237 .iter()
238 .map(|plan| plan.schema())
239 .any(|schema| schema != logical_inputs[0].schema())
240 {
241 return Err(DataFusionError::Plan(
242 "candidate logical plans should have the same schema".to_string(),
243 ));
244 }
245
246 if physical_inputs
247 .iter()
248 .map(|plan| plan.schema())
249 .any(|schema| schema != physical_inputs[0].schema())
250 {
251 return Err(DataFusionError::Plan(
252 "candidate physical plans should have the same schema".to_string(),
253 ));
254 }
255
256 Ok(Some(Arc::new(OneOfExec::try_new(
257 physical_inputs.to_vec(),
258 None,
259 Arc::clone(&self.cost),
260 )?)))
261 }
262}
263
264#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)]
267pub struct OneOf {
268 branches: Vec<LogicalPlan>,
269}
270
271impl UserDefinedLogicalNodeCore for OneOf {
272 fn name(&self) -> &str {
273 "OneOf"
274 }
275
276 fn inputs(&self) -> Vec<&LogicalPlan> {
277 self.branches.iter().collect_vec()
278 }
279
280 fn schema(&self) -> &datafusion_common::DFSchemaRef {
281 self.branches[0].schema()
282 }
283
284 fn expressions(&self) -> Vec<datafusion::prelude::Expr> {
285 vec![]
286 }
287
288 fn prevent_predicate_push_down_columns(&self) -> std::collections::HashSet<String> {
289 HashSet::new() }
291
292 fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
293 write!(f, "OneOf")
294 }
295
296 fn from_template(&self, _exprs: &[datafusion::prelude::Expr], inputs: &[LogicalPlan]) -> Self {
297 Self {
298 branches: inputs.to_vec(),
299 }
300 }
301
302 fn with_exprs_and_inputs(
303 &self,
304 _exprs: Vec<datafusion::prelude::Expr>,
305 inputs: Vec<LogicalPlan>,
306 ) -> Result<Self> {
307 Ok(Self { branches: inputs })
308 }
309}
310
311#[derive(Clone)]
313pub struct OneOfExec {
314 candidates: Vec<Arc<dyn ExecutionPlan>>,
315 required_input_ordering: Option<LexRequirement>,
319 best: usize,
321 cost: CostFn,
323}
324
325impl std::fmt::Debug for OneOfExec {
326 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 f.debug_struct("OneOfExec")
328 .field("candidates", &self.candidates)
329 .field("required_input_ordering", &self.required_input_ordering)
330 .field("best", &self.best)
331 .finish_non_exhaustive()
332 }
333}
334
335impl OneOfExec {
336 pub fn try_new(
338 candidates: Vec<Arc<dyn ExecutionPlan>>,
339 required_input_ordering: Option<LexRequirement>,
340 cost: CostFn,
341 ) -> Result<Self> {
342 if candidates.is_empty() {
343 return Err(DataFusionError::Plan(
344 "can't create OneOfExec with empty children".to_string(),
345 ));
346 }
347 let best = candidates
348 .iter()
349 .position_min_by_key(|candidate| OrderedFloat(cost(candidate.as_ref())))
350 .unwrap();
351
352 Ok(Self {
353 candidates,
354 required_input_ordering,
355 best,
356 cost,
357 })
358 }
359
360 pub fn best(&self) -> Arc<dyn ExecutionPlan> {
363 Arc::clone(&self.candidates[self.best])
364 }
365
366 pub fn with_required_input_ordering(self, requirement: Option<LexRequirement>) -> Self {
369 Self {
370 required_input_ordering: requirement,
371 ..self
372 }
373 }
374}
375
376impl ExecutionPlan for OneOfExec {
377 fn name(&self) -> &str {
378 "OneOfExec"
379 }
380
381 fn as_any(&self) -> &dyn std::any::Any {
382 self
383 }
384
385 fn properties(&self) -> &PlanProperties {
386 self.candidates[self.best].properties()
387 }
388
389 fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
390 vec![self.required_input_ordering.clone(); self.children().len()]
391 }
392
393 fn maintains_input_order(&self) -> Vec<bool> {
394 vec![true; self.candidates.len()]
395 }
396
397 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
398 vec![false; self.candidates.len()]
399 }
400
401 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
402 self.candidates.iter().collect_vec()
403 }
404
405 fn with_new_children(
406 self: Arc<Self>,
407 children: Vec<Arc<dyn ExecutionPlan>>,
408 ) -> Result<Arc<dyn ExecutionPlan>> {
409 if children.len() == 1 {
410 return Ok(Arc::clone(&children[0]));
411 }
412
413 Ok(Arc::new(Self::try_new(
414 children,
415 self.required_input_ordering.clone(),
416 Arc::clone(&self.cost),
417 )?))
418 }
419
420 fn execute(
421 &self,
422 partition: usize,
423 context: Arc<TaskContext>,
424 ) -> Result<SendableRecordBatchStream> {
425 self.candidates[self.best].execute(partition, context)
426 }
427
428 fn statistics(&self) -> Result<datafusion_common::Statistics> {
429 self.candidates[self.best].statistics()
430 }
431}
432
433impl DisplayAs for OneOfExec {
434 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
435 let costs = self
436 .children()
437 .iter()
438 .map(|c| (self.cost)(c.as_ref()))
439 .collect_vec();
440 match t {
441 DisplayFormatType::Default | DisplayFormatType::Verbose => {
442 write!(
443 f,
444 "OneOfExec(best={}), costs=[{}], required_input_ordering=[{}]",
445 self.best,
446 costs.into_iter().join(","),
447 format_physical_sort_requirement_list(
448 &self
449 .required_input_ordering
450 .clone()
451 .unwrap_or_default()
452 .into_iter()
453 .map(PhysicalSortExpr::from)
454 .map(PhysicalSortRequirement::from)
455 .collect_vec()
456 )
457 )
458 }
459 }
460 }
461}
462
463#[derive(Debug, Clone, Default)]
465pub struct PruneCandidates;
466
467impl PhysicalOptimizerRule for PruneCandidates {
468 fn optimize(
469 &self,
470 plan: Arc<dyn ExecutionPlan>,
471 _config: &datafusion::config::ConfigOptions,
472 ) -> Result<Arc<dyn ExecutionPlan>> {
473 plan.transform(&|plan: Arc<dyn ExecutionPlan>| {
475 if let Some(one_of_exec) = plan.as_any().downcast_ref::<OneOfExec>() {
476 Ok(Transformed::new(
477 one_of_exec.best(),
478 true,
479 TreeNodeRecursion::Jump,
480 ))
481 } else {
482 Ok(Transformed::no(plan))
483 }
484 })
485 .map(|t| t.data)
486 }
487
488 fn name(&self) -> &str {
489 "PruneCandidates"
490 }
491
492 fn schema_check(&self) -> bool {
493 true
494 }
495}