kapot_executor/
executor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! kapot executor logic
19
20use crate::execution_engine::DefaultExecutionEngine;
21use crate::execution_engine::ExecutionEngine;
22use crate::execution_engine::QueryStageExecutor;
23use crate::metrics::ExecutorMetricsCollector;
24use kapot_core::error::KapotError;
25use kapot_core::serde::protobuf;
26use kapot_core::serde::protobuf::ExecutorRegistration;
27use kapot_core::serde::scheduler::PartitionId;
28use dashmap::DashMap;
29use datafusion::execution::context::TaskContext;
30use datafusion::execution::runtime_env::RuntimeEnv;
31use datafusion::functions::all_default_functions;
32use datafusion::functions_aggregate::all_default_aggregate_functions;
33use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
34use futures::future::AbortHandle;
35use std::collections::HashMap;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40
41pub struct TasksDrainedFuture(pub Arc<Executor>);
42
43impl Future for TasksDrainedFuture {
44    type Output = ();
45
46    fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
47        if self.0.abort_handles.len() > 0 {
48            Poll::Pending
49        } else {
50            Poll::Ready(())
51        }
52    }
53}
54
55type AbortHandles = Arc<DashMap<(usize, PartitionId), AbortHandle>>;
56
57/// kapot executor
58#[derive(Clone)]
59pub struct Executor {
60    /// Metadata
61    pub metadata: ExecutorRegistration,
62
63    /// Directory for storing partial results
64    pub work_dir: String,
65
66    /// Scalar functions that are registered in the Executor
67    pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
68
69    /// Aggregate functions registered in the Executor
70    pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
71
72    /// Window functions registered in the Executor
73    pub window_functions: HashMap<String, Arc<WindowUDF>>,
74
75    /// Runtime environment for Executor
76    runtime: Arc<RuntimeEnv>,
77
78    /// Runtime environment for Executor with data cache.
79    /// The difference with [`runtime`] is that it leverages a different [`object_store_registry`].
80    /// And others things are shared with [`runtime`].
81    runtime_with_data_cache: Option<Arc<RuntimeEnv>>,
82
83    /// Collector for runtime execution metrics
84    pub metrics_collector: Arc<dyn ExecutorMetricsCollector>,
85
86    /// Concurrent tasks can run in executor
87    pub concurrent_tasks: usize,
88
89    /// Handles to abort executing tasks
90    abort_handles: AbortHandles,
91
92    /// Execution engine that the executor will delegate to
93    /// for executing query stages
94    pub(crate) execution_engine: Arc<dyn ExecutionEngine>,
95}
96
97impl Executor {
98    /// Create a new executor instance
99    pub fn new(
100        metadata: ExecutorRegistration,
101        work_dir: &str,
102        runtime: Arc<RuntimeEnv>,
103        runtime_with_data_cache: Option<Arc<RuntimeEnv>>,
104        metrics_collector: Arc<dyn ExecutorMetricsCollector>,
105        concurrent_tasks: usize,
106        execution_engine: Option<Arc<dyn ExecutionEngine>>,
107    ) -> Self {
108        let scalar_functions = all_default_functions()
109            .into_iter()
110            .map(|f| (f.name().to_string(), f))
111            .collect();
112
113        let aggregate_functions = all_default_aggregate_functions()
114            .into_iter()
115            .map(|f| (f.name().to_string(), f))
116            .collect();
117
118        Self {
119            metadata,
120            work_dir: work_dir.to_owned(),
121            scalar_functions,
122            aggregate_functions,
123            // TODO: set to default window functions when they are moved to udwf
124            window_functions: HashMap::new(),
125            runtime,
126            runtime_with_data_cache,
127            metrics_collector,
128            concurrent_tasks,
129            abort_handles: Default::default(),
130            execution_engine: execution_engine
131                .unwrap_or_else(|| Arc::new(DefaultExecutionEngine {})),
132        }
133    }
134}
135
136impl Executor {
137    pub fn get_runtime(&self, data_cache: bool) -> Arc<RuntimeEnv> {
138        if data_cache {
139            if let Some(runtime) = self.runtime_with_data_cache.clone() {
140                runtime
141            } else {
142                self.runtime.clone()
143            }
144        } else {
145            self.runtime.clone()
146        }
147    }
148
149    /// Execute one partition of a query stage and persist the result to disk in IPC format. On
150    /// success, return a RecordBatch containing metadata about the results, including path
151    /// and statistics.
152    pub async fn execute_query_stage(
153        &self,
154        task_id: usize,
155        partition: PartitionId,
156        query_stage_exec: Arc<dyn QueryStageExecutor>,
157        task_ctx: Arc<TaskContext>,
158    ) -> Result<Vec<protobuf::ShuffleWritePartition>, KapotError> {
159        let (task, abort_handle) = futures::future::abortable(
160            query_stage_exec.execute_query_stage(partition.partition_id, task_ctx),
161        );
162
163        self.abort_handles
164            .insert((task_id, partition.clone()), abort_handle);
165
166        let partitions = task.await??;
167
168        self.abort_handles.remove(&(task_id, partition.clone()));
169
170        self.metrics_collector.record_stage(
171            &partition.job_id,
172            partition.stage_id,
173            partition.partition_id,
174            query_stage_exec,
175        );
176
177        Ok(partitions)
178    }
179
180    pub async fn cancel_task(
181        &self,
182        task_id: usize,
183        job_id: String,
184        stage_id: usize,
185        partition_id: usize,
186    ) -> Result<bool, KapotError> {
187        if let Some((_, handle)) = self.abort_handles.remove(&(
188            task_id,
189            PartitionId {
190                job_id,
191                stage_id,
192                partition_id,
193            },
194        )) {
195            handle.abort();
196            Ok(true)
197        } else {
198            Ok(false)
199        }
200    }
201
202    pub fn work_dir(&self) -> &str {
203        &self.work_dir
204    }
205
206    pub fn active_task_count(&self) -> usize {
207        self.abort_handles.len()
208    }
209}
210
211#[cfg(test)]
212mod test {
213    use crate::execution_engine::DefaultQueryStageExec;
214    use crate::executor::Executor;
215    use crate::metrics::LoggingMetricsCollector;
216    use arrow::datatypes::{Schema, SchemaRef};
217    use arrow::record_batch::RecordBatch;
218    use datafusion::physical_expr::EquivalenceProperties;
219    use datafusion::physical_plan::execution_plan::Boundedness;
220    use kapot_core::execution_plans::ShuffleWriterExec;
221    use kapot_core::serde::protobuf::ExecutorRegistration;
222    use kapot_core::serde::scheduler::PartitionId;
223    use datafusion::error::{DataFusionError, Result};
224    use datafusion::execution::context::TaskContext;
225
226    use datafusion::physical_plan::{
227        DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
228        RecordBatchStream, SendableRecordBatchStream, Statistics,
229    };
230    use datafusion::prelude::SessionContext;
231    use futures::Stream;
232    use std::any::Any;
233    use std::pin::Pin;
234    use std::sync::Arc;
235    use std::task::{Context, Poll};
236    use std::time::Duration;
237    use tempfile::TempDir;
238
239    /// A RecordBatchStream that will never terminate
240    struct NeverendingRecordBatchStream;
241
242    impl RecordBatchStream for NeverendingRecordBatchStream {
243        fn schema(&self) -> SchemaRef {
244            Arc::new(Schema::empty())
245        }
246    }
247
248    impl Stream for NeverendingRecordBatchStream {
249        type Item = Result<RecordBatch, DataFusionError>;
250
251        fn poll_next(
252            self: Pin<&mut Self>,
253            _cx: &mut Context<'_>,
254        ) -> Poll<Option<Self::Item>> {
255            Poll::Pending
256        }
257    }
258
259    /// An ExecutionPlan which will never terminate
260    #[derive(Debug)]
261    pub struct NeverendingOperator {
262        properties: PlanProperties,
263    }
264
265    impl NeverendingOperator {
266        fn new() -> Self {
267            let equivalence_properties = EquivalenceProperties::new(Arc::new(
268                Schema::empty(),
269            ));
270
271
272            NeverendingOperator {
273                properties: PlanProperties::new(
274                    equivalence_properties,
275                    Partitioning::UnknownPartitioning(1),
276                    datafusion::physical_plan::execution_plan::EmissionType::Both,
277                    Boundedness::Bounded,
278                ),
279            }
280        }
281    }
282
283    impl DisplayAs for NeverendingOperator {
284        fn fmt_as(
285            &self,
286            t: DisplayFormatType,
287            f: &mut std::fmt::Formatter,
288        ) -> std::fmt::Result {
289            match t {
290                DisplayFormatType::Default | DisplayFormatType::Verbose => {
291                    write!(f, "NeverendingOperator")
292                }
293            }
294        }
295    }
296
297    impl ExecutionPlan for NeverendingOperator {
298        fn name(&self) -> &str {
299            "NeverendingOperator"
300        }
301
302        fn as_any(&self) -> &dyn Any {
303            self
304        }
305
306        fn schema(&self) -> SchemaRef {
307            Arc::new(Schema::empty())
308        }
309
310        fn properties(&self) -> &PlanProperties {
311            &self.properties
312        }
313
314        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
315            vec![]
316        }
317
318        fn with_new_children(
319            self: Arc<Self>,
320            _children: Vec<Arc<dyn ExecutionPlan>>,
321        ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
322            Ok(self)
323        }
324
325        fn execute(
326            &self,
327            _partition: usize,
328            _context: Arc<TaskContext>,
329        ) -> datafusion::common::Result<SendableRecordBatchStream> {
330            Ok(Box::pin(NeverendingRecordBatchStream))
331        }
332
333        fn statistics(&self) -> Result<Statistics> {
334            Ok(Statistics::new_unknown(&self.schema()))
335        }
336    }
337
338    #[tokio::test]
339    async fn test_task_cancellation() {
340        let work_dir = TempDir::new()
341            .unwrap()
342            .into_path()
343            .into_os_string()
344            .into_string()
345            .unwrap();
346
347        let shuffle_write = ShuffleWriterExec::try_new(
348            "job-id".to_owned(),
349            1,
350            Arc::new(NeverendingOperator::new()),
351            work_dir.clone(),
352            None,
353        )
354        .expect("creating shuffle writer");
355
356        let query_stage_exec = DefaultQueryStageExec::new(shuffle_write);
357
358        let executor_registration = ExecutorRegistration {
359            id: "executor".to_string(),
360            port: 0,
361            grpc_port: 0,
362            specification: None,
363            optional_host: None,
364        };
365
366        let ctx = SessionContext::new();
367
368        let executor = Executor::new(
369            executor_registration,
370            &work_dir,
371            ctx.runtime_env(),
372            None,
373            Arc::new(LoggingMetricsCollector {}),
374            2,
375            None,
376        );
377
378        let (sender, receiver) = tokio::sync::oneshot::channel();
379
380        // Spawn our non-terminating task on a separate fiber.
381        let executor_clone = executor.clone();
382        tokio::task::spawn(async move {
383            let part = PartitionId {
384                job_id: "job-id".to_owned(),
385                stage_id: 1,
386                partition_id: 0,
387            };
388            let task_result = executor_clone
389                .execute_query_stage(1, part, Arc::new(query_stage_exec), ctx.task_ctx())
390                .await;
391            sender.send(task_result).expect("sending result");
392        });
393
394        // Now cancel the task. We can only cancel once the task has been executed and has an `AbortHandle` registered, so
395        // poll until that happens.
396        for _ in 0..20 {
397            if executor
398                .cancel_task(1, "job-id".to_owned(), 1, 0)
399                .await
400                .expect("cancelling task")
401            {
402                break;
403            } else {
404                tokio::time::sleep(Duration::from_millis(50)).await;
405            }
406        }
407
408        // Wait for our task to complete
409        let result = tokio::time::timeout(Duration::from_secs(5), receiver).await;
410
411        // Make sure the task didn't timeout
412        assert!(result.is_ok());
413
414        // Make sure the actual task failed
415        let inner_result = result.unwrap().unwrap();
416        assert!(inner_result.is_err());
417    }
418}