1pub(crate) use io::{multishot::MultishotUringIo, oneshot::OneshotUringIo};
2use io_uring::{cqueue, squeue, CompletionQueue, IoUring};
3use result::RingResults;
4use slab::Slab;
5use std::{
6 cell::{RefCell, RefMut},
7 rc::Rc,
8};
9
10mod io;
11mod result;
12
13pub struct ReactorUring<T: Clone> {
14 inner: Rc<RefCell<ReactorInner<T>>>,
15}
16
17impl<T: Clone> ReactorUring<T> {
18 pub fn new() -> Self {
19 Self {
20 inner: Rc::new(RefCell::new(ReactorInner::new())),
21 }
22 }
23
24 pub fn new_oneshot_io(&self) -> OneshotUringIo<T> {
25 OneshotUringIo::new(self.inner.clone())
26 }
27
28 pub fn new_multishot_io(&self) -> MultishotUringIo<T> {
29 MultishotUringIo::new(self.inner.clone())
30 }
31
32 pub fn react(&self) -> IoCompletionIter<'_, T> {
33 let mut borrow = self.inner.borrow_mut();
34
35 borrow.uring.submit_and_wait(1).unwrap();
36
37 let compl_queue = unsafe {
41 std::mem::transmute::<io_uring::CompletionQueue<'_>, io_uring::CompletionQueue<'_>>(
42 borrow.uring.completion(),
43 )
44 };
45
46 IoCompletionIter {
47 compl_queue,
48 ring: borrow,
49 }
50 }
51}
52
53pub(crate) struct ReactorInner<T> {
54 uring: IoUring,
55 pending: Slab<PendingIo<T>>,
56 results: RingResults,
57}
58
59#[derive(Clone, Copy)]
60enum IoKind {
61 Oneshot,
62 Multi,
63}
64
65#[derive(Clone)]
66struct PendingIo<T> {
67 assoc_obj: T,
68 result_slab_idx: usize,
69 kind: IoKind,
70}
71
72impl<T> ReactorInner<T> {
73 fn new() -> Self {
74 Self {
75 uring: IoUring::new(1024).unwrap(),
76 pending: Slab::new(),
77 results: RingResults::new(),
78 }
79 }
80
81 fn submit_io(&mut self, entry: squeue::Entry, obj: T, kind: IoKind) -> (u64, usize) {
82 let result_slab_idx = match kind {
83 IoKind::Oneshot => self.results.get_oneshot().create_slot(),
84 IoKind::Multi => self.results.get_multishot().create_slot(),
85 };
86
87 let slot = self.pending.insert(PendingIo {
88 assoc_obj: obj,
89 result_slab_idx,
90 kind,
91 });
92
93 unsafe {
94 self.uring
95 .submission()
96 .push(&entry.user_data(slot as u64))
97 .unwrap();
98 }
99
100 (slot as u64, result_slab_idx)
101 }
102}
103
104pub struct IoCompletionIter<'a, T: Clone> {
105 compl_queue: CompletionQueue<'a>,
106 ring: RefMut<'a, ReactorInner<T>>,
107}
108
109impl<T: Clone> Iterator for IoCompletionIter<'_, T> {
110 type Item = T;
111
112 fn next(&mut self) -> Option<Self::Item> {
113 let entry = self.compl_queue.next()?;
114
115 let pending_io = self
116 .ring
117 .pending
118 .get_mut(entry.user_data() as usize)
119 .unwrap()
120 .clone();
121
122 match pending_io.kind {
123 IoKind::Oneshot => {
124 self.ring
125 .results
126 .get_oneshot()
127 .set_result(entry.result(), pending_io.result_slab_idx);
128 self.ring.pending.remove(entry.user_data() as usize);
129 }
130 IoKind::Multi => {
131 let results = self.ring.results.get_multishot();
132 results.push_result(entry.result(), pending_io.result_slab_idx);
133 if !cqueue::more(entry.flags()) {
134 results.set_finished(pending_io.result_slab_idx);
135 }
136 }
137 }
138
139 Some(pending_io.assoc_obj)
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use std::{
146 os::fd::{AsFd, AsRawFd, FromRawFd, OwnedFd},
147 task::Poll,
148 };
149
150 use io_uring::{opcode, types};
151 use libc::{AF_LOCAL, SOCK_NONBLOCK, SOCK_STREAM};
152
153 use super::ReactorUring;
154
155 fn write(fd: impl AsFd, buf: &[u8]) {
156 let ret = unsafe {
157 libc::write(
158 fd.as_fd().as_raw_fd(),
159 buf.as_ptr() as *const _,
160 buf.len() as _,
161 )
162 };
163
164 if ret == -1 {
165 panic!("write failed");
166 }
167 }
168
169 fn read(fd: impl AsFd, buf: &mut [u8]) {
170 let ret = unsafe {
171 libc::read(
172 fd.as_fd().as_raw_fd(),
173 buf.as_mut_ptr() as *mut _,
174 buf.len() as _,
175 )
176 };
177
178 if ret == -1 {
179 panic!("write failed");
180 }
181 }
182
183 fn run_test(f: impl FnOnce(OwnedFd, OwnedFd, &mut ReactorUring<u32>)) {
184 let mut fds = [0, 0];
185 let ret =
186 unsafe { libc::socketpair(AF_LOCAL, SOCK_STREAM | SOCK_NONBLOCK, 0, fds.as_mut_ptr()) };
187
188 if ret == -1 {
189 panic!("Pipe failed");
190 }
191
192 let a = unsafe { OwnedFd::from_raw_fd(fds[0]) };
193 let b = unsafe { OwnedFd::from_raw_fd(fds[1]) };
194 let mut uring = ReactorUring::new();
195
196 f(a, b, &mut uring);
197
198 assert!(uring.inner.borrow().results.is_empty());
199 }
200
201 #[test]
202 fn single_wakeup_read() {
203 run_test(|a, b, uring| {
204 let mut buf = [0];
205
206 let mut io = uring.new_oneshot_io();
207 let result = io.submit_or_get_result(|| {
208 (
209 opcode::Read::new(types::Fd(a.as_raw_fd()), buf.as_mut_ptr(), 1).build(),
210 10,
211 )
212 });
213
214 assert!(matches!(result, Poll::Pending));
215
216 let t1 = std::thread::spawn(move || {
217 write(b, &[2]);
218 });
219
220 let mut objs = uring.react();
221
222 assert_eq!(objs.next(), Some(10));
223 assert_eq!(objs.next(), None);
224
225 drop(objs);
226
227 let result =
228 io.submit_or_get_result(|| panic!("Should not be called, as result will be ready"));
229
230 assert!(matches!(result, Poll::Ready(Ok(1))));
231
232 t1.join().unwrap();
233 });
234 }
235
236 #[test]
237 fn io_dropped_before_react_cleanup() {
238 run_test(|a, b, uring| {
239 let mut buf = [0];
240
241 let mut io = uring.new_oneshot_io();
242 assert!(matches!(
243 io.submit_or_get_result(|| {
244 (
245 opcode::Read::new(types::Fd(a.as_raw_fd()), buf.as_mut_ptr(), 1).build(),
246 10,
247 )
248 }),
249 Poll::Pending
250 ));
251
252 drop(io);
253
254 let t1 = std::thread::spawn(move || {
255 write(b, &[2]);
256 });
257
258 let mut objs = uring.react();
259
260 assert_eq!(objs.next(), Some(10));
261 assert_eq!(objs.next(), None);
262
263 t1.join().unwrap();
264 });
265 }
266
267 #[test]
268 fn single_wakeup_write() {
269 run_test(|a, b, uring| {
270 let buf = [0];
271
272 let mut io = uring.new_oneshot_io();
273 let result = io.submit_or_get_result(|| {
274 (
275 opcode::Write::new(types::Fd(a.as_raw_fd()), buf.as_ptr(), buf.len() as _)
276 .build(),
277 20,
278 )
279 });
280
281 assert!(matches!(result, Poll::Pending));
282
283 let t1 = std::thread::spawn(move || {
284 let mut buf = [10];
285 read(b, &mut buf);
286 assert_eq!(buf, [0]);
287 });
288
289 let mut objs = uring.react();
290
291 assert_eq!(objs.next(), Some(20));
292 assert_eq!(objs.next(), None);
293
294 drop(objs);
295
296 let result =
297 io.submit_or_get_result(|| panic!("Should not be called, as result will be ready"));
298
299 assert!(matches!(result, Poll::Ready(Ok(1))));
300
301 t1.join().unwrap();
302 });
303 }
304
305 #[test]
306 fn multi_events_same_fd_read() {
307 run_test(|a, b, uring| {
308 let mut buf = [0, 0];
309
310 let mut io1 = uring.new_oneshot_io();
311 assert!(matches!(
312 io1.submit_or_get_result(|| {
313 (
314 opcode::Read::new(types::Fd(a.as_raw_fd()), buf.as_mut_ptr(), 1).build(),
315 10,
316 )
317 }),
318 Poll::Pending
319 ));
320
321 let mut io2 = uring.new_oneshot_io();
322 assert!(matches!(
323 io2.submit_or_get_result(|| {
324 (
325 opcode::Read::new(types::Fd(a.as_raw_fd()), buf.as_mut_ptr(), 1).build(),
326 20,
327 )
328 }),
329 Poll::Pending
330 ));
331
332 let t1 = std::thread::spawn(move || {
333 write(b, &[0xde, 0xad]);
334 });
335
336 let objs: Vec<_> = uring.react().collect();
337
338 assert_eq!(objs.len(), 2);
339 assert!(objs.contains(&10));
340 assert!(objs.contains(&20));
341
342 assert!(matches!(
343 io1.submit_or_get_result(|| panic!("Should not be called")),
344 Poll::Ready(Ok(1))
345 ));
346 assert!(matches!(
347 io2.submit_or_get_result(|| panic!("Should not be called")),
348 Poll::Ready(Ok(1))
349 ));
350 assert_eq!(buf, [0xad, 0]);
351
352 t1.join().unwrap();
353 });
354 }
355
356 #[test]
357 fn multi_events_same_fd_write() {
358 run_test(|a, b, uring| {
359 let buf = [0xbe, 0xef];
360
361 let mut io1 = uring.new_oneshot_io();
362 assert!(matches!(
363 io1.submit_or_get_result(|| {
364 (
365 opcode::Write::new(types::Fd(a.as_raw_fd()), buf.as_ptr(), 2).build(),
366 10,
367 )
368 }),
369 Poll::Pending
370 ));
371
372 let mut io2 = uring.new_oneshot_io();
373 assert!(matches!(
374 io2.submit_or_get_result(|| {
375 (
376 opcode::Write::new(types::Fd(a.as_raw_fd()), buf.as_ptr(), 2).build(),
377 20,
378 )
379 }),
380 Poll::Pending
381 ));
382
383 let t1 = std::thread::spawn(move || {
384 let mut buf = [0, 0];
385 read(b.as_fd(), &mut buf);
386 assert_eq!(buf, [0xbe, 0xef]);
387 read(b, &mut buf);
388 });
389
390 let objs: Vec<_> = uring.react().collect();
391
392 assert_eq!(objs.len(), 2);
393 assert!(objs.contains(&10));
394 assert!(objs.contains(&20));
395
396 assert!(matches!(
397 io1.submit_or_get_result(|| panic!("Should not be called")),
398 Poll::Ready(Ok(2))
399 ));
400 assert!(matches!(
401 io2.submit_or_get_result(|| panic!("Should not be called")),
402 Poll::Ready(Ok(2))
403 ));
404
405 t1.join().unwrap();
406 });
407 }
408}