compio_driver/iour/
mod.rs

1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration};
5
6use compio_log::{instrument, trace, warn};
7use crossbeam_queue::SegQueue;
8cfg_if::cfg_if! {
9    if #[cfg(feature = "io-uring-cqe32")] {
10        use io_uring::cqueue::Entry32 as CEntry;
11    } else {
12        use io_uring::cqueue::Entry as CEntry;
13    }
14}
15cfg_if::cfg_if! {
16    if #[cfg(feature = "io-uring-sqe128")] {
17        use io_uring::squeue::Entry128 as SEntry;
18    } else {
19        use io_uring::squeue::Entry as SEntry;
20    }
21}
22use io_uring::{
23    IoUring,
24    cqueue::more,
25    opcode::{AsyncCancel, PollAdd},
26    types::{Fd, SubmitArgs, Timespec},
27};
28use slab::Slab;
29
30use crate::{AsyncifyPool, BufferPool, DriverType, Entry, Key, ProactorBuilder, syscall};
31
32pub(crate) mod op;
33
34pub(crate) fn is_op_supported(code: u8) -> bool {
35    #[cfg(feature = "once_cell_try")]
36    use std::sync::OnceLock;
37
38    #[cfg(not(feature = "once_cell_try"))]
39    use once_cell::sync::OnceCell as OnceLock;
40
41    static PROBE: OnceLock<io_uring::Probe> = OnceLock::new();
42
43    PROBE
44        .get_or_try_init(|| {
45            let mut probe = io_uring::Probe::new();
46
47            io_uring::IoUring::new(2)?
48                .submitter()
49                .register_probe(&mut probe)?;
50
51            std::io::Result::Ok(probe)
52        })
53        .map(|probe| probe.is_supported(code))
54        .unwrap_or_default()
55}
56
57/// The created entry of [`OpCode`].
58pub enum OpEntry {
59    /// This operation creates an io-uring submission entry.
60    Submission(io_uring::squeue::Entry),
61    #[cfg(feature = "io-uring-sqe128")]
62    /// This operation creates an 128-bit io-uring submission entry.
63    Submission128(io_uring::squeue::Entry128),
64    /// This operation is a blocking one.
65    Blocking,
66}
67
68impl From<io_uring::squeue::Entry> for OpEntry {
69    fn from(value: io_uring::squeue::Entry) -> Self {
70        Self::Submission(value)
71    }
72}
73
74#[cfg(feature = "io-uring-sqe128")]
75impl From<io_uring::squeue::Entry128> for OpEntry {
76    fn from(value: io_uring::squeue::Entry128) -> Self {
77        Self::Submission128(value)
78    }
79}
80
81/// Abstraction of io-uring operations.
82pub trait OpCode {
83    /// Create submission entry.
84    fn create_entry(self: Pin<&mut Self>) -> OpEntry;
85
86    /// Call the operation in a blocking way. This method will only be called if
87    /// [`create_entry`] returns [`OpEntry::Blocking`].
88    fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
89        unreachable!("this operation is asynchronous")
90    }
91
92    /// Set the result when it successfully completes.
93    /// The operation stores the result and is responsible to release it if the
94    /// operation is cancelled.
95    ///
96    /// # Safety
97    ///
98    /// Users should not call it.
99    unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
100}
101
102/// Low-level driver of io-uring.
103pub(crate) struct Driver {
104    inner: IoUring<SEntry, CEntry>,
105    notifier: Notifier,
106    pool: AsyncifyPool,
107    pool_completed: Arc<SegQueue<Entry>>,
108    buffer_group_ids: Slab<()>,
109    need_push_notifier: bool,
110}
111
112impl Driver {
113    const CANCEL: u64 = u64::MAX;
114    const NOTIFY: u64 = u64::MAX - 1;
115
116    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
117        instrument!(compio_log::Level::TRACE, "new", ?builder);
118        trace!("new iour driver");
119        let notifier = Notifier::new()?;
120        let mut io_uring_builder = IoUring::builder();
121        if let Some(sqpoll_idle) = builder.sqpoll_idle {
122            io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
123        }
124        if builder.coop_taskrun {
125            io_uring_builder.setup_coop_taskrun();
126        }
127        if builder.taskrun_flag {
128            io_uring_builder.setup_taskrun_flag();
129        }
130
131        let inner = io_uring_builder.build(builder.capacity)?;
132
133        let submitter = inner.submitter();
134
135        if let Some(fd) = builder.eventfd {
136            submitter.register_eventfd(fd)?;
137        }
138
139        Ok(Self {
140            inner,
141            notifier,
142            pool: builder.create_or_get_thread_pool(),
143            pool_completed: Arc::new(SegQueue::new()),
144            buffer_group_ids: Slab::new(),
145            need_push_notifier: true,
146        })
147    }
148
149    pub fn driver_type(&self) -> DriverType {
150        DriverType::IoUring
151    }
152
153    // Auto means that it choose to wait or not automatically.
154    fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
155        instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
156
157        // when taskrun is true, there are completed cqes wait to handle, no need to
158        // block the submit
159        let want_sqe = if self.inner.submission().taskrun() {
160            0
161        } else {
162            1
163        };
164
165        let res = {
166            // Last part of submission queue, wait till timeout.
167            if let Some(duration) = timeout {
168                let timespec = timespec(duration);
169                let args = SubmitArgs::new().timespec(&timespec);
170                self.inner.submitter().submit_with_args(want_sqe, &args)
171            } else {
172                self.inner.submit_and_wait(want_sqe)
173            }
174        };
175        trace!("submit result: {res:?}");
176        match res {
177            Ok(_) => {
178                if self.inner.completion().is_empty() {
179                    Err(io::ErrorKind::TimedOut.into())
180                } else {
181                    Ok(())
182                }
183            }
184            Err(e) => match e.raw_os_error() {
185                Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
186                Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
187                _ => Err(e),
188            },
189        }
190    }
191
192    fn poll_blocking(&mut self) {
193        // Cheaper than pop.
194        if !self.pool_completed.is_empty() {
195            while let Some(entry) = self.pool_completed.pop() {
196                unsafe {
197                    entry.notify();
198                }
199            }
200        }
201    }
202
203    fn poll_entries(&mut self) -> bool {
204        self.poll_blocking();
205
206        let mut cqueue = self.inner.completion();
207        cqueue.sync();
208        let has_entry = !cqueue.is_empty();
209        for entry in cqueue {
210            match entry.user_data() {
211                Self::CANCEL => {}
212                Self::NOTIFY => {
213                    let flags = entry.flags();
214                    if !more(flags) {
215                        self.need_push_notifier = true;
216                    }
217                    self.notifier.clear().expect("cannot clear notifier");
218                }
219                _ => unsafe {
220                    create_entry(entry).notify();
221                },
222            }
223        }
224        has_entry
225    }
226
227    pub fn create_op<T: crate::sys::OpCode + 'static>(&self, op: T) -> Key<T> {
228        Key::new(self.as_raw_fd(), op)
229    }
230
231    pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
232        Ok(())
233    }
234
235    pub fn cancel(&mut self, op: &mut Key<dyn crate::sys::OpCode>) {
236        instrument!(compio_log::Level::TRACE, "cancel", ?op);
237        trace!("cancel RawOp");
238        unsafe {
239            #[allow(clippy::useless_conversion)]
240            if self
241                .inner
242                .submission()
243                .push(
244                    &AsyncCancel::new(op.user_data() as _)
245                        .build()
246                        .user_data(Self::CANCEL)
247                        .into(),
248                )
249                .is_err()
250            {
251                warn!("could not push AsyncCancel entry");
252            }
253        }
254    }
255
256    fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
257        loop {
258            let mut squeue = self.inner.submission();
259            match unsafe { squeue.push(&entry) } {
260                Ok(()) => {
261                    squeue.sync();
262                    break Ok(());
263                }
264                Err(_) => {
265                    drop(squeue);
266                    self.poll_entries();
267                    match self.submit_auto(Some(Duration::ZERO)) {
268                        Ok(()) => {}
269                        Err(e)
270                            if matches!(
271                                e.kind(),
272                                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
273                            ) => {}
274                        Err(e) => return Err(e),
275                    }
276                }
277            }
278        }
279    }
280
281    pub fn push(&mut self, op: &mut Key<dyn crate::sys::OpCode>) -> Poll<io::Result<usize>> {
282        instrument!(compio_log::Level::TRACE, "push", ?op);
283        let user_data = op.user_data();
284        let op_pin = op.as_op_pin();
285        trace!("push RawOp");
286        match op_pin.create_entry() {
287            OpEntry::Submission(entry) => {
288                #[allow(clippy::useless_conversion)]
289                self.push_raw(entry.user_data(user_data as _).into())?;
290                Poll::Pending
291            }
292            #[cfg(feature = "io-uring-sqe128")]
293            OpEntry::Submission128(entry) => {
294                self.push_raw(entry.user_data(user_data as _))?;
295                Poll::Pending
296            }
297            OpEntry::Blocking => loop {
298                if self.push_blocking(user_data) {
299                    break Poll::Pending;
300                } else {
301                    self.poll_blocking();
302                }
303            },
304        }
305    }
306
307    fn push_blocking(&mut self, user_data: usize) -> bool {
308        let handle = self.handle();
309        let completed = self.pool_completed.clone();
310        self.pool
311            .dispatch(move || {
312                let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
313                let op_pin = op.as_op_pin();
314                let res = op_pin.call_blocking();
315                completed.push(Entry::new(user_data, res));
316                handle.notify().ok();
317            })
318            .is_ok()
319    }
320
321    pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
322        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
323        // Anyway we need to submit once, no matter there are entries in squeue.
324        trace!("start polling");
325
326        if self.need_push_notifier {
327            #[allow(clippy::useless_conversion)]
328            self.push_raw(
329                PollAdd::new(Fd(self.notifier.as_raw_fd()), libc::POLLIN as _)
330                    .multi(true)
331                    .build()
332                    .user_data(Self::NOTIFY)
333                    .into(),
334            )?;
335            self.need_push_notifier = false;
336        }
337
338        if !self.poll_entries() {
339            self.submit_auto(timeout)?;
340            self.poll_entries();
341        }
342
343        Ok(())
344    }
345
346    pub fn handle(&self) -> NotifyHandle {
347        self.notifier.handle()
348    }
349
350    pub fn create_buffer_pool(
351        &mut self,
352        buffer_len: u16,
353        buffer_size: usize,
354    ) -> io::Result<BufferPool> {
355        let buffer_group = self.buffer_group_ids.insert(());
356        if buffer_group > u16::MAX as usize {
357            self.buffer_group_ids.remove(buffer_group);
358
359            return Err(io::Error::new(
360                io::ErrorKind::OutOfMemory,
361                "too many buffer pool allocated",
362            ));
363        }
364
365        let buf_ring = io_uring_buf_ring::IoUringBufRing::new_with_flags(
366            &self.inner,
367            buffer_len,
368            buffer_group as _,
369            buffer_size,
370            0,
371        )?;
372
373        #[cfg(fusion)]
374        {
375            Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
376                buf_ring,
377            )))
378        }
379        #[cfg(not(fusion))]
380        {
381            Ok(BufferPool::new(buf_ring))
382        }
383    }
384
385    /// # Safety
386    ///
387    /// caller must make sure release the buffer pool with correct driver
388    pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
389        #[cfg(fusion)]
390        let buffer_pool = buffer_pool.into_io_uring();
391
392        let buffer_group = buffer_pool.buffer_group();
393        buffer_pool.into_inner().release(&self.inner)?;
394        self.buffer_group_ids.remove(buffer_group as _);
395
396        Ok(())
397    }
398}
399
400impl AsRawFd for Driver {
401    fn as_raw_fd(&self) -> RawFd {
402        self.inner.as_raw_fd()
403    }
404}
405
406fn create_entry(cq_entry: CEntry) -> Entry {
407    let result = cq_entry.result();
408    let result = if result < 0 {
409        let result = if result == -libc::ECANCELED {
410            libc::ETIMEDOUT
411        } else {
412            -result
413        };
414        Err(io::Error::from_raw_os_error(result))
415    } else {
416        Ok(result as _)
417    };
418    let mut entry = Entry::new(cq_entry.user_data() as _, result);
419    entry.set_flags(cq_entry.flags());
420
421    entry
422}
423
424fn timespec(duration: std::time::Duration) -> Timespec {
425    Timespec::new()
426        .sec(duration.as_secs())
427        .nsec(duration.subsec_nanos())
428}
429
430#[derive(Debug)]
431struct Notifier {
432    fd: Arc<OwnedFd>,
433}
434
435impl Notifier {
436    /// Create a new notifier.
437    fn new() -> io::Result<Self> {
438        let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
439        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
440        Ok(Self { fd: Arc::new(fd) })
441    }
442
443    pub fn clear(&self) -> io::Result<()> {
444        loop {
445            let mut buffer = [0u64];
446            let res = syscall!(libc::read(
447                self.fd.as_raw_fd(),
448                buffer.as_mut_ptr().cast(),
449                std::mem::size_of::<u64>()
450            ));
451            match res {
452                Ok(len) => {
453                    debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
454                    break Ok(());
455                }
456                // Clear the next time:)
457                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
458                // Just like read_exact
459                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
460                Err(e) => break Err(e),
461            }
462        }
463    }
464
465    pub fn handle(&self) -> NotifyHandle {
466        NotifyHandle::new(self.fd.clone())
467    }
468}
469
470impl AsRawFd for Notifier {
471    fn as_raw_fd(&self) -> RawFd {
472        self.fd.as_raw_fd()
473    }
474}
475
476/// A notify handle to the inner driver.
477pub struct NotifyHandle {
478    fd: Arc<OwnedFd>,
479}
480
481impl NotifyHandle {
482    pub(crate) fn new(fd: Arc<OwnedFd>) -> Self {
483        Self { fd }
484    }
485
486    /// Notify the inner driver.
487    pub fn notify(&self) -> io::Result<()> {
488        let data = 1u64;
489        syscall!(libc::write(
490            self.fd.as_raw_fd(),
491            &data as *const _ as *const _,
492            std::mem::size_of::<u64>(),
493        ))?;
494        Ok(())
495    }
496}