1use std::{
2 future::Future,
3 sync::{Arc, OnceLock},
4};
5
6use rand::Rng;
7use tokio::task::JoinHandle;
8
9use crate::worker::{LevelWorker, LevelWorkerHandle};
10
11static GLOBAL_RUNTIME: OnceLock<LevelRuntimeHandle> = OnceLock::new();
12
13#[track_caller]
19pub fn spawn_balanced<F>(future: F) -> JoinHandle<F::Output>
20where
21 F: Future + Send + 'static,
22 F::Output: Send + 'static,
23{
24 match GLOBAL_RUNTIME
25 .get()
26 .map(|worker| worker.spawn_balanced(future))
27 {
28 Some(handle) => handle,
29 None => {
30 panic!("spawn_balanced can only be called in processes running a default level worker")
31 }
32 }
33}
34
35#[track_caller]
39pub fn spawn_on_each<F>(future: impl Fn() -> F) -> Vec<JoinHandle<F::Output>>
40where
41 F: Future + Send + 'static,
42 F::Output: Send + 'static,
43{
44 match GLOBAL_RUNTIME
45 .get()
46 .map(|worker| worker.spawn_on_each(future))
47 {
48 Some(handle) => handle,
49 None => {
50 panic!("spawn_balanced can only be called in processes running a default level worker")
51 }
52 }
53}
54
55pub struct LevelRuntime {
59 workers: Vec<Arc<LevelWorker>>,
60 thread_name: std::sync::Arc<dyn Fn() -> String + Send + Sync + 'static>,
61}
62impl std::fmt::Debug for LevelRuntime {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 f.debug_struct("LevelRuntime")
65 .field("workers", &self.workers)
66 .finish()
67 }
68}
69impl LevelRuntime {
70 pub(crate) fn from_workers(
71 thread_name: std::sync::Arc<dyn Fn() -> String + Send + Sync + 'static>,
72 workers: Vec<LevelWorker>,
73 ) -> Self {
74 Self {
75 workers: workers.into_iter().map(Arc::new).collect(),
76 thread_name,
77 }
78 }
79
80 pub fn set_default(&self) {
83 if GLOBAL_RUNTIME.set(self.handle()).is_err() {
84 panic!("must only set one default level runtime per process")
85 }
86 }
87
88 pub fn handle(&self) -> LevelRuntimeHandle {
90 LevelRuntimeHandle {
91 workers: self.workers.iter().map(|w| w.handle()).collect(),
92 }
93 }
94
95 pub fn run(self) {
97 self.run_with_termination(std::future::pending());
98 }
99
100 pub fn start(&self) {
102 let _handles: Vec<_> = self
103 .workers
104 .iter()
105 .map(|worker| {
106 let name = (self.thread_name)();
107 let worker = worker.clone();
108 std::thread::Builder::new()
109 .name(name)
110 .spawn(move || {
111 worker.run(std::future::pending());
112 })
113 .expect("must be able to spawn level worker")
114 })
115 .collect();
116 }
117
118 pub fn run_with_termination(self, termination: impl Future<Output = ()> + Send + 'static) {
120 let termination = futures::FutureExt::shared(termination);
121 let handles: Vec<_> = self
122 .workers
123 .into_iter()
124 .map(|worker| {
125 let termination = termination.clone();
126 std::thread::Builder::new()
127 .name((self.thread_name)())
128 .spawn(move || {
129 worker.run(termination);
130 })
131 .expect("must be able to spawn level worker")
132 })
133 .collect();
134 for handle in handles {
135 let _ = handle.join();
136 }
137 }
138}
139
140#[derive(Clone, Debug)]
142pub struct LevelRuntimeHandle {
143 workers: Vec<LevelWorkerHandle>,
144}
145
146impl LevelRuntimeHandle {
147 #[track_caller]
150 pub fn spawn_balanced<F>(&self, future: F) -> JoinHandle<F::Output>
151 where
152 F: Future + Send + 'static,
153 F::Output: Send + 'static,
154 {
155 balance_workers(&self.workers).spawn_local(future)
156 }
157
158 #[track_caller]
160 pub fn pick_worker(&self) -> &LevelWorkerHandle {
161 balance_workers(&self.workers)
162 }
163
164 #[track_caller]
167 pub fn spawn_on_each<F>(&self, future: impl Fn() -> F) -> Vec<JoinHandle<F::Output>>
168 where
169 F: Future + Send + 'static,
170 F::Output: Send + 'static,
171 {
172 self.workers
173 .iter()
174 .map(|worker| worker.spawn_local(future()))
175 .collect()
176 }
177}
178
179#[track_caller]
180fn balance_workers(workers: &[LevelWorkerHandle]) -> &LevelWorkerHandle {
181 let mut rng = rand::rng();
182 let a = rng.random_range(0..(workers.len()));
183 let b = rng.random_range(0..(workers.len()));
184 log::info!(
185 "{a}: {}, {b}: {}",
186 workers[a].concurrency(),
187 workers[b].concurrency()
188 );
189 if workers[a].concurrency() < workers[b].concurrency() {
190 &workers[a]
191 } else {
192 &workers[b]
193 }
194}