datafusion_dft/execution/executor/
dedicated.rs1use std::{
19 fmt::Display,
20 sync::{Arc, OnceLock},
21 time::Duration,
22};
23
24use futures::{
25 future::{BoxFuture, Shared},
26 Future, FutureExt, TryFutureExt,
27};
28use log::{info, warn};
29use parking_lot::RwLock;
30use tokio::{
31 runtime::Handle,
32 sync::{oneshot::error::RecvError, Notify},
33 task::JoinSet,
34};
35
36use crate::config::ExecutionConfig;
37
38use super::io::register_io_runtime;
39
40const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(60 * 5);
41
42#[derive(Debug)]
44#[allow(missing_docs)]
45pub enum JobError {
46 WorkerGone,
47 Panic { msg: String },
48}
49
50impl Display for JobError {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 JobError::WorkerGone => {
54 write!(f, "Worker thread gone, executor was likely shut down")
55 }
56 JobError::Panic { msg } => write!(f, "Panic: {}", msg),
57 }
58 }
59}
60
61#[derive(Clone)]
112pub struct DedicatedExecutor {
113 state: Arc<RwLock<State>>,
114
115 testing: bool,
119}
120
121static TESTING_EXECUTOR: OnceLock<DedicatedExecutor> = OnceLock::new();
123
124impl DedicatedExecutor {
125 pub fn new(
137 name: &str,
138 config: ExecutionConfig,
139 runtime_builder: tokio::runtime::Builder,
140 ) -> Self {
141 Self::new_inner(name, config, runtime_builder, false)
142 }
143
144 fn new_inner(
145 name: &str,
146 config: ExecutionConfig,
147 runtime_builder: tokio::runtime::Builder,
148 testing: bool,
149 ) -> Self {
150 let name = name.to_owned();
151
152 let notify_shutdown = Arc::new(Notify::new());
153 let notify_shutdown_captured = Arc::clone(¬ify_shutdown);
154
155 let (tx_shutdown, rx_shutdown) = tokio::sync::oneshot::channel();
156 let (tx_handle, rx_handle) = std::sync::mpsc::channel();
157
158 let io_handle = tokio::runtime::Handle::try_current().ok();
159 let thread = std::thread::Builder::new()
160 .name(format!("{name} driver"))
161 .spawn(move || {
162 register_io_runtime(io_handle.clone());
165
166 info!(
167 "Creating DedicatedExecutor with {} threads",
168 config.dedicated_executor_threads
169 );
170
171 let mut runtime_builder = runtime_builder;
172 let runtime = runtime_builder
173 .worker_threads(config.dedicated_executor_threads)
174 .on_thread_start(move || register_io_runtime(io_handle.clone()))
175 .build()
176 .expect("Creating tokio runtime");
177
178 runtime.block_on(async move {
179 let shutdown = notify_shutdown_captured.notified();
187 let mut shutdown = std::pin::pin!(shutdown);
188 shutdown.as_mut().enable();
189
190 if tx_handle.send(Handle::current()).is_err() {
191 return;
192 }
193 shutdown.await;
194 });
195
196 runtime.shutdown_timeout(SHUTDOWN_TIMEOUT);
197
198 tx_shutdown.send(()).ok();
200 })
201 .expect("executor setup");
202
203 let handle = rx_handle.recv().expect("driver started");
204
205 let state = State {
206 handle: Some(handle),
207 start_shutdown: notify_shutdown,
208 completed_shutdown: rx_shutdown.map_err(Arc::new).boxed().shared(),
209 thread: Some(thread),
210 };
211
212 Self {
213 state: Arc::new(RwLock::new(state)),
214 testing,
215 }
216 }
217
218 pub fn new_testing() -> Self {
222 TESTING_EXECUTOR
223 .get_or_init(|| {
224 let mut runtime_builder = tokio::runtime::Builder::new_current_thread();
225
226 runtime_builder.enable_time();
231
232 Self::new_inner("testing", ExecutionConfig::default(), runtime_builder, true)
233 })
234 .clone()
235 }
236
237 pub fn spawn<T>(&self, task: T) -> impl Future<Output = Result<T::Output, JobError>>
249 where
250 T: Future + Send + 'static,
251 T::Output: Send + 'static,
252 {
253 let handle = {
254 let state = self.state.read();
255 state.handle.clone()
256 };
257
258 let Some(handle) = handle else {
259 return futures::future::err(JobError::WorkerGone).boxed();
260 };
261
262 let mut join_set = JoinSet::new();
264 join_set.spawn_on(task, &handle);
265 async move {
266 join_set
267 .join_next()
268 .await
269 .expect("just spawned task")
270 .map_err(|e| match e.try_into_panic() {
271 Ok(e) => {
272 let s = if let Some(s) = e.downcast_ref::<String>() {
273 s.clone()
274 } else if let Some(s) = e.downcast_ref::<&str>() {
275 s.to_string()
276 } else {
277 "unknown internal error".to_string()
278 };
279
280 JobError::Panic { msg: s }
281 }
282 Err(_) => JobError::WorkerGone,
283 })
284 }
285 .boxed()
286 }
287
288 pub fn shutdown(&self) {
290 if self.testing {
291 return;
292 }
293
294 let mut state = self.state.write();
297 state.handle = None;
298 state.start_shutdown.notify_one();
299 }
300
301 pub async fn join(&self) {
314 if self.testing {
315 return;
316 }
317
318 self.shutdown();
319
320 let handle = {
322 let state = self.state.read();
323 state.completed_shutdown.clone()
324 };
325
326 handle.await.expect("Thread died?")
329 }
330}
331
332struct State {
339 handle: Option<Handle>,
343
344 start_shutdown: Arc<Notify>,
349
350 completed_shutdown: Shared<BoxFuture<'static, Result<(), Arc<RecvError>>>>,
352
353 thread: Option<std::thread::JoinHandle<()>>,
355}
356
357impl Drop for State {
360 fn drop(&mut self) {
361 if self.handle.is_some() {
362 warn!("DedicatedExecutor dropped without calling shutdown()");
363 self.handle = None;
364 self.start_shutdown.notify_one();
365 }
366
367 if !std::thread::panicking() && self.completed_shutdown.clone().now_or_never().is_none() {
369 warn!("DedicatedExecutor dropped without waiting for worker termination",);
370 }
371
372 self.thread.take().expect("not dropped yet").join().ok();
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use crate::execution::executor::io::spawn_io;
380
381 use super::*;
382 use std::{
383 panic::panic_any,
384 sync::{Arc, Barrier},
385 time::Duration,
386 };
387 use tokio::{net::TcpListener, sync::Barrier as AsyncBarrier};
388
389 async fn do_work(result: usize, barrier: Arc<Barrier>) -> usize {
391 barrier.wait();
392 result
393 }
394
395 async fn do_work_async(result: usize, barrier: Arc<AsyncBarrier>) -> usize {
397 barrier.wait().await;
398 result
399 }
400
401 fn exec() -> DedicatedExecutor {
402 exec_with_threads(1)
403 }
404
405 fn exec2() -> DedicatedExecutor {
406 exec_with_threads(2)
407 }
408
409 fn exec_with_threads(threads: usize) -> DedicatedExecutor {
410 let mut runtime_builder = tokio::runtime::Builder::new_multi_thread();
411 runtime_builder.worker_threads(threads);
412 runtime_builder.enable_all();
413
414 DedicatedExecutor::new(
415 "Test DedicatedExecutor",
416 ExecutionConfig::default(),
417 runtime_builder,
418 )
419 }
420
421 async fn test_io_runtime_multi_thread_impl(dedicated: DedicatedExecutor) {
422 let io_runtime_id = std::thread::current().id();
423 dedicated
424 .spawn(async move {
425 let dedicated_id = std::thread::current().id();
426 let spawned = spawn_io(async move { std::thread::current().id() }).await;
427
428 assert_ne!(dedicated_id, spawned);
429 assert_eq!(io_runtime_id, spawned);
430 })
431 .await
432 .unwrap();
433 }
434
435 #[tokio::test]
436 async fn basic() {
437 let barrier = Arc::new(Barrier::new(2));
438
439 let exec = exec();
440 let dedicated_task = exec.spawn(do_work(42, Arc::clone(&barrier)));
441
442 barrier.wait();
447
448 assert_eq!(dedicated_task.await.unwrap(), 42);
450
451 exec.join().await;
452 }
453
454 #[tokio::test]
455 async fn basic_clone() {
456 let barrier = Arc::new(Barrier::new(2));
457 let exec = exec();
458 let dedicated_task = exec.clone().spawn(do_work(42, Arc::clone(&barrier)));
460 barrier.wait();
461 assert_eq!(dedicated_task.await.unwrap(), 42);
462
463 exec.join().await;
464 }
465
466 #[tokio::test]
467 async fn drop_empty_exec() {
468 exec();
469 }
470
471 #[tokio::test]
472 async fn drop_clone() {
473 let barrier = Arc::new(Barrier::new(2));
474 let exec = exec();
475
476 drop(exec.clone());
477
478 let task = exec.spawn(do_work(42, Arc::clone(&barrier)));
479 barrier.wait();
480 assert_eq!(task.await.unwrap(), 42);
481
482 exec.join().await;
483 }
484
485 #[tokio::test]
486 #[should_panic(expected = "foo")]
487 async fn just_panic() {
488 struct S(DedicatedExecutor);
489
490 impl Drop for S {
491 fn drop(&mut self) {
492 self.0.join().now_or_never();
493 }
494 }
495
496 let exec = exec();
497 let _s = S(exec);
498
499 panic!("foo")
501 }
502
503 #[tokio::test]
504 async fn multi_task() {
505 let barrier = Arc::new(Barrier::new(3));
506
507 let exec = exec2();
509 let dedicated_task1 = exec.spawn(do_work(11, Arc::clone(&barrier)));
510 let dedicated_task2 = exec.spawn(do_work(42, Arc::clone(&barrier)));
511
512 barrier.wait();
514
515 assert_eq!(dedicated_task1.await.unwrap(), 11);
517 assert_eq!(dedicated_task2.await.unwrap(), 42);
518
519 exec.join().await;
520 }
521
522 #[tokio::test]
523 async fn tokio_spawn() {
524 let exec = exec2();
525
526 let dedicated_task = exec.spawn(async move {
529 let t1 = tokio::task::spawn(async { 25usize });
531 t1.await.unwrap()
532 });
533
534 assert_eq!(dedicated_task.await.unwrap(), 25);
536
537 exec.join().await;
538 }
539
540 #[tokio::test]
541 async fn panic_on_executor_str() {
542 let exec = exec();
543 let dedicated_task = exec.spawn(async move {
544 if true {
545 panic!("At the disco, on the dedicated task scheduler");
546 } else {
547 42
548 }
549 });
550
551 let err = dedicated_task.await.unwrap_err();
553 assert_eq!(
554 err.to_string(),
555 "Panic: At the disco, on the dedicated task scheduler",
556 );
557
558 exec.join().await;
559 }
560
561 #[tokio::test]
562 async fn panic_on_executor_string() {
563 let exec = exec();
564 let dedicated_task = exec.spawn(async move {
565 if true {
566 panic!("{} {}", 1, 2);
567 } else {
568 42
569 }
570 });
571
572 let err = dedicated_task.await.unwrap_err();
574 assert_eq!(err.to_string(), "Panic: 1 2",);
575
576 exec.join().await;
577 }
578
579 #[tokio::test]
580 async fn panic_on_executor_other() {
581 let exec = exec();
582 let dedicated_task = exec.spawn(async move {
583 if true {
584 panic_any(1)
585 } else {
586 42
587 }
588 });
589
590 let err = dedicated_task.await.unwrap_err();
592 assert_eq!(err.to_string(), "Panic: unknown internal error",);
593
594 exec.join().await;
595 }
596
597 #[tokio::test]
598 async fn executor_shutdown_while_task_running() {
599 let barrier_1 = Arc::new(Barrier::new(2));
600 let captured_1 = Arc::clone(&barrier_1);
601 let barrier_2 = Arc::new(Barrier::new(2));
602 let captured_2 = Arc::clone(&barrier_2);
603
604 let exec = exec();
605 let dedicated_task = exec.spawn(async move {
606 captured_1.wait();
607 do_work(42, captured_2).await
608 });
609 barrier_1.wait();
610
611 exec.shutdown();
612 barrier_2.wait();
614
615 assert_eq!(dedicated_task.await.unwrap(), 42);
617
618 exec.join().await;
619 }
620
621 #[tokio::test]
622 async fn executor_submit_task_after_shutdown() {
623 let exec = exec();
624
625 exec.shutdown();
627 let dedicated_task = exec.spawn(async { 11 });
628
629 let err = dedicated_task.await.unwrap_err();
631 assert_eq!(
632 err.to_string(),
633 "Worker thread gone, executor was likely shut down"
634 );
635
636 exec.join().await;
637 }
638
639 #[tokio::test]
640 async fn executor_submit_task_after_clone_shutdown() {
641 let exec = exec();
642
643 exec.clone().join().await;
645
646 let dedicated_task = exec.spawn(async { 11 });
648
649 let err = dedicated_task.await.unwrap_err();
651 assert_eq!(
652 err.to_string(),
653 "Worker thread gone, executor was likely shut down"
654 );
655
656 exec.join().await;
657 }
658
659 #[tokio::test]
660 async fn executor_join() {
661 let exec = exec();
662 exec.join().await;
664 }
665
666 #[tokio::test]
667 async fn executor_join2() {
668 let exec = exec();
669 exec.join().await;
671 exec.join().await;
672 }
673
674 #[tokio::test]
675 #[allow(clippy::redundant_clone)]
676 async fn executor_clone_join() {
677 let exec = exec();
678 exec.clone().join().await;
680 exec.clone().join().await;
681 exec.join().await;
682 }
683
684 #[tokio::test]
685 async fn drop_receiver() {
686 let exec = exec();
688
689 let barrier1_pre = Arc::new(AsyncBarrier::new(2));
691 let barrier1_pre_captured = Arc::clone(&barrier1_pre);
692 let barrier1_post = Arc::new(AsyncBarrier::new(2));
693 let barrier1_post_captured = Arc::clone(&barrier1_post);
694 let dedicated_task1 = exec.spawn(async move {
695 barrier1_pre_captured.wait().await;
696 do_work_async(11, barrier1_post_captured).await
697 });
698 barrier1_pre.wait().await;
699
700 let barrier2_pre = Arc::new(AsyncBarrier::new(2));
702 let barrier2_pre_captured = Arc::clone(&barrier2_pre);
703 let barrier2_post = Arc::new(AsyncBarrier::new(2));
704 let barrier2_post_captured = Arc::clone(&barrier2_post);
705 let dedicated_task2 = exec.spawn(async move {
706 barrier2_pre_captured.wait().await;
707 do_work_async(22, barrier2_post_captured).await
708 });
709 barrier2_pre.wait().await;
710
711 drop(dedicated_task1);
713
714 tokio::time::timeout(Duration::from_secs(1), async {
716 loop {
717 if Arc::strong_count(&barrier1_post) == 1 {
718 return;
719 }
720 tokio::time::sleep(Duration::from_millis(10)).await
721 }
722 })
723 .await
724 .unwrap();
725
726 barrier2_post.wait().await;
728 assert_eq!(dedicated_task2.await.unwrap(), 22);
729 tokio::time::timeout(Duration::from_secs(1), async {
730 loop {
731 if Arc::strong_count(&barrier2_post) == 1 {
732 return;
733 }
734 tokio::time::sleep(Duration::from_millis(10)).await
735 }
736 })
737 .await
738 .unwrap();
739
740 exec.join().await;
741 }
742
743 #[tokio::test]
744 async fn test_io_runtime_multi_thread() {
745 let mut runtime_builder = tokio::runtime::Builder::new_multi_thread();
746 runtime_builder.worker_threads(1);
747
748 let dedicated = DedicatedExecutor::new(
749 "Test DedicatedExecutor",
750 ExecutionConfig::default(),
751 runtime_builder,
752 );
753 test_io_runtime_multi_thread_impl(dedicated).await;
754 }
755
756 #[tokio::test]
757 async fn test_io_runtime_current_thread() {
758 let runtime_builder = tokio::runtime::Builder::new_current_thread();
759
760 let dedicated = DedicatedExecutor::new(
761 "Test DedicatedExecutor",
762 ExecutionConfig::default(),
763 runtime_builder,
764 );
765 test_io_runtime_multi_thread_impl(dedicated).await;
766 }
767
768 #[tokio::test]
769 async fn test_that_testing_executor_prevents_io() {
770 let exec = DedicatedExecutor::new_testing();
771
772 let io_disabled = exec
773 .spawn(async move {
774 TcpListener::bind("127.0.0.1:0")
776 .catch_unwind()
777 .await
778 .is_err()
779 })
780 .await
781 .unwrap();
782
783 assert!(io_disabled)
784 }
785}