Skip to main content

level_runtime/
runtime.rs

1use std::{
2    future::Future,
3    sync::{Arc, OnceLock},
4};
5
6use rand::Rng;
7use tokio::task::JoinHandle;
8
9use crate::worker::{LevelWorker, LevelWorkerHandle};
10
11static GLOBAL_RUNTIME: OnceLock<LevelRuntimeHandle> = OnceLock::new();
12
13/// Spawn a task on one of the local runtimes, with a work balance heuristic.
14///
15/// Use this when you can do your work on any runtime and it does not matter which.
16/// Be mindful of cross-runtime synchronization and await, which is more costly in
17/// a level runtime than a regular tokio multi-thread runtime.
18#[track_caller]
19pub fn spawn_balanced<F>(future: F) -> JoinHandle<F::Output>
20where
21    F: Future + Send + 'static,
22    F::Output: Send + 'static,
23{
24    match GLOBAL_RUNTIME
25        .get()
26        .map(|worker| worker.spawn_balanced(future))
27    {
28        Some(handle) => handle,
29        None => {
30            panic!("spawn_balanced can only be called in processes running a default level worker")
31        }
32    }
33}
34
35/// Spawn a copy of a task on each local runtime.
36///
37/// Use this to place a server listener on each of your thread local threads.
38#[track_caller]
39pub fn spawn_on_each<F>(future: impl Fn() -> F) -> Vec<JoinHandle<F::Output>>
40where
41    F: Future + Send + 'static,
42    F::Output: Send + 'static,
43{
44    match GLOBAL_RUNTIME
45        .get()
46        .map(|worker| worker.spawn_on_each(future))
47    {
48        Some(handle) => handle,
49        None => {
50            panic!("spawn_balanced can only be called in processes running a default level worker")
51        }
52    }
53}
54
55/// A wrapper for a collection of tokio current-thread runtimes.
56///
57/// It offers a load leveling heuristic, but not work stealing.
58pub struct LevelRuntime {
59    workers: Vec<Arc<LevelWorker>>,
60    thread_name: std::sync::Arc<dyn Fn() -> String + Send + Sync + 'static>,
61}
62impl std::fmt::Debug for LevelRuntime {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        f.debug_struct("LevelRuntime")
65            .field("workers", &self.workers)
66            .finish()
67    }
68}
69impl LevelRuntime {
70    pub(crate) fn from_workers(
71        thread_name: std::sync::Arc<dyn Fn() -> String + Send + Sync + 'static>,
72        workers: Vec<LevelWorker>,
73    ) -> Self {
74        Self {
75            workers: workers.into_iter().map(Arc::new).collect(),
76            thread_name,
77        }
78    }
79
80    /// Register this LevelRuntime as the default runtime for the static
81    /// `level_runtime::spawn*` family functions.
82    pub fn set_default(&self) {
83        if GLOBAL_RUNTIME.set(self.handle()).is_err() {
84            panic!("must only set one default level runtime per process")
85        }
86    }
87
88    /// Get a spawn handle for the runtime.
89    pub fn handle(&self) -> LevelRuntimeHandle {
90        LevelRuntimeHandle {
91            workers: self.workers.iter().map(|w| w.handle()).collect(),
92        }
93    }
94
95    /// Execute this runtime.
96    pub fn run(self) {
97        self.run_with_termination(std::future::pending());
98    }
99
100    /// Start this runtime in the background.
101    pub fn start(&self) {
102        let _handles: Vec<_> = self
103            .workers
104            .iter()
105            .map(|worker| {
106                let name = (self.thread_name)();
107                let worker = worker.clone();
108                std::thread::Builder::new()
109                    .name(name)
110                    .spawn(move || {
111                        worker.run(std::future::pending());
112                    })
113                    .expect("must be able to spawn level worker")
114            })
115            .collect();
116    }
117
118    /// Execute this runtime. When `termination` completes, the backing executors will close.
119    pub fn run_with_termination(self, termination: impl Future<Output = ()> + Send + 'static) {
120        let termination = futures::FutureExt::shared(termination);
121        let handles: Vec<_> = self
122            .workers
123            .into_iter()
124            .map(|worker| {
125                let termination = termination.clone();
126                std::thread::Builder::new()
127                    .name((self.thread_name)())
128                    .spawn(move || {
129                        worker.run(termination);
130                    })
131                    .expect("must be able to spawn level worker")
132            })
133            .collect();
134        for handle in handles {
135            let _ = handle.join();
136        }
137    }
138}
139
140/// A handle to a LevelRuntime for spawning tasks.
141#[derive(Clone, Debug)]
142pub struct LevelRuntimeHandle {
143    workers: Vec<LevelWorkerHandle>,
144}
145
146impl LevelRuntimeHandle {
147    /// Spawn this future on one of the workers; choose which one based
148    /// on a load heuristic.
149    #[track_caller]
150    pub fn spawn_balanced<F>(&self, future: F) -> JoinHandle<F::Output>
151    where
152        F: Future + Send + 'static,
153        F::Output: Send + 'static,
154    {
155        balance_workers(&self.workers).spawn_local(future)
156    }
157
158    /// Get one consistent worker handle
159    #[track_caller]
160    pub fn pick_worker(&self) -> &LevelWorkerHandle {
161        balance_workers(&self.workers)
162    }
163
164    /// Spawn a copy of this future on each runtime. Do this for server listeners
165    /// using SO_REUSEADDR.
166    #[track_caller]
167    pub fn spawn_on_each<F>(&self, future: impl Fn() -> F) -> Vec<JoinHandle<F::Output>>
168    where
169        F: Future + Send + 'static,
170        F::Output: Send + 'static,
171    {
172        self.workers
173            .iter()
174            .map(|worker| worker.spawn_local(future()))
175            .collect()
176    }
177}
178
179#[track_caller]
180fn balance_workers(workers: &[LevelWorkerHandle]) -> &LevelWorkerHandle {
181    let mut rng = rand::rng();
182    let a = rng.random_range(0..(workers.len()));
183    let b = rng.random_range(0..(workers.len()));
184    log::info!(
185        "{a}: {}, {b}: {}",
186        workers[a].concurrency(),
187        workers[b].concurrency()
188    );
189    if workers[a].concurrency() < workers[b].concurrency() {
190        &workers[a]
191    } else {
192        &workers[b]
193    }
194}