Skip to main content

level_runtime/
worker.rs

1use std::{
2    cell::RefCell,
3    future::Future,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use tokio::task::JoinHandle;
8
9use crate::concurrency_tracker::ConcurrencyTracker;
10
11thread_local! {
12    static LOCAL_RUNTIME: RefCell<Option<LevelWorkerHandle>> = const { RefCell::new(None) }
13}
14
15/// Spawn this future on this local thread runtime.
16///
17/// You should generally replace your `tokio::spawn` with this. By using
18/// `spawn_local` instead of `tokio::spawn`, you give the load heuristic
19/// more information.
20///
21/// If your work truly does not have a thread affinity consideration, use
22/// `spawn_balanced` instead.
23#[track_caller]
24pub fn spawn_local<F>(future: F) -> JoinHandle<F::Output>
25where
26    F: Future + Send + 'static,
27    F::Output: Send + 'static,
28{
29    match LOCAL_RUNTIME.try_with(|context| {
30        context
31            .borrow()
32            .as_ref()
33            .map(|worker| worker.spawn_local(future))
34    }) {
35        Ok(Some(handle)) => handle,
36        Ok(None) => panic!("spawn_local can only be called on threads running a level worker"),
37        Err(_access_error) => panic!("spawn_local called on a destroyed thread context"),
38    }
39}
40
41/// A thread-local runtime wrapper
42pub struct LevelWorker {
43    runtime: tokio::runtime::Runtime,
44    concurrency: Arc<AtomicUsize>,
45}
46impl std::fmt::Debug for LevelWorker {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("LevelWorker").finish()
49    }
50}
51impl LevelWorker {
52    pub(crate) fn from_runtime(runtime: tokio::runtime::Runtime) -> Self {
53        Self {
54            runtime,
55            concurrency: Default::default(),
56        }
57    }
58
59    pub(crate) fn run(self: Arc<Self>, termination: impl Future<Output = ()>) {
60        LOCAL_RUNTIME.with(|none| {
61            assert!(none.borrow().is_none(), "you can't run twice!");
62            none.replace(Some(self.handle()));
63        });
64        assert_eq!(
65            LOCAL_RUNTIME.try_with(|a| { a.borrow().is_some() }),
66            Ok(true)
67        );
68        self.runtime.block_on(termination);
69        log::debug!("runtime ended");
70    }
71
72    /// A spawn handle for the local runtime
73    pub fn handle(&self) -> LevelWorkerHandle {
74        LevelWorkerHandle {
75            handle: self.runtime.handle().clone(),
76            concurrency: self.concurrency.clone(),
77        }
78    }
79}
80
81/// A spawn handle for the local runtime
82#[derive(Clone)]
83pub struct LevelWorkerHandle {
84    handle: tokio::runtime::Handle,
85    concurrency: Arc<AtomicUsize>,
86}
87impl std::fmt::Debug for LevelWorkerHandle {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("LevelWorkerHandle")
90            .field("concurrency", &self.concurrency)
91            .finish()
92    }
93}
94
95impl LevelWorkerHandle {
96    /// Spawn the future on this thread's local runtime
97    #[track_caller]
98    pub fn spawn_local<F>(&self, future: F) -> JoinHandle<F::Output>
99    where
100        F: Future + Send + 'static,
101        F::Output: Send + 'static,
102    {
103        self.handle
104            .spawn(ConcurrencyTracker::wrap(self.concurrency.clone(), future))
105    }
106
107    /// Run the future on this thread's local runtime
108    #[track_caller]
109    pub fn block_on<F>(&self, future: F) -> F::Output
110    where
111        F: Future,
112    {
113        self.handle.block_on(future)
114    }
115
116    /// How many tasks are currently spawned on this level worker?
117    pub fn concurrency(&self) -> usize {
118        self.concurrency.load(std::sync::atomic::Ordering::Relaxed)
119    }
120}