compio_driver/iour/
mod.rs

1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{
5    io,
6    os::fd::FromRawFd,
7    pin::Pin,
8    sync::Arc,
9    task::{Poll, Wake, Waker},
10    time::Duration,
11};
12
13use compio_log::{instrument, trace, warn};
14use crossbeam_queue::SegQueue;
15cfg_if::cfg_if! {
16    if #[cfg(feature = "io-uring-cqe32")] {
17        use io_uring::cqueue::Entry32 as CEntry;
18    } else {
19        use io_uring::cqueue::Entry as CEntry;
20    }
21}
22cfg_if::cfg_if! {
23    if #[cfg(feature = "io-uring-sqe128")] {
24        use io_uring::squeue::Entry128 as SEntry;
25    } else {
26        use io_uring::squeue::Entry as SEntry;
27    }
28}
29use io_uring::{
30    IoUring,
31    cqueue::more,
32    opcode::{AsyncCancel, PollAdd},
33    types::{Fd, SubmitArgs, Timespec},
34};
35use slab::Slab;
36
37use crate::{AsyncifyPool, BufferPool, DriverType, Entry, Key, ProactorBuilder, syscall};
38
39pub(crate) mod op;
40
41pub(crate) fn is_op_supported(code: u8) -> bool {
42    #[cfg(feature = "once_cell_try")]
43    use std::sync::OnceLock;
44
45    #[cfg(not(feature = "once_cell_try"))]
46    use once_cell::sync::OnceCell as OnceLock;
47
48    static PROBE: OnceLock<io_uring::Probe> = OnceLock::new();
49
50    PROBE
51        .get_or_try_init(|| {
52            let mut probe = io_uring::Probe::new();
53
54            io_uring::IoUring::new(2)?
55                .submitter()
56                .register_probe(&mut probe)?;
57
58            std::io::Result::Ok(probe)
59        })
60        .map(|probe| probe.is_supported(code))
61        .unwrap_or_default()
62}
63
64/// The created entry of [`OpCode`].
65pub enum OpEntry {
66    /// This operation creates an io-uring submission entry.
67    Submission(io_uring::squeue::Entry),
68    #[cfg(feature = "io-uring-sqe128")]
69    /// This operation creates an 128-bit io-uring submission entry.
70    Submission128(io_uring::squeue::Entry128),
71    /// This operation is a blocking one.
72    Blocking,
73}
74
75impl From<io_uring::squeue::Entry> for OpEntry {
76    fn from(value: io_uring::squeue::Entry) -> Self {
77        Self::Submission(value)
78    }
79}
80
81#[cfg(feature = "io-uring-sqe128")]
82impl From<io_uring::squeue::Entry128> for OpEntry {
83    fn from(value: io_uring::squeue::Entry128) -> Self {
84        Self::Submission128(value)
85    }
86}
87
88/// Abstraction of io-uring operations.
89pub trait OpCode {
90    /// Create submission entry.
91    fn create_entry(self: Pin<&mut Self>) -> OpEntry;
92
93    /// Call the operation in a blocking way. This method will only be called if
94    /// [`create_entry`] returns [`OpEntry::Blocking`].
95    fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
96        unreachable!("this operation is asynchronous")
97    }
98
99    /// Set the result when it successfully completes.
100    /// The operation stores the result and is responsible to release it if the
101    /// operation is cancelled.
102    ///
103    /// # Safety
104    ///
105    /// Users should not call it.
106    unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
107}
108
109/// Low-level driver of io-uring.
110pub(crate) struct Driver {
111    inner: IoUring<SEntry, CEntry>,
112    notifier: Notifier,
113    pool: AsyncifyPool,
114    pool_completed: Arc<SegQueue<Entry>>,
115    buffer_group_ids: Slab<()>,
116    need_push_notifier: bool,
117}
118
119impl Driver {
120    const CANCEL: u64 = u64::MAX;
121    const NOTIFY: u64 = u64::MAX - 1;
122
123    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
124        instrument!(compio_log::Level::TRACE, "new", ?builder);
125        trace!("new iour driver");
126        let notifier = Notifier::new()?;
127        let mut io_uring_builder = IoUring::builder();
128        if let Some(sqpoll_idle) = builder.sqpoll_idle {
129            io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
130        }
131        if builder.coop_taskrun {
132            io_uring_builder.setup_coop_taskrun();
133        }
134        if builder.taskrun_flag {
135            io_uring_builder.setup_taskrun_flag();
136        }
137
138        let inner = io_uring_builder.build(builder.capacity)?;
139
140        let submitter = inner.submitter();
141
142        if let Some(fd) = builder.eventfd {
143            submitter.register_eventfd(fd)?;
144        }
145
146        Ok(Self {
147            inner,
148            notifier,
149            pool: builder.create_or_get_thread_pool(),
150            pool_completed: Arc::new(SegQueue::new()),
151            buffer_group_ids: Slab::new(),
152            need_push_notifier: true,
153        })
154    }
155
156    pub fn driver_type(&self) -> DriverType {
157        DriverType::IoUring
158    }
159
160    // Auto means that it choose to wait or not automatically.
161    fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
162        instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
163
164        // when taskrun is true, there are completed cqes wait to handle, no need to
165        // block the submit
166        let want_sqe = if self.inner.submission().taskrun() {
167            0
168        } else {
169            1
170        };
171
172        let res = {
173            // Last part of submission queue, wait till timeout.
174            if let Some(duration) = timeout {
175                let timespec = timespec(duration);
176                let args = SubmitArgs::new().timespec(&timespec);
177                self.inner.submitter().submit_with_args(want_sqe, &args)
178            } else {
179                self.inner.submit_and_wait(want_sqe)
180            }
181        };
182        trace!("submit result: {res:?}");
183        match res {
184            Ok(_) => {
185                if self.inner.completion().is_empty() {
186                    Err(io::ErrorKind::TimedOut.into())
187                } else {
188                    Ok(())
189                }
190            }
191            Err(e) => match e.raw_os_error() {
192                Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
193                Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
194                _ => Err(e),
195            },
196        }
197    }
198
199    fn poll_blocking(&mut self) {
200        // Cheaper than pop.
201        if !self.pool_completed.is_empty() {
202            while let Some(entry) = self.pool_completed.pop() {
203                unsafe {
204                    entry.notify();
205                }
206            }
207        }
208    }
209
210    fn poll_entries(&mut self) -> bool {
211        self.poll_blocking();
212
213        let mut cqueue = self.inner.completion();
214        cqueue.sync();
215        let has_entry = !cqueue.is_empty();
216        for entry in cqueue {
217            match entry.user_data() {
218                Self::CANCEL => {}
219                Self::NOTIFY => {
220                    let flags = entry.flags();
221                    if !more(flags) {
222                        self.need_push_notifier = true;
223                    }
224                    self.notifier.clear().expect("cannot clear notifier");
225                }
226                _ => unsafe {
227                    create_entry(entry).notify();
228                },
229            }
230        }
231        has_entry
232    }
233
234    pub fn create_op<T: crate::sys::OpCode + 'static>(&self, op: T) -> Key<T> {
235        Key::new(self.as_raw_fd(), op)
236    }
237
238    pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
239        Ok(())
240    }
241
242    pub fn cancel(&mut self, op: &mut Key<dyn crate::sys::OpCode>) {
243        instrument!(compio_log::Level::TRACE, "cancel", ?op);
244        trace!("cancel RawOp");
245        unsafe {
246            #[allow(clippy::useless_conversion)]
247            if self
248                .inner
249                .submission()
250                .push(
251                    &AsyncCancel::new(op.user_data() as _)
252                        .build()
253                        .user_data(Self::CANCEL)
254                        .into(),
255                )
256                .is_err()
257            {
258                warn!("could not push AsyncCancel entry");
259            }
260        }
261    }
262
263    fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
264        loop {
265            let mut squeue = self.inner.submission();
266            match unsafe { squeue.push(&entry) } {
267                Ok(()) => {
268                    squeue.sync();
269                    break Ok(());
270                }
271                Err(_) => {
272                    drop(squeue);
273                    self.poll_entries();
274                    match self.submit_auto(Some(Duration::ZERO)) {
275                        Ok(()) => {}
276                        Err(e)
277                            if matches!(
278                                e.kind(),
279                                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
280                            ) => {}
281                        Err(e) => return Err(e),
282                    }
283                }
284            }
285        }
286    }
287
288    pub fn push(&mut self, op: &mut Key<dyn crate::sys::OpCode>) -> Poll<io::Result<usize>> {
289        instrument!(compio_log::Level::TRACE, "push", ?op);
290        let user_data = op.user_data();
291        let op_pin = op.as_op_pin();
292        trace!("push RawOp");
293        match op_pin.create_entry() {
294            OpEntry::Submission(entry) => {
295                #[allow(clippy::useless_conversion)]
296                self.push_raw(entry.user_data(user_data as _).into())?;
297                Poll::Pending
298            }
299            #[cfg(feature = "io-uring-sqe128")]
300            OpEntry::Submission128(entry) => {
301                self.push_raw(entry.user_data(user_data as _))?;
302                Poll::Pending
303            }
304            OpEntry::Blocking => loop {
305                if self.push_blocking(user_data) {
306                    break Poll::Pending;
307                } else {
308                    self.poll_blocking();
309                }
310            },
311        }
312    }
313
314    fn push_blocking(&mut self, user_data: usize) -> bool {
315        let waker = self.waker();
316        let completed = self.pool_completed.clone();
317        self.pool
318            .dispatch(move || {
319                let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
320                let op_pin = op.as_op_pin();
321                let res = op_pin.call_blocking();
322                completed.push(Entry::new(user_data, res));
323                waker.wake();
324            })
325            .is_ok()
326    }
327
328    pub fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
329        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
330        // Anyway we need to submit once, no matter there are entries in squeue.
331        trace!("start polling");
332
333        if self.need_push_notifier {
334            #[allow(clippy::useless_conversion)]
335            self.push_raw(
336                PollAdd::new(Fd(self.notifier.as_raw_fd()), libc::POLLIN as _)
337                    .multi(true)
338                    .build()
339                    .user_data(Self::NOTIFY)
340                    .into(),
341            )?;
342            self.need_push_notifier = false;
343        }
344
345        if !self.poll_entries() {
346            self.submit_auto(timeout)?;
347            self.poll_entries();
348        }
349
350        Ok(())
351    }
352
353    pub fn waker(&self) -> Waker {
354        self.notifier.waker()
355    }
356
357    pub fn create_buffer_pool(
358        &mut self,
359        buffer_len: u16,
360        buffer_size: usize,
361    ) -> io::Result<BufferPool> {
362        let buffer_group = self.buffer_group_ids.insert(());
363        if buffer_group > u16::MAX as usize {
364            self.buffer_group_ids.remove(buffer_group);
365
366            return Err(io::Error::new(
367                io::ErrorKind::OutOfMemory,
368                "too many buffer pool allocated",
369            ));
370        }
371
372        let buf_ring = io_uring_buf_ring::IoUringBufRing::new_with_flags(
373            &self.inner,
374            buffer_len,
375            buffer_group as _,
376            buffer_size,
377            0,
378        )?;
379
380        #[cfg(fusion)]
381        {
382            Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
383                buf_ring,
384            )))
385        }
386        #[cfg(not(fusion))]
387        {
388            Ok(BufferPool::new(buf_ring))
389        }
390    }
391
392    /// # Safety
393    ///
394    /// caller must make sure release the buffer pool with correct driver
395    pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
396        #[cfg(fusion)]
397        let buffer_pool = buffer_pool.into_io_uring();
398
399        let buffer_group = buffer_pool.buffer_group();
400        unsafe { buffer_pool.into_inner().release(&self.inner)? };
401        self.buffer_group_ids.remove(buffer_group as _);
402
403        Ok(())
404    }
405}
406
407impl AsRawFd for Driver {
408    fn as_raw_fd(&self) -> RawFd {
409        self.inner.as_raw_fd()
410    }
411}
412
413fn create_entry(cq_entry: CEntry) -> Entry {
414    let result = cq_entry.result();
415    let result = if result < 0 {
416        let result = if result == -libc::ECANCELED {
417            libc::ETIMEDOUT
418        } else {
419            -result
420        };
421        Err(io::Error::from_raw_os_error(result))
422    } else {
423        Ok(result as _)
424    };
425    let mut entry = Entry::new(cq_entry.user_data() as _, result);
426    entry.set_flags(cq_entry.flags());
427
428    entry
429}
430
431fn timespec(duration: std::time::Duration) -> Timespec {
432    Timespec::new()
433        .sec(duration.as_secs())
434        .nsec(duration.subsec_nanos())
435}
436
437#[derive(Debug)]
438struct Notifier {
439    notify: Arc<Notify>,
440}
441
442impl Notifier {
443    /// Create a new notifier.
444    fn new() -> io::Result<Self> {
445        let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
446        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
447        Ok(Self {
448            notify: Arc::new(Notify::new(fd)),
449        })
450    }
451
452    pub fn clear(&self) -> io::Result<()> {
453        loop {
454            let mut buffer = [0u64];
455            let res = syscall!(libc::read(
456                self.as_raw_fd(),
457                buffer.as_mut_ptr().cast(),
458                std::mem::size_of::<u64>()
459            ));
460            match res {
461                Ok(len) => {
462                    debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
463                    break Ok(());
464                }
465                // Clear the next time:)
466                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
467                // Just like read_exact
468                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
469                Err(e) => break Err(e),
470            }
471        }
472    }
473
474    pub fn waker(&self) -> Waker {
475        Waker::from(self.notify.clone())
476    }
477}
478
479impl AsRawFd for Notifier {
480    fn as_raw_fd(&self) -> RawFd {
481        self.notify.fd.as_raw_fd()
482    }
483}
484
485/// A notify handle to the inner driver.
486#[derive(Debug)]
487struct Notify {
488    fd: OwnedFd,
489}
490
491impl Notify {
492    pub(crate) fn new(fd: OwnedFd) -> Self {
493        Self { fd }
494    }
495
496    /// Notify the inner driver.
497    pub fn notify(&self) -> io::Result<()> {
498        let data = 1u64;
499        syscall!(libc::write(
500            self.fd.as_raw_fd(),
501            &data as *const _ as *const _,
502            std::mem::size_of::<u64>(),
503        ))?;
504        Ok(())
505    }
506}
507
508impl Wake for Notify {
509    fn wake(self: Arc<Self>) {
510        self.wake_by_ref();
511    }
512
513    fn wake_by_ref(self: &Arc<Self>) {
514        self.notify().ok();
515    }
516}