trale/reactor/
uring.rs

1pub(crate) use io::{multishot::MultishotUringIo, oneshot::OneshotUringIo};
2use io_uring::{cqueue, squeue, CompletionQueue, IoUring};
3use result::RingResults;
4use slab::Slab;
5use std::{
6    cell::{RefCell, RefMut},
7    rc::Rc,
8};
9
10mod io;
11mod result;
12
13pub struct ReactorUring<T: Clone> {
14    inner: Rc<RefCell<ReactorInner<T>>>,
15}
16
17impl<T: Clone> ReactorUring<T> {
18    pub fn new() -> Self {
19        Self {
20            inner: Rc::new(RefCell::new(ReactorInner::new())),
21        }
22    }
23
24    pub fn new_oneshot_io(&self) -> OneshotUringIo<T> {
25        OneshotUringIo::new(self.inner.clone())
26    }
27
28    pub fn new_multishot_io(&self) -> MultishotUringIo<T> {
29        MultishotUringIo::new(self.inner.clone())
30    }
31
32    pub fn react(&self) -> IoCompletionIter<'_, T> {
33        let mut borrow = self.inner.borrow_mut();
34
35        borrow.uring.submit_and_wait(1).unwrap();
36
37        // SAFETY: This object lives along side both the `objs` and `results`
38        // RefMuts. Therefore, `borrow` will remained borrowed for the lifetime
39        // of both `objs` and `results` making the change to `'a` safe.
40        let compl_queue = unsafe {
41            std::mem::transmute::<io_uring::CompletionQueue<'_>, io_uring::CompletionQueue<'_>>(
42                borrow.uring.completion(),
43            )
44        };
45
46        IoCompletionIter {
47            compl_queue,
48            ring: borrow,
49        }
50    }
51}
52
53pub(crate) struct ReactorInner<T> {
54    uring: IoUring,
55    pending: Slab<PendingIo<T>>,
56    results: RingResults,
57}
58
59#[derive(Clone, Copy)]
60enum IoKind {
61    Oneshot,
62    Multi,
63}
64
65#[derive(Clone)]
66struct PendingIo<T> {
67    assoc_obj: T,
68    result_slab_idx: usize,
69    kind: IoKind,
70}
71
72impl<T> ReactorInner<T> {
73    fn new() -> Self {
74        Self {
75            uring: IoUring::new(1024).unwrap(),
76            pending: Slab::new(),
77            results: RingResults::new(),
78        }
79    }
80
81    fn submit_io(&mut self, entry: squeue::Entry, obj: T, kind: IoKind) -> (u64, usize) {
82        let result_slab_idx = match kind {
83            IoKind::Oneshot => self.results.get_oneshot().create_slot(),
84            IoKind::Multi => self.results.get_multishot().create_slot(),
85        };
86
87        let slot = self.pending.insert(PendingIo {
88            assoc_obj: obj,
89            result_slab_idx,
90            kind,
91        });
92
93        unsafe {
94            self.uring
95                .submission()
96                .push(&entry.user_data(slot as u64))
97                .unwrap();
98        }
99
100        (slot as u64, result_slab_idx)
101    }
102}
103
104pub struct IoCompletionIter<'a, T: Clone> {
105    compl_queue: CompletionQueue<'a>,
106    ring: RefMut<'a, ReactorInner<T>>,
107}
108
109impl<T: Clone> Iterator for IoCompletionIter<'_, T> {
110    type Item = T;
111
112    fn next(&mut self) -> Option<Self::Item> {
113        let entry = self.compl_queue.next()?;
114
115        let pending_io = self
116            .ring
117            .pending
118            .get_mut(entry.user_data() as usize)
119            .unwrap()
120            .clone();
121
122        match pending_io.kind {
123            IoKind::Oneshot => {
124                self.ring
125                    .results
126                    .get_oneshot()
127                    .set_result(entry.result(), pending_io.result_slab_idx);
128                self.ring.pending.remove(entry.user_data() as usize);
129            }
130            IoKind::Multi => {
131                let results = self.ring.results.get_multishot();
132                results.push_result(entry.result(), pending_io.result_slab_idx);
133                if !cqueue::more(entry.flags()) {
134                    results.set_finished(pending_io.result_slab_idx);
135                }
136            }
137        }
138
139        Some(pending_io.assoc_obj)
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::{
146        os::fd::{AsFd, AsRawFd, FromRawFd, OwnedFd},
147        task::Poll,
148    };
149
150    use io_uring::{opcode, types};
151    use libc::{AF_LOCAL, SOCK_NONBLOCK, SOCK_STREAM};
152
153    use super::ReactorUring;
154
155    fn write(fd: impl AsFd, buf: &[u8]) {
156        let ret = unsafe {
157            libc::write(
158                fd.as_fd().as_raw_fd(),
159                buf.as_ptr() as *const _,
160                buf.len() as _,
161            )
162        };
163
164        if ret == -1 {
165            panic!("write failed");
166        }
167    }
168
169    fn read(fd: impl AsFd, buf: &mut [u8]) {
170        let ret = unsafe {
171            libc::read(
172                fd.as_fd().as_raw_fd(),
173                buf.as_mut_ptr() as *mut _,
174                buf.len() as _,
175            )
176        };
177
178        if ret == -1 {
179            panic!("write failed");
180        }
181    }
182
183    fn run_test(f: impl FnOnce(OwnedFd, OwnedFd, &mut ReactorUring<u32>)) {
184        let mut fds = [0, 0];
185        let ret =
186            unsafe { libc::socketpair(AF_LOCAL, SOCK_STREAM | SOCK_NONBLOCK, 0, fds.as_mut_ptr()) };
187
188        if ret == -1 {
189            panic!("Pipe failed");
190        }
191
192        let a = unsafe { OwnedFd::from_raw_fd(fds[0]) };
193        let b = unsafe { OwnedFd::from_raw_fd(fds[1]) };
194        let mut uring = ReactorUring::new();
195
196        f(a, b, &mut uring);
197
198        assert!(uring.inner.borrow().results.is_empty());
199    }
200
201    #[test]
202    fn single_wakeup_read() {
203        run_test(|a, b, uring| {
204            let mut buf = [0];
205
206            let mut io = uring.new_oneshot_io();
207            let result = io.submit_or_get_result(|| {
208                (
209                    opcode::Read::new(types::Fd(a.as_raw_fd()), buf.as_mut_ptr(), 1).build(),
210                    10,
211                )
212            });
213
214            assert!(matches!(result, Poll::Pending));
215
216            let t1 = std::thread::spawn(move || {
217                write(b, &[2]);
218            });
219
220            let mut objs = uring.react();
221
222            assert_eq!(objs.next(), Some(10));
223            assert_eq!(objs.next(), None);
224
225            drop(objs);
226
227            let result =
228                io.submit_or_get_result(|| panic!("Should not be called, as result will be ready"));
229
230            assert!(matches!(result, Poll::Ready(Ok(1))));
231
232            t1.join().unwrap();
233        });
234    }
235
236    #[test]
237    fn io_dropped_before_react_cleanup() {
238        run_test(|a, b, uring| {
239            let mut buf = [0];
240
241            let mut io = uring.new_oneshot_io();
242            assert!(matches!(
243                io.submit_or_get_result(|| {
244                    (
245                        opcode::Read::new(types::Fd(a.as_raw_fd()), buf.as_mut_ptr(), 1).build(),
246                        10,
247                    )
248                }),
249                Poll::Pending
250            ));
251
252            drop(io);
253
254            let t1 = std::thread::spawn(move || {
255                write(b, &[2]);
256            });
257
258            let mut objs = uring.react();
259
260            assert_eq!(objs.next(), Some(10));
261            assert_eq!(objs.next(), None);
262
263            t1.join().unwrap();
264        });
265    }
266
267    #[test]
268    fn single_wakeup_write() {
269        run_test(|a, b, uring| {
270            let buf = [0];
271
272            let mut io = uring.new_oneshot_io();
273            let result = io.submit_or_get_result(|| {
274                (
275                    opcode::Write::new(types::Fd(a.as_raw_fd()), buf.as_ptr(), buf.len() as _)
276                        .build(),
277                    20,
278                )
279            });
280
281            assert!(matches!(result, Poll::Pending));
282
283            let t1 = std::thread::spawn(move || {
284                let mut buf = [10];
285                read(b, &mut buf);
286                assert_eq!(buf, [0]);
287            });
288
289            let mut objs = uring.react();
290
291            assert_eq!(objs.next(), Some(20));
292            assert_eq!(objs.next(), None);
293
294            drop(objs);
295
296            let result =
297                io.submit_or_get_result(|| panic!("Should not be called, as result will be ready"));
298
299            assert!(matches!(result, Poll::Ready(Ok(1))));
300
301            t1.join().unwrap();
302        });
303    }
304
305    #[test]
306    fn multi_events_same_fd_read() {
307        run_test(|a, b, uring| {
308            let mut buf = [0, 0];
309
310            let mut io1 = uring.new_oneshot_io();
311            assert!(matches!(
312                io1.submit_or_get_result(|| {
313                    (
314                        opcode::Read::new(types::Fd(a.as_raw_fd()), buf.as_mut_ptr(), 1).build(),
315                        10,
316                    )
317                }),
318                Poll::Pending
319            ));
320
321            let mut io2 = uring.new_oneshot_io();
322            assert!(matches!(
323                io2.submit_or_get_result(|| {
324                    (
325                        opcode::Read::new(types::Fd(a.as_raw_fd()), buf.as_mut_ptr(), 1).build(),
326                        20,
327                    )
328                }),
329                Poll::Pending
330            ));
331
332            let t1 = std::thread::spawn(move || {
333                write(b, &[0xde, 0xad]);
334            });
335
336            let objs: Vec<_> = uring.react().collect();
337
338            assert_eq!(objs.len(), 2);
339            assert!(objs.contains(&10));
340            assert!(objs.contains(&20));
341
342            assert!(matches!(
343                io1.submit_or_get_result(|| panic!("Should not be called")),
344                Poll::Ready(Ok(1))
345            ));
346            assert!(matches!(
347                io2.submit_or_get_result(|| panic!("Should not be called")),
348                Poll::Ready(Ok(1))
349            ));
350            assert_eq!(buf, [0xad, 0]);
351
352            t1.join().unwrap();
353        });
354    }
355
356    #[test]
357    fn multi_events_same_fd_write() {
358        run_test(|a, b, uring| {
359            let buf = [0xbe, 0xef];
360
361            let mut io1 = uring.new_oneshot_io();
362            assert!(matches!(
363                io1.submit_or_get_result(|| {
364                    (
365                        opcode::Write::new(types::Fd(a.as_raw_fd()), buf.as_ptr(), 2).build(),
366                        10,
367                    )
368                }),
369                Poll::Pending
370            ));
371
372            let mut io2 = uring.new_oneshot_io();
373            assert!(matches!(
374                io2.submit_or_get_result(|| {
375                    (
376                        opcode::Write::new(types::Fd(a.as_raw_fd()), buf.as_ptr(), 2).build(),
377                        20,
378                    )
379                }),
380                Poll::Pending
381            ));
382
383            let t1 = std::thread::spawn(move || {
384                let mut buf = [0, 0];
385                read(b.as_fd(), &mut buf);
386                assert_eq!(buf, [0xbe, 0xef]);
387                read(b, &mut buf);
388            });
389
390            let objs: Vec<_> = uring.react().collect();
391
392            assert_eq!(objs.len(), 2);
393            assert!(objs.contains(&10));
394            assert!(objs.contains(&20));
395
396            assert!(matches!(
397                io1.submit_or_get_result(|| panic!("Should not be called")),
398                Poll::Ready(Ok(2))
399            ));
400            assert!(matches!(
401                io2.submit_or_get_result(|| panic!("Should not be called")),
402                Poll::Ready(Ok(2))
403            ));
404
405            t1.join().unwrap();
406        });
407    }
408}