kapot_executor/
execution_engine.rs1use async_trait::async_trait;
19use kapot_core::execution_plans::ShuffleWriterExec;
20use kapot_core::serde::protobuf::ShuffleWritePartition;
21use kapot_core::utils;
22use datafusion::error::{DataFusionError, Result};
23use datafusion::execution::context::TaskContext;
24use datafusion::physical_plan::metrics::MetricsSet;
25use datafusion::physical_plan::ExecutionPlan;
26use std::fmt::Debug;
27use std::sync::Arc;
28
29pub trait ExecutionEngine: Sync + Send {
32 fn create_query_stage_exec(
33 &self,
34 job_id: String,
35 stage_id: usize,
36 plan: Arc<dyn ExecutionPlan>,
37 work_dir: &str,
38 ) -> Result<Arc<dyn QueryStageExecutor>>;
39}
40
41#[async_trait]
46pub trait QueryStageExecutor: Sync + Send + Debug {
47 async fn execute_query_stage(
48 &self,
49 input_partition: usize,
50 context: Arc<TaskContext>,
51 ) -> Result<Vec<ShuffleWritePartition>>;
52
53 fn collect_plan_metrics(&self) -> Vec<MetricsSet>;
54}
55
56pub struct DefaultExecutionEngine {}
57
58impl ExecutionEngine for DefaultExecutionEngine {
59 fn create_query_stage_exec(
60 &self,
61 job_id: String,
62 stage_id: usize,
63 plan: Arc<dyn ExecutionPlan>,
64 work_dir: &str,
65 ) -> Result<Arc<dyn QueryStageExecutor>> {
66 let exec = if let Some(shuffle_writer) =
68 plan.as_any().downcast_ref::<ShuffleWriterExec>()
69 {
70 ShuffleWriterExec::try_new(
72 job_id,
73 stage_id,
74 plan.children()[0].clone(),
75 work_dir.to_string(),
76 shuffle_writer.shuffle_output_partitioning().cloned(),
77 )
78 } else {
79 Err(DataFusionError::Internal(
80 "Plan passed to new_query_stage_exec is not a ShuffleWriterExec"
81 .to_string(),
82 ))
83 }?;
84 Ok(Arc::new(DefaultQueryStageExec::new(exec)))
85 }
86}
87
88#[derive(Debug)]
89pub struct DefaultQueryStageExec {
90 shuffle_writer: ShuffleWriterExec,
91}
92
93impl DefaultQueryStageExec {
94 pub fn new(shuffle_writer: ShuffleWriterExec) -> Self {
95 Self { shuffle_writer }
96 }
97}
98
99#[async_trait]
100impl QueryStageExecutor for DefaultQueryStageExec {
101 async fn execute_query_stage(
102 &self,
103 input_partition: usize,
104 context: Arc<TaskContext>,
105 ) -> Result<Vec<ShuffleWritePartition>> {
106 self.shuffle_writer
107 .execute_shuffle_write(input_partition, context)
108 .await
109 }
110
111 fn collect_plan_metrics(&self) -> Vec<MetricsSet> {
112 utils::collect_plan_metrics(&self.shuffle_writer)
113 }
114}