1use std::collections::HashMap;
21use std::sync::Arc;
22
23use kapot_core::error::{KapotError, Result};
24use kapot_core::{
25 execution_plans::{ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec},
26 serde::scheduler::PartitionLocation,
27};
28use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
29use datafusion::physical_plan::repartition::RepartitionExec;
30use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
31use datafusion::physical_plan::windows::WindowAggExec;
32use datafusion::physical_plan::{
33 with_new_children_if_necessary, ExecutionPlan, Partitioning,
34};
35
36use log::{debug, info};
37
38type PartialQueryStageResult = (Arc<dyn ExecutionPlan>, Vec<Arc<ShuffleWriterExec>>);
39
40pub struct DistributedPlanner {
41 next_stage_id: usize,
42}
43
44impl DistributedPlanner {
45 pub fn new() -> Self {
46 Self { next_stage_id: 0 }
47 }
48}
49
50impl Default for DistributedPlanner {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl DistributedPlanner {
57 pub fn plan_query_stages<'a>(
61 &'a mut self,
62 job_id: &'a str,
63 execution_plan: Arc<dyn ExecutionPlan>,
64 ) -> Result<Vec<Arc<ShuffleWriterExec>>> {
65 info!("planning query stages for job {}", job_id);
66 let (new_plan, mut stages) =
67 self.plan_query_stages_internal(job_id, execution_plan)?;
68 stages.push(create_shuffle_writer(
69 job_id,
70 self.next_stage_id(),
71 new_plan,
72 None,
73 )?);
74 Ok(stages)
75 }
76
77 fn plan_query_stages_internal<'a>(
81 &'a mut self,
82 job_id: &'a str,
83 execution_plan: Arc<dyn ExecutionPlan>,
84 ) -> Result<PartialQueryStageResult> {
85 if execution_plan.children().is_empty() {
87 return Ok((execution_plan, vec![]));
88 }
89
90 let mut stages = vec![];
91 let mut children = vec![];
92 for child in execution_plan.children() {
93 let (new_child, mut child_stages) =
94 self.plan_query_stages_internal(job_id, child.clone())?;
95 children.push(new_child);
96 stages.append(&mut child_stages);
97 }
98
99 if let Some(_coalesce) = execution_plan
100 .as_any()
101 .downcast_ref::<CoalescePartitionsExec>()
102 {
103 let shuffle_writer = create_shuffle_writer(
104 job_id,
105 self.next_stage_id(),
106 children[0].clone(),
107 None,
108 )?;
109 let unresolved_shuffle = create_unresolved_shuffle(&shuffle_writer);
110 stages.push(shuffle_writer);
111 Ok((
112 with_new_children_if_necessary(execution_plan, vec![unresolved_shuffle])?,
113 stages,
114 ))
115 } else if let Some(_sort_preserving_merge) = execution_plan
116 .as_any()
117 .downcast_ref::<SortPreservingMergeExec>(
118 ) {
119 let shuffle_writer = create_shuffle_writer(
120 job_id,
121 self.next_stage_id(),
122 children[0].clone(),
123 None,
124 )?;
125 let unresolved_shuffle = create_unresolved_shuffle(&shuffle_writer);
126 stages.push(shuffle_writer);
127 Ok((
128 with_new_children_if_necessary(execution_plan, vec![unresolved_shuffle])?,
129 stages,
130 ))
131 } else if let Some(repart) =
132 execution_plan.as_any().downcast_ref::<RepartitionExec>()
133 {
134 match repart.properties().output_partitioning() {
135 Partitioning::Hash(_, _) => {
136 let shuffle_writer = create_shuffle_writer(
137 job_id,
138 self.next_stage_id(),
139 children[0].clone(),
140 Some(repart.partitioning().to_owned()),
141 )?;
142 let unresolved_shuffle = create_unresolved_shuffle(&shuffle_writer);
143 stages.push(shuffle_writer);
144 Ok((unresolved_shuffle, stages))
145 }
146 _ => {
147 Ok((children[0].clone(), stages))
149 }
150 }
151 } else if let Some(window) =
152 execution_plan.as_any().downcast_ref::<WindowAggExec>()
153 {
154 Err(KapotError::NotImplemented(format!(
155 "WindowAggExec with window {window:?}"
156 )))
157 } else {
158 Ok((
159 with_new_children_if_necessary(execution_plan, children)?,
160 stages,
161 ))
162 }
163 }
164
165 fn next_stage_id(&mut self) -> usize {
167 self.next_stage_id += 1;
168 self.next_stage_id
169 }
170}
171
172fn create_unresolved_shuffle(
173 shuffle_writer: &ShuffleWriterExec,
174) -> Arc<UnresolvedShuffleExec> {
175 Arc::new(UnresolvedShuffleExec::new(
176 shuffle_writer.stage_id(),
177 shuffle_writer.schema(),
178 shuffle_writer
179 .properties()
180 .output_partitioning()
181 .partition_count(),
182 ))
183}
184
185pub fn find_unresolved_shuffles(
187 plan: &Arc<dyn ExecutionPlan>,
188) -> Result<Vec<UnresolvedShuffleExec>> {
189 if let Some(unresolved_shuffle) =
190 plan.as_any().downcast_ref::<UnresolvedShuffleExec>()
191 {
192 Ok(vec![unresolved_shuffle.clone()])
193 } else {
194 Ok(plan
195 .children()
196 .into_iter()
197 .map(find_unresolved_shuffles)
198 .collect::<Result<Vec<_>>>()?
199 .into_iter()
200 .flatten()
201 .collect())
202 }
203}
204
205pub fn remove_unresolved_shuffles(
206 stage: Arc<dyn ExecutionPlan>,
207 partition_locations: &HashMap<usize, HashMap<usize, Vec<PartitionLocation>>>,
208) -> Result<Arc<dyn ExecutionPlan>> {
209 let mut new_children: Vec<Arc<dyn ExecutionPlan>> = vec![];
210 for child in stage.children() {
211 if let Some(unresolved_shuffle) =
212 child.as_any().downcast_ref::<UnresolvedShuffleExec>()
213 {
214 let mut relevant_locations = vec![];
215 let p = partition_locations
216 .get(&unresolved_shuffle.stage_id)
217 .ok_or_else(|| {
218 KapotError::General(
219 "Missing partition location. Could not remove unresolved shuffles"
220 .to_owned(),
221 )
222 })?
223 .clone();
224
225 for i in 0..unresolved_shuffle.output_partition_count {
226 if let Some(x) = p.get(&i) {
227 relevant_locations.push(x.to_owned());
228 } else {
229 relevant_locations.push(vec![]);
230 }
231 }
232 debug!(
233 "Creating shuffle reader: {}",
234 relevant_locations
235 .iter()
236 .map(|c| c
237 .iter()
238 .filter(|l| !l.path.is_empty())
239 .map(|l| l.path.clone())
240 .collect::<Vec<_>>()
241 .join(", "))
242 .collect::<Vec<_>>()
243 .join("\n")
244 );
245 new_children.push(Arc::new(ShuffleReaderExec::try_new(
246 unresolved_shuffle.stage_id,
247 relevant_locations,
248 unresolved_shuffle.schema().clone(),
249 )?))
250 } else {
251 new_children.push(remove_unresolved_shuffles(
252 child.clone(),
253 partition_locations,
254 )?);
255 }
256 }
257 Ok(with_new_children_if_necessary(stage, new_children)?)
258}
259
260pub fn rollback_resolved_shuffles(
264 stage: Arc<dyn ExecutionPlan>,
265) -> Result<Arc<dyn ExecutionPlan>> {
266 let mut new_children: Vec<Arc<dyn ExecutionPlan>> = vec![];
267 for child in stage.children() {
268 if let Some(shuffle_reader) = child.as_any().downcast_ref::<ShuffleReaderExec>() {
269 let output_partition_count = shuffle_reader
270 .properties()
271 .output_partitioning()
272 .partition_count();
273 let stage_id = shuffle_reader.stage_id;
274
275 let unresolved_shuffle = Arc::new(UnresolvedShuffleExec::new(
276 stage_id,
277 shuffle_reader.schema(),
278 output_partition_count,
279 ));
280 new_children.push(unresolved_shuffle);
281 } else {
282 new_children.push(rollback_resolved_shuffles(child.clone())?);
283 }
284 }
285 Ok(with_new_children_if_necessary(stage, new_children)?)
286}
287
288fn create_shuffle_writer(
289 job_id: &str,
290 stage_id: usize,
291 plan: Arc<dyn ExecutionPlan>,
292 partitioning: Option<Partitioning>,
293) -> Result<Arc<ShuffleWriterExec>> {
294 Ok(Arc::new(ShuffleWriterExec::try_new(
295 job_id.to_owned(),
296 stage_id,
297 plan,
298 "".to_owned(), partitioning,
300 )?))
301}
302
303#[cfg(test)]
304mod test {
305 use crate::planner::DistributedPlanner;
306 use crate::test_utils::datafusion_test_context;
307 use kapot_core::error::KapotError;
308 use kapot_core::execution_plans::UnresolvedShuffleExec;
309 use kapot_core::serde::KapotCodec;
310 use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode};
311 use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
312 use datafusion::physical_plan::joins::HashJoinExec;
313 use datafusion::physical_plan::projection::ProjectionExec;
314 use datafusion::physical_plan::sorts::sort::SortExec;
315 use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
316 use datafusion::physical_plan::{displayable, ExecutionPlan};
317 use datafusion::prelude::SessionContext;
318 use datafusion_proto::physical_plan::AsExecutionPlan;
319 use datafusion_proto::protobuf::LogicalPlanNode;
320 use datafusion_proto::protobuf::PhysicalPlanNode;
321 use std::ops::Deref;
322 use std::sync::Arc;
323 use uuid::Uuid;
324
325 macro_rules! downcast_exec {
326 ($exec: expr, $ty: ty) => {
327 $exec.as_any().downcast_ref::<$ty>().expect(&format!(
328 "Downcast to {} failed. Got {:?}",
329 stringify!($ty),
330 $exec
331 ))
332 };
333 }
334
335 #[tokio::test]
336 async fn distributed_aggregate_plan() -> Result<(), KapotError> {
337 let ctx = datafusion_test_context("testdata").await?;
338 let session_state = ctx.state();
339
340 let df = ctx
342 .sql(
343 "select l_returnflag, sum(l_extendedprice * 1) as sum_disc_price
344 from lineitem
345 group by l_returnflag
346 order by l_returnflag",
347 )
348 .await?;
349
350 let plan = df.into_optimized_plan()?;
351 let plan = session_state.optimize(&plan)?;
352 let plan = session_state.create_physical_plan(&plan).await?;
353
354 let mut planner = DistributedPlanner::new();
355 let job_uuid = Uuid::new_v4();
356 let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
357 for (i, stage) in stages.iter().enumerate() {
358 println!("Stage {i}:\n{}", displayable(stage.as_ref()).indent(false));
359 }
360
361 assert_eq!(3, stages.len());
380
381 let stage0 = stages[0].children()[0].clone();
383 let partial_hash = downcast_exec!(stage0, AggregateExec);
384 assert!(*partial_hash.mode() == AggregateMode::Partial);
385
386 let stage1 = stages[1].children()[0].clone();
388 let sort = downcast_exec!(stage1, SortExec);
389 let projection = sort.children()[0].clone();
390 let projection = downcast_exec!(projection, ProjectionExec);
391 let final_hash = projection.children()[0].clone();
392 let final_hash = downcast_exec!(final_hash, AggregateExec);
393 assert!(*final_hash.mode() == AggregateMode::FinalPartitioned);
394 let coalesce = final_hash.children()[0].clone();
395 let coalesce = downcast_exec!(coalesce, CoalesceBatchesExec);
396 let unresolved_shuffle = coalesce.children()[0].clone();
397 let unresolved_shuffle =
398 downcast_exec!(unresolved_shuffle, UnresolvedShuffleExec);
399 assert_eq!(unresolved_shuffle.stage_id, 1);
400 assert_eq!(unresolved_shuffle.output_partition_count, 2);
401
402 let stage2 = stages[2].children()[0].clone();
404 let merge = downcast_exec!(stage2, SortPreservingMergeExec);
405 let unresolved_shuffle = merge.children()[0].clone();
406 let unresolved_shuffle =
407 downcast_exec!(unresolved_shuffle, UnresolvedShuffleExec);
408 assert_eq!(unresolved_shuffle.stage_id, 2);
409 assert_eq!(unresolved_shuffle.output_partition_count, 2);
410
411 Ok(())
412 }
413
414 #[tokio::test]
415 async fn distributed_join_plan() -> Result<(), KapotError> {
416 let ctx = datafusion_test_context("testdata").await?;
417 let session_state = ctx.state();
418
419 let df = ctx
421 .sql(
422 "select
423 l_shipmode,
424 sum(case
425 when o_orderpriority = '1-URGENT'
426 or o_orderpriority = '2-HIGH'
427 then 1
428 else 0
429 end) as high_line_count,
430 sum(case
431 when o_orderpriority <> '1-URGENT'
432 and o_orderpriority <> '2-HIGH'
433 then 1
434 else 0
435 end) as low_line_count
436from
437 lineitem
438 join
439 orders
440 on
441 l_orderkey = o_orderkey
442where
443 l_shipmode in ('MAIL', 'SHIP')
444 and l_commitdate < l_receiptdate
445 and l_shipdate < l_commitdate
446 and l_receiptdate >= date '1994-01-01'
447 and l_receiptdate < date '1995-01-01'
448group by
449 l_shipmode
450order by
451 l_shipmode;
452",
453 )
454 .await?;
455
456 let plan = df.into_optimized_plan()?;
457 let plan = session_state.optimize(&plan)?;
458 let plan = session_state.create_physical_plan(&plan).await?;
459
460 let mut planner = DistributedPlanner::new();
461 let job_uuid = Uuid::new_v4();
462 let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
463 for (i, stage) in stages.iter().enumerate() {
464 println!("Stage {i}:\n{}", displayable(stage.as_ref()).indent(false));
465 }
466
467 assert_eq!(5, stages.len());
500
501 assert_eq!(
505 2,
506 stages[0].children()[0]
507 .properties()
508 .output_partitioning()
509 .partition_count()
510 );
511 assert_eq!(
512 2,
513 stages[0]
514 .shuffle_output_partitioning()
515 .expect("stage 0")
516 .partition_count()
517 );
518
519 assert_eq!(
521 1,
522 stages[1].children()[0]
523 .properties()
524 .output_partitioning()
525 .partition_count()
526 );
527 assert_eq!(
528 2,
529 stages[1]
530 .shuffle_output_partitioning()
531 .expect("stage 1")
532 .partition_count()
533 );
534
535 let input = stages[2].children()[0].clone();
537 assert_eq!(
538 2,
539 input.properties().output_partitioning().partition_count()
540 );
541 assert_eq!(
542 2,
543 stages[2]
544 .shuffle_output_partitioning()
545 .expect("stage 2")
546 .partition_count()
547 );
548
549 let hash_agg = downcast_exec!(input, AggregateExec);
550
551 let coalesce_batches = hash_agg.children()[0].clone();
552 let coalesce_batches = downcast_exec!(coalesce_batches, CoalesceBatchesExec);
553
554 let join = coalesce_batches.children()[0].clone();
555 let join = downcast_exec!(join, HashJoinExec);
556 assert!(join.contain_projection());
557
558 let join_input_1 = join.children()[0].clone();
559 let join_input_1 = join_input_1.children()[0].clone();
561 let unresolved_shuffle_reader_1 =
562 downcast_exec!(join_input_1, UnresolvedShuffleExec);
563 assert_eq!(unresolved_shuffle_reader_1.output_partition_count, 2);
564
565 let join_input_2 = join.children()[1].clone();
566 let join_input_2 = join_input_2.children()[0].clone();
568 let unresolved_shuffle_reader_2 =
569 downcast_exec!(join_input_2, UnresolvedShuffleExec);
570 assert_eq!(unresolved_shuffle_reader_2.output_partition_count, 2);
571
572 assert_eq!(
574 2,
575 stages[3].children()[0]
576 .properties()
577 .output_partitioning()
578 .partition_count()
579 );
580 assert!(stages[3].shuffle_output_partitioning().is_none());
581
582 assert_eq!(
584 1,
585 stages[4].children()[0]
586 .properties()
587 .output_partitioning()
588 .partition_count()
589 );
590 assert!(stages[4].shuffle_output_partitioning().is_none());
591
592 Ok(())
593 }
594
595 #[ignore]
596 #[tokio::test]
598 async fn roundtrip_serde_aggregate() -> Result<(), KapotError> {
599 let ctx = datafusion_test_context("testdata").await?;
600 let session_state = ctx.state();
601
602 let df = ctx
604 .sql(
605 "select l_returnflag, sum(l_extendedprice * 1) as sum_disc_price
606 from lineitem
607 group by l_returnflag
608 order by l_returnflag",
609 )
610 .await?;
611
612 let plan = df.into_optimized_plan()?;
613 let plan = session_state.optimize(&plan)?;
614 let plan = session_state.create_physical_plan(&plan).await?;
615
616 let mut planner = DistributedPlanner::new();
617 let job_uuid = Uuid::new_v4();
618 let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
619
620 let partial_hash = stages[0].children()[0].clone();
621 let partial_hash_serde = roundtrip_operator(&ctx, partial_hash.clone())?;
622
623 let partial_hash = downcast_exec!(partial_hash, AggregateExec);
624 let partial_hash_serde = downcast_exec!(partial_hash_serde, AggregateExec);
625
626 assert_eq!(
627 format!("{partial_hash:?}"),
628 format!("{partial_hash_serde:?}")
629 );
630
631 Ok(())
632 }
633
634 fn roundtrip_operator(
635 ctx: &SessionContext,
636 plan: Arc<dyn ExecutionPlan>,
637 ) -> Result<Arc<dyn ExecutionPlan>, KapotError> {
638 let codec: KapotCodec<LogicalPlanNode, PhysicalPlanNode> =
639 KapotCodec::default();
640 let proto: datafusion_proto::protobuf::PhysicalPlanNode =
641 datafusion_proto::protobuf::PhysicalPlanNode::try_from_physical_plan(
642 plan.clone(),
643 codec.physical_extension_codec(),
644 )?;
645 let runtime = ctx.runtime_env();
646 let result_exec_plan: Arc<dyn ExecutionPlan> = (proto).try_into_physical_plan(
647 ctx,
648 runtime.deref(),
649 codec.physical_extension_codec(),
650 )?;
651 Ok(result_exec_plan)
652 }
653}