kioto_uring_executor/
runtime.rs

1use std::cell::RefCell;
2use std::future::Future;
3use std::num::NonZeroUsize;
4use std::pin::Pin;
5use std::sync::mpsc as std_mpsc;
6use std::sync::Arc;
7
8use parking_lot::RwLock;
9
10use rand::Rng;
11
12use tokio::sync::mpsc;
13
14pub struct Task {
15    future: Pin<Box<dyn Future<Output = ()> + 'static>>,
16}
17
18unsafe impl Send for Task {}
19
20pub type TaskSender = mpsc::UnboundedSender<Task>;
21
22thread_local! {
23    pub(super) static ACTIVE_RUNTIME: RefCell<Option<Arc<RuntimeInner>>> = const { RefCell::new(None) };
24}
25
26pub(super) struct RuntimeInner {
27    task_senders: RwLock<Vec<TaskSender>>,
28}
29
30pub struct Runtime {
31    inner: Arc<RuntimeInner>,
32}
33
34impl Default for Runtime {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40pub struct SpawnRing {
41    inner: Arc<RuntimeInner>,
42    thread_idx: usize,
43}
44
45impl SpawnRing {
46    pub(super) fn new(inner: Arc<RuntimeInner>) -> Self {
47        Self {
48            inner,
49            thread_idx: 0,
50        }
51    }
52
53    pub fn get(&self) -> usize {
54        self.thread_idx
55    }
56
57    pub fn advance(&mut self) {
58        let num_worker_threads = self.inner.get_num_threads();
59        self.thread_idx = (self.thread_idx + 1) % num_worker_threads;
60    }
61}
62
63impl Runtime {
64    pub fn new() -> Self {
65        let thread_count = std::thread::available_parallelism().unwrap();
66        Self::new_with_threads(thread_count)
67    }
68
69    pub fn new_with_threads(num_os_threads: NonZeroUsize) -> Self {
70        let num_os_threads = num_os_threads.get();
71        log::info!("Initialized tokio runtime with {num_os_threads} worker thread(s)");
72        let inner = Arc::new(RuntimeInner {
73            task_senders: Default::default(),
74        });
75
76        for _ in 0..num_os_threads {
77            let (sender, mut receiver) = mpsc::unbounded_channel::<Task>();
78            inner.task_senders.write().push(sender);
79            let inner = inner.clone();
80
81            std::thread::spawn(move || {
82                ACTIVE_RUNTIME.with_borrow_mut(|r| {
83                    *r = Some(inner);
84                });
85                tokio_uring::start(async {
86                    while let Some(task) = receiver.recv().await {
87                        tokio_uring::spawn(task.future);
88                    }
89                });
90            });
91        }
92
93        Self { inner }
94    }
95
96    /// Blocks the current thread until the runtime has finished th task
97    pub fn block_on<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
98        &self,
99        task: F,
100    ) -> T {
101        self.inner.block_on(task)
102    }
103
104    /// Blocks the current thread until the runtime has finished th task (unsafe version)
105    ///
106    /// # Safety
107    /// Make sure task is Send before polled for the first time
108    /// (Can be not Send afterwards)
109    pub unsafe fn unsafe_block_on<T: Send + 'static, F: Future<Output = T> + 'static>(
110        &self,
111        task: F,
112    ) -> T {
113        self.inner.unsafe_block_on(task)
114    }
115
116    /// Spawns the task on a random thread
117    pub fn spawn<F: Future<Output = ()> + Send + 'static>(&self, task: F) {
118        self.inner.spawn(task)
119    }
120
121    /// How many worker threads are there?
122    pub fn get_num_threads(&self) -> usize {
123        self.inner.get_num_threads()
124    }
125
126    /// Spawns the task on a specific thread
127    pub fn spawn_at<F: Future<Output = ()> + Send + 'static>(&self, offset: usize, task: F) {
128        self.inner.spawn_at(offset, task)
129    }
130
131    /// # Safety
132    ///
133    /// Make sure task is Send before polled for the first time
134    /// (Can be not Send afterwards)
135    pub unsafe fn unsafe_spawn_at<F: Future<Output = ()> + 'static>(&self, offset: usize, task: F) {
136        self.inner.unsafe_spawn_at(offset, task)
137    }
138
139    /// # Safety
140    ///
141    /// Make sure task is Send before polled for the first time
142    /// (Can be not Send afterwards)
143    pub unsafe fn unsafe_spawn<F: Future<Output = ()> + 'static>(&self, task: F) {
144        self.inner.unsafe_spawn(task)
145    }
146
147    /// Create a primitive that lets you distribute tasks
148    /// across worker threads in a round-robin fashion
149    pub fn new_spawn_ring(&self) -> SpawnRing {
150        SpawnRing::new(self.inner.clone())
151    }
152}
153
154impl Drop for Runtime {
155    fn drop(&mut self) {
156        *self.inner.task_senders.write() = vec![];
157    }
158}
159
160impl RuntimeInner {
161    pub fn spawn<F: Future<Output = ()> + Send + 'static>(&self, task: F) {
162        let task = Task {
163            future: Box::pin(task),
164        };
165
166        let senders = self.task_senders.read();
167        if senders.is_empty() {
168            panic!("Executor not set up yet!");
169        }
170
171        let idx = rand::thread_rng().gen_range(0..senders.len());
172        if let Err(err) = senders[idx].send(task) {
173            panic!("Failed to spawn task: {err}");
174        }
175    }
176
177    /// Spawns the task on a specific thread
178    pub fn spawn_at<F: Future<Output = ()> + Send + 'static>(&self, offset: usize, task: F) {
179        let task = Task {
180            future: Box::pin(task),
181        };
182
183        let senders = self.task_senders.read();
184        if senders.is_empty() {
185            panic!("Executor not set up yet!");
186        }
187
188        let idx = offset % senders.len();
189        if let Err(err) = senders[idx].send(task) {
190            panic!("Failed to spawn task: {err}");
191        }
192    }
193
194    /// Blocks the current thread until the runtime has finished th task
195    pub fn block_on<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
196        &self,
197        task: F,
198    ) -> T {
199        let (sender, receiver) = std_mpsc::channel();
200
201        self.spawn(async move {
202            let res = task.await;
203            sender.send(res).expect("Notification failed");
204        });
205
206        receiver.recv().expect("Failed to wait for task")
207    }
208
209    /// Blocks the current thread until the runtime has finished th task (unsafe version)
210    ///
211    /// # Safety
212    /// Make sure task is Send before polled for the first time
213    /// (Can be not Send afterwards)
214    pub unsafe fn unsafe_block_on<T: Send + 'static, F: Future<Output = T> + 'static>(
215        &self,
216        task: F,
217    ) -> T {
218        let (sender, receiver) = std_mpsc::channel();
219
220        self.unsafe_spawn(async move {
221            let res = task.await;
222            sender.send(res).expect("Notification failed");
223        });
224
225        receiver.recv().expect("Failed to wait for task")
226    }
227
228    /// # Safety
229    ///
230    /// Make sure task is Send before polled for the first time
231    /// (Can be not Send afterwards)
232    pub unsafe fn unsafe_spawn_at<F: Future<Output = ()> + 'static>(&self, offset: usize, task: F) {
233        let task = Task {
234            future: Box::pin(task),
235        };
236
237        let senders = self.task_senders.read();
238        if senders.is_empty() {
239            panic!("Executor not set up yet!");
240        }
241
242        let idx = offset % senders.len();
243        if let Err(err) = senders[idx].send(task) {
244            panic!("Failed to spawn task: {err}");
245        }
246    }
247
248    /// # Safety
249    ///
250    /// Make sure task is Send before polled for the first time
251    /// (Can be not Send afterwards)
252    pub unsafe fn unsafe_spawn<F: Future<Output = ()> + 'static>(&self, task: F) {
253        let task = Task {
254            future: Box::pin(task),
255        };
256
257        let senders = self.task_senders.read();
258        if senders.is_empty() {
259            panic!("Executor not set up yet!");
260        }
261
262        let idx = rand::thread_rng().gen_range(0..senders.len());
263        if let Err(err) = senders[idx].send(task) {
264            panic!("Failed to spawn task: {err}");
265        }
266    }
267
268    pub fn get_num_threads(&self) -> usize {
269        let senders = self.task_senders.read();
270        if senders.is_empty() {
271            panic!("No active kioto runtime")
272        }
273
274        senders.len()
275    }
276}