Skip to main content

compio_driver/sys/iour/
mod.rs

1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
4use std::{
5    io,
6    os::fd::FromRawFd,
7    pin::Pin,
8    sync::Arc,
9    task::{Poll, Wake, Waker},
10    time::Duration,
11};
12
13use compio_log::{instrument, trace, warn};
14cfg_if::cfg_if! {
15    if #[cfg(feature = "io-uring-cqe32")] {
16        use io_uring::cqueue::Entry32 as CEntry;
17    } else {
18        use io_uring::cqueue::Entry as CEntry;
19    }
20}
21cfg_if::cfg_if! {
22    if #[cfg(feature = "io-uring-sqe128")] {
23        use io_uring::squeue::Entry128 as SEntry;
24    } else {
25        use io_uring::squeue::Entry as SEntry;
26    }
27}
28use flume::{Receiver, Sender};
29use io_uring::{
30    IoUring,
31    cqueue::more,
32    opcode::{AsyncCancel, PollAdd},
33    types::{Fd, SubmitArgs, Timespec},
34};
35use slab::Slab;
36
37use crate::{
38    AsyncifyPool, BufferPool, DriverType, Entry, ProactorBuilder,
39    key::{ErasedKey, Key, RefExt},
40    syscall,
41};
42
43mod extra;
44pub use extra::Extra;
45pub(crate) mod op;
46
47pub(crate) fn is_op_supported(code: u8) -> bool {
48    #[cfg(feature = "once_cell_try")]
49    use std::sync::OnceLock;
50
51    #[cfg(not(feature = "once_cell_try"))]
52    use once_cell::sync::OnceCell as OnceLock;
53
54    static PROBE: OnceLock<io_uring::Probe> = OnceLock::new();
55
56    PROBE
57        .get_or_try_init(|| {
58            let mut probe = io_uring::Probe::new();
59
60            io_uring::IoUring::new(2)?
61                .submitter()
62                .register_probe(&mut probe)?;
63
64            std::io::Result::Ok(probe)
65        })
66        .map(|probe| probe.is_supported(code))
67        .unwrap_or_default()
68}
69
70/// The created entry of [`OpCode`].
71pub enum OpEntry {
72    /// This operation creates an io-uring submission entry.
73    Submission(io_uring::squeue::Entry),
74    #[cfg(feature = "io-uring-sqe128")]
75    /// This operation creates an 128-bit io-uring submission entry.
76    Submission128(io_uring::squeue::Entry128),
77    /// This operation is a blocking one.
78    Blocking,
79}
80
81impl OpEntry {
82    fn personality(self, personality: Option<u16>) -> Self {
83        let Some(personality) = personality else {
84            return self;
85        };
86
87        match self {
88            Self::Submission(entry) => Self::Submission(entry.personality(personality)),
89            #[cfg(feature = "io-uring-sqe128")]
90            Self::Submission128(entry) => Self::Submission128(entry.personality(personality)),
91            Self::Blocking => Self::Blocking,
92        }
93    }
94}
95
96impl From<io_uring::squeue::Entry> for OpEntry {
97    fn from(value: io_uring::squeue::Entry) -> Self {
98        Self::Submission(value)
99    }
100}
101
102#[cfg(feature = "io-uring-sqe128")]
103impl From<io_uring::squeue::Entry128> for OpEntry {
104    fn from(value: io_uring::squeue::Entry128) -> Self {
105        Self::Submission128(value)
106    }
107}
108
109/// Abstraction of io-uring operations.
110///
111/// # Safety
112///
113/// The returned Entry from `create_entry` must be valid until the operation is
114/// completed.
115pub unsafe trait OpCode {
116    /// Create submission entry.
117    fn create_entry(self: Pin<&mut Self>) -> OpEntry;
118
119    /// Call the operation in a blocking way. This method will only be called if
120    /// [`create_entry`] returns [`OpEntry::Blocking`].
121    ///
122    /// [`create_entry`]: OpCode::create_entry
123    fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
124        unreachable!("this operation is asynchronous")
125    }
126
127    /// Set the result when it successfully completes.
128    /// The operation stores the result and is responsible to release it if the
129    /// operation is cancelled.
130    ///
131    /// # Safety
132    ///
133    /// Users should not call it.
134    unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
135}
136
137pub use OpCode as IourOpCode;
138
139/// Low-level driver of io-uring.
140pub(crate) struct Driver {
141    inner: IoUring<SEntry, CEntry>,
142    notifier: Notifier,
143    pool: AsyncifyPool,
144    completed_tx: Sender<Entry>,
145    completed_rx: Receiver<Entry>,
146    buffer_group_ids: Slab<()>,
147    need_push_notifier: bool,
148}
149
150impl Driver {
151    const CANCEL: u64 = u64::MAX;
152    const NOTIFY: u64 = u64::MAX - 1;
153
154    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
155        instrument!(compio_log::Level::TRACE, "new", ?builder);
156        trace!("new iour driver");
157        let notifier = Notifier::new()?;
158        let mut io_uring_builder = IoUring::builder();
159        if let Some(sqpoll_idle) = builder.sqpoll_idle {
160            io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
161        }
162        if builder.coop_taskrun {
163            io_uring_builder.setup_coop_taskrun();
164        }
165        if builder.taskrun_flag {
166            io_uring_builder.setup_taskrun_flag();
167        }
168
169        let inner = io_uring_builder.build(builder.capacity)?;
170
171        let submitter = inner.submitter();
172
173        if let Some(fd) = builder.eventfd {
174            submitter.register_eventfd(fd)?;
175        }
176
177        let (completed_tx, completed_rx) = flume::unbounded();
178
179        Ok(Self {
180            inner,
181            notifier,
182            completed_tx,
183            completed_rx,
184            pool: builder.create_or_get_thread_pool(),
185            buffer_group_ids: Slab::new(),
186            need_push_notifier: true,
187        })
188    }
189
190    pub fn driver_type(&self) -> DriverType {
191        DriverType::IoUring
192    }
193
194    #[allow(dead_code)]
195    pub fn as_iour(&self) -> Option<&Self> {
196        Some(self)
197    }
198
199    pub fn register_personality(&self) -> io::Result<u16> {
200        self.inner.submitter().register_personality()
201    }
202
203    pub fn unregister_personality(&self, personality: u16) -> io::Result<()> {
204        self.inner.submitter().unregister_personality(personality)
205    }
206
207    // Auto means that it choose to wait or not automatically.
208    fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
209        instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
210
211        // when taskrun is true, there are completed cqes wait to handle, no need to
212        // block the submit
213        let want_sqe = if self.inner.submission().taskrun() {
214            0
215        } else {
216            1
217        };
218
219        let res = {
220            // Last part of submission queue, wait till timeout.
221            if let Some(duration) = timeout {
222                let timespec = timespec(duration);
223                let args = SubmitArgs::new().timespec(&timespec);
224                self.inner.submitter().submit_with_args(want_sqe, &args)
225            } else {
226                self.inner.submit_and_wait(want_sqe)
227            }
228        };
229        trace!("submit result: {res:?}");
230        match res {
231            Ok(_) => {
232                if self.inner.completion().is_empty() {
233                    Err(io::ErrorKind::TimedOut.into())
234                } else {
235                    Ok(())
236                }
237            }
238            Err(e) => match e.raw_os_error() {
239                Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
240                Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
241                _ => Err(e),
242            },
243        }
244    }
245
246    fn poll_blocking(&mut self) {
247        while let Ok(entry) = self.completed_rx.try_recv() {
248            entry.notify();
249        }
250    }
251
252    fn poll_entries(&mut self) -> bool {
253        self.poll_blocking();
254
255        let mut cqueue = self.inner.completion();
256        cqueue.sync();
257        let has_entry = !cqueue.is_empty();
258        for entry in cqueue {
259            match entry.user_data() {
260                Self::CANCEL => {}
261                Self::NOTIFY => {
262                    let flags = entry.flags();
263                    if !more(flags) {
264                        self.need_push_notifier = true;
265                    }
266                    self.notifier.clear().expect("cannot clear notifier");
267                }
268                _ => create_entry(entry).notify(),
269            }
270        }
271        has_entry
272    }
273
274    pub fn default_extra(&self) -> Extra {
275        Extra::new()
276    }
277
278    pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
279        Ok(())
280    }
281
282    pub fn cancel<T>(&mut self, key: Key<T>) {
283        instrument!(compio_log::Level::TRACE, "cancel", ?key);
284        trace!("cancel RawOp");
285        unsafe {
286            #[allow(clippy::useless_conversion)]
287            if self
288                .inner
289                .submission()
290                .push(
291                    &AsyncCancel::new(key.as_raw() as _)
292                        .build()
293                        .user_data(Self::CANCEL)
294                        .into(),
295                )
296                .is_err()
297            {
298                warn!("could not push AsyncCancel entry");
299            }
300        }
301    }
302
303    fn push_raw_with_key(&mut self, entry: SEntry, key: ErasedKey) -> io::Result<()> {
304        let entry = entry.user_data(key.as_raw() as _);
305        self.push_raw(entry)?; // if push failed, do not leak the key. Drop it upon return.
306        key.into_raw();
307        Ok(())
308    }
309
310    fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
311        loop {
312            let mut squeue = self.inner.submission();
313            match unsafe { squeue.push(&entry) } {
314                Ok(()) => {
315                    squeue.sync();
316                    break Ok(());
317                }
318                Err(_) => {
319                    drop(squeue);
320                    self.poll_entries();
321                    match self.submit_auto(Some(Duration::ZERO)) {
322                        Ok(()) => {}
323                        Err(e)
324                            if matches!(
325                                e.kind(),
326                                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
327                            ) => {}
328                        Err(e) => return Err(e),
329                    }
330                }
331            }
332        }
333    }
334
335    pub fn push(&mut self, key: ErasedKey) -> Poll<io::Result<usize>> {
336        instrument!(compio_log::Level::TRACE, "push", ?key);
337        let personality = key.borrow().extra().as_iour().get_personality();
338        let entry = key
339            .borrow()
340            .pinned_op()
341            .create_entry()
342            .personality(personality);
343        trace!(?personality, "push Key");
344        match entry {
345            OpEntry::Submission(entry) => {
346                if is_op_supported(entry.get_opcode() as _) {
347                    #[allow(clippy::useless_conversion)]
348                    self.push_raw_with_key(entry.into(), key)?;
349                } else {
350                    self.push_blocking(key)
351                }
352            }
353            #[cfg(feature = "io-uring-sqe128")]
354            OpEntry::Submission128(entry) => {
355                self.push_raw_with_key(entry, key)?;
356            }
357            OpEntry::Blocking => self.push_blocking(key),
358        }
359        Poll::Pending
360    }
361
362    fn push_blocking(&mut self, key: ErasedKey) {
363        let waker = self.waker();
364        let completed = self.completed_tx.clone();
365        // SAFETY: we're submitting into the driver, so it's safe to freeze here.
366        let mut key = unsafe { key.freeze() };
367        let mut closure = move || {
368            let res = key.pinned_op().call_blocking();
369            let _ = completed.send(Entry::new(key.into_inner(), res));
370            waker.wake();
371        };
372        while let Err(e) = self.pool.dispatch(closure) {
373            closure = e.0;
374            // do something to avoid busy loop
375            self.poll_blocking();
376            std::thread::yield_now();
377        }
378        self.poll_blocking();
379    }
380
381    pub fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
382        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
383        // Anyway we need to submit once, no matter if there are entries in squeue.
384        trace!("start polling");
385
386        if self.need_push_notifier {
387            #[allow(clippy::useless_conversion)]
388            self.push_raw(
389                PollAdd::new(Fd(self.notifier.as_raw_fd()), libc::POLLIN as _)
390                    .multi(true)
391                    .build()
392                    .user_data(Self::NOTIFY)
393                    .into(),
394            )?;
395            self.need_push_notifier = false;
396        }
397
398        if !self.poll_entries() {
399            self.submit_auto(timeout)?;
400            self.poll_entries();
401        }
402
403        Ok(())
404    }
405
406    pub fn waker(&self) -> Waker {
407        self.notifier.waker()
408    }
409
410    pub fn create_buffer_pool(
411        &mut self,
412        buffer_len: u16,
413        buffer_size: usize,
414    ) -> io::Result<BufferPool> {
415        let buffer_group = self.buffer_group_ids.insert(());
416        if buffer_group > u16::MAX as usize {
417            self.buffer_group_ids.remove(buffer_group);
418
419            return Err(io::Error::new(
420                io::ErrorKind::OutOfMemory,
421                "too many buffer pool allocated",
422            ));
423        }
424
425        let buf_ring = io_uring_buf_ring::IoUringBufRing::new_with_flags(
426            &self.inner,
427            buffer_len,
428            buffer_group as _,
429            buffer_size,
430            0,
431        )?;
432
433        #[cfg(fusion)]
434        {
435            Ok(BufferPool::new_io_uring(crate::IoUringBufferPool::new(
436                buf_ring,
437            )))
438        }
439        #[cfg(not(fusion))]
440        {
441            Ok(BufferPool::new(buf_ring))
442        }
443    }
444
445    /// # Safety
446    ///
447    /// caller must make sure release the buffer pool with correct driver
448    pub unsafe fn release_buffer_pool(&mut self, buffer_pool: BufferPool) -> io::Result<()> {
449        #[cfg(fusion)]
450        let buffer_pool = buffer_pool.into_io_uring();
451
452        let buffer_group = buffer_pool.buffer_group();
453        unsafe { buffer_pool.into_inner().release(&self.inner)? };
454        self.buffer_group_ids.remove(buffer_group as _);
455
456        Ok(())
457    }
458}
459
460impl AsRawFd for Driver {
461    fn as_raw_fd(&self) -> RawFd {
462        self.inner.as_raw_fd()
463    }
464}
465
466fn create_entry(cq_entry: CEntry) -> Entry {
467    let result = cq_entry.result();
468    let result = if result < 0 {
469        let result = if result == -libc::ECANCELED {
470            libc::ETIMEDOUT
471        } else {
472            -result
473        };
474        Err(io::Error::from_raw_os_error(result))
475    } else {
476        Ok(result as _)
477    };
478    let key = unsafe { ErasedKey::from_raw(cq_entry.user_data() as _) };
479    let mut entry = Entry::new(key, result);
480    entry.set_flags(cq_entry.flags());
481
482    entry
483}
484
485fn timespec(duration: std::time::Duration) -> Timespec {
486    Timespec::new()
487        .sec(duration.as_secs())
488        .nsec(duration.subsec_nanos())
489}
490
491#[derive(Debug)]
492struct Notifier {
493    notify: Arc<Notify>,
494}
495
496impl Notifier {
497    /// Create a new notifier.
498    fn new() -> io::Result<Self> {
499        let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
500        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
501        Ok(Self {
502            notify: Arc::new(Notify::new(fd)),
503        })
504    }
505
506    pub fn clear(&self) -> io::Result<()> {
507        loop {
508            let mut buffer = [0u64];
509            let res = syscall!(libc::read(
510                self.as_raw_fd(),
511                buffer.as_mut_ptr().cast(),
512                std::mem::size_of::<u64>()
513            ));
514            match res {
515                Ok(len) => {
516                    debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
517                    break Ok(());
518                }
519                // Clear the next time:)
520                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
521                // Just like read_exact
522                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
523                Err(e) => break Err(e),
524            }
525        }
526    }
527
528    pub fn waker(&self) -> Waker {
529        Waker::from(self.notify.clone())
530    }
531}
532
533impl AsRawFd for Notifier {
534    fn as_raw_fd(&self) -> RawFd {
535        self.notify.fd.as_raw_fd()
536    }
537}
538
539/// A notify handle to the inner driver.
540#[derive(Debug)]
541pub(crate) struct Notify {
542    fd: OwnedFd,
543}
544
545impl Notify {
546    pub(crate) fn new(fd: OwnedFd) -> Self {
547        Self { fd }
548    }
549
550    /// Notify the inner driver.
551    pub fn notify(&self) -> io::Result<()> {
552        let data = 1u64;
553        syscall!(libc::write(
554            self.fd.as_raw_fd(),
555            &data as *const _ as *const _,
556            std::mem::size_of::<u64>(),
557        ))?;
558        Ok(())
559    }
560}
561
562impl Wake for Notify {
563    fn wake(self: Arc<Self>) {
564        self.wake_by_ref();
565    }
566
567    fn wake_by_ref(self: &Arc<Self>) {
568        self.notify().ok();
569    }
570}