use async_trait::async_trait;
use kapot_core::execution_plans::ShuffleWriterExec;
use kapot_core::serde::protobuf::ShuffleWritePartition;
use kapot_core::utils;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::TaskContext;
use datafusion::physical_plan::metrics::MetricsSet;
use datafusion::physical_plan::ExecutionPlan;
use std::fmt::Debug;
use std::sync::Arc;
pub trait ExecutionEngine: Sync + Send {
fn create_query_stage_exec(
&self,
job_id: String,
stage_id: usize,
plan: Arc<dyn ExecutionPlan>,
work_dir: &str,
) -> Result<Arc<dyn QueryStageExecutor>>;
}
#[async_trait]
pub trait QueryStageExecutor: Sync + Send + Debug {
async fn execute_query_stage(
&self,
input_partition: usize,
context: Arc<TaskContext>,
) -> Result<Vec<ShuffleWritePartition>>;
fn collect_plan_metrics(&self) -> Vec<MetricsSet>;
}
pub struct DefaultExecutionEngine {}
impl ExecutionEngine for DefaultExecutionEngine {
fn create_query_stage_exec(
&self,
job_id: String,
stage_id: usize,
plan: Arc<dyn ExecutionPlan>,
work_dir: &str,
) -> Result<Arc<dyn QueryStageExecutor>> {
let exec = if let Some(shuffle_writer) =
plan.as_any().downcast_ref::<ShuffleWriterExec>()
{
ShuffleWriterExec::try_new(
job_id,
stage_id,
plan.children()[0].clone(),
work_dir.to_string(),
shuffle_writer.shuffle_output_partitioning().cloned(),
)
} else {
Err(DataFusionError::Internal(
"Plan passed to new_query_stage_exec is not a ShuffleWriterExec"
.to_string(),
))
}?;
Ok(Arc::new(DefaultQueryStageExec::new(exec)))
}
}
#[derive(Debug)]
pub struct DefaultQueryStageExec {
shuffle_writer: ShuffleWriterExec,
}
impl DefaultQueryStageExec {
pub fn new(shuffle_writer: ShuffleWriterExec) -> Self {
Self { shuffle_writer }
}
}
#[async_trait]
impl QueryStageExecutor for DefaultQueryStageExec {
async fn execute_query_stage(
&self,
input_partition: usize,
context: Arc<TaskContext>,
) -> Result<Vec<ShuffleWritePartition>> {
self.shuffle_writer
.execute_shuffle_write(input_partition, context)
.await
}
fn collect_plan_metrics(&self) -> Vec<MetricsSet> {
utils::collect_plan_metrics(&self.shuffle_writer)
}
}