kapot_executor/
execution_engine.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use async_trait::async_trait;
19use kapot_core::execution_plans::ShuffleWriterExec;
20use kapot_core::serde::protobuf::ShuffleWritePartition;
21use kapot_core::utils;
22use datafusion::error::{DataFusionError, Result};
23use datafusion::execution::context::TaskContext;
24use datafusion::physical_plan::metrics::MetricsSet;
25use datafusion::physical_plan::ExecutionPlan;
26use std::fmt::Debug;
27use std::sync::Arc;
28
29/// Execution engine extension point
30
31pub trait ExecutionEngine: Sync + Send {
32    fn create_query_stage_exec(
33        &self,
34        job_id: String,
35        stage_id: usize,
36        plan: Arc<dyn ExecutionPlan>,
37        work_dir: &str,
38    ) -> Result<Arc<dyn QueryStageExecutor>>;
39}
40
41/// QueryStageExecutor executes a section of a query plan that has consistent partitioning and
42/// can be executed as one unit with each partition being executed in parallel. The output of each
43/// partition is re-partitioned and streamed to disk in Arrow IPC format. Future stages of the query
44/// will use the ShuffleReaderExec to read these results.
45#[async_trait]
46pub trait QueryStageExecutor: Sync + Send + Debug {
47    async fn execute_query_stage(
48        &self,
49        input_partition: usize,
50        context: Arc<TaskContext>,
51    ) -> Result<Vec<ShuffleWritePartition>>;
52
53    fn collect_plan_metrics(&self) -> Vec<MetricsSet>;
54}
55
56pub struct DefaultExecutionEngine {}
57
58impl ExecutionEngine for DefaultExecutionEngine {
59    fn create_query_stage_exec(
60        &self,
61        job_id: String,
62        stage_id: usize,
63        plan: Arc<dyn ExecutionPlan>,
64        work_dir: &str,
65    ) -> Result<Arc<dyn QueryStageExecutor>> {
66        // the query plan created by the scheduler always starts with a ShuffleWriterExec
67        let exec = if let Some(shuffle_writer) =
68            plan.as_any().downcast_ref::<ShuffleWriterExec>()
69        {
70            // recreate the shuffle writer with the correct working directory
71            ShuffleWriterExec::try_new(
72                job_id,
73                stage_id,
74                plan.children()[0].clone(),
75                work_dir.to_string(),
76                shuffle_writer.shuffle_output_partitioning().cloned(),
77            )
78        } else {
79            Err(DataFusionError::Internal(
80                "Plan passed to new_query_stage_exec is not a ShuffleWriterExec"
81                    .to_string(),
82            ))
83        }?;
84        Ok(Arc::new(DefaultQueryStageExec::new(exec)))
85    }
86}
87
88#[derive(Debug)]
89pub struct DefaultQueryStageExec {
90    shuffle_writer: ShuffleWriterExec,
91}
92
93impl DefaultQueryStageExec {
94    pub fn new(shuffle_writer: ShuffleWriterExec) -> Self {
95        Self { shuffle_writer }
96    }
97}
98
99#[async_trait]
100impl QueryStageExecutor for DefaultQueryStageExec {
101    async fn execute_query_stage(
102        &self,
103        input_partition: usize,
104        context: Arc<TaskContext>,
105    ) -> Result<Vec<ShuffleWritePartition>> {
106        self.shuffle_writer
107            .execute_shuffle_write(input_partition, context)
108            .await
109    }
110
111    fn collect_plan_metrics(&self) -> Vec<MetricsSet> {
112        utils::collect_plan_metrics(&self.shuffle_writer)
113    }
114}