1use crate::execution_engine::DefaultExecutionEngine;
21use crate::execution_engine::ExecutionEngine;
22use crate::execution_engine::QueryStageExecutor;
23use crate::metrics::ExecutorMetricsCollector;
24use kapot_core::error::KapotError;
25use kapot_core::serde::protobuf;
26use kapot_core::serde::protobuf::ExecutorRegistration;
27use kapot_core::serde::scheduler::PartitionId;
28use dashmap::DashMap;
29use datafusion::execution::context::TaskContext;
30use datafusion::execution::runtime_env::RuntimeEnv;
31use datafusion::functions::all_default_functions;
32use datafusion::functions_aggregate::all_default_aggregate_functions;
33use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
34use futures::future::AbortHandle;
35use std::collections::HashMap;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40
41pub struct TasksDrainedFuture(pub Arc<Executor>);
42
43impl Future for TasksDrainedFuture {
44 type Output = ();
45
46 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
47 if self.0.abort_handles.len() > 0 {
48 Poll::Pending
49 } else {
50 Poll::Ready(())
51 }
52 }
53}
54
55type AbortHandles = Arc<DashMap<(usize, PartitionId), AbortHandle>>;
56
57#[derive(Clone)]
59pub struct Executor {
60 pub metadata: ExecutorRegistration,
62
63 pub work_dir: String,
65
66 pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
68
69 pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
71
72 pub window_functions: HashMap<String, Arc<WindowUDF>>,
74
75 runtime: Arc<RuntimeEnv>,
77
78 runtime_with_data_cache: Option<Arc<RuntimeEnv>>,
82
83 pub metrics_collector: Arc<dyn ExecutorMetricsCollector>,
85
86 pub concurrent_tasks: usize,
88
89 abort_handles: AbortHandles,
91
92 pub(crate) execution_engine: Arc<dyn ExecutionEngine>,
95}
96
97impl Executor {
98 pub fn new(
100 metadata: ExecutorRegistration,
101 work_dir: &str,
102 runtime: Arc<RuntimeEnv>,
103 runtime_with_data_cache: Option<Arc<RuntimeEnv>>,
104 metrics_collector: Arc<dyn ExecutorMetricsCollector>,
105 concurrent_tasks: usize,
106 execution_engine: Option<Arc<dyn ExecutionEngine>>,
107 ) -> Self {
108 let scalar_functions = all_default_functions()
109 .into_iter()
110 .map(|f| (f.name().to_string(), f))
111 .collect();
112
113 let aggregate_functions = all_default_aggregate_functions()
114 .into_iter()
115 .map(|f| (f.name().to_string(), f))
116 .collect();
117
118 Self {
119 metadata,
120 work_dir: work_dir.to_owned(),
121 scalar_functions,
122 aggregate_functions,
123 window_functions: HashMap::new(),
125 runtime,
126 runtime_with_data_cache,
127 metrics_collector,
128 concurrent_tasks,
129 abort_handles: Default::default(),
130 execution_engine: execution_engine
131 .unwrap_or_else(|| Arc::new(DefaultExecutionEngine {})),
132 }
133 }
134}
135
136impl Executor {
137 pub fn get_runtime(&self, data_cache: bool) -> Arc<RuntimeEnv> {
138 if data_cache {
139 if let Some(runtime) = self.runtime_with_data_cache.clone() {
140 runtime
141 } else {
142 self.runtime.clone()
143 }
144 } else {
145 self.runtime.clone()
146 }
147 }
148
149 pub async fn execute_query_stage(
153 &self,
154 task_id: usize,
155 partition: PartitionId,
156 query_stage_exec: Arc<dyn QueryStageExecutor>,
157 task_ctx: Arc<TaskContext>,
158 ) -> Result<Vec<protobuf::ShuffleWritePartition>, KapotError> {
159 let (task, abort_handle) = futures::future::abortable(
160 query_stage_exec.execute_query_stage(partition.partition_id, task_ctx),
161 );
162
163 self.abort_handles
164 .insert((task_id, partition.clone()), abort_handle);
165
166 let partitions = task.await??;
167
168 self.abort_handles.remove(&(task_id, partition.clone()));
169
170 self.metrics_collector.record_stage(
171 &partition.job_id,
172 partition.stage_id,
173 partition.partition_id,
174 query_stage_exec,
175 );
176
177 Ok(partitions)
178 }
179
180 pub async fn cancel_task(
181 &self,
182 task_id: usize,
183 job_id: String,
184 stage_id: usize,
185 partition_id: usize,
186 ) -> Result<bool, KapotError> {
187 if let Some((_, handle)) = self.abort_handles.remove(&(
188 task_id,
189 PartitionId {
190 job_id,
191 stage_id,
192 partition_id,
193 },
194 )) {
195 handle.abort();
196 Ok(true)
197 } else {
198 Ok(false)
199 }
200 }
201
202 pub fn work_dir(&self) -> &str {
203 &self.work_dir
204 }
205
206 pub fn active_task_count(&self) -> usize {
207 self.abort_handles.len()
208 }
209}
210
211#[cfg(test)]
212mod test {
213 use crate::execution_engine::DefaultQueryStageExec;
214 use crate::executor::Executor;
215 use crate::metrics::LoggingMetricsCollector;
216 use arrow::datatypes::{Schema, SchemaRef};
217 use arrow::record_batch::RecordBatch;
218 use datafusion::physical_expr::EquivalenceProperties;
219 use datafusion::physical_plan::execution_plan::Boundedness;
220 use kapot_core::execution_plans::ShuffleWriterExec;
221 use kapot_core::serde::protobuf::ExecutorRegistration;
222 use kapot_core::serde::scheduler::PartitionId;
223 use datafusion::error::{DataFusionError, Result};
224 use datafusion::execution::context::TaskContext;
225
226 use datafusion::physical_plan::{
227 DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
228 RecordBatchStream, SendableRecordBatchStream, Statistics,
229 };
230 use datafusion::prelude::SessionContext;
231 use futures::Stream;
232 use std::any::Any;
233 use std::pin::Pin;
234 use std::sync::Arc;
235 use std::task::{Context, Poll};
236 use std::time::Duration;
237 use tempfile::TempDir;
238
239 struct NeverendingRecordBatchStream;
241
242 impl RecordBatchStream for NeverendingRecordBatchStream {
243 fn schema(&self) -> SchemaRef {
244 Arc::new(Schema::empty())
245 }
246 }
247
248 impl Stream for NeverendingRecordBatchStream {
249 type Item = Result<RecordBatch, DataFusionError>;
250
251 fn poll_next(
252 self: Pin<&mut Self>,
253 _cx: &mut Context<'_>,
254 ) -> Poll<Option<Self::Item>> {
255 Poll::Pending
256 }
257 }
258
259 #[derive(Debug)]
261 pub struct NeverendingOperator {
262 properties: PlanProperties,
263 }
264
265 impl NeverendingOperator {
266 fn new() -> Self {
267 let equivalence_properties = EquivalenceProperties::new(Arc::new(
268 Schema::empty(),
269 ));
270
271
272 NeverendingOperator {
273 properties: PlanProperties::new(
274 equivalence_properties,
275 Partitioning::UnknownPartitioning(1),
276 datafusion::physical_plan::execution_plan::EmissionType::Both,
277 Boundedness::Bounded,
278 ),
279 }
280 }
281 }
282
283 impl DisplayAs for NeverendingOperator {
284 fn fmt_as(
285 &self,
286 t: DisplayFormatType,
287 f: &mut std::fmt::Formatter,
288 ) -> std::fmt::Result {
289 match t {
290 DisplayFormatType::Default | DisplayFormatType::Verbose => {
291 write!(f, "NeverendingOperator")
292 }
293 }
294 }
295 }
296
297 impl ExecutionPlan for NeverendingOperator {
298 fn name(&self) -> &str {
299 "NeverendingOperator"
300 }
301
302 fn as_any(&self) -> &dyn Any {
303 self
304 }
305
306 fn schema(&self) -> SchemaRef {
307 Arc::new(Schema::empty())
308 }
309
310 fn properties(&self) -> &PlanProperties {
311 &self.properties
312 }
313
314 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
315 vec![]
316 }
317
318 fn with_new_children(
319 self: Arc<Self>,
320 _children: Vec<Arc<dyn ExecutionPlan>>,
321 ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
322 Ok(self)
323 }
324
325 fn execute(
326 &self,
327 _partition: usize,
328 _context: Arc<TaskContext>,
329 ) -> datafusion::common::Result<SendableRecordBatchStream> {
330 Ok(Box::pin(NeverendingRecordBatchStream))
331 }
332
333 fn statistics(&self) -> Result<Statistics> {
334 Ok(Statistics::new_unknown(&self.schema()))
335 }
336 }
337
338 #[tokio::test]
339 async fn test_task_cancellation() {
340 let work_dir = TempDir::new()
341 .unwrap()
342 .into_path()
343 .into_os_string()
344 .into_string()
345 .unwrap();
346
347 let shuffle_write = ShuffleWriterExec::try_new(
348 "job-id".to_owned(),
349 1,
350 Arc::new(NeverendingOperator::new()),
351 work_dir.clone(),
352 None,
353 )
354 .expect("creating shuffle writer");
355
356 let query_stage_exec = DefaultQueryStageExec::new(shuffle_write);
357
358 let executor_registration = ExecutorRegistration {
359 id: "executor".to_string(),
360 port: 0,
361 grpc_port: 0,
362 specification: None,
363 optional_host: None,
364 };
365
366 let ctx = SessionContext::new();
367
368 let executor = Executor::new(
369 executor_registration,
370 &work_dir,
371 ctx.runtime_env(),
372 None,
373 Arc::new(LoggingMetricsCollector {}),
374 2,
375 None,
376 );
377
378 let (sender, receiver) = tokio::sync::oneshot::channel();
379
380 let executor_clone = executor.clone();
382 tokio::task::spawn(async move {
383 let part = PartitionId {
384 job_id: "job-id".to_owned(),
385 stage_id: 1,
386 partition_id: 0,
387 };
388 let task_result = executor_clone
389 .execute_query_stage(1, part, Arc::new(query_stage_exec), ctx.task_ctx())
390 .await;
391 sender.send(task_result).expect("sending result");
392 });
393
394 for _ in 0..20 {
397 if executor
398 .cancel_task(1, "job-id".to_owned(), 1, 0)
399 .await
400 .expect("cancelling task")
401 {
402 break;
403 } else {
404 tokio::time::sleep(Duration::from_millis(50)).await;
405 }
406 }
407
408 let result = tokio::time::timeout(Duration::from_secs(5), receiver).await;
410
411 assert!(result.is_ok());
413
414 let inner_result = result.unwrap().unwrap();
416 assert!(inner_result.is_err());
417 }
418}