rhizomedb_runtime/
tokio.rs1use std::{
2 cell::RefCell,
3 io,
4 marker::PhantomData,
5 sync::{
6 atomic::{AtomicUsize, Ordering},
7 Arc,
8 },
9 thread,
10};
11
12use futures::{
13 channel::mpsc::{unbounded, UnboundedSender},
14 stream::StreamExt,
15 Future,
16};
17use once_cell::sync::Lazy;
18use tokio::task::{spawn_local, LocalSet};
19
20type SpawnTask = Box<dyn Send + FnOnce()>;
21
22static DEFAULT_WORKER_NAME: &str = "rhizomedb-runtime-worker";
23
24thread_local! {
25 static TASK_COUNT: RefCell<Option<Arc<AtomicUsize>>> = RefCell::new(None);
26 static LOCAL_SET: LocalSet = LocalSet::new()
27}
28
29#[derive(Clone)]
30pub struct LocalWorker {
31 task_count: Arc<AtomicUsize>,
32 tx: UnboundedSender<SpawnTask>,
33}
34
35impl LocalWorker {
36 pub fn new() -> io::Result<Self> {
37 let (tx, mut rx) = unbounded::<SpawnTask>();
38 let task_count: Arc<AtomicUsize> = Arc::default();
39
40 let rt = tokio::runtime::Builder::new_current_thread()
41 .enable_all()
42 .build()?;
43
44 {
45 let task_count = task_count.clone();
46 thread::Builder::new()
47 .name(DEFAULT_WORKER_NAME.into())
48 .spawn(move || {
49 TASK_COUNT.with(move |m| {
50 *m.borrow_mut() = Some(task_count);
51 });
52
53 LOCAL_SET.with(|local_set| {
54 local_set.block_on(&rt, async move {
55 while let Some(m) = rx.next().await {
56 m();
57 }
58 });
59 });
60 })?;
61 }
62
63 Ok(Self { task_count, tx })
64 }
65
66 pub fn task_count(&self) -> usize {
67 self.task_count.load(Ordering::Acquire)
68 }
69
70 pub fn spawn_pinned<F, Fut>(&self, f: F)
71 where
72 F: FnOnce() -> Fut,
73 F: Send + 'static,
74 Fut: 'static + Future<Output = ()>,
75 {
76 let guard = LocalJobCountGuard::new(self.task_count.clone());
77
78 let _ = self.tx.unbounded_send(Box::new(move || {
82 spawn_local(async move {
83 let _guard = guard;
84
85 f().await;
86 });
87 }));
88 }
89}
90
91pub struct LocalJobCountGuard(Arc<AtomicUsize>);
92
93impl LocalJobCountGuard {
94 fn new(inner: Arc<AtomicUsize>) -> Self {
95 inner.fetch_add(1, Ordering::AcqRel);
96
97 LocalJobCountGuard(inner)
98 }
99}
100
101impl Drop for LocalJobCountGuard {
102 fn drop(&mut self) {
103 self.0.fetch_sub(1, Ordering::AcqRel);
104 }
105}
106
107#[derive(Clone)]
108pub struct Runtime {
109 workers: Arc<Vec<LocalWorker>>,
110}
111
112impl Runtime {
113 pub fn new(num_workers: usize) -> io::Result<Self> {
114 assert!(num_workers > 0, "must have more than 1 worker.");
115
116 let mut workers = Vec::with_capacity(num_workers);
117
118 for _ in 0..num_workers {
119 let worker = LocalWorker::new()?;
120 workers.push(worker);
121 }
122
123 Ok(Self {
124 workers: workers.into(),
125 })
126 }
127
128 pub fn spawn_local<F>(f: F)
129 where
130 F: Future<Output = ()> + 'static,
131 {
132 match LocalHandle::try_current() {
133 Some(m) => {
134 m.spawn_local(f);
135 }
136 None => {
137 tokio::task::spawn_local(f);
138 }
139 }
140 }
141
142 pub fn spawn_pinned<F, Fut>(&self, create_task: F)
143 where
144 F: FnOnce() -> Fut,
145 F: Send + 'static,
146 Fut: futures::Future<Output = ()> + 'static,
147 {
148 let worker = self.find_least_busy_local_worker();
149 worker.spawn_pinned(create_task);
150 }
151
152 fn find_least_busy_local_worker(&self) -> &LocalWorker {
153 let mut workers = self.workers.iter();
154
155 let mut worker = workers.next().expect("must have more than 1 worker.");
156 let mut task_count = worker.task_count();
157
158 for current_worker in workers {
159 if task_count == 0 {
160 break;
161 }
162
163 let current_worker_task_count = current_worker.task_count();
164
165 if current_worker_task_count < task_count {
166 task_count = current_worker_task_count;
167 worker = current_worker;
168 }
169 }
170
171 worker
172 }
173}
174
175impl Default for Runtime {
176 fn default() -> Self {
177 static DEFAULT_RT: Lazy<Runtime> =
178 Lazy::new(|| Runtime::new(num_cpus::get()).expect("failed to create runtime."));
179
180 DEFAULT_RT.clone()
181 }
182}
183
184#[derive(Debug, Clone)]
185pub struct LocalHandle {
186 _marker: PhantomData<*const ()>,
187 task_count: Arc<AtomicUsize>,
188}
189
190impl LocalHandle {
191 pub fn current() -> Self {
192 Self::try_current().expect("outside of runtime.")
193 }
194
195 fn try_current() -> Option<Self> {
196 thread_local! {
198 static LOCAL_HANDLE: Option<LocalHandle> = TASK_COUNT
199 .with(|m| m.borrow().clone())
200 .map(|task_count| LocalHandle { task_count, _marker: PhantomData });
201 }
202
203 LOCAL_HANDLE.with(|m| m.clone())
204 }
205
206 pub fn spawn_local<F>(&self, f: F)
207 where
208 F: Future<Output = ()> + 'static,
209 {
210 let guard = LocalJobCountGuard::new(self.task_count.clone());
211
212 LOCAL_SET.with(move |local_set| {
213 local_set.spawn_local(async move {
214 let _guard = guard;
215
216 f.await
217 })
218 });
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use std::time::Duration;
225
226 use futures::channel::oneshot;
227 use tokio::{sync::Barrier, test, time::timeout};
228
229 use super::*;
230
231 #[test]
232 async fn test_spawn_pinned_least_busy() {
233 let runtime = Runtime::new(2).expect("failed to create runtime.");
234
235 let (tx1, rx1) = oneshot::channel();
236 let (tx2, rx2) = oneshot::channel();
237
238 let barrier = Arc::new(Barrier::new(2));
239
240 {
241 let barrier = barrier.clone();
242 runtime.spawn_pinned(move || async move {
243 barrier.wait().await;
244
245 tx1.send(std::thread::current().id())
246 .expect("failed to send!");
247 });
248 }
249
250 runtime.spawn_pinned(move || async move {
251 barrier.wait().await;
252
253 tx2.send(std::thread::current().id())
254 .expect("failed to send!");
255 });
256
257 let result1 = timeout(Duration::from_secs(5), rx1)
258 .await
259 .expect("task timed out.")
260 .expect("failed to receive.");
261
262 let result2 = timeout(Duration::from_secs(5), rx2)
263 .await
264 .expect("task timed out.")
265 .expect("failed to receive.");
266
267 assert_ne!(result1, result2);
269 }
270
271 #[test]
272 async fn test_spawn_local_within_send() {
273 let runtime = Runtime::default();
274
275 let (tx, rx) = oneshot::channel();
276
277 runtime.spawn_pinned(move || async move {
278 tokio::task::spawn(async move {
279 Runtime::spawn_local(async move {
280 tx.send(()).expect("failed to send!");
281 })
282 });
283 });
284
285 timeout(Duration::from_secs(5), rx)
286 .await
287 .expect("task timed out.")
288 .expect("failed to receive.");
289 }
290}