compio_driver/iour/
mod.rs

1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration};
5
6use compio_log::{instrument, trace, warn};
7use crossbeam_queue::SegQueue;
8cfg_if::cfg_if! {
9    if #[cfg(feature = "io-uring-cqe32")] {
10        use io_uring::cqueue::Entry32 as CEntry;
11    } else {
12        use io_uring::cqueue::Entry as CEntry;
13    }
14}
15cfg_if::cfg_if! {
16    if #[cfg(feature = "io-uring-sqe128")] {
17        use io_uring::squeue::Entry128 as SEntry;
18    } else {
19        use io_uring::squeue::Entry as SEntry;
20    }
21}
22use io_uring::{
23    IoUring,
24    cqueue::more,
25    opcode::{AsyncCancel, PollAdd},
26    types::{Fd, SubmitArgs, Timespec},
27};
28pub(crate) use libc::{sockaddr_storage, socklen_t};
29#[cfg(io_uring)]
30use slab::Slab;
31
32use crate::{AsyncifyPool, BufferPool, Entry, Key, ProactorBuilder, syscall};
33
34pub(crate) mod op;
35
36/// The created entry of [`OpCode`].
37pub enum OpEntry {
38    /// This operation creates an io-uring submission entry.
39    Submission(io_uring::squeue::Entry),
40    #[cfg(feature = "io-uring-sqe128")]
41    /// This operation creates an 128-bit io-uring submission entry.
42    Submission128(io_uring::squeue::Entry128),
43    /// This operation is a blocking one.
44    Blocking,
45}
46
47impl From<io_uring::squeue::Entry> for OpEntry {
48    fn from(value: io_uring::squeue::Entry) -> Self {
49        Self::Submission(value)
50    }
51}
52
53#[cfg(feature = "io-uring-sqe128")]
54impl From<io_uring::squeue::Entry128> for OpEntry {
55    fn from(value: io_uring::squeue::Entry128) -> Self {
56        Self::Submission128(value)
57    }
58}
59
60/// Abstraction of io-uring operations.
61pub trait OpCode {
62    /// Create submission entry.
63    fn create_entry(self: Pin<&mut Self>) -> OpEntry;
64
65    /// Call the operation in a blocking way. This method will only be called if
66    /// [`create_entry`] returns [`OpEntry::Blocking`].
67    fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
68        unreachable!("this operation is asynchronous")
69    }
70
71    /// Set the result when it successfully completes.
72    /// The operation stores the result and is responsible to release it if the
73    /// operation is cancelled.
74    ///
75    /// # Safety
76    ///
77    /// Users should not call it.
78    unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
79}
80
81/// Low-level driver of io-uring.
82pub(crate) struct Driver {
83    inner: IoUring<SEntry, CEntry>,
84    notifier: Notifier,
85    pool: AsyncifyPool,
86    pool_completed: Arc<SegQueue<Entry>>,
87    #[cfg(io_uring)]
88    buffer_group_ids: Slab<()>,
89}
90
91impl Driver {
92    const CANCEL: u64 = u64::MAX;
93    const NOTIFY: u64 = u64::MAX - 1;
94
95    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
96        instrument!(compio_log::Level::TRACE, "new", ?builder);
97        trace!("new iour driver");
98        let notifier = Notifier::new()?;
99        let mut io_uring_builder = IoUring::builder();
100        if let Some(sqpoll_idle) = builder.sqpoll_idle {
101            io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
102        }
103        if builder.coop_taskrun {
104            io_uring_builder.setup_coop_taskrun();
105        }
106        if builder.taskrun_flag {
107            io_uring_builder.setup_taskrun_flag();
108        }
109
110        let mut inner = io_uring_builder.build(builder.capacity)?;
111
112        if let Some(fd) = builder.eventfd {
113            inner.submitter().register_eventfd(fd)?;
114        }
115
116        #[allow(clippy::useless_conversion)]
117        unsafe {
118            inner
119                .submission()
120                .push(
121                    &PollAdd::new(Fd(notifier.as_raw_fd()), libc::POLLIN as _)
122                        .multi(true)
123                        .build()
124                        .user_data(Self::NOTIFY)
125                        .into(),
126                )
127                .expect("the squeue sould not be full");
128        }
129        Ok(Self {
130            inner,
131            notifier,
132            pool: builder.create_or_get_thread_pool(),
133            pool_completed: Arc::new(SegQueue::new()),
134            #[cfg(io_uring)]
135            buffer_group_ids: Slab::new(),
136        })
137    }
138
139    // Auto means that it choose to wait or not automatically.
140    fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
141        instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
142
143        // when taskrun is true, there are completed cqes wait to handle, no need to
144        // block the submit
145        let want_sqe = if self.inner.submission().taskrun() {
146            0
147        } else {
148            1
149        };
150
151        let res = {
152            // Last part of submission queue, wait till timeout.
153            if let Some(duration) = timeout {
154                let timespec = timespec(duration);
155                let args = SubmitArgs::new().timespec(&timespec);
156                self.inner.submitter().submit_with_args(want_sqe, &args)
157            } else {
158                self.inner.submit_and_wait(want_sqe)
159            }
160        };
161        trace!("submit result: {res:?}");
162        match res {
163            Ok(_) => {
164                if self.inner.completion().is_empty() {
165                    Err(io::ErrorKind::TimedOut.into())
166                } else {
167                    Ok(())
168                }
169            }
170            Err(e) => match e.raw_os_error() {
171                Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
172                Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
173                _ => Err(e),
174            },
175        }
176    }
177
178    fn poll_blocking(&mut self) {
179        // Cheaper than pop.
180        if !self.pool_completed.is_empty() {
181            while let Some(entry) = self.pool_completed.pop() {
182                unsafe {
183                    entry.notify();
184                }
185            }
186        }
187    }
188
189    fn poll_entries(&mut self) -> bool {
190        self.poll_blocking();
191
192        let mut cqueue = self.inner.completion();
193        cqueue.sync();
194        let has_entry = !cqueue.is_empty();
195        for entry in cqueue {
196            match entry.user_data() {
197                Self::CANCEL => {}
198                Self::NOTIFY => {
199                    let flags = entry.flags();
200                    debug_assert!(more(flags));
201                    self.notifier.clear().expect("cannot clear notifier");
202                }
203                _ => unsafe {
204                    create_entry(entry).notify();
205                },
206            }
207        }
208        has_entry
209    }
210
211    pub fn create_op<T: crate::sys::OpCode + 'static>(&self, op: T) -> Key<T> {
212        Key::new(self.as_raw_fd(), op)
213    }
214
215    pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
216        Ok(())
217    }
218
219    pub fn cancel(&mut self, op: &mut Key<dyn crate::sys::OpCode>) {
220        instrument!(compio_log::Level::TRACE, "cancel", ?op);
221        trace!("cancel RawOp");
222        unsafe {
223            #[allow(clippy::useless_conversion)]
224            if self
225                .inner
226                .submission()
227                .push(
228                    &AsyncCancel::new(op.user_data() as _)
229                        .build()
230                        .user_data(Self::CANCEL)
231                        .into(),
232                )
233                .is_err()
234            {
235                warn!("could not push AsyncCancel entry");
236            }
237        }
238    }
239
240    fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
241        loop {
242            let mut squeue = self.inner.submission();
243            match unsafe { squeue.push(&entry) } {
244                Ok(()) => {
245                    squeue.sync();
246                    break Ok(());
247                }
248                Err(_) => {
249                    drop(squeue);
250                    self.poll_entries();
251                    match self.submit_auto(Some(Duration::ZERO)) {
252                        Ok(()) => {}
253                        Err(e)
254                            if matches!(
255                                e.kind(),
256                                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
257                            ) => {}
258                        Err(e) => return Err(e),
259                    }
260                }
261            }
262        }
263    }
264
265    pub fn push(&mut self, op: &mut Key<dyn crate::sys::OpCode>) -> Poll<io::Result<usize>> {
266        instrument!(compio_log::Level::TRACE, "push", ?op);
267        let user_data = op.user_data();
268        let op_pin = op.as_op_pin();
269        trace!("push RawOp");
270        match op_pin.create_entry() {
271            OpEntry::Submission(entry) => {
272                #[allow(clippy::useless_conversion)]
273                self.push_raw(entry.user_data(user_data as _).into())?;
274                Poll::Pending
275            }
276            #[cfg(feature = "io-uring-sqe128")]
277            OpEntry::Submission128(entry) => {
278                self.push_raw(entry.user_data(user_data as _))?;
279                Poll::Pending
280            }
281            OpEntry::Blocking => loop {
282                if self.push_blocking(user_data) {
283                    break Poll::Pending;
284                } else {
285                    self.poll_blocking();
286                }
287            },
288        }
289    }
290
291    fn push_blocking(&mut self, user_data: usize) -> bool {
292        let handle = self.handle();
293        let completed = self.pool_completed.clone();
294        self.pool
295            .dispatch(move || {
296                let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
297                let op_pin = op.as_op_pin();
298                let res = op_pin.call_blocking();
299                completed.push(Entry::new(user_data, res));
300                handle.notify().ok();
301            })
302            .is_ok()
303    }
304
305    pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
306        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
307        // Anyway we need to submit once, no matter there are entries in squeue.
308        trace!("start polling");
309
310        if !self.poll_entries() {
311            self.submit_auto(timeout)?;
312            self.poll_entries();
313        }
314
315        Ok(())
316    }
317
318    pub fn handle(&self) -> NotifyHandle {
319        self.notifier.handle()
320    }
321
322    #[cfg(io_uring)]
323    pub fn create_buffer_pool(
324        &mut self,
325        buffer_len: u16,
326        buffer_size: usize,
327    ) -> io::Result<BufferPool> {
328        let buffer_group = self.buffer_group_ids.insert(());
329        if buffer_group > u16::MAX as usize {
330            self.buffer_group_ids.remove(buffer_group);
331
332            return Err(io::Error::new(
333                io::ErrorKind::OutOfMemory,
334                "too many buffer pool allocated",
335            ));
336        }
337
338        let buf_ring = io_uring_buf_ring::IoUringBufRing::new(
339            &self.inner,
340            buffer_len,
341            buffer_group as _,
342            buffer_size,
343        )?;
344
345        #[cfg(fusion)]
346        {
347            Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
348                buf_ring,
349            )))
350        }
351        #[cfg(not(fusion))]
352        {
353            Ok(BufferPool::new(buf_ring))
354        }
355    }
356
357    #[cfg(not(io_uring))]
358    pub fn create_buffer_pool(
359        &mut self,
360        buffer_len: u16,
361        buffer_size: usize,
362    ) -> io::Result<BufferPool> {
363        Ok(BufferPool::new(buffer_len, buffer_size))
364    }
365
366    /// # Safety
367    ///
368    /// caller must make sure release the buffer pool with correct driver
369    #[cfg(io_uring)]
370    pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
371        #[cfg(fusion)]
372        let buffer_pool = buffer_pool.into_io_uring();
373
374        let buffer_group = buffer_pool.buffer_group();
375        buffer_pool.into_inner().release(&self.inner)?;
376        self.buffer_group_ids.remove(buffer_group as _);
377
378        Ok(())
379    }
380
381    /// # Safety
382    ///
383    /// caller must make sure release the buffer pool with correct driver
384    #[cfg(not(io_uring))]
385    pub unsafe fn release_buffer_pool(&mut self, _: BufferPool) -> io::Result<()> {
386        Ok(())
387    }
388}
389
390impl AsRawFd for Driver {
391    fn as_raw_fd(&self) -> RawFd {
392        self.inner.as_raw_fd()
393    }
394}
395
396fn create_entry(cq_entry: CEntry) -> Entry {
397    let result = cq_entry.result();
398    let result = if result < 0 {
399        let result = if result == -libc::ECANCELED {
400            libc::ETIMEDOUT
401        } else {
402            -result
403        };
404        Err(io::Error::from_raw_os_error(result))
405    } else {
406        Ok(result as _)
407    };
408    let mut entry = Entry::new(cq_entry.user_data() as _, result);
409    entry.set_flags(cq_entry.flags());
410
411    entry
412}
413
414fn timespec(duration: std::time::Duration) -> Timespec {
415    Timespec::new()
416        .sec(duration.as_secs())
417        .nsec(duration.subsec_nanos())
418}
419
420#[derive(Debug)]
421struct Notifier {
422    fd: Arc<OwnedFd>,
423}
424
425impl Notifier {
426    /// Create a new notifier.
427    fn new() -> io::Result<Self> {
428        let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
429        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
430        Ok(Self { fd: Arc::new(fd) })
431    }
432
433    pub fn clear(&self) -> io::Result<()> {
434        loop {
435            let mut buffer = [0u64];
436            let res = syscall!(libc::read(
437                self.fd.as_raw_fd(),
438                buffer.as_mut_ptr().cast(),
439                std::mem::size_of::<u64>()
440            ));
441            match res {
442                Ok(len) => {
443                    debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
444                    break Ok(());
445                }
446                // Clear the next time:)
447                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
448                // Just like read_exact
449                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
450                Err(e) => break Err(e),
451            }
452        }
453    }
454
455    pub fn handle(&self) -> NotifyHandle {
456        NotifyHandle::new(self.fd.clone())
457    }
458}
459
460impl AsRawFd for Notifier {
461    fn as_raw_fd(&self) -> RawFd {
462        self.fd.as_raw_fd()
463    }
464}
465
466/// A notify handle to the inner driver.
467pub struct NotifyHandle {
468    fd: Arc<OwnedFd>,
469}
470
471impl NotifyHandle {
472    pub(crate) fn new(fd: Arc<OwnedFd>) -> Self {
473        Self { fd }
474    }
475
476    /// Notify the inner driver.
477    pub fn notify(&self) -> io::Result<()> {
478        let data = 1u64;
479        syscall!(libc::write(
480            self.fd.as_raw_fd(),
481            &data as *const _ as *const _,
482            std::mem::size_of::<u64>(),
483        ))?;
484        Ok(())
485    }
486}