1use std::{
2 cell::RefCell,
3 future::Future,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use tokio::task::JoinHandle;
8
9use crate::concurrency_tracker::ConcurrencyTracker;
10
11thread_local! {
12 static LOCAL_RUNTIME: RefCell<Option<LevelWorkerHandle>> = const { RefCell::new(None) }
13}
14
15#[track_caller]
24pub fn spawn_local<F>(future: F) -> JoinHandle<F::Output>
25where
26 F: Future + Send + 'static,
27 F::Output: Send + 'static,
28{
29 match LOCAL_RUNTIME.try_with(|context| {
30 context
31 .borrow()
32 .as_ref()
33 .map(|worker| worker.spawn_local(future))
34 }) {
35 Ok(Some(handle)) => handle,
36 Ok(None) => panic!("spawn_local can only be called on threads running a level worker"),
37 Err(_access_error) => panic!("spawn_local called on a destroyed thread context"),
38 }
39}
40
41pub struct LevelWorker {
43 runtime: tokio::runtime::Runtime,
44 concurrency: Arc<AtomicUsize>,
45}
46impl std::fmt::Debug for LevelWorker {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("LevelWorker").finish()
49 }
50}
51impl LevelWorker {
52 pub(crate) fn from_runtime(runtime: tokio::runtime::Runtime) -> Self {
53 Self {
54 runtime,
55 concurrency: Default::default(),
56 }
57 }
58
59 pub(crate) fn run(self: Arc<Self>, termination: impl Future<Output = ()>) {
60 LOCAL_RUNTIME.with(|none| {
61 assert!(none.borrow().is_none(), "you can't run twice!");
62 none.replace(Some(self.handle()));
63 });
64 assert_eq!(
65 LOCAL_RUNTIME.try_with(|a| { a.borrow().is_some() }),
66 Ok(true)
67 );
68 self.runtime.block_on(termination);
69 log::debug!("runtime ended");
70 }
71
72 pub fn handle(&self) -> LevelWorkerHandle {
74 LevelWorkerHandle {
75 handle: self.runtime.handle().clone(),
76 concurrency: self.concurrency.clone(),
77 }
78 }
79}
80
81#[derive(Clone)]
83pub struct LevelWorkerHandle {
84 handle: tokio::runtime::Handle,
85 concurrency: Arc<AtomicUsize>,
86}
87impl std::fmt::Debug for LevelWorkerHandle {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("LevelWorkerHandle")
90 .field("concurrency", &self.concurrency)
91 .finish()
92 }
93}
94
95impl LevelWorkerHandle {
96 #[track_caller]
98 pub fn spawn_local<F>(&self, future: F) -> JoinHandle<F::Output>
99 where
100 F: Future + Send + 'static,
101 F::Output: Send + 'static,
102 {
103 self.handle
104 .spawn(ConcurrencyTracker::wrap(self.concurrency.clone(), future))
105 }
106
107 #[track_caller]
109 pub fn block_on<F>(&self, future: F) -> F::Output
110 where
111 F: Future,
112 {
113 self.handle.block_on(future)
114 }
115
116 pub fn concurrency(&self) -> usize {
118 self.concurrency.load(std::sync::atomic::Ordering::Relaxed)
119 }
120}