rhizomedb_runtime/
tokio.rs

1use std::{
2    cell::RefCell,
3    io,
4    marker::PhantomData,
5    sync::{
6        atomic::{AtomicUsize, Ordering},
7        Arc,
8    },
9    thread,
10};
11
12use futures::{
13    channel::mpsc::{unbounded, UnboundedSender},
14    stream::StreamExt,
15    Future,
16};
17use once_cell::sync::Lazy;
18use tokio::task::{spawn_local, LocalSet};
19
20type SpawnTask = Box<dyn Send + FnOnce()>;
21
22static DEFAULT_WORKER_NAME: &str = "rhizomedb-runtime-worker";
23
24thread_local! {
25    static TASK_COUNT: RefCell<Option<Arc<AtomicUsize>>> = RefCell::new(None);
26    static LOCAL_SET: LocalSet = LocalSet::new()
27}
28
29#[derive(Clone)]
30pub struct LocalWorker {
31    task_count: Arc<AtomicUsize>,
32    tx: UnboundedSender<SpawnTask>,
33}
34
35impl LocalWorker {
36    pub fn new() -> io::Result<Self> {
37        let (tx, mut rx) = unbounded::<SpawnTask>();
38        let task_count: Arc<AtomicUsize> = Arc::default();
39
40        let rt = tokio::runtime::Builder::new_current_thread()
41            .enable_all()
42            .build()?;
43
44        {
45            let task_count = task_count.clone();
46            thread::Builder::new()
47                .name(DEFAULT_WORKER_NAME.into())
48                .spawn(move || {
49                    TASK_COUNT.with(move |m| {
50                        *m.borrow_mut() = Some(task_count);
51                    });
52
53                    LOCAL_SET.with(|local_set| {
54                        local_set.block_on(&rt, async move {
55                            while let Some(m) = rx.next().await {
56                                m();
57                            }
58                        });
59                    });
60                })?;
61        }
62
63        Ok(Self { task_count, tx })
64    }
65
66    pub fn task_count(&self) -> usize {
67        self.task_count.load(Ordering::Acquire)
68    }
69
70    pub fn spawn_pinned<F, Fut>(&self, f: F)
71    where
72        F: FnOnce() -> Fut,
73        F: Send + 'static,
74        Fut: 'static + Future<Output = ()>,
75    {
76        let guard = LocalJobCountGuard::new(self.task_count.clone());
77
78        // We ignore the result upon a failure, this can never happen unless the runtime is
79        // exiting which all instances of Runtime will be dropped at that time and hence cannot
80        // spawn pinned tasks.
81        let _ = self.tx.unbounded_send(Box::new(move || {
82            spawn_local(async move {
83                let _guard = guard;
84
85                f().await;
86            });
87        }));
88    }
89}
90
91pub struct LocalJobCountGuard(Arc<AtomicUsize>);
92
93impl LocalJobCountGuard {
94    fn new(inner: Arc<AtomicUsize>) -> Self {
95        inner.fetch_add(1, Ordering::AcqRel);
96
97        LocalJobCountGuard(inner)
98    }
99}
100
101impl Drop for LocalJobCountGuard {
102    fn drop(&mut self) {
103        self.0.fetch_sub(1, Ordering::AcqRel);
104    }
105}
106
107#[derive(Clone)]
108pub struct Runtime {
109    workers: Arc<Vec<LocalWorker>>,
110}
111
112impl Runtime {
113    pub fn new(num_workers: usize) -> io::Result<Self> {
114        assert!(num_workers > 0, "must have more than 1 worker.");
115
116        let mut workers = Vec::with_capacity(num_workers);
117
118        for _ in 0..num_workers {
119            let worker = LocalWorker::new()?;
120            workers.push(worker);
121        }
122
123        Ok(Self {
124            workers: workers.into(),
125        })
126    }
127
128    pub fn spawn_local<F>(f: F)
129    where
130        F: Future<Output = ()> + 'static,
131    {
132        match LocalHandle::try_current() {
133            Some(m) => {
134                m.spawn_local(f);
135            }
136            None => {
137                tokio::task::spawn_local(f);
138            }
139        }
140    }
141
142    pub fn spawn_pinned<F, Fut>(&self, create_task: F)
143    where
144        F: FnOnce() -> Fut,
145        F: Send + 'static,
146        Fut: futures::Future<Output = ()> + 'static,
147    {
148        let worker = self.find_least_busy_local_worker();
149        worker.spawn_pinned(create_task);
150    }
151
152    fn find_least_busy_local_worker(&self) -> &LocalWorker {
153        let mut workers = self.workers.iter();
154
155        let mut worker = workers.next().expect("must have more than 1 worker.");
156        let mut task_count = worker.task_count();
157
158        for current_worker in workers {
159            if task_count == 0 {
160                break;
161            }
162
163            let current_worker_task_count = current_worker.task_count();
164
165            if current_worker_task_count < task_count {
166                task_count = current_worker_task_count;
167                worker = current_worker;
168            }
169        }
170
171        worker
172    }
173}
174
175impl Default for Runtime {
176    fn default() -> Self {
177        static DEFAULT_RT: Lazy<Runtime> =
178            Lazy::new(|| Runtime::new(num_cpus::get()).expect("failed to create runtime."));
179
180        DEFAULT_RT.clone()
181    }
182}
183
184#[derive(Debug, Clone)]
185pub struct LocalHandle {
186    _marker: PhantomData<*const ()>,
187    task_count: Arc<AtomicUsize>,
188}
189
190impl LocalHandle {
191    pub fn current() -> Self {
192        Self::try_current().expect("outside of runtime.")
193    }
194
195    fn try_current() -> Option<Self> {
196        // We cache the handle to prevent borrowing RefCell.
197        thread_local! {
198            static LOCAL_HANDLE: Option<LocalHandle> = TASK_COUNT
199                .with(|m| m.borrow().clone())
200                .map(|task_count| LocalHandle { task_count, _marker: PhantomData });
201        }
202
203        LOCAL_HANDLE.with(|m| m.clone())
204    }
205
206    pub fn spawn_local<F>(&self, f: F)
207    where
208        F: Future<Output = ()> + 'static,
209    {
210        let guard = LocalJobCountGuard::new(self.task_count.clone());
211
212        LOCAL_SET.with(move |local_set| {
213            local_set.spawn_local(async move {
214                let _guard = guard;
215
216                f.await
217            })
218        });
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use std::time::Duration;
225
226    use futures::channel::oneshot;
227    use tokio::{sync::Barrier, test, time::timeout};
228
229    use super::*;
230
231    #[test]
232    async fn test_spawn_pinned_least_busy() {
233        let runtime = Runtime::new(2).expect("failed to create runtime.");
234
235        let (tx1, rx1) = oneshot::channel();
236        let (tx2, rx2) = oneshot::channel();
237
238        let barrier = Arc::new(Barrier::new(2));
239
240        {
241            let barrier = barrier.clone();
242            runtime.spawn_pinned(move || async move {
243                barrier.wait().await;
244
245                tx1.send(std::thread::current().id())
246                    .expect("failed to send!");
247            });
248        }
249
250        runtime.spawn_pinned(move || async move {
251            barrier.wait().await;
252
253            tx2.send(std::thread::current().id())
254                .expect("failed to send!");
255        });
256
257        let result1 = timeout(Duration::from_secs(5), rx1)
258            .await
259            .expect("task timed out.")
260            .expect("failed to receive.");
261
262        let result2 = timeout(Duration::from_secs(5), rx2)
263            .await
264            .expect("task timed out.")
265            .expect("failed to receive.");
266
267        // first task and second task are not on the same thread.
268        assert_ne!(result1, result2);
269    }
270
271    #[test]
272    async fn test_spawn_local_within_send() {
273        let runtime = Runtime::default();
274
275        let (tx, rx) = oneshot::channel();
276
277        runtime.spawn_pinned(move || async move {
278            tokio::task::spawn(async move {
279                Runtime::spawn_local(async move {
280                    tx.send(()).expect("failed to send!");
281                })
282            });
283        });
284
285        timeout(Duration::from_secs(5), rx)
286            .await
287            .expect("task timed out.")
288            .expect("failed to receive.");
289    }
290}