1use std::{
19 fmt::Display,
20 sync::{Arc, OnceLock},
21 time::Duration,
22};
23
24use futures::{
25 future::{BoxFuture, Shared},
26 Future, FutureExt, TryFutureExt,
27};
28use log::{info, warn};
29use parking_lot::RwLock;
30use tokio::{
31 runtime::Handle,
32 sync::{oneshot::error::RecvError, Notify},
33 task::JoinSet,
34};
35
36use crate::config::ExecutionConfig;
37
38use super::io::register_io_runtime;
39
40const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(60 * 5);
41
42#[derive(Debug)]
44#[allow(missing_docs)]
45pub enum JobError {
46 WorkerGone,
47 Panic { msg: String },
48}
49
50impl Display for JobError {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 JobError::WorkerGone => {
54 write!(f, "Worker thread gone, executor was likely shut down")
55 }
56 JobError::Panic { msg } => write!(f, "Panic: {}", msg),
57 }
58 }
59}
60
61#[derive(Clone)]
112pub struct DedicatedExecutor {
113 state: Arc<RwLock<State>>,
114
115 testing: bool,
119}
120
121static TESTING_EXECUTOR: OnceLock<DedicatedExecutor> = OnceLock::new();
123
124impl DedicatedExecutor {
125 pub fn new(
137 name: &str,
138 config: ExecutionConfig,
139 runtime_builder: tokio::runtime::Builder,
140 ) -> Self {
141 Self::new_inner(name, config, runtime_builder, false)
142 }
143
144 fn new_inner(
145 name: &str,
146 config: ExecutionConfig,
147 runtime_builder: tokio::runtime::Builder,
148 testing: bool,
149 ) -> Self {
150 let name = name.to_owned();
151
152 let notify_shutdown = Arc::new(Notify::new());
153 let notify_shutdown_captured = Arc::clone(¬ify_shutdown);
154
155 let (tx_shutdown, rx_shutdown) = tokio::sync::oneshot::channel();
156 let (tx_handle, rx_handle) = std::sync::mpsc::channel();
157
158 let io_handle = tokio::runtime::Handle::try_current().ok();
159 let thread = std::thread::Builder::new()
160 .name(format!("{name} driver"))
161 .spawn(move || {
162 register_io_runtime(io_handle.clone());
165
166 info!(
167 "Creating DedicatedExecutor with {} threads",
168 config.dedicated_executor_threads
169 );
170
171 let mut runtime_builder = runtime_builder;
172 let runtime = runtime_builder
173 .worker_threads(config.dedicated_executor_threads)
174 .on_thread_start(move || register_io_runtime(io_handle.clone()))
175 .build()
176 .expect("Creating tokio runtime");
177
178 runtime.block_on(async move {
179 let shutdown = notify_shutdown_captured.notified();
187 let mut shutdown = std::pin::pin!(shutdown);
188 shutdown.as_mut().enable();
189
190 if tx_handle.send(Handle::current()).is_err() {
191 return;
192 }
193 shutdown.await;
194 });
195
196 runtime.shutdown_timeout(SHUTDOWN_TIMEOUT);
197
198 tx_shutdown.send(()).ok();
200 })
201 .expect("executor setup");
202
203 let handle = rx_handle.recv().expect("driver started");
204
205 #[cfg(feature = "observability")]
207 let metrics_collector = {
208 use crate::observability::TokioMetricsCollector;
209 use std::time::Duration;
210
211 Some(TokioMetricsCollector::start(
212 handle.clone(),
213 "cpu_runtime".to_string(),
214 Duration::from_secs(10),
215 ))
216 };
217
218 let state = State {
219 handle: Some(handle),
220 start_shutdown: notify_shutdown,
221 completed_shutdown: rx_shutdown.map_err(Arc::new).boxed().shared(),
222 thread: Some(thread),
223 #[cfg(feature = "observability")]
224 _metrics_collector: metrics_collector,
225 };
226
227 Self {
228 state: Arc::new(RwLock::new(state)),
229 testing,
230 }
231 }
232
233 pub fn new_testing() -> Self {
237 TESTING_EXECUTOR
238 .get_or_init(|| {
239 let mut runtime_builder = tokio::runtime::Builder::new_current_thread();
240
241 runtime_builder.enable_time();
246
247 Self::new_inner("testing", ExecutionConfig::default(), runtime_builder, true)
248 })
249 .clone()
250 }
251
252 pub fn spawn<T>(&self, task: T) -> impl Future<Output = Result<T::Output, JobError>>
264 where
265 T: Future + Send + 'static,
266 T::Output: Send + 'static,
267 {
268 let handle = {
269 let state = self.state.read();
270 state.handle.clone()
271 };
272
273 let Some(handle) = handle else {
274 return futures::future::err(JobError::WorkerGone).boxed();
275 };
276
277 let mut join_set = JoinSet::new();
279 join_set.spawn_on(task, &handle);
280 async move {
281 join_set
282 .join_next()
283 .await
284 .expect("just spawned task")
285 .map_err(|e| match e.try_into_panic() {
286 Ok(e) => {
287 let s = if let Some(s) = e.downcast_ref::<String>() {
288 s.clone()
289 } else if let Some(s) = e.downcast_ref::<&str>() {
290 s.to_string()
291 } else {
292 "unknown internal error".to_string()
293 };
294
295 JobError::Panic { msg: s }
296 }
297 Err(_) => JobError::WorkerGone,
298 })
299 }
300 .boxed()
301 }
302
303 pub fn shutdown(&self) {
305 if self.testing {
306 return;
307 }
308
309 let mut state = self.state.write();
312 state.handle = None;
313 state.start_shutdown.notify_one();
314 }
315
316 pub async fn join(&self) {
329 if self.testing {
330 return;
331 }
332
333 self.shutdown();
334
335 let handle = {
337 let state = self.state.read();
338 state.completed_shutdown.clone()
339 };
340
341 handle.await.expect("Thread died?")
344 }
345}
346
347struct State {
354 handle: Option<Handle>,
358
359 start_shutdown: Arc<Notify>,
364
365 completed_shutdown: Shared<BoxFuture<'static, Result<(), Arc<RecvError>>>>,
367
368 thread: Option<std::thread::JoinHandle<()>>,
370
371 #[cfg(feature = "observability")]
373 _metrics_collector: Option<crate::observability::TokioMetricsCollector>,
374}
375
376impl Drop for State {
379 fn drop(&mut self) {
380 if self.handle.is_some() {
381 warn!("DedicatedExecutor dropped without calling shutdown()");
382 self.handle = None;
383 self.start_shutdown.notify_one();
384 }
385
386 if !std::thread::panicking() && self.completed_shutdown.clone().now_or_never().is_none() {
388 warn!("DedicatedExecutor dropped without waiting for worker termination",);
389 }
390
391 self.thread.take().expect("not dropped yet").join().ok();
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use crate::executor::io::spawn_io;
399
400 use super::*;
401 use std::{
402 panic::panic_any,
403 sync::{Arc, Barrier},
404 time::Duration,
405 };
406 use tokio::{net::TcpListener, sync::Barrier as AsyncBarrier};
407
408 async fn do_work(result: usize, barrier: Arc<Barrier>) -> usize {
410 barrier.wait();
411 result
412 }
413
414 async fn do_work_async(result: usize, barrier: Arc<AsyncBarrier>) -> usize {
416 barrier.wait().await;
417 result
418 }
419
420 fn exec() -> DedicatedExecutor {
421 exec_with_threads(1)
422 }
423
424 fn exec2() -> DedicatedExecutor {
425 exec_with_threads(2)
426 }
427
428 fn exec_with_threads(threads: usize) -> DedicatedExecutor {
429 let mut runtime_builder = tokio::runtime::Builder::new_multi_thread();
430 runtime_builder.worker_threads(threads);
431 runtime_builder.enable_all();
432
433 DedicatedExecutor::new(
434 "Test DedicatedExecutor",
435 ExecutionConfig::default(),
436 runtime_builder,
437 )
438 }
439
440 async fn test_io_runtime_multi_thread_impl(dedicated: DedicatedExecutor) {
441 let io_runtime_id = std::thread::current().id();
442 dedicated
443 .spawn(async move {
444 let dedicated_id = std::thread::current().id();
445 let spawned = spawn_io(async move { std::thread::current().id() }).await;
446
447 assert_ne!(dedicated_id, spawned);
448 assert_eq!(io_runtime_id, spawned);
449 })
450 .await
451 .unwrap();
452 }
453
454 #[tokio::test]
455 async fn basic() {
456 let barrier = Arc::new(Barrier::new(2));
457
458 let exec = exec();
459 let dedicated_task = exec.spawn(do_work(42, Arc::clone(&barrier)));
460
461 barrier.wait();
466
467 assert_eq!(dedicated_task.await.unwrap(), 42);
469
470 exec.join().await;
471 }
472
473 #[tokio::test]
474 async fn basic_clone() {
475 let barrier = Arc::new(Barrier::new(2));
476 let exec = exec();
477 let dedicated_task = exec.clone().spawn(do_work(42, Arc::clone(&barrier)));
479 barrier.wait();
480 assert_eq!(dedicated_task.await.unwrap(), 42);
481
482 exec.join().await;
483 }
484
485 #[tokio::test]
486 async fn drop_empty_exec() {
487 exec();
488 }
489
490 #[tokio::test]
491 async fn drop_clone() {
492 let barrier = Arc::new(Barrier::new(2));
493 let exec = exec();
494
495 drop(exec.clone());
496
497 let task = exec.spawn(do_work(42, Arc::clone(&barrier)));
498 barrier.wait();
499 assert_eq!(task.await.unwrap(), 42);
500
501 exec.join().await;
502 }
503
504 #[tokio::test]
505 #[should_panic(expected = "foo")]
506 async fn just_panic() {
507 struct S(DedicatedExecutor);
508
509 impl Drop for S {
510 fn drop(&mut self) {
511 self.0.join().now_or_never();
512 }
513 }
514
515 let exec = exec();
516 let _s = S(exec);
517
518 panic!("foo")
520 }
521
522 #[tokio::test]
523 async fn multi_task() {
524 let barrier = Arc::new(Barrier::new(3));
525
526 let exec = exec2();
528 let dedicated_task1 = exec.spawn(do_work(11, Arc::clone(&barrier)));
529 let dedicated_task2 = exec.spawn(do_work(42, Arc::clone(&barrier)));
530
531 barrier.wait();
533
534 assert_eq!(dedicated_task1.await.unwrap(), 11);
536 assert_eq!(dedicated_task2.await.unwrap(), 42);
537
538 exec.join().await;
539 }
540
541 #[tokio::test]
542 async fn tokio_spawn() {
543 let exec = exec2();
544
545 let dedicated_task = exec.spawn(async move {
548 let t1 = tokio::task::spawn(async { 25usize });
550 t1.await.unwrap()
551 });
552
553 assert_eq!(dedicated_task.await.unwrap(), 25);
555
556 exec.join().await;
557 }
558
559 #[tokio::test]
560 async fn panic_on_executor_str() {
561 let exec = exec();
562 let dedicated_task = exec.spawn(async move {
563 if true {
564 panic!("At the disco, on the dedicated task scheduler");
565 } else {
566 42
567 }
568 });
569
570 let err = dedicated_task.await.unwrap_err();
572 assert_eq!(
573 err.to_string(),
574 "Panic: At the disco, on the dedicated task scheduler",
575 );
576
577 exec.join().await;
578 }
579
580 #[tokio::test]
581 async fn panic_on_executor_string() {
582 let exec = exec();
583 let dedicated_task = exec.spawn(async move {
584 if true {
585 panic!("{} {}", 1, 2);
586 } else {
587 42
588 }
589 });
590
591 let err = dedicated_task.await.unwrap_err();
593 assert_eq!(err.to_string(), "Panic: 1 2",);
594
595 exec.join().await;
596 }
597
598 #[tokio::test]
599 async fn panic_on_executor_other() {
600 let exec = exec();
601 let dedicated_task = exec.spawn(async move {
602 if true {
603 panic_any(1)
604 } else {
605 42
606 }
607 });
608
609 let err = dedicated_task.await.unwrap_err();
611 assert_eq!(err.to_string(), "Panic: unknown internal error",);
612
613 exec.join().await;
614 }
615
616 #[tokio::test]
617 async fn executor_shutdown_while_task_running() {
618 let barrier_1 = Arc::new(Barrier::new(2));
619 let captured_1 = Arc::clone(&barrier_1);
620 let barrier_2 = Arc::new(Barrier::new(2));
621 let captured_2 = Arc::clone(&barrier_2);
622
623 let exec = exec();
624 let dedicated_task = exec.spawn(async move {
625 captured_1.wait();
626 do_work(42, captured_2).await
627 });
628 barrier_1.wait();
629
630 exec.shutdown();
631 barrier_2.wait();
633
634 assert_eq!(dedicated_task.await.unwrap(), 42);
636
637 exec.join().await;
638 }
639
640 #[tokio::test]
641 async fn executor_submit_task_after_shutdown() {
642 let exec = exec();
643
644 exec.shutdown();
646 let dedicated_task = exec.spawn(async { 11 });
647
648 let err = dedicated_task.await.unwrap_err();
650 assert_eq!(
651 err.to_string(),
652 "Worker thread gone, executor was likely shut down"
653 );
654
655 exec.join().await;
656 }
657
658 #[tokio::test]
659 async fn executor_submit_task_after_clone_shutdown() {
660 let exec = exec();
661
662 exec.clone().join().await;
664
665 let dedicated_task = exec.spawn(async { 11 });
667
668 let err = dedicated_task.await.unwrap_err();
670 assert_eq!(
671 err.to_string(),
672 "Worker thread gone, executor was likely shut down"
673 );
674
675 exec.join().await;
676 }
677
678 #[tokio::test]
679 async fn executor_join() {
680 let exec = exec();
681 exec.join().await;
683 }
684
685 #[tokio::test]
686 async fn executor_join2() {
687 let exec = exec();
688 exec.join().await;
690 exec.join().await;
691 }
692
693 #[tokio::test]
694 #[allow(clippy::redundant_clone)]
695 async fn executor_clone_join() {
696 let exec = exec();
697 exec.clone().join().await;
699 exec.clone().join().await;
700 exec.join().await;
701 }
702
703 #[tokio::test]
704 async fn drop_receiver() {
705 let exec = exec();
707
708 let barrier1_pre = Arc::new(AsyncBarrier::new(2));
710 let barrier1_pre_captured = Arc::clone(&barrier1_pre);
711 let barrier1_post = Arc::new(AsyncBarrier::new(2));
712 let barrier1_post_captured = Arc::clone(&barrier1_post);
713 let dedicated_task1 = exec.spawn(async move {
714 barrier1_pre_captured.wait().await;
715 do_work_async(11, barrier1_post_captured).await
716 });
717 barrier1_pre.wait().await;
718
719 let barrier2_pre = Arc::new(AsyncBarrier::new(2));
721 let barrier2_pre_captured = Arc::clone(&barrier2_pre);
722 let barrier2_post = Arc::new(AsyncBarrier::new(2));
723 let barrier2_post_captured = Arc::clone(&barrier2_post);
724 let dedicated_task2 = exec.spawn(async move {
725 barrier2_pre_captured.wait().await;
726 do_work_async(22, barrier2_post_captured).await
727 });
728 barrier2_pre.wait().await;
729
730 drop(dedicated_task1);
732
733 tokio::time::timeout(Duration::from_secs(1), async {
735 loop {
736 if Arc::strong_count(&barrier1_post) == 1 {
737 return;
738 }
739 tokio::time::sleep(Duration::from_millis(10)).await
740 }
741 })
742 .await
743 .unwrap();
744
745 barrier2_post.wait().await;
747 assert_eq!(dedicated_task2.await.unwrap(), 22);
748 tokio::time::timeout(Duration::from_secs(1), async {
749 loop {
750 if Arc::strong_count(&barrier2_post) == 1 {
751 return;
752 }
753 tokio::time::sleep(Duration::from_millis(10)).await
754 }
755 })
756 .await
757 .unwrap();
758
759 exec.join().await;
760 }
761
762 #[tokio::test]
763 async fn test_io_runtime_multi_thread() {
764 let mut runtime_builder = tokio::runtime::Builder::new_multi_thread();
765 runtime_builder.worker_threads(1);
766
767 let dedicated = DedicatedExecutor::new(
768 "Test DedicatedExecutor",
769 ExecutionConfig::default(),
770 runtime_builder,
771 );
772 test_io_runtime_multi_thread_impl(dedicated).await;
773 }
774
775 #[tokio::test]
776 async fn test_io_runtime_current_thread() {
777 let runtime_builder = tokio::runtime::Builder::new_current_thread();
778
779 let dedicated = DedicatedExecutor::new(
780 "Test DedicatedExecutor",
781 ExecutionConfig::default(),
782 runtime_builder,
783 );
784 test_io_runtime_multi_thread_impl(dedicated).await;
785 }
786
787 #[tokio::test]
788 async fn test_that_testing_executor_prevents_io() {
789 let exec = DedicatedExecutor::new_testing();
790
791 let io_disabled = exec
792 .spawn(async move {
793 TcpListener::bind("127.0.0.1:0")
795 .catch_unwind()
796 .await
797 .is_err()
798 })
799 .await
800 .unwrap();
801
802 assert!(io_disabled)
803 }
804}