kioto_uring_executor/
runtime.rs1use std::cell::RefCell;
2use std::future::Future;
3use std::num::NonZeroUsize;
4use std::pin::Pin;
5use std::sync::mpsc as std_mpsc;
6use std::sync::Arc;
7
8use parking_lot::RwLock;
9
10use rand::Rng;
11
12use tokio::sync::mpsc;
13
14pub struct Task {
15 future: Pin<Box<dyn Future<Output = ()> + 'static>>,
16}
17
18unsafe impl Send for Task {}
19
20pub type TaskSender = mpsc::UnboundedSender<Task>;
21
22thread_local! {
23 pub(super) static ACTIVE_RUNTIME: RefCell<Option<Arc<RuntimeInner>>> = const { RefCell::new(None) };
24}
25
26pub(super) struct RuntimeInner {
27 task_senders: RwLock<Vec<TaskSender>>,
28}
29
30pub struct Runtime {
31 inner: Arc<RuntimeInner>,
32}
33
34impl Default for Runtime {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40pub struct SpawnRing {
41 inner: Arc<RuntimeInner>,
42 thread_idx: usize,
43}
44
45impl SpawnRing {
46 pub(super) fn new(inner: Arc<RuntimeInner>) -> Self {
47 Self {
48 inner,
49 thread_idx: 0,
50 }
51 }
52
53 pub fn get(&self) -> usize {
54 self.thread_idx
55 }
56
57 pub fn advance(&mut self) {
58 let num_worker_threads = self.inner.get_num_threads();
59 self.thread_idx = (self.thread_idx + 1) % num_worker_threads;
60 }
61}
62
63impl Runtime {
64 pub fn new() -> Self {
65 let thread_count = std::thread::available_parallelism().unwrap();
66 Self::new_with_threads(thread_count)
67 }
68
69 pub fn new_with_threads(num_os_threads: NonZeroUsize) -> Self {
70 let num_os_threads = num_os_threads.get();
71 log::info!("Initialized tokio runtime with {num_os_threads} worker thread(s)");
72 let inner = Arc::new(RuntimeInner {
73 task_senders: Default::default(),
74 });
75
76 for _ in 0..num_os_threads {
77 let (sender, mut receiver) = mpsc::unbounded_channel::<Task>();
78 inner.task_senders.write().push(sender);
79 let inner = inner.clone();
80
81 std::thread::spawn(move || {
82 ACTIVE_RUNTIME.with_borrow_mut(|r| {
83 *r = Some(inner);
84 });
85 tokio_uring::start(async {
86 while let Some(task) = receiver.recv().await {
87 tokio_uring::spawn(task.future);
88 }
89 });
90 });
91 }
92
93 Self { inner }
94 }
95
96 pub fn block_on<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
98 &self,
99 task: F,
100 ) -> T {
101 self.inner.block_on(task)
102 }
103
104 pub unsafe fn unsafe_block_on<T: Send + 'static, F: Future<Output = T> + 'static>(
110 &self,
111 task: F,
112 ) -> T {
113 self.inner.unsafe_block_on(task)
114 }
115
116 pub fn spawn<F: Future<Output = ()> + Send + 'static>(&self, task: F) {
118 self.inner.spawn(task)
119 }
120
121 pub fn get_num_threads(&self) -> usize {
123 self.inner.get_num_threads()
124 }
125
126 pub fn spawn_at<F: Future<Output = ()> + Send + 'static>(&self, offset: usize, task: F) {
128 self.inner.spawn_at(offset, task)
129 }
130
131 pub unsafe fn unsafe_spawn_at<F: Future<Output = ()> + 'static>(&self, offset: usize, task: F) {
136 self.inner.unsafe_spawn_at(offset, task)
137 }
138
139 pub unsafe fn unsafe_spawn<F: Future<Output = ()> + 'static>(&self, task: F) {
144 self.inner.unsafe_spawn(task)
145 }
146
147 pub fn new_spawn_ring(&self) -> SpawnRing {
150 SpawnRing::new(self.inner.clone())
151 }
152}
153
154impl Drop for Runtime {
155 fn drop(&mut self) {
156 *self.inner.task_senders.write() = vec![];
157 }
158}
159
160impl RuntimeInner {
161 pub fn spawn<F: Future<Output = ()> + Send + 'static>(&self, task: F) {
162 let task = Task {
163 future: Box::pin(task),
164 };
165
166 let senders = self.task_senders.read();
167 if senders.is_empty() {
168 panic!("Executor not set up yet!");
169 }
170
171 let idx = rand::thread_rng().gen_range(0..senders.len());
172 if let Err(err) = senders[idx].send(task) {
173 panic!("Failed to spawn task: {err}");
174 }
175 }
176
177 pub fn spawn_at<F: Future<Output = ()> + Send + 'static>(&self, offset: usize, task: F) {
179 let task = Task {
180 future: Box::pin(task),
181 };
182
183 let senders = self.task_senders.read();
184 if senders.is_empty() {
185 panic!("Executor not set up yet!");
186 }
187
188 let idx = offset % senders.len();
189 if let Err(err) = senders[idx].send(task) {
190 panic!("Failed to spawn task: {err}");
191 }
192 }
193
194 pub fn block_on<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
196 &self,
197 task: F,
198 ) -> T {
199 let (sender, receiver) = std_mpsc::channel();
200
201 self.spawn(async move {
202 let res = task.await;
203 sender.send(res).expect("Notification failed");
204 });
205
206 receiver.recv().expect("Failed to wait for task")
207 }
208
209 pub unsafe fn unsafe_block_on<T: Send + 'static, F: Future<Output = T> + 'static>(
215 &self,
216 task: F,
217 ) -> T {
218 let (sender, receiver) = std_mpsc::channel();
219
220 self.unsafe_spawn(async move {
221 let res = task.await;
222 sender.send(res).expect("Notification failed");
223 });
224
225 receiver.recv().expect("Failed to wait for task")
226 }
227
228 pub unsafe fn unsafe_spawn_at<F: Future<Output = ()> + 'static>(&self, offset: usize, task: F) {
233 let task = Task {
234 future: Box::pin(task),
235 };
236
237 let senders = self.task_senders.read();
238 if senders.is_empty() {
239 panic!("Executor not set up yet!");
240 }
241
242 let idx = offset % senders.len();
243 if let Err(err) = senders[idx].send(task) {
244 panic!("Failed to spawn task: {err}");
245 }
246 }
247
248 pub unsafe fn unsafe_spawn<F: Future<Output = ()> + 'static>(&self, task: F) {
253 let task = Task {
254 future: Box::pin(task),
255 };
256
257 let senders = self.task_senders.read();
258 if senders.is_empty() {
259 panic!("Executor not set up yet!");
260 }
261
262 let idx = rand::thread_rng().gen_range(0..senders.len());
263 if let Err(err) = senders[idx].send(task) {
264 panic!("Failed to spawn task: {err}");
265 }
266 }
267
268 pub fn get_num_threads(&self) -> usize {
269 let senders = self.task_senders.read();
270 if senders.is_empty() {
271 panic!("No active kioto runtime")
272 }
273
274 senders.len()
275 }
276}